diff --git a/.gitignore b/.gitignore index a1da56aa9..252283a6e 100644 --- a/.gitignore +++ b/.gitignore @@ -58,4 +58,5 @@ artifacts/ **/times.csv transformer_engine/build_info.txt transformer_engine/common/util/hip_nvml.* +transformer_engine/LITE_BUILD *.DS_Store diff --git a/build_tools/utils.py b/build_tools/utils.py index e250238e6..98198b683 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -415,13 +415,14 @@ def get_frameworks() -> List[str]: if framework not in supported_frameworks: raise ValueError(f"Transformer Engine does not support framework={framework}") - if rocm_build(): + if rocm_build() and not bool(int(os.getenv("NVTE_LITE_ONLY", "0"))): _unsupported_frameworks = [] if "pytorch" in _frameworks: try: - from torch.utils.cpp_extension import IS_HIP_EXTENSION + import torch.utils.cpp_extension + IS_HIP_EXTENSION = getattr(torch.utils.cpp_extension, "IS_HIP_EXTENSION", False) except ImportError: - IS_HIP_EXTENSION=False + IS_HIP_EXTENSION = False if not IS_HIP_EXTENSION: if "pytorch" in _requested_frameworks: _unsupported_frameworks.append("pytorch") diff --git a/setup.py b/setup.py index c66af39df..4f32a5581 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ class HipifyMeta(egg_info): """Custom egg_info command to hipify source files before packaging.""" def run(self): - if rocm_build(): + if rocm_build() and not bool(int(os.getenv("NVTE_LITE_ONLY", "0"))): from build_tools.hipify.hipify import do_hipify print("Running hipification of installable headers for ROCm build...") do_hipify(current_file_path, current_file_path / "transformer_engine/common/include") @@ -229,7 +229,8 @@ def git_check_submodules() -> None: if __name__ == "__main__": __version__ = te_version() - git_check_submodules() + if not bool(int(os.getenv("NVTE_LITE_ONLY", "0"))): + git_check_submodules() with open("README.rst", encoding="utf-8") as f: long_description = f.read() @@ -256,6 +257,23 @@ def git_check_submodules() -> None: "rocm_pytorch": [f"transformer_engine_rocm7[pytorch]=={__version__}"], "rocm_jax": [f"transformer_engine_rocm7[jax]=={__version__}"], } + elif bool(int(os.getenv("NVTE_LITE_ONLY", "0"))): + # Lite-only build: no C++ compilation, pure Python + Triton kernels. + # Builds in seconds. NVTE_LITE=1 is forced at import time via marker file. + install_requires, test_requires = setup_requirements() + ext_modules = [] + cmdclass = {"bdist_wheel": TimedBdist} + package_data = { + "": ["VERSION.txt", "LITE_BUILD"], + "transformer_engine.pytorch.triton_kernels.gmm": ["configs/*.json"], + } + include_package_data = True + extras_require = {"test": test_requires} + + # Write marker file so import-time code knows this is a lite-only wheel + marker_path = current_file_path / "transformer_engine" / "LITE_BUILD" + marker_path.write_text("This is a lite-only build. NVTE_LITE=1 is forced.\n") + PACKAGE_NAME = "tealite" else: install_requires, test_requires = setup_requirements() ext_modules = [setup_common_extension()] @@ -289,7 +307,8 @@ def git_check_submodules() -> None: ) ) - PACKAGE_NAME="transformer_engine" + if not bool(int(os.getenv("NVTE_LITE_ONLY", "0"))): + PACKAGE_NAME="transformer_engine" if (rocm_build() and bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) and not bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))) ): PACKAGE_NAME=f"transformer_engine_rocm{rocm_version()[0]}" diff --git a/tests/pytorch/attention/run_lite_cp_test.py b/tests/pytorch/attention/run_lite_cp_test.py new file mode 100644 index 000000000..f6e2ecdf3 --- /dev/null +++ b/tests/pytorch/attention/run_lite_cp_test.py @@ -0,0 +1,290 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Multi-process worker for testing context parallelism in lite mode. + +This script is launched via torch.distributed.launch with >= 2 GPUs. +It runs DotProductAttention with and without CP, then compares outputs +and gradients. + +Only BSHD and SBHD formats are tested (THD requires C++ thd_* helpers +that are not yet implemented in lite mode). +""" + +import logging +import os +import pathlib +import sys + +os.environ["NVTE_LITE"] = "1" + +# Ensure repo root is on sys.path for dev-tree runs (no pip install) +_repo_root = str(pathlib.Path(__file__).resolve().parent.parent.parent.parent) +if _repo_root not in sys.path: + sys.path.insert(0, _repo_root) + +import torch +import torch.distributed as dist + +from transformer_engine.pytorch import DotProductAttention + + +logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") + + +# --------------------------------------------------------------------------- +# Configs +# --------------------------------------------------------------------------- + +class CPTestConfig: + """Minimal model config for CP tests.""" + + def __init__( + self, + batch_size, + max_seqlen, + num_heads, + head_dim, + num_gqa_groups=None, + attn_mask_type="causal", + ): + self.batch_size = batch_size + self.max_seqlen = max_seqlen + self.num_heads = num_heads + self.head_dim = head_dim + self.num_gqa_groups = num_heads if num_gqa_groups is None else num_gqa_groups + + +TEST_CONFIGS = { + "mha_causal": CPTestConfig(2, 1024, 8, 64, attn_mask_type="causal"), + "gqa_causal": CPTestConfig(2, 1024, 8, 64, num_gqa_groups=2, attn_mask_type="causal"), + "mha_no_mask": CPTestConfig(2, 1024, 8, 64, attn_mask_type="no_mask"), + "gqa_no_mask": CPTestConfig(2, 1024, 8, 64, num_gqa_groups=2, attn_mask_type="no_mask"), +} + + +# --------------------------------------------------------------------------- +# DualChunkSwap partitioning for BSHD / SBHD +# --------------------------------------------------------------------------- + +def partition_for_cp(tensor, qkv_format, rank, world_size): + """Partition a tensor along the sequence dimension using DualChunkSwap. + + Each rank gets 2 chunks: [rank] and [2*world_size - rank - 1]. + """ + seq_dim = qkv_format.index("s") + shape = list(tensor.shape) + chunk_size = shape[seq_dim] // (2 * world_size) + new_shape = shape[:seq_dim] + [2 * world_size, chunk_size] + shape[seq_dim + 1:] + tensor = tensor.view(*new_shape) + seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=tensor.device) + tensor = tensor.index_select(seq_dim, seq_idx) + final_shape = shape[:seq_dim] + [2 * chunk_size] + shape[seq_dim + 1:] + return tensor.reshape(*final_shape).contiguous() + + +def partition_dout(dout, qkv_format, rank, world_size): + """Partition dout (output gradient) for CP comparison. + + dout shape from DPA is (b, s, h*d) for bshd or (s, b, h*d) for sbhd. + """ + seq_dim = 0 if qkv_format == "sbhd" else 1 + shape = list(dout.shape) + chunk_size = shape[seq_dim] // (2 * world_size) + new_shape = shape[:seq_dim] + [2 * world_size, chunk_size] + shape[seq_dim + 1:] + dout = dout.view(*new_shape) + seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=dout.device) + dout = dout.index_select(seq_dim, seq_idx) + final_shape = shape[:seq_dim] + [2 * chunk_size] + shape[seq_dim + 1:] + return dout.reshape(*final_shape).contiguous() + + +# --------------------------------------------------------------------------- +# Core test logic +# --------------------------------------------------------------------------- + +def run_test( + config_name, + qkv_format, + cp_comm_type, + attn_mask_type, + dtype_str="bf16", +): + """Run a single CP vs no-CP comparison test using DotProductAttention.""" + # Initialize distributed process group + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl") + + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{local_rank}") + + config = TEST_CONFIGS[config_name] + dtype = {"bf16": torch.bfloat16, "fp16": torch.float16}[dtype_str] + + b = config.batch_size + s = config.max_seqlen + h_q = config.num_heads + h_kv = config.num_gqa_groups + d = config.head_dim + + assert s % (2 * world_size) == 0, ( + f"seqlen ({s}) must be divisible by 2*cp_size ({2 * world_size})" + ) + + # Generate full inputs -- same across all ranks (seeded) + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + if qkv_format == "bshd": + q_shape = (b, s, h_q, d) + k_shape = (b, s, h_kv, d) + v_shape = (b, s, h_kv, d) + elif qkv_format == "sbhd": + q_shape = (s, b, h_q, d) + k_shape = (s, b, h_kv, d) + v_shape = (s, b, h_kv, d) + else: + raise ValueError(f"Unsupported qkv_format: {qkv_format}") + + q_orig = torch.randn(q_shape, dtype=dtype, device=device) + k_orig = torch.randn(k_shape, dtype=dtype, device=device) + v_orig = torch.randn(v_shape, dtype=dtype, device=device) + + # DPA output shape is (b, s, h*d) for bshd, (s, b, h*d) for sbhd + if qkv_format == "bshd": + dout_shape = (b, s, h_q * d) + else: + dout_shape = (s, b, h_q * d) + dout_orig = torch.randn(dout_shape, dtype=dtype, device=device) + + # ============== Run WITHOUT CP ============== + core_attn = DotProductAttention( + h_q, d, num_gqa_groups=h_kv, attention_dropout=0.0, + qkv_format=qkv_format, attn_mask_type=attn_mask_type, + ).cuda() + + q, k, v = [x.clone().detach().requires_grad_(True) for x in [q_orig, k_orig, v_orig]] + dout = dout_orig.clone().detach() + + out = core_attn(q, k, v) + out.backward(dout) + dq, dk, dv = q.grad, k.grad, v.grad + + # ============== Run WITH CP ============== + # Set up communication group + cp_comm_ranks = list(range(world_size)) + cp_group = dist.new_group(cp_comm_ranks, backend="nccl") + cp_stream = torch.cuda.Stream(device=device) + + # Partition inputs for this rank using DualChunkSwap + q_, k_, v_ = [ + partition_for_cp(x, qkv_format, rank, world_size).clone().detach().requires_grad_(True) + for x in [q_orig, k_orig, v_orig] + ] + dout_ = partition_dout(dout_orig, qkv_format, rank, world_size) + + # Configure CP on the attention module + core_attn.set_context_parallel_group(cp_group, cp_comm_ranks, cp_stream, cp_comm_type) + + out_ = core_attn(q_, k_, v_) + out_.backward(dout_) + dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad + + # ============== Validate ============== + # Check no NaN/Inf + for name, t in [("out_cp", out_), ("dq_cp", dq_), ("dk_cp", dk_), ("dv_cp", dv_)]: + assert torch.all(torch.isfinite(t)), f"Rank {rank}: {name} contains NaN or Inf!" + + # Slice reference to match this rank's CP partition + seq_dim = qkv_format.index("s") + + # For Q-side tensors (out, dq): partition ref the same as Q was partitioned + # DPA output is (b, s, h*d) / (s, b, h*d) -- seq_dim is 1 / 0 + out_seq_dim = 1 if qkv_format == "bshd" else 0 + + def slice_ref(ref_tensor, local_tensor, s_dim): + """Slice full reference tensor to match this rank's DualChunkSwap partition.""" + shape = list(ref_tensor.shape) + chunk_size = shape[s_dim] // (2 * world_size) + new_shape = shape[:s_dim] + [2 * world_size, chunk_size] + shape[s_dim + 1:] + ref_chunked = ref_tensor.view(*new_shape) + seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=ref_tensor.device) + ref_sliced = ref_chunked.index_select(s_dim, seq_idx) + local_reshaped = local_tensor.view(*ref_sliced.shape) + return ref_sliced, local_reshaped + + # Tolerances + if dtype_str == "bf16": + if h_q == h_kv: + atol, rtol = 2.5e-2, 2.5e-2 + else: + atol, rtol = 3.5e-2, 3.5e-2 + else: + atol, rtol = 5e-3, 5e-3 + + # Compare output and Q-side grads (use output seq_dim since DPA reshapes) + for name, ref_full, cp_local in [("out", out, out_), ("dq", dq, dq_)]: + s_dim = out_seq_dim if name == "out" else seq_dim + ref_s, cp_s = slice_ref(ref_full, cp_local, s_dim) + + for ci in range(2): + if s_dim == 1: # bshd + rc = ref_s[:, ci] + cc = cp_s[:, ci] + else: # sbhd + rc = ref_s[ci] + cc = cp_s[ci] + + try: + torch.testing.assert_close(rc, cc, atol=atol, rtol=rtol) + except AssertionError: + diff = (rc.float() - cc.float()).abs() + rmse = diff.pow(2).mean().sqrt().item() + val_range = max(rc.abs().max().item(), cc.abs().max().item(), 1e-6) + assert rmse < 0.02 * val_range, ( + f"Rank {rank}: {name} chunk {ci} RMSE {rmse:.6f} > " + f"tol {0.02 * val_range:.6f}" + ) + + # Compare K/V-side grads + for name, ref_full, cp_local in [("dk", dk, dk_), ("dv", dv, dv_)]: + ref_s, cp_s = slice_ref(ref_full, cp_local, seq_dim) + + for ci in range(2): + if seq_dim == 1: + rc = ref_s[:, ci] + cc = cp_s[:, ci] + else: + rc = ref_s[ci] + cc = cp_s[ci] + + try: + torch.testing.assert_close(rc, cc, atol=atol, rtol=rtol) + except AssertionError: + diff = (rc.float() - cc.float()).abs() + rmse = diff.pow(2).mean().sqrt().item() + val_range = max(rc.abs().max().item(), cc.abs().max().item(), 1e-6) + assert rmse < 0.02 * val_range, ( + f"Rank {rank}: {name} chunk {ci} RMSE {rmse:.6f} > " + f"tol {0.02 * val_range:.6f}" + ) + + logging.info( + f"Rank {rank}: PASSED -- config={config_name} fmt={qkv_format} " + f"comm={cp_comm_type} mask={attn_mask_type} dtype={dtype_str}" + ) + + dist.destroy_process_group() + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + kwargs = dict(arg.split("=") for arg in sys.argv[2:]) + run_test(**kwargs) diff --git a/tests/pytorch/attention/test_lite_cp.py b/tests/pytorch/attention/test_lite_cp.py new file mode 100644 index 000000000..1f6f25871 --- /dev/null +++ b/tests/pytorch/attention/test_lite_cp.py @@ -0,0 +1,192 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Pytest launcher for context parallelism tests in lite mode (BSHD & SBHD). + +These tests verify that CP works in lite mode without the C++ thd_* helpers. +Only BSHD and SBHD formats are tested -- THD requires the thd_* implementations +that are stubbed out in _lite/context_parallel.py. + +Run with: + NVTE_LITE=1 pytest tests/pytorch/attention/test_lite_cp.py -v + +Requires at least 2 GPUs (4 for a2a+p2p). +""" + +import os +import subprocess +import sys +import pathlib +import logging + +import pytest +import torch + +logging.basicConfig(level=logging.INFO) + +_SCRIPT_DIR = pathlib.Path(__file__).resolve().parent +_REPO_ROOT = str(_SCRIPT_DIR.parent.parent.parent) +_WORKER_SCRIPT = str(_SCRIPT_DIR / "run_lite_cp_test.py") + +# --------------------------------------------------------------------------- +# Test matrix +# --------------------------------------------------------------------------- + +# Config names matching CPTestConfig in run_lite_cp_test.py +_CONFIGS_CAUSAL = ["mha_causal", "gqa_causal"] +_CONFIGS_NO_MASK = ["mha_no_mask", "gqa_no_mask"] + +_QKV_FORMATS = ["bshd", "sbhd"] + +# CP comm types that work with BSHD/SBHD (no THD needed) +_CP_COMM_TYPES = ["p2p", "all_gather", "a2a"] + + +def _get_num_gpus(cp_comm_type): + """Return number of GPUs required for a given CP comm type.""" + if cp_comm_type == "a2a+p2p": + return 4 + return 2 + + +def _run_worker(num_gpus, **kwargs): + """Launch the multi-process test worker and check its exit code.""" + args = [ + sys.executable, + "-m", "torch.distributed.launch", + f"--nproc-per-node={num_gpus}", + _WORKER_SCRIPT, + ] + for k, v in kwargs.items(): + args.append(f"{k}={v}") + + env = os.environ.copy() + env["NVTE_LITE"] = "1" + # Ensure repo root is on PYTHONPATH for dev-tree runs + env["PYTHONPATH"] = _REPO_ROOT + os.pathsep + env.get("PYTHONPATH", "") + + result = subprocess.run( + args, + capture_output=True, + text=True, + env=env, + timeout=300, + ) + + if result.returncode != 0: + # Print full output for debugging + logging.error("STDOUT:\n%s", result.stdout) + logging.error("STDERR:\n%s", result.stderr) + pytest.fail( + f"CP test worker failed (exit {result.returncode}). " + f"See log output above for details." + ) + + +def _skip_if_insufficient_gpus(num_gpus): + if torch.cuda.device_count() < num_gpus: + pytest.skip(f"Test requires {num_gpus} GPUs, found {torch.cuda.device_count()}") + + +# --------------------------------------------------------------------------- +# Tests: P2P (ring exchange) +# --------------------------------------------------------------------------- + +class TestCPWithP2P: + """Context parallelism with P2P (ring) KV exchange -- BSHD & SBHD.""" + + @pytest.mark.parametrize("qkv_format", _QKV_FORMATS) + @pytest.mark.parametrize("config_name", _CONFIGS_CAUSAL + _CONFIGS_NO_MASK) + def test_p2p(self, config_name, qkv_format): + num_gpus = _get_num_gpus("p2p") + _skip_if_insufficient_gpus(num_gpus) + attn_mask_type = "causal" if "causal" in config_name else "no_mask" + _run_worker( + num_gpus, + config_name=config_name, + qkv_format=qkv_format, + cp_comm_type="p2p", + attn_mask_type=attn_mask_type, + ) + + +# --------------------------------------------------------------------------- +# Tests: All-Gather +# --------------------------------------------------------------------------- + +class TestCPWithAllGather: + """Context parallelism with KV All-Gather -- BSHD & SBHD.""" + + @pytest.mark.parametrize("qkv_format", _QKV_FORMATS) + @pytest.mark.parametrize("config_name", _CONFIGS_CAUSAL + _CONFIGS_NO_MASK) + def test_all_gather(self, config_name, qkv_format): + num_gpus = _get_num_gpus("all_gather") + _skip_if_insufficient_gpus(num_gpus) + attn_mask_type = "causal" if "causal" in config_name else "no_mask" + _run_worker( + num_gpus, + config_name=config_name, + qkv_format=qkv_format, + cp_comm_type="all_gather", + attn_mask_type=attn_mask_type, + ) + + +# --------------------------------------------------------------------------- +# Tests: A2A (Ulysses) +# --------------------------------------------------------------------------- + +class TestCPWithA2A: + """Context parallelism with All-to-All (Ulysses) -- BSHD & SBHD. + + A2A requires num_heads and num_gqa_groups divisible by cp_size. + The test configs satisfy this for cp_size=2. + """ + + @pytest.mark.parametrize("qkv_format", _QKV_FORMATS) + @pytest.mark.parametrize("config_name", _CONFIGS_CAUSAL + _CONFIGS_NO_MASK) + def test_a2a(self, config_name, qkv_format): + num_gpus = _get_num_gpus("a2a") + _skip_if_insufficient_gpus(num_gpus) + attn_mask_type = "causal" if "causal" in config_name else "no_mask" + _run_worker( + num_gpus, + config_name=config_name, + qkv_format=qkv_format, + cp_comm_type="a2a", + attn_mask_type=attn_mask_type, + ) + + +# --------------------------------------------------------------------------- +# Tests: dtype coverage +# --------------------------------------------------------------------------- + +class TestCPDtypes: + """Verify CP works with both bf16 and fp16 in lite mode.""" + + @pytest.mark.parametrize("dtype_str", ["bf16", "fp16"]) + def test_p2p_bshd_dtypes(self, dtype_str): + _skip_if_insufficient_gpus(2) + _run_worker( + 2, + config_name="mha_causal", + qkv_format="bshd", + cp_comm_type="p2p", + attn_mask_type="causal", + dtype_str=dtype_str, + ) + + @pytest.mark.parametrize("dtype_str", ["bf16", "fp16"]) + def test_a2a_bshd_dtypes(self, dtype_str): + _skip_if_insufficient_gpus(2) + _run_worker( + 2, + config_name="mha_causal", + qkv_format="bshd", + cp_comm_type="a2a", + attn_mask_type="causal", + dtype_str=dtype_str, + ) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 50624df9e..e9a4eae60 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -3,6 +3,7 @@ # See LICENSE for license information. from typing import Callable, Tuple, Union, List import math +import os import torch import pytest from transformer_engine.pytorch.attention.rope import ( @@ -11,6 +12,8 @@ apply_fused_qkv_rotary_pos_emb, ) +_IS_LITE = os.environ.get("NVTE_LITE", "0") == "1" + # Gradient is a broadcasted scalar def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: @@ -58,6 +61,10 @@ def test_fused_rope( # are with the maximum length of the rope embeddings. pytest.skip("Skipping test with margin=0 and start_positions=True") + if _IS_LITE: + if transpose is not None: + pytest.skip("Lite mode: non-contiguous tensors not supported in fused RoPE kernel") + device = torch.device("cuda:0") batch_size, head_num = 2, 64 t = torch.rand( @@ -143,6 +150,11 @@ def test_fused_rope_thd( start_positions: bool, margin: int, ) -> None: + if _IS_LITE: + if transpose is not None: + pytest.skip("Lite mode: non-contiguous tensors not supported in fused RoPE kernel") + if cp_size > 1: + pytest.skip("Lite mode: THD format with CP not yet supported (thd_* stubs)") device = torch.device("cuda:0") batch_size, head_num = 2, 64 diff --git a/tests/pytorch/test_lite.py b/tests/pytorch/test_lite.py new file mode 100644 index 000000000..363c88b5d --- /dev/null +++ b/tests/pytorch/test_lite.py @@ -0,0 +1,4967 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for TE Lite mode (NVTE_LITE=1). + +These tests verify that the pure-Python _lite backend can replace the +C++ transformer_engine_torch extension for core TE modules. + +Run with: + NVTE_LITE=1 pytest tests/pytorch/test_lite.py -v +""" + +import os +import pytest +import torch +import torch.nn.functional as F + +# Ensure lite mode is active before importing TE +os.environ["NVTE_LITE"] = "1" + +import transformer_engine.pytorch as te # noqa: E402 +import transformer_engine_torch as tex # noqa: E402 +from transformer_engine.common import recipe # noqa: E402 +from transformer_engine.pytorch.quantization import FP8GlobalStateManager # noqa: E402 + + +@pytest.fixture(autouse=True) +def _check_lite_mode(): + """Skip all tests if lite mode is not active.""" + assert tex.__name__ == "transformer_engine.pytorch._lite", ( + "NVTE_LITE=1 must be set before importing transformer_engine" + ) + + +@pytest.fixture +def device(): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + return "cuda" + + +# --------------------------------------------------------------------------- +# Import / smoke tests +# --------------------------------------------------------------------------- + +class TestImport: + """Verify that the lite module loads and aliases correctly.""" + + def test_lite_module_loaded(self): + assert "transformer_engine.pytorch._lite" in tex.__name__ + + def test_key_symbols_exist(self): + required = [ + "DType", "FP8TensorMeta", "NVTE_Fused_Attn_Backend", + "generic_gemm", "te_general_grouped_gemm", + "layernorm_fwd", "layernorm_bwd", + "rmsnorm_fwd", "rmsnorm_bwd", "gelu", "silu", "swiglu", + "multi_tensor_adam", "multi_tensor_scale", + ] + for name in required: + assert hasattr(tex, name), f"Missing symbol: {name}" + + +# --------------------------------------------------------------------------- +# Forward tests +# --------------------------------------------------------------------------- + +class TestForward: + """Forward pass for core TE modules under lite mode.""" + + @pytest.mark.parametrize("in_features,out_features", [(1024, 512), (256, 256)]) + def test_linear(self, device, in_features, out_features): + mod = te.Linear(in_features, out_features, bias=True).to( + dtype=torch.bfloat16, device=device + ) + x = torch.randn(4, in_features, device=device, dtype=torch.bfloat16) + y = mod(x) + assert y.shape == (4, out_features) + + def test_layernorm_linear(self, device): + mod = te.LayerNormLinear(1024, 512, bias=True).to( + dtype=torch.bfloat16, device=device + ) + x = torch.randn(4, 1024, device=device, dtype=torch.bfloat16) + y = mod(x) + assert y.shape == (4, 512) + + def test_layernorm_mlp(self, device): + mod = te.LayerNormMLP(1024, 4096).to(dtype=torch.bfloat16, device=device) + x = torch.randn(4, 1024, device=device, dtype=torch.bfloat16) + y = mod(x) + assert y.shape == (4, 1024) + + def test_layernorm(self, device): + mod = te.LayerNorm(1024).to(dtype=torch.bfloat16, device=device) + x = torch.randn(4, 1024, device=device, dtype=torch.bfloat16) + y = mod(x) + assert y.shape == (4, 1024) + + def test_rmsnorm(self, device): + mod = te.RMSNorm(1024).to(dtype=torch.bfloat16, device=device) + x = torch.randn(4, 1024, device=device, dtype=torch.bfloat16) + y = mod(x) + assert y.shape == (4, 1024) + + def test_transformer_layer(self, device): + mod = te.TransformerLayer(1024, 4096, 16).to( + dtype=torch.bfloat16, device=device + ) + x = torch.randn(2, 8, 1024, device=device, dtype=torch.bfloat16) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + y = mod(x) + assert y.shape == (2, 8, 1024) + + +# --------------------------------------------------------------------------- +# Forward + backward tests +# --------------------------------------------------------------------------- + +class TestBackward: + """Forward + backward pass for core TE modules under lite mode.""" + + @pytest.mark.parametrize("in_features,out_features", [(1024, 512), (256, 256)]) + def test_linear(self, device, in_features, out_features): + mod = te.Linear(in_features, out_features, bias=True).to( + dtype=torch.bfloat16, device=device + ) + x = torch.randn(4, in_features, device=device, dtype=torch.bfloat16, requires_grad=True) + y = mod(x) + y.sum().backward() + assert x.grad is not None + assert x.grad.shape == x.shape + + def test_layernorm_linear(self, device): + mod = te.LayerNormLinear(1024, 512, bias=True).to( + dtype=torch.bfloat16, device=device + ) + x = torch.randn(4, 1024, device=device, dtype=torch.bfloat16, requires_grad=True) + y = mod(x) + y.sum().backward() + assert x.grad is not None + + def test_layernorm_mlp(self, device): + mod = te.LayerNormMLP(1024, 4096).to(dtype=torch.bfloat16, device=device) + x = torch.randn(4, 1024, device=device, dtype=torch.bfloat16, requires_grad=True) + y = mod(x) + y.sum().backward() + assert x.grad is not None + + def test_layernorm(self, device): + mod = te.LayerNorm(1024).to(dtype=torch.bfloat16, device=device) + x = torch.randn(4, 1024, device=device, dtype=torch.bfloat16, requires_grad=True) + y = mod(x) + y.sum().backward() + assert x.grad is not None + + def test_rmsnorm(self, device): + mod = te.RMSNorm(1024).to(dtype=torch.bfloat16, device=device) + x = torch.randn(4, 1024, device=device, dtype=torch.bfloat16, requires_grad=True) + y = mod(x) + y.sum().backward() + assert x.grad is not None + + def test_transformer_layer(self, device): + mod = te.TransformerLayer(1024, 4096, 16).to( + dtype=torch.bfloat16, device=device + ) + x = torch.randn(2, 8, 1024, device=device, dtype=torch.bfloat16, requires_grad=True) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + y = mod(x) + y.sum().backward() + assert x.grad is not None + + +# --------------------------------------------------------------------------- +# Numerical correctness +# --------------------------------------------------------------------------- + +class TestNumerical: + """Verify lite-mode results match PyTorch reference implementations.""" + + def test_linear_fp32_exact(self, device): + """te.Linear should match torch.nn.Linear exactly in FP32.""" + te_mod = te.Linear(256, 128, bias=True).to(device=device) + pt_mod = torch.nn.Linear(256, 128, bias=True).to(device=device) + with torch.no_grad(): + pt_mod.weight.copy_(te_mod.weight) + pt_mod.bias.copy_(te_mod.bias) + x = torch.randn(4, 256, device=device) + y_te = te_mod(x) + y_pt = pt_mod(x) + assert torch.allclose(y_te, y_pt, atol=0, rtol=0), ( + f"FP32 max diff: {(y_te - y_pt).abs().max().item():.2e}" + ) + + def test_linear_bf16_close(self, device): + """te.Linear should be close to torch.nn.Linear in BF16.""" + te_mod = te.Linear(256, 128, bias=True).to(dtype=torch.bfloat16, device=device) + pt_mod = torch.nn.Linear(256, 128, bias=True).to(dtype=torch.bfloat16, device=device) + with torch.no_grad(): + pt_mod.weight.copy_(te_mod.weight) + pt_mod.bias.copy_(te_mod.bias) + x = torch.randn(4, 256, device=device, dtype=torch.bfloat16) + y_te = te_mod(x).to(torch.bfloat16) + y_pt = pt_mod(x) + assert torch.allclose(y_te, y_pt, atol=5e-3, rtol=1e-2), ( + f"BF16 max diff: {(y_te - y_pt).abs().max().item():.2e}" + ) + + def test_linear_backward_fp32_exact(self, device): + """Backward gradients should match exactly in FP32.""" + te_mod = te.Linear(256, 128, bias=True).to(device=device) + pt_mod = torch.nn.Linear(256, 128, bias=True).to(device=device) + with torch.no_grad(): + pt_mod.weight.copy_(te_mod.weight) + pt_mod.bias.copy_(te_mod.bias) + x_te = torch.randn(4, 256, device=device, requires_grad=True) + x_pt = x_te.detach().clone().requires_grad_(True) + te_mod(x_te).sum().backward() + pt_mod(x_pt).sum().backward() + assert torch.allclose(x_te.grad, x_pt.grad, atol=0, rtol=0), ( + f"Grad max diff: {(x_te.grad - x_pt.grad).abs().max().item():.2e}" + ) + + def test_layernorm_close(self, device): + """te.LayerNorm should be close to torch.nn.LayerNorm.""" + te_mod = te.LayerNorm(512).to(dtype=torch.bfloat16, device=device) + pt_mod = torch.nn.LayerNorm(512).to(dtype=torch.bfloat16, device=device) + with torch.no_grad(): + pt_mod.weight.copy_(te_mod.weight) + pt_mod.bias.copy_(te_mod.bias) + x = torch.randn(4, 512, device=device, dtype=torch.bfloat16) + y_te = te_mod(x) + y_pt = pt_mod(x) + assert torch.allclose(y_te, y_pt, atol=5e-3, rtol=1e-2), ( + f"LayerNorm max diff: {(y_te - y_pt).abs().max().item():.2e}" + ) + + +# --------------------------------------------------------------------------- +# Triton kernel wiring (Phase 2) +# --------------------------------------------------------------------------- + +class TestTritonNorms: + """Verify that Triton norm kernels are wired correctly via _lite.""" + + def test_triton_norms_loadable(self): + """Triton norm kernels should be importable when Triton is installed.""" + try: + import triton # noqa: F401 + has_triton = True + except ImportError: + has_triton = False + + from transformer_engine.pytorch._lite.norms import ( + _try_load_triton_norms, + _triton_ln_fwd, + ) + _try_load_triton_norms() + if has_triton: + from transformer_engine.pytorch._lite import norms as _n + assert _n._triton_ln_fwd is not None, "Triton norms should load when Triton is available" + # If triton is not installed, the fallback path is tested by other tests. + + def test_aiter_norms_active(self, device): + """AITER Triton norm kernels should be the active backend when AITER is available.""" + from transformer_engine.pytorch._lite.aiter_utils import is_aiter_available + from transformer_engine.pytorch._lite import norms as _n + + # Trigger lazy loading by calling a norm function + x = torch.randn(4, 128, device=device, dtype=torch.bfloat16) + w = torch.randn(128, device=device, dtype=torch.bfloat16) + b = torch.randn(128, device=device, dtype=torch.bfloat16) + tex.rmsnorm_fwd(x, w, 1e-5, None, None, None, 0, False) + tex.layernorm_fwd(x, w, b, 1e-5, None, None, None, 0, False) + + if is_aiter_available(): + assert _n._aiter_rms_fwd is not None, "AITER RMSNorm fwd should be loaded" + assert _n._aiter_rms_bwd is not None, "AITER RMSNorm bwd should be loaded" + assert _n._aiter_ln_fwd is not None, "AITER LayerNorm fwd should be loaded" + assert _n._aiter_ln_bwd is not None, "AITER LayerNorm bwd should be loaded" + + def test_aiter_rmsnorm_fwd_bwd(self, device): + """AITER RMSNorm forward and backward produce correct results.""" + from transformer_engine.pytorch._lite.norms import _rmsnorm_fwd_pytorch, _rmsnorm_bwd_pytorch + + hidden = 512 + x = torch.randn(8, hidden, device=device, dtype=torch.bfloat16) + w = torch.randn(hidden, device=device, dtype=torch.bfloat16) + g = torch.randn(8, hidden, device=device, dtype=torch.bfloat16) + + # PyTorch reference + y_pt, rstd_pt = _rmsnorm_fwd_pytorch(x, w, 1e-5, False) + dx_pt, dw_pt = _rmsnorm_bwd_pytorch(g, x, rstd_pt, w, False) + + # AITER-backed tex path + y_te, _, rstd_te = tex.rmsnorm_fwd(x, w, 1e-5, None, None, None, 0, False) + dx_te, dw_te = tex.rmsnorm_bwd(g, x, rstd_te, w, 0, False) + + assert torch.allclose(y_te, y_pt, atol=5e-1, rtol=5e-2), ( + f"RMSNorm fwd max diff: {(y_te - y_pt).abs().max().item():.2e}" + ) + assert torch.allclose(dx_te.to(torch.bfloat16), dx_pt.to(torch.bfloat16), + atol=5e-1, rtol=5e-2), ( + f"RMSNorm bwd dx max diff: {(dx_te - dx_pt).abs().max().item():.2e}" + ) + + def test_fused_rmsnorm_fp8_quant_active(self, device): + """Fused RMSNorm+FP8 quantize kernel is used for Float8Quantizer.""" + from transformer_engine.pytorch._lite import norms as _n + from transformer_engine.pytorch import Float8Quantizer + + _n._try_load_aiter_norms() + if _n._aiter_fused_rms_fp8_static is None: + pytest.skip("AITER fused RMSNorm+FP8 kernel not available") + + hidden = 256 + x = torch.randn(8, hidden, device=device, dtype=torch.bfloat16) + w = torch.randn(hidden, device=device, dtype=torch.bfloat16) + + q = Float8Quantizer( + scale=torch.tensor([4.0], dtype=torch.float32, device=device), + amax=torch.tensor([0.0], dtype=torch.float32, device=device), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + + out, _, rsigma = tex.rmsnorm_fwd(x, w, 1e-5, None, q, None, 0, False) + + # Verify output is a Float8Tensor + assert type(out).__name__ == "Float8Tensor", ( + f"Expected Float8Tensor, got {type(out).__name__}" + ) + # Verify shape is preserved + assert out.shape == x.shape, f"Shape mismatch: {out.shape} vs {x.shape}" + # Verify amax was updated (non-zero means kernel ran) + assert q.amax.item() > 0, "amax should be updated by fused kernel" + # Verify scale_inv was set + assert hasattr(out, '_scale_inv') + expected_scale_inv = 1.0 / 4.0 + assert abs(out._scale_inv.item() - expected_scale_inv) < 1e-6 + + def test_fused_rmsnorm_fp8_quant_vs_separate(self, device): + """Fused RMSNorm+FP8 path matches separate norm->quantize path.""" + from transformer_engine.pytorch._lite.norms import ( + _aiter_fused_rms_fp8_static, _try_load_aiter_norms, + _rmsnorm_fwd_pytorch, + ) + from transformer_engine.pytorch import Float8Quantizer + + _try_load_aiter_norms() + if _aiter_fused_rms_fp8_static is None: + pytest.skip("AITER fused RMSNorm+FP8 kernel not available") + + hidden = 512 + x = torch.randn(16, hidden, device=device, dtype=torch.bfloat16) + w = torch.randn(hidden, device=device, dtype=torch.bfloat16) + scale_val = 6.0 + + # Separate path: norm then quantize manually + normed, _ = _rmsnorm_fwd_pytorch(x, w, 1e-5, False) + dequant_scale = 1.0 / scale_val + fp8_separate = (normed.float() * scale_val).to(torch.float8_e4m3fnuz) + deq_separate = fp8_separate.float() * dequant_scale + + # Fused path via tex + q = Float8Quantizer( + scale=torch.tensor([scale_val], dtype=torch.float32, device=device), + amax=torch.tensor([0.0], dtype=torch.float32, device=device), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + out_fused, _, _ = tex.rmsnorm_fwd(x, w, 1e-5, None, q, None, 0, False) + deq_fused = out_fused._data.view(torch.float8_e4m3fnuz).float() * dequant_scale + + # Allow FP8 rounding tolerance — different intermediate precision + # (fused kernel does norm in float32 vs PyTorch fallback in bf16) + diff = (deq_separate - deq_fused).abs().max().item() + assert diff < 0.5, ( + f"Fused vs separate max dequantized diff: {diff:.4f} (expected < 0.5)" + ) + + def test_fused_rmsnorm_fp8_quant_3d_input(self, device): + """Fused path handles 3D input shape correctly.""" + from transformer_engine.pytorch._lite.norms import ( + _aiter_fused_rms_fp8_static, _try_load_aiter_norms, + ) + from transformer_engine.pytorch import Float8Quantizer + + _try_load_aiter_norms() + if _aiter_fused_rms_fp8_static is None: + pytest.skip("AITER fused RMSNorm+FP8 kernel not available") + + x = torch.randn(4, 8, 256, device=device, dtype=torch.bfloat16) + w = torch.randn(256, device=device, dtype=torch.bfloat16) + + q = Float8Quantizer( + scale=torch.tensor([2.0], dtype=torch.float32, device=device), + amax=torch.tensor([0.0], dtype=torch.float32, device=device), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + out, _, rsigma = tex.rmsnorm_fwd(x, w, 1e-5, None, q, None, 0, False) + + assert out.shape == (4, 8, 256), f"Expected (4,8,256), got {out.shape}" + assert rsigma.shape == (4, 8), f"Expected (4,8), got {rsigma.shape}" + + # --- Current Scaling: fused RMSNorm + per-row dynamic FP8 quantize --- + + def test_fused_rmsnorm_current_scaling_active(self, device): + """Fused RMSNorm+FP8 per-row dynamic quant kernel is used for CurrentScaling.""" + from transformer_engine.pytorch._lite import norms as _n + from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer + + _n._try_load_aiter_norms() + if _n._aiter_fused_rms_dynamic_quant is None: + pytest.skip("AITER rmsnorm2d_fwd_with_dynamicquant not available") + + hidden = 256 + x = torch.randn(8, hidden, device=device, dtype=torch.bfloat16) + w = torch.randn(hidden, device=device, dtype=torch.bfloat16) + + q = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=device, + ) + + out, _, rsigma = tex.rmsnorm_fwd(x, w, 1e-5, None, q, None, 0, False) + + # Verify output is Float8Tensor + assert type(out).__name__ == "Float8Tensor", ( + f"Expected Float8Tensor, got {type(out).__name__}" + ) + # Verify shape is preserved + assert out.shape == x.shape, f"Shape mismatch: {out.shape} vs {x.shape}" + # Verify _scale_inv is per-row (M,) not scalar + assert hasattr(out, '_scale_inv') + assert out._scale_inv.shape == (8,), ( + f"Expected per-row scale_inv shape (8,), got {out._scale_inv.shape}" + ) + # Verify all per-row scales are positive (valid dequant scales) + assert (out._scale_inv > 0).all(), "All per-row scales should be positive" + + def test_fused_rmsnorm_current_scaling_vs_separate(self, device): + """Fused per-row path matches separate norm->quantize within FP8 tolerance.""" + from transformer_engine.pytorch._lite.norms import ( + _aiter_fused_rms_dynamic_quant, _try_load_aiter_norms, + _rmsnorm_fwd_pytorch, + ) + + _try_load_aiter_norms() + if _aiter_fused_rms_dynamic_quant is None: + pytest.skip("AITER rmsnorm2d_fwd_with_dynamicquant not available") + + hidden = 512 + x = torch.randn(16, hidden, device=device, dtype=torch.bfloat16) + w = torch.randn(hidden, device=device, dtype=torch.bfloat16) + + # Reference: separate RMSNorm (PyTorch) + normed_ref, _ = _rmsnorm_fwd_pytorch(x, w, 1e-5, False) + + # Fused: RMSNorm + per-row dynamic FP8 quant + fp8_dtype = torch.float8_e4m3fnuz + out_fp8 = torch.empty_like(x, dtype=fp8_dtype) + yscale = torch.empty(x.shape[0], dtype=torch.float32, device=device) + _aiter_fused_rms_dynamic_quant(out_fp8, x, yscale, w, 1e-5) + + # Dequantize: FP8 data * per-row scale + deq_fused = out_fp8.to(torch.float32) * yscale.unsqueeze(1) + + # FP8 E4M3 has ~3.5% relative error budget — use generous tolerance + max_err = (normed_ref.float() - deq_fused).abs().max().item() + rel_err = ( + (normed_ref.float() - deq_fused).abs() + / (normed_ref.float().abs() + 1e-8) + ).mean().item() + assert rel_err < 0.05, ( + f"Mean relative error {rel_err:.4f} exceeds 5% tolerance" + ) + assert max_err < 1.0, ( + f"Max abs error {max_err:.4f} exceeds tolerance" + ) + + def test_fused_rmsnorm_current_scaling_per_row_scales_vary(self, device): + """Per-row scales should differ across rows (not degenerate scalar).""" + from transformer_engine.pytorch._lite.norms import ( + _aiter_fused_rms_dynamic_quant, _try_load_aiter_norms, + ) + + _try_load_aiter_norms() + if _aiter_fused_rms_dynamic_quant is None: + pytest.skip("AITER rmsnorm2d_fwd_with_dynamicquant not available") + + # Use input with varying row magnitudes to ensure different scales + hidden = 256 + x = torch.randn(32, hidden, device=device, dtype=torch.bfloat16) + # Scale rows differently so per-row scales must differ + row_scales = torch.linspace(0.1, 10.0, 32, device=device).unsqueeze(1) + x = x * row_scales.to(x.dtype) + w = torch.ones(hidden, device=device, dtype=torch.bfloat16) + + fp8_dtype = torch.float8_e4m3fnuz + out_fp8 = torch.empty_like(x, dtype=fp8_dtype) + yscale = torch.empty(32, dtype=torch.float32, device=device) + _aiter_fused_rms_dynamic_quant(out_fp8, x, yscale, w, 1e-5) + + # With 32 rows at very different magnitudes, all scales should be unique + unique_scales = yscale.unique().numel() + assert unique_scales == 32, ( + f"Expected 32 unique per-row scales, got {unique_scales}" + ) + + def test_fused_rmsnorm_current_scaling_3d_input(self, device): + """Fused per-row path handles 3D input shape correctly.""" + from transformer_engine.pytorch._lite import norms as _n + from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer + + _n._try_load_aiter_norms() + if _n._aiter_fused_rms_dynamic_quant is None: + pytest.skip("AITER rmsnorm2d_fwd_with_dynamicquant not available") + + x = torch.randn(4, 8, 256, device=device, dtype=torch.bfloat16) + w = torch.randn(256, device=device, dtype=torch.bfloat16) + + q = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=device, + ) + out, _, rsigma = tex.rmsnorm_fwd(x, w, 1e-5, None, q, None, 0, False) + + assert out.shape == (4, 8, 256), f"Expected (4,8,256), got {out.shape}" + # rsigma should be flattened to batch dims + assert rsigma.shape == (4, 8), f"Expected (4,8), got {rsigma.shape}" + # scale_inv should be per-row over the flattened batch: 4*8 = 32 rows + assert out._scale_inv.shape == (32,), ( + f"Expected per-row scale_inv shape (32,), got {out._scale_inv.shape}" + ) + + def test_aiter_layernorm_fwd_bwd(self, device): + """AITER LayerNorm forward and backward produce correct results.""" + from transformer_engine.pytorch._lite.norms import ( + _layernorm_fwd_pytorch, _layernorm_bwd_pytorch, + ) + + hidden = 512 + x = torch.randn(8, hidden, device=device, dtype=torch.bfloat16) + w = torch.randn(hidden, device=device, dtype=torch.bfloat16) + b = torch.randn(hidden, device=device, dtype=torch.bfloat16) + g = torch.randn(8, hidden, device=device, dtype=torch.bfloat16) + + # PyTorch reference + y_pt, mean_pt, rstd_pt = _layernorm_fwd_pytorch(x, w, b, 1e-5, False) + dx_pt, dw_pt, db_pt = _layernorm_bwd_pytorch(g, x, mean_pt, rstd_pt, w, False) + + # AITER-backed tex path + y_te, mean_te, rstd_te = tex.layernorm_fwd(x, w, b, 1e-5, None, None, None, 0, False) + dx_te, dw_te, db_te = tex.layernorm_bwd(g, x, mean_te, rstd_te, w, 0, False) + + assert torch.allclose(y_te, y_pt, atol=8e-2, rtol=2e-2), ( + f"LayerNorm fwd max diff: {(y_te - y_pt).abs().max().item():.2e}" + ) + assert torch.allclose(dx_te.to(torch.bfloat16), dx_pt.to(torch.bfloat16), + atol=8e-2, rtol=2e-2), ( + f"LayerNorm bwd dx max diff: {(dx_te - dx_pt).abs().max().item():.2e}" + ) + + @pytest.mark.parametrize("hidden_size", [256, 512, 1024]) + def test_layernorm_fwd_triton_vs_pytorch(self, device, hidden_size): + """Triton layernorm_fwd should match PyTorch reference.""" + from transformer_engine.pytorch._lite.norms import ( + _layernorm_fwd_pytorch, + ) + weight = torch.randn(hidden_size, device=device, dtype=torch.bfloat16) + bias = torch.randn(hidden_size, device=device, dtype=torch.bfloat16) + x = torch.randn(8, hidden_size, device=device, dtype=torch.bfloat16) + + y_pt, mean_pt, rstd_pt = _layernorm_fwd_pytorch( + x, weight, bias, 1e-5, False, + ) + y_te, mean_te, rstd_te = tex.layernorm_fwd( + x, weight, bias, 1e-5, None, None, None, 0, False, + ) + # Dequantize if needed + if hasattr(y_te, 'dequantize'): + y_te = y_te.dequantize() + # BF16 layernorm: Triton fused kernel vs PyTorch individual ops have + # different rounding, so tolerance must accommodate BF16 ULP differences + assert torch.allclose(y_te.to(torch.bfloat16), y_pt, atol=8e-2, rtol=2e-2), ( + f"LayerNorm fwd max diff: {(y_te.to(torch.bfloat16) - y_pt).abs().max().item():.2e}" + ) + + @pytest.mark.parametrize("hidden_size", [256, 512, 1024]) + def test_rmsnorm_fwd_triton_vs_pytorch(self, device, hidden_size): + """Triton rmsnorm_fwd should match PyTorch reference.""" + from transformer_engine.pytorch._lite.norms import ( + _rmsnorm_fwd_pytorch, + ) + weight = torch.randn(hidden_size, device=device, dtype=torch.bfloat16) + x = torch.randn(8, hidden_size, device=device, dtype=torch.bfloat16) + + y_pt, rstd_pt = _rmsnorm_fwd_pytorch( + x, weight, 1e-5, False, + ) + y_te, _, rstd_te = tex.rmsnorm_fwd( + x, weight, 1e-5, None, None, None, 0, False, + ) + if hasattr(y_te, 'dequantize'): + y_te = y_te.dequantize() + # BF16 RMSNorm: Triton fused kernel vs PyTorch individual ops have + # different rounding; wider tolerance for BF16 comparison + assert torch.allclose(y_te.to(torch.bfloat16), y_pt, atol=5e-1, rtol=5e-2), ( + f"RMSNorm fwd max diff: {(y_te.to(torch.bfloat16) - y_pt).abs().max().item():.2e}" + ) + + def test_layernorm_bwd_triton_vs_pytorch(self, device): + """Triton layernorm_bwd should match PyTorch reference.""" + from transformer_engine.pytorch._lite.norms import ( + _layernorm_fwd_pytorch, + _layernorm_bwd_pytorch, + ) + hidden = 512 + weight = torch.randn(hidden, device=device, dtype=torch.bfloat16) + bias = torch.randn(hidden, device=device, dtype=torch.bfloat16) + x = torch.randn(8, hidden, device=device, dtype=torch.bfloat16) + grad_out = torch.randn(8, hidden, device=device, dtype=torch.bfloat16) + + _, mean, rstd = _layernorm_fwd_pytorch( + x, weight, bias, 1e-5, False, + ) + dx_pt, dw_pt, db_pt = _layernorm_bwd_pytorch( + grad_out, x, mean, rstd, weight, False, + ) + dx_te, dw_te, db_te = tex.layernorm_bwd( + grad_out, x, mean, rstd, weight, 0, False, + ) + assert torch.allclose(dx_te, dx_pt, atol=8e-2, rtol=2e-2), ( + f"LayerNorm bwd dx max diff: {(dx_te - dx_pt).abs().max().item():.2e}" + ) + # Weight grad is reduced over batch -- wider BF16 tolerance + assert torch.allclose(dw_te, dw_pt, atol=5e-2, rtol=5e-2), ( + f"LayerNorm bwd dw max diff: {(dw_te - dw_pt).abs().max().item():.2e}" + ) + + def test_rmsnorm_bwd_triton_vs_pytorch(self, device): + """Triton rmsnorm_bwd should match PyTorch reference.""" + from transformer_engine.pytorch._lite.norms import ( + _rmsnorm_fwd_pytorch, + _rmsnorm_bwd_pytorch, + ) + hidden = 512 + weight = torch.randn(hidden, device=device, dtype=torch.bfloat16) + x = torch.randn(8, hidden, device=device, dtype=torch.bfloat16) + grad_out = torch.randn(8, hidden, device=device, dtype=torch.bfloat16) + + _, rstd = _rmsnorm_fwd_pytorch( + x, weight, 1e-5, False, + ) + dx_pt, dw_pt = _rmsnorm_bwd_pytorch( + grad_out, x, rstd, weight, False, + ) + dx_te, dw_te = tex.rmsnorm_bwd( + grad_out, x, rstd, weight, 0, False, + ) + # Cast to common dtype -- PyTorch fallback may promote to fp32 + # while Triton kernel returns in input dtype + dx_te_bf16 = dx_te.to(torch.bfloat16) + dx_pt_bf16 = dx_pt.to(torch.bfloat16) + dw_te_bf16 = dw_te.to(torch.bfloat16) + dw_pt_bf16 = dw_pt.to(torch.bfloat16) + assert torch.allclose(dx_te_bf16, dx_pt_bf16, atol=5e-2, rtol=2e-2), ( + f"RMSNorm bwd dx max diff: {(dx_te_bf16 - dx_pt_bf16).abs().max().item():.2e}" + ) + assert torch.allclose(dw_te_bf16, dw_pt_bf16, atol=5e-2, rtol=5e-2), ( + f"RMSNorm bwd dw max diff: {(dw_te_bf16 - dw_pt_bf16).abs().max().item():.2e}" + ) + + def test_layernorm_3d_input(self, device): + """Norm functions should handle 3D input (batch, seq, hidden).""" + mod = te.LayerNorm(256).to(dtype=torch.bfloat16, device=device) + x = torch.randn(2, 4, 256, device=device, dtype=torch.bfloat16) + y = mod(x) + assert y.shape == (2, 4, 256) + + def test_rmsnorm_3d_input(self, device): + """RMSNorm should handle 3D input (batch, seq, hidden).""" + mod = te.RMSNorm(256).to(dtype=torch.bfloat16, device=device) + x = torch.randn(2, 4, 256, device=device, dtype=torch.bfloat16) + y = mod(x) + assert y.shape == (2, 4, 256) + + +# --------------------------------------------------------------------------- +# FP8 quantize / dequantize (Phase 2) +# --------------------------------------------------------------------------- + +class TestQuantize: + """Verify FP8 quantize/dequantize works in lite mode without recursion.""" + + def test_fp8_quantize_no_recursion(self, device): + """tex.quantize with Float8Quantizer should not recurse.""" + from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer + + x = torch.randn(8, 16, device=device, dtype=torch.bfloat16) + amax_val = x.abs().max().item() + fp8_max = 240.0 + scale = torch.tensor([fp8_max / amax_val], device=device, dtype=torch.float32) + amax = torch.tensor([0.0], device=device, dtype=torch.float32) + q = Float8Quantizer(scale=scale, amax=amax, fp8_dtype=tex.DType.kFloat8E4M3) + + result = tex.quantize(x, q) + assert hasattr(result, '_data'), "Quantize should return a Float8Tensor" + assert result._data.shape == (8, 16) + assert result._data.dtype == torch.uint8 + + def test_fp8_dequantize(self, device): + """tex.dequantize should reconstruct values from FP8.""" + from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer + + x = torch.randn(8, 16, device=device, dtype=torch.bfloat16) + amax_val = x.abs().max().item() + fp8_max = 240.0 + scale = torch.tensor([fp8_max / amax_val], device=device, dtype=torch.float32) + amax = torch.tensor([0.0], device=device, dtype=torch.float32) + q = Float8Quantizer(scale=scale, amax=amax, fp8_dtype=tex.DType.kFloat8E4M3) + + quantized = tex.quantize(x, q) + y = tex.dequantize(quantized, tex.DType.kBFloat16) + + # FP8 quantization error should be small with proper scaling + max_abs_err = (y.to(torch.bfloat16) - x).abs().max().item() + assert max_abs_err < 0.5, f"FP8 roundtrip max error too large: {max_abs_err:.4f}" + + def test_fp8_roundtrip_relative_error(self, device): + """FP8 quantize-dequantize roundtrip should have low relative error.""" + from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer + + x = torch.randn(32, 64, device=device, dtype=torch.bfloat16) + amax_val = x.abs().max().item() + fp8_max = 240.0 + scale = torch.tensor([fp8_max / amax_val], device=device, dtype=torch.float32) + amax = torch.tensor([0.0], device=device, dtype=torch.float32) + q = Float8Quantizer(scale=scale, amax=amax, fp8_dtype=tex.DType.kFloat8E4M3) + + quantized = tex.quantize(x, q) + y = tex.dequantize(quantized, tex.DType.kBFloat16) + + mean_rel_err = ((y.to(torch.bfloat16) - x).abs() / (x.abs() + 1e-8)).mean().item() + assert mean_rel_err < 0.1, f"FP8 mean relative error too large: {mean_rel_err:.4f}" + + def test_quantize_no_quantizer(self, device): + """quantize with no quantizer should return tensor as-is.""" + x = torch.randn(4, 8, device=device, dtype=torch.bfloat16) + result = tex.quantize(x, None) + assert torch.equal(result, x) + + def test_quantize_with_output(self, device): + """quantize with output tensor should copy into it.""" + x = torch.randn(4, 8, device=device, dtype=torch.bfloat16) + out = torch.empty_like(x) + result = tex.quantize(x, None, output=out) + assert torch.equal(result, x) + + def test_bgrad_quantize(self, device): + """bgrad_quantize should return (bias_grad, quantized).""" + x = torch.randn(4, 8, device=device, dtype=torch.bfloat16) + bgrad, quantized = tex.bgrad_quantize(x, None) + expected_bgrad = x.sum(dim=0) + assert torch.allclose(bgrad, expected_bgrad) + + def test_dequantize_plain_tensor(self, device): + """dequantize on a plain tensor should just cast dtype.""" + x = torch.randn(4, 8, device=device, dtype=torch.float32) + y = tex.dequantize(x, tex.DType.kBFloat16) + assert y.dtype == torch.bfloat16 + assert torch.allclose(y, x.to(torch.bfloat16)) + + # --- CurrentScaling per-row quantize (backward path) --- + + def test_current_scaling_quantize_per_row(self, device): + """CurrentScaling quantizer should produce per-row scales via AITER.""" + import sys + _qmod = sys.modules["transformer_engine.pytorch._lite.quantize"] + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, + ) + + _qmod._try_load_aiter_quant() + if _qmod._aiter_dynamic_per_token_quant is None: + pytest.skip("AITER dynamic_per_token_quant_fp8_i8 not available") + + x = torch.randn(16, 64, device=device, dtype=torch.bfloat16) + q = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=device, + ) + + result = tex.quantize(x, q) + + # Should be Float8Tensor + assert hasattr(result, '_data'), "Expected Float8Tensor output" + assert result._data.shape == (16, 64) + # Per-row: _scale_inv should be (M,) not scalar + assert result._scale_inv.shape == (16,), ( + f"Expected per-row scale_inv (16,), got {result._scale_inv.shape}" + ) + assert (result._scale_inv > 0).all(), "All per-row scales should be positive" + + def test_current_scaling_quantize_roundtrip(self, device): + """CurrentScaling per-row quantize->dequantize roundtrip has low error.""" + import sys + _qmod = sys.modules["transformer_engine.pytorch._lite.quantize"] + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, + ) + + _qmod._try_load_aiter_quant() + if _qmod._aiter_dynamic_per_token_quant is None: + pytest.skip("AITER dynamic_per_token_quant_fp8_i8 not available") + + x = torch.randn(32, 128, device=device, dtype=torch.bfloat16) + q = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=device, + ) + + result = tex.quantize(x, q) + + # Manual dequantize: FP8 data * per-row scale + from transformer_engine.pytorch._lite.quantize import _te_dtype_to_torch_fp8 + fp8_dtype = _te_dtype_to_torch_fp8(q.dtype) + fp8_data = result._data.view(fp8_dtype) + deq = fp8_data.to(torch.float32) * result._scale_inv.unsqueeze(1) + + rel_err = ( + (x.float() - deq).abs() / (x.float().abs() + 1e-8) + ).mean().item() + assert rel_err < 0.05, ( + f"Per-row quantize roundtrip mean rel error {rel_err:.4f} > 5%" + ) + + def test_current_scaling_quantize_backward_dgrad_flow(self, device): + """Simulate backward: quantize dY per-row, then per-token GEMM for dgrad.""" + import sys + _qmod = sys.modules["transformer_engine.pytorch._lite.quantize"] + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, + ) + + _qmod._try_load_aiter_quant() + if _qmod._aiter_dynamic_per_token_quant is None: + pytest.skip("AITER dynamic_per_token_quant_fp8_i8 not available") + + try: + from aiter.ops.triton.gemm_a8w8_per_token_scale import ( + gemm_a8w8_per_token_scale, + ) + except ImportError: + pytest.skip("AITER gemm_a8w8_per_token_scale not available") + + M, N, K = 32, 64, 128 # dY is [M, N], W is [N, K], dX = dY @ W + fp8_dtype = torch.float8_e4m3fnuz + + # dY: quantize per-row via CurrentScaling + dY = torch.randn(M, N, device=device, dtype=torch.bfloat16) + q = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=device, + ) + dY_quant = tex.quantize(dY, q) + dY_fp8 = dY_quant._data.view(fp8_dtype) + dY_scale = dY_quant._scale_inv # (M,) + + # W: per-tensor quantize + W = torch.randn(N, K, device=device, dtype=torch.bfloat16) + w_amax = W.abs().max() + w_qs = 240.0 / w_amax + W_fp8 = (W.float() * w_qs).to(fp8_dtype) + w_ds = torch.full((K, 1), (1.0 / w_qs).item(), + dtype=torch.float32, device=device) + + # dgrad GEMM: dX = dY @ W (dY is [M,N], W is [N,K]) + # per_token_scale: Y = X @ W^T, so X=dY [M,N], W_t=W^T [K,N] + # We need W transposed: W is [N,K], so W^T is [K,N] + # gemm_a8w8_per_token_scale(x, w, x_scale, w_scale) computes x @ w^T + # So: x=dY_fp8 [M,N], w=W_fp8^T [K,N] → result [M,K] + # But kernel takes w in [N_out, K_in] and transposes internally + # Actually: kernel computes Y = X @ W^T where W is [K, N] + # So we pass w=W^T = [K, N], then kernel does dY @ (W^T)^T = dY @ W + W_T_fp8 = W_fp8.t().contiguous() # [K, N] + result = gemm_a8w8_per_token_scale( + dY_fp8, W_T_fp8, + dY_scale.unsqueeze(1), w_ds, + ) + + # Reference: dequant both, matmul + dY_deq = dY_fp8.to(torch.float32) * dY_scale.unsqueeze(1) + W_deq = W_fp8.to(torch.float32) * (1.0 / w_qs.item()) + ref = dY_deq @ W_deq # [M,N] @ [N,K] = [M,K] + + assert result.shape == (M, K), f"Expected ({M},{K}), got {result.shape}" + rel_err = ( + (result.float() - ref).abs() / (ref.abs() + 1e-8) + ).mean().item() + assert rel_err < 0.05, ( + f"Backward dgrad per-row mean rel error {rel_err:.4f} > 5%" + ) + + +# --------------------------------------------------------------------------- +# MXFP8 BlockScaling tests +# --------------------------------------------------------------------------- + +class TestMXFP8: + """Verify MXFP8 BlockScaling detection, quantize, GEMM, and norms in lite mode.""" + + def test_mxfp8_detection_not_fp4(self, device): + """MXFP8 tensor should be detected as MXFP8, not FP4.""" + from transformer_engine.pytorch._lite.gemm import _is_mxfp8, _is_fp4 + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + # Both dims must be divisible by 32 for MXFP8 + q = MXFP8Quantizer(tex.DType.kFloat8E4M3) + t = q.make_empty((32, 64), dtype=torch.bfloat16, device=device) + assert _is_mxfp8(t), "MXFP8 tensor not detected by _is_mxfp8()" + assert not _is_fp4(t), "MXFP8 tensor should NOT match _is_fp4()" + + def test_fp4_detection_not_mxfp8(self, device): + """MXFP4 tensor should be detected as FP4, not MXFP8.""" + from transformer_engine.pytorch._lite.gemm import _is_mxfp8, _is_fp4 + try: + from transformer_engine.pytorch.tensor.mxfp4_tensor import MXFP4Quantizer + except ImportError: + pytest.skip("MXFP4Quantizer not available") + + q = MXFP4Quantizer(tex.DType.kFloat4E2M1) + t = q.make_empty((32, 64), dtype=torch.bfloat16, device=device) + assert _is_fp4(t), "MXFP4 tensor not detected by _is_fp4()" + assert not _is_mxfp8(t), "MXFP4 tensor should NOT match _is_mxfp8()" + + def test_mxfp8_is_quantized(self, device): + """_is_quantized should return True for MXFP8 tensors.""" + from transformer_engine.pytorch._lite.gemm import _is_quantized + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + q = MXFP8Quantizer(tex.DType.kFloat8E4M3) + t = q.make_empty((32, 64), dtype=torch.bfloat16, device=device) + assert _is_quantized(t), "MXFP8 tensor should be detected as quantized" + + def test_linear_scale_to_e8m0(self, device): + """E8M0 conversion should produce correct biased exponents.""" + import sys + _qmod = sys.modules["transformer_engine.pytorch._lite.quantize"] + e8m0_fn = _qmod._linear_scale_to_e8m0 + + scales = torch.tensor([1.0, 2.0, 0.5, 4.0, 0.25], device=device) + e8m0 = e8m0_fn(scales) + # 1.0 = 2^0 → 0 + 127 = 127 + # 2.0 = 2^1 → 1 + 127 = 128 + # 0.5 = 2^-1 → -1 + 127 = 126 + # 4.0 = 2^2 → 2 + 127 = 129 + # 0.25 = 2^-2 → -2 + 127 = 125 + expected = torch.tensor([127, 128, 126, 129, 125], dtype=torch.uint8, device=device) + assert torch.equal(e8m0, expected), ( + f"E8M0 mismatch: {e8m0.tolist()} vs {expected.tolist()}" + ) + + def test_mxfp8_quantize_roundtrip(self, device): + """MXFP8 quantize→dequantize roundtrip should have low error.""" + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + x = torch.randn(32, 128, device=device, dtype=torch.bfloat16) + q = MXFP8Quantizer(tex.DType.kFloat8E4M3) + + quantized = tex.quantize(x, q) + # Verify it's an MXFP8 tensor + assert hasattr(quantized, '_rowwise_data'), "Expected MXFP8Tensor" + assert hasattr(quantized, '_rowwise_scale_inv'), "Expected MXFP8 scales" + + # Dequantize and check error + deq = tex.dequantize(quantized, tex.DType.kBFloat16) + rel_err = ( + (x.float() - deq.float()).abs() / (x.float().abs() + 1e-8) + ).mean().item() + assert rel_err < 0.1, ( + f"MXFP8 roundtrip mean relative error {rel_err:.4f} > 10%" + ) + + def test_mxfp8_quantize_pytorch_fallback(self, device): + """MXFP8 PyTorch fallback should produce correct results without Triton.""" + import sys + _qmod = sys.modules["transformer_engine.pytorch._lite.quantize"] + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + x = torch.randn(32, 128, device=device, dtype=torch.bfloat16) + q = MXFP8Quantizer(tex.DType.kFloat8E4M3) + out = q.make_empty(x.shape, dtype=x.dtype, device=device) + + result = _qmod._quantize_mxfp8_pytorch(x, q, out) + + # Verify data was written + assert result._rowwise_data is not None + assert result._rowwise_data.any(), "FP8 data should be non-zero" + # Verify E8M0 scales were written + assert result._rowwise_scale_inv is not None + assert result._rowwise_scale_inv.any(), "E8M0 scales should be non-zero" + + # Dequantize via Triton/tensor method and check error + deq = result.dequantize(dtype=torch.bfloat16) + rel_err = ( + (x.float() - deq.float()).abs() / (x.float().abs() + 1e-8) + ).mean().item() + assert rel_err < 0.1, ( + f"MXFP8 PyTorch fallback roundtrip mean rel error {rel_err:.4f} > 10%" + ) + + def test_mxfp8_gemm_dequant_path(self, device): + """MXFP8 tensors through generic_gemm should produce correct BF16 via dequant.""" + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + M, N, K = 32, 64, 128 + A_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) + B_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) + + # Quantize both to MXFP8 + q = MXFP8Quantizer(tex.DType.kFloat8E4M3) + A_mxfp8 = tex.quantize(A_bf16, q) + B_mxfp8 = tex.quantize(B_bf16, q) + + # Reference: dequantize then matmul + A_deq = A_mxfp8.dequantize(dtype=torch.bfloat16) + B_deq = B_mxfp8.dequantize(dtype=torch.bfloat16) + ref = B_deq @ A_deq.t() # TN layout: result = B @ A^T + + ws = torch.empty(1024, device=device, dtype=torch.uint8) + out, _, _, _ = tex.generic_gemm( + A_mxfp8, True, B_mxfp8, False, None, None, None, + None, None, False, None, False, ws, ws.shape[0], + False, False, + ) + + assert out.shape == (M, N), f"Expected ({M},{N}), got {out.shape}" + # Should match dequant reference closely (same precision path) + max_diff = (out.float() - ref.float()).abs().max().item() + assert max_diff < 0.5, ( + f"MXFP8 GEMM max diff {max_diff:.4f} vs dequant reference" + ) + + def test_mxfp8_fused_rmsnorm(self, device): + """Fused RMSNorm+MXFP8 quant should produce valid MXFP8Tensor.""" + from transformer_engine.pytorch._lite import norms as _n + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + _n._try_load_aiter_norms() + if _n._aiter_fused_rms_fp8_group is None: + pytest.skip("AITER fused_rms_fp8_group_quant not available") + + hidden = 256 + x = torch.randn(32, hidden, device=device, dtype=torch.bfloat16) + w = torch.randn(hidden, device=device, dtype=torch.bfloat16) + + q = MXFP8Quantizer(tex.DType.kFloat8E4M3) + out, _, rsigma = tex.rmsnorm_fwd(x, w, 1e-5, None, q, None, 0, False) + + # Verify output is MXFP8 + assert hasattr(out, '_rowwise_data'), ( + f"Expected MXFP8Tensor, got {type(out).__name__}" + ) + assert out._rowwise_data is not None, "MXFP8 rowwise data should be populated" + assert out._rowwise_scale_inv is not None, "MXFP8 scales should be populated" + assert out._rowwise_scale_inv.any(), "E8M0 scales should be non-zero" + + +# --------------------------------------------------------------------------- +# GEMM tests +# --------------------------------------------------------------------------- + +class TestGemm: + """Verify generic_gemm in lite mode (AITER CK/Triton + PyTorch fallback).""" + + DTYPE = torch.bfloat16 + + def _workspace(self, device): + return torch.empty(1024, device=device, dtype=torch.uint8) + + # -- Basic matmul (TN layout: weight[out,in], input[batch,in]) -------- + + def test_gemm_tn_basic(self, device): + """TN GEMM: A[out,in].T @ B[batch,in] -> [batch,out].""" + M, N, K = 128, 64, 256 + A = torch.randn(N, K, device=device, dtype=self.DTYPE) # weight [out, in] + B = torch.randn(M, K, device=device, dtype=self.DTYPE) # input [batch, in] + ws = self._workspace(device) + out, bias_grad, gelu_in, extra = tex.generic_gemm( + A, True, B, False, None, None, None, + None, None, False, None, False, ws, ws.shape[0], + False, False, + ) + assert out.shape == (M, N), f"Expected ({M},{N}), got {out.shape}" + + def test_gemm_tn_numerical(self, device): + """TN GEMM result should match torch.matmul reference.""" + M, N, K = 32, 64, 128 + A = torch.randn(N, K, device=device, dtype=self.DTYPE) + B = torch.randn(M, K, device=device, dtype=self.DTYPE) + ws = self._workspace(device) + out, _, _, _ = tex.generic_gemm( + A, True, B, False, None, None, None, + None, None, False, None, False, ws, ws.shape[0], + False, False, + ) + ref = B @ A.t() # [M,K] @ [K,N] = [M,N] + assert torch.allclose(out.to(self.DTYPE), ref, atol=1e-2, rtol=1e-2), ( + f"GEMM max diff: {(out.to(self.DTYPE) - ref).abs().max().item():.4e}" + ) + + @pytest.mark.parametrize("transA,transB,shapeA,shapeB,expect", [ + (True, False, (64, 128), (32, 128), (32, 64)), # TN + (False, False, (128, 64), (32, 128), (32, 64)), # NN + (True, True, (64, 128), (128, 32), (32, 64)), # TT + (False, True, (128, 64), (128, 32), (32, 64)), # NT + ]) + def test_gemm_transpose_combos(self, device, transA, transB, shapeA, shapeB, expect): + """All transpose combinations should produce correct shapes.""" + A = torch.randn(*shapeA, device=device, dtype=self.DTYPE) + B = torch.randn(*shapeB, device=device, dtype=self.DTYPE) + ws = self._workspace(device) + out, _, _, _ = tex.generic_gemm( + A, transA, B, transB, None, None, None, + None, None, False, None, False, ws, ws.shape[0], + False, False, + ) + assert out.shape == expect, f"Expected {expect}, got {out.shape}" + + # -- Bias epilogue ---------------------------------------------------- + + def test_gemm_with_bias(self, device): + """GEMM + bias addition should work.""" + M, N, K = 32, 64, 128 + A = torch.randn(N, K, device=device, dtype=self.DTYPE) + B = torch.randn(M, K, device=device, dtype=self.DTYPE) + bias = torch.randn(N, device=device, dtype=self.DTYPE) + ws = self._workspace(device) + out, _, _, _ = tex.generic_gemm( + A, True, B, False, None, None, None, + bias, None, False, None, False, ws, ws.shape[0], + False, False, + ) + ref = B @ A.t() + bias + assert torch.allclose(out.to(self.DTYPE), ref, atol=1e-2, rtol=1e-2), ( + f"GEMM+bias max diff: {(out.to(self.DTYPE) - ref).abs().max().item():.4e}" + ) + + def test_gemm_bias_grad(self, device): + """GEMM with grad=True should compute bias gradient. + + In the backward pass, B is the grad_output dY. The bias gradient + is dY.reshape(-1, dY.shape[-1]).sum(dim=0). + cuBLAS column-major: result = op(B) @ op(A). + Use transA=False, transB=False so result = B @ A. + A=[N,K], B=[M,N] → result=[M,K], bias_grad=B.sum(0)=[N]. + """ + M, N, K = 32, 64, 128 + A = torch.randn(N, K, device=device, dtype=self.DTYPE) + B = torch.randn(M, N, device=device, dtype=self.DTYPE) # dY [M, N] + bias = torch.randn(N, device=device, dtype=self.DTYPE) + ws = self._workspace(device) + _, bias_grad, _, _ = tex.generic_gemm( + A, False, B, False, None, None, None, + bias, None, False, None, True, ws, ws.shape[0], + False, False, + ) + ref_bgrad = B.reshape(-1, B.shape[-1]).sum(dim=0) + assert bias_grad.shape == ref_bgrad.shape, ( + f"bias_grad shape {bias_grad.shape} != expected {ref_bgrad.shape}" + ) + assert torch.allclose(bias_grad.to(self.DTYPE), ref_bgrad, atol=1e-2, rtol=1e-2) + + # -- GELU epilogue ---------------------------------------------------- + + def test_gemm_with_gelu(self, device): + """GEMM + GELU activation should work.""" + M, N, K = 32, 64, 128 + A = torch.randn(N, K, device=device, dtype=self.DTYPE) + B = torch.randn(M, K, device=device, dtype=self.DTYPE) + gelu_in = torch.empty(M, N, device=device, dtype=self.DTYPE) + ws = self._workspace(device) + out, _, gelu_saved, _ = tex.generic_gemm( + A, True, B, False, None, None, None, + None, None, True, gelu_in, False, ws, ws.shape[0], + False, False, + ) + # gelu_saved should hold the pre-GELU values + ref_pre_gelu = B @ A.t() + ref_out = torch.nn.functional.gelu(ref_pre_gelu, approximate='tanh') + assert torch.allclose(out.to(self.DTYPE), ref_out, atol=1e-2, rtol=1e-2), ( + f"GEMM+GELU max diff: {(out.to(self.DTYPE) - ref_out).abs().max().item():.4e}" + ) + + # -- Accumulate ------------------------------------------------------- + + def test_gemm_accumulate(self, device): + """GEMM with accumulate=True should add to existing D.""" + M, N, K = 32, 64, 128 + A = torch.randn(N, K, device=device, dtype=self.DTYPE) + B = torch.randn(M, K, device=device, dtype=self.DTYPE) + D = torch.randn(M, N, device=device, dtype=self.DTYPE) + D_orig = D.clone() + ws = self._workspace(device) + out, _, _, _ = tex.generic_gemm( + A, True, B, False, D, None, None, + None, None, False, None, False, ws, ws.shape[0], + True, False, + ) + ref = D_orig + B @ A.t() + assert torch.allclose(D.to(self.DTYPE), ref, atol=1e-2, rtol=1e-2), ( + f"Accumulate max diff: {(D.to(self.DTYPE) - ref).abs().max().item():.4e}" + ) + + # -- Alpha scaling ---------------------------------------------------- + + def test_gemm_alpha(self, device): + """GEMM with alpha != 1.0 should scale the result.""" + M, N, K = 32, 64, 128 + A = torch.randn(N, K, device=device, dtype=self.DTYPE) + B = torch.randn(M, K, device=device, dtype=self.DTYPE) + ws = self._workspace(device) + out, _, _, _ = tex.generic_gemm( + A, True, B, False, None, None, None, + None, None, False, None, False, ws, ws.shape[0], + False, False, alpha=0.5, + ) + ref = 0.5 * (B @ A.t()) + assert torch.allclose(out.to(self.DTYPE), ref, atol=1e-2, rtol=1e-2) + + # -- Output into pre-allocated D (no accumulate) ---------------------- + + def test_gemm_output_into_d(self, device): + """GEMM should write result into D when accumulate=False.""" + M, N, K = 32, 64, 128 + A = torch.randn(N, K, device=device, dtype=self.DTYPE) + B = torch.randn(M, K, device=device, dtype=self.DTYPE) + D = torch.zeros(M, N, device=device, dtype=self.DTYPE) + ws = self._workspace(device) + out, _, _, _ = tex.generic_gemm( + A, True, B, False, D, None, None, + None, None, False, None, False, ws, ws.shape[0], + False, False, + ) + ref = B @ A.t() + assert torch.allclose(D, ref, atol=1e-2, rtol=1e-2) + + # -- Return format ---------------------------------------------------- + + def test_gemm_return_format(self, device): + """generic_gemm should return (out, bias_grad, gelu_input, extra_output).""" + M, N, K = 16, 32, 64 + A = torch.randn(N, K, device=device, dtype=self.DTYPE) + B = torch.randn(M, K, device=device, dtype=self.DTYPE) + ws = self._workspace(device) + result = tex.generic_gemm( + A, True, B, False, None, None, None, + None, None, False, None, False, ws, ws.shape[0], + False, False, + ) + assert isinstance(result, tuple) and len(result) == 4, ( + f"Expected 4-tuple, got {type(result)} of len {len(result)}" + ) + + # -- FP32 precision --------------------------------------------------- + + def test_gemm_fp32(self, device): + """GEMM should work with FP32 inputs.""" + M, N, K = 32, 64, 128 + A = torch.randn(N, K, device=device, dtype=torch.float32) + B = torch.randn(M, K, device=device, dtype=torch.float32) + ws = self._workspace(device) + out, _, _, _ = tex.generic_gemm( + A, True, B, False, None, None, None, + None, None, False, None, False, ws, ws.shape[0], + False, False, + ) + ref = B @ A.t() + assert torch.allclose(out, ref, atol=1e-5, rtol=1e-5) + + # -- output_dtype honored across mixed-precision operands -------------- + + def test_gemm_output_dtype_honored_mixed_operands(self, device): + """output_dtype must be honored even when an operand is fp32. + + Regression: the PyTorch fallback promoted compute to fp32 whenever + either operand was fp32 and ignored `output_dtype`, returning fp32. + The next module then failed `set_activation_dtype` with input=fp32 + against bf16 weights. cuBLAS in the full build always casts to the + caller-requested dtype. + """ + M, N, K = 16, 32, 64 + ws = self._workspace(device) + + # bf16 weight × fp32 activation: naive promotion would yield fp32. + # Caller requests bf16 output. + A = torch.randn(N, K, device=device, dtype=torch.bfloat16) + B_fp32 = torch.randn(M, K, device=device, dtype=torch.float32) + + out, _, _, _ = tex.generic_gemm( + A, True, B_fp32, False, None, None, tex.DType.kBFloat16, + None, None, False, None, False, ws, ws.shape[0], + False, False, + ) + assert out.dtype == torch.bfloat16, ( + f"Expected bf16 output, got {out.dtype} — output_dtype not honored" + ) + + # Symmetric case: fp32 weight × bf16 activation. + A_fp32 = torch.randn(N, K, device=device, dtype=torch.float32) + B = torch.randn(M, K, device=device, dtype=torch.bfloat16) + out, _, _, _ = tex.generic_gemm( + A_fp32, True, B, False, None, None, tex.DType.kBFloat16, + None, None, False, None, False, ws, ws.shape[0], + False, False, + ) + assert out.dtype == torch.bfloat16 + + # torch.dtype is also accepted (pass-through in _resolve_output_dtype). + out, _, _, _ = tex.generic_gemm( + A, True, B_fp32, False, None, None, torch.bfloat16, + None, None, False, None, False, ws, ws.shape[0], + False, False, + ) + assert out.dtype == torch.bfloat16 + + # -- Per-row scaled FP8 GEMM (CurrentScaling) ---------------------------- + + def test_gemm_per_row_scaled_fp8(self, device): + """FP8 GEMM with per-row activation scales dispatches correctly.""" + from transformer_engine.pytorch._lite.gemm import ( + _is_per_row_scaled, _is_block_scaled, is_aiter_available, + ) + if not is_aiter_available(): + pytest.skip("AITER not available") + try: + from aiter.ops.triton.gemm_a8w8_per_token_scale import ( + gemm_a8w8_per_token_scale, + ) + except ImportError: + pytest.skip("AITER gemm_a8w8_per_token_scale not available") + + M, N, K = 32, 64, 128 + fp8_dtype = torch.float8_e4m3fnuz + + # Create per-row-scaled activation (simulates fused norm+quant output) + x_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) + from aiter.ops.triton.quant import dynamic_per_token_quant_fp8_i8 + x_fp8 = torch.empty(M, K, dtype=fp8_dtype, device=device) + x_scale = torch.empty(M, dtype=torch.float32, device=device) + dynamic_per_token_quant_fp8_i8(x_fp8, x_bf16, x_scale) + + # Create per-tensor-scaled weight + w_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) + w_amax = w_bf16.abs().max() + w_quant_scale = 240.0 / w_amax + w_fp8 = (w_bf16.float() * w_quant_scale).to(fp8_dtype) + w_dequant = torch.tensor([1.0 / w_quant_scale.item()], + dtype=torch.float32, device=device) + + # Verify scale detection + assert _is_per_row_scaled(x_scale), "x_scale should be per-row" + assert not _is_per_row_scaled(w_dequant), "w_dequant should not be per-row" + assert not _is_block_scaled(x_scale), "per-row scale should not be block-scaled" + + # Build mock Float8Tensor-like objects for generic_gemm + class _FP8Wrap: + def __init__(self, data, scale_inv): + self._data = data + self._scale_inv = scale_inv + @property + def dtype(self): + return self._data.dtype + + A = _FP8Wrap(w_fp8, w_dequant) # weight [N, K], transA=True + B = _FP8Wrap(x_fp8, x_scale) # activation [M, K], transB=False + + ws = self._workspace(device) + out, _, _, _ = tex.generic_gemm( + A, True, B, False, None, None, None, + None, None, False, None, False, ws, ws.shape[0], + False, False, + ) + + # Reference: dequantize both and matmul in float32 + x_deq = x_fp8.to(torch.float32) * x_scale.unsqueeze(1) + w_deq = w_fp8.to(torch.float32) * (1.0 / w_quant_scale.item()) + ref = x_deq @ w_deq.t() + + assert out.shape == (M, N), f"Expected ({M},{N}), got {out.shape}" + rel_err = ( + (out.float() - ref).abs() / (ref.abs() + 1e-8) + ).mean().item() + assert rel_err < 0.05, ( + f"Per-row GEMM mean relative error {rel_err:.4f} exceeds 5% tolerance" + ) + + def test_gemm_per_row_scaled_numerical_accuracy(self, device): + """Per-row scaled FP8 GEMM matches dequantized matmul reference.""" + from transformer_engine.pytorch._lite.gemm import is_aiter_available + if not is_aiter_available(): + pytest.skip("AITER not available") + try: + from aiter.ops.triton.gemm_a8w8_per_token_scale import ( + gemm_a8w8_per_token_scale, + ) + except ImportError: + pytest.skip("AITER gemm_a8w8_per_token_scale not available") + + from aiter.ops.triton.rmsnorm import rmsnorm2d_fwd_with_dynamicquant + + M, K, N = 64, 256, 128 + fp8_dtype = torch.float8_e4m3fnuz + + # Full forward path: input → fused RMSNorm+quant → per-token GEMM + inp = torch.randn(M, K, device=device, dtype=torch.bfloat16) + norm_w = torch.randn(K, device=device, dtype=torch.bfloat16) + x_fp8 = torch.empty(M, K, dtype=fp8_dtype, device=device) + x_scale = torch.empty(M, dtype=torch.float32, device=device) + rmsnorm2d_fwd_with_dynamicquant(x_fp8, inp, x_scale, norm_w, 1e-5) + + # Weight (per-tensor quantized) + w_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) + w_amax = w_bf16.abs().max() + w_qs = 240.0 / w_amax + w_fp8 = (w_bf16.float() * w_qs).to(fp8_dtype) + w_ds = torch.full((N, 1), (1.0 / w_qs).item(), + dtype=torch.float32, device=device) + + result = gemm_a8w8_per_token_scale( + x_fp8, w_fp8, x_scale.unsqueeze(1), w_ds, + ) + + # Reference + x_deq = x_fp8.to(torch.float32) * x_scale.unsqueeze(1) + w_deq = w_fp8.to(torch.float32) * w_ds + ref = x_deq @ w_deq.t() + + rel_err = ( + (result.float() - ref).abs() / (ref.abs() + 1e-8) + ).mean().item() + assert rel_err < 0.02, ( + f"End-to-end fused norm→per-row GEMM mean rel error {rel_err:.4f} > 2%" + ) + + +# --------------------------------------------------------------------------- +# GEMM backend coverage (pytorch / triton / ck parity + dispatch asserts) +# --------------------------------------------------------------------------- + +class _FP8Wrap: + """Minimal Float8Tensor shim for generic_gemm. + + _is_quantized(tensor) returns True for anything with (_data, _scale_inv). + _get_raw_data returns (_data, _scale_inv), which is all downstream + dispatch paths need. + """ + def __init__(self, data, scale_inv): + self._data = data + self._scale_inv = scale_inv + + @property + def dtype(self): + return self._data.dtype + + +def _quant_per_tensor_e4m3(x_bf16, fp8_dtype=torch.float8_e4m3fnuz): + """Quantize a BF16 tensor to FP8 with a single scalar per-tensor scale. + + Returns (fp8_data, scale_inv_scalar) where scale_inv_scalar is a 1-elem + tensor carrying dequant scale (matches Float8Tensor._scale_inv layout). + """ + amax = x_bf16.abs().max().clamp_min(1e-6) + qscale = 240.0 / amax + x_fp8 = (x_bf16.float() * qscale).to(fp8_dtype) + scale_inv = torch.tensor([1.0 / qscale.item()], + dtype=torch.float32, device=x_bf16.device) + return x_fp8, scale_inv + + +class TestGemmBackendMatrix: + """Ensure all three GEMM backends produce correct results and the + `pytorch` backend actually takes the fast `torch._scaled_mm` path. + """ + + DTYPE = torch.bfloat16 + + def _workspace(self, device): + return torch.empty(1024, device=device, dtype=torch.uint8) + + def _set_backend(self, monkeypatch, backend): + """Swap _GEMM_BACKEND in the gemm module for one test.""" + from transformer_engine.pytorch._lite import gemm as lite_gemm + monkeypatch.setattr(lite_gemm, "_GEMM_BACKEND", backend) + + # -- Backend matrix: BF16 --------------------------------------------- + + @pytest.mark.parametrize("backend", ["pytorch", "triton", "ck"]) + def test_bf16_gemm_matches_reference(self, device, monkeypatch, backend): + """BF16 GEMM must agree with torch.matmul on every backend. + + Protects against silent regressions in the ck/triton paths now + that `pytorch` is the default. + """ + self._set_backend(monkeypatch, backend) + M, N, K = 32, 64, 128 + A = torch.randn(N, K, device=device, dtype=self.DTYPE) + B = torch.randn(M, K, device=device, dtype=self.DTYPE) + ws = self._workspace(device) + out, _, _, _ = tex.generic_gemm( + A, True, B, False, None, None, None, + None, None, False, None, False, ws, ws.shape[0], + False, False, + ) + ref = B @ A.t() + max_diff = (out.to(self.DTYPE) - ref).abs().max().item() + assert max_diff < 5e-2, ( + f"[backend={backend}] BF16 GEMM max diff {max_diff:.4e}" + ) + + # -- Backend matrix: per-tensor FP8 (DelayedScaling layout) ----------- + + @pytest.mark.parametrize("backend", ["pytorch", "triton", "ck"]) + def test_per_tensor_fp8_gemm_matches_dequant( + self, device, monkeypatch, backend, + ): + """Per-tensor FP8×FP8 GEMM (DelayedScaling shape) must match the + dequantized reference on every backend. + + This is the recipe Megatron hard-codes. Scalar scales should route + to the per-tensor kernel family on all three backends. A regression + here means the production training path is broken. + """ + from transformer_engine.pytorch._lite.gemm import is_aiter_available + if backend in ("triton", "ck") and not is_aiter_available(): + pytest.skip("AITER not available") + + self._set_backend(monkeypatch, backend) + M, N, K = 64, 128, 256 + fp8 = torch.float8_e4m3fnuz + + x_bf16 = torch.randn(M, K, device=device, dtype=self.DTYPE) + w_bf16 = torch.randn(N, K, device=device, dtype=self.DTYPE) + x_fp8, x_scale = _quant_per_tensor_e4m3(x_bf16, fp8) + w_fp8, w_scale = _quant_per_tensor_e4m3(w_bf16, fp8) + + A = _FP8Wrap(w_fp8, w_scale) # weight [N, K] + B = _FP8Wrap(x_fp8, x_scale) # activation [M, K] + + ws = self._workspace(device) + out, _, _, _ = tex.generic_gemm( + A, True, B, False, None, None, tex.DType.kBFloat16, + None, None, False, None, False, ws, ws.shape[0], + False, False, + ) + + # Dequantized reference in fp32 + x_deq = x_fp8.float() * x_scale.item() + w_deq = w_fp8.float() * w_scale.item() + ref = x_deq @ w_deq.t() + + rel_err = ( + (out.float() - ref).abs() / (ref.abs() + 1e-3) + ).mean().item() + assert rel_err < 0.05, ( + f"[backend={backend}] per-tensor FP8 GEMM mean rel err " + f"{rel_err:.4f} exceeds 5% tolerance" + ) + assert out.shape == (M, N) + assert out.dtype == torch.bfloat16 + + # -- Dispatch path counters: pytorch backend must take fast path ------ + + def test_pytorch_backend_takes_scaled_mm_path( + self, device, monkeypatch, + ): + """Per-tensor FP8 under backend=pytorch must land on _scaled_mm, + not dequant+matmul. + + The whole point of the pytorch default is hipBLASLt via _scaled_mm. + If a future change silently forces scalar scales into a rejected + layout (e.g. broadcast to (M,1)), this test catches it by reading + the dispatch counters. + """ + from transformer_engine.pytorch._lite import gemm as lite_gemm + if not hasattr(torch, "_scaled_mm"): + pytest.skip("torch._scaled_mm not available") + + self._set_backend(monkeypatch, "pytorch") + # Counters are gated behind _LITE_DIAG; flip on and zero them. + monkeypatch.setattr(lite_gemm, "_LITE_DIAG", True) + lite_gemm._GEMM_CALLS.clear() + + M, N, K = 64, 128, 256 + fp8 = torch.float8_e4m3fnuz + x_fp8, x_scale = _quant_per_tensor_e4m3( + torch.randn(M, K, device=device, dtype=self.DTYPE), fp8, + ) + w_fp8, w_scale = _quant_per_tensor_e4m3( + torch.randn(N, K, device=device, dtype=self.DTYPE), fp8, + ) + + A = _FP8Wrap(w_fp8, w_scale) + B = _FP8Wrap(x_fp8, x_scale) + ws = self._workspace(device) + tex.generic_gemm( + A, True, B, False, None, None, tex.DType.kBFloat16, + None, None, False, None, False, ws, ws.shape[0], + False, False, + ) + + calls = dict(lite_gemm._GEMM_CALLS) + assert calls.get("pytorch_scaled_mm_ok", 0) >= 1, ( + f"Expected pytorch_scaled_mm_ok>=1 for per-tensor FP8 under " + f"backend=pytorch; got counters {calls}" + ) + assert calls.get("pytorch_dequant_matmul", 0) == 0, ( + f"Per-tensor FP8 under backend=pytorch must not fall back to " + f"dequant+matmul (the 100-1000x slow path); got counters {calls}" + ) + + # -- Pad-M: M not divisible by 16 ------------------------------------ + + def test_scaled_mm_pads_m_when_not_div16( + self, device, monkeypatch, + ): + """hipBLASLt FP8 requires mat1 rows div-by-16. The pad-then-slice + path (commit 3ed9d8ae) must preserve numerical correctness. + + Uses M=100 (pads to 112); no unit test currently exercises this. + """ + from transformer_engine.pytorch._lite import gemm as lite_gemm + if not hasattr(torch, "_scaled_mm"): + pytest.skip("torch._scaled_mm not available") + + self._set_backend(monkeypatch, "pytorch") + monkeypatch.setattr(lite_gemm, "_LITE_DIAG", True) + lite_gemm._GEMM_CALLS.clear() + + M, N, K = 100, 64, 128 # M not div-by-16, K div-by-16 + assert M % 16 != 0 and K % 16 == 0, "Shape preconditions" + + fp8 = torch.float8_e4m3fnuz + x_bf16 = torch.randn(M, K, device=device, dtype=self.DTYPE) + w_bf16 = torch.randn(N, K, device=device, dtype=self.DTYPE) + x_fp8, x_scale = _quant_per_tensor_e4m3(x_bf16, fp8) + w_fp8, w_scale = _quant_per_tensor_e4m3(w_bf16, fp8) + + A = _FP8Wrap(w_fp8, w_scale) + B = _FP8Wrap(x_fp8, x_scale) + ws = self._workspace(device) + out, _, _, _ = tex.generic_gemm( + A, True, B, False, None, None, tex.DType.kBFloat16, + None, None, False, None, False, ws, ws.shape[0], + False, False, + ) + + # The fast path should still fire — pad-and-slice, not dequant. + calls = dict(lite_gemm._GEMM_CALLS) + assert calls.get("pytorch_scaled_mm_ok", 0) >= 1, ( + f"Pad-M case must still land on _scaled_mm; got {calls}" + ) + + # Output must match dequantized reference AND be sliced back to M. + assert out.shape == (M, N), ( + f"Output must be sliced back to M={M}, got {out.shape}" + ) + x_deq = x_fp8.float() * x_scale.item() + w_deq = w_fp8.float() * w_scale.item() + ref = x_deq @ w_deq.t() + rel_err = ( + (out.float() - ref).abs() / (ref.abs() + 1e-3) + ).mean().item() + assert rel_err < 0.05, ( + f"Pad-M FP8 GEMM mean rel err {rel_err:.4f} exceeds 5%" + ) + + +# --------------------------------------------------------------------------- +# Attention tests +# --------------------------------------------------------------------------- + +class TestAttention: + """Verify attention kernels in lite mode (AITER CK + SDPA fallback).""" + + B, S, H, D = 2, 64, 16, 64 + DTYPE = torch.bfloat16 + + def _make_qkv(self, device, fmt="bshd", num_kv_heads=None): + """Create Q, K, V tensors and cu_seqlens for a given format.""" + B, S, H, D = self.B, self.S, self.H, self.D + H_kv = num_kv_heads or H + cu = torch.arange(0, (B + 1) * S, S, device=device, dtype=torch.int32) + if fmt == "bshd": + q = torch.randn(B, S, H, D, device=device, dtype=self.DTYPE) + k = torch.randn(B, S, H_kv, D, device=device, dtype=self.DTYPE) + v = torch.randn(B, S, H_kv, D, device=device, dtype=self.DTYPE) + elif fmt == "sbhd": + q = torch.randn(S, B, H, D, device=device, dtype=self.DTYPE) + k = torch.randn(S, B, H_kv, D, device=device, dtype=self.DTYPE) + v = torch.randn(S, B, H_kv, D, device=device, dtype=self.DTYPE) + elif fmt == "thd": + total = B * S + cu = torch.arange(0, (B + 1) * S, S, device=device, dtype=torch.int32) + q = torch.randn(total, H, D, device=device, dtype=self.DTYPE) + k = torch.randn(total, H_kv, D, device=device, dtype=self.DTYPE) + v = torch.randn(total, H_kv, D, device=device, dtype=self.DTYPE) + else: + raise ValueError(fmt) + return q, k, v, cu + + # -- get_fused_attn_backend ------------------------------------------- + + def test_backend_selection(self): + """get_fused_attn_backend should return a valid backend.""" + backend = tex.get_fused_attn_backend( + True, # is_training + tex.DType.kBFloat16, tex.DType.kBFloat16, # q/kv dtype + tex.NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD, # layout + tex.NVTE_Bias_Type.NVTE_NO_BIAS, + tex.NVTE_Mask_Type.NVTE_CAUSAL_MASK, + tex.NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, + 0.0, 16, 16, 64, 64, 64, 64, -1, -1, False, False, + ) + assert backend in ( + tex.NVTE_Fused_Attn_Backend.NVTE_CK, + tex.NVTE_Fused_Attn_Backend.NVTE_SDPA, + ) + + # -- fused_attn_fwd / bwd (C++ binding interface) -------------------- + + @pytest.mark.parametrize("layout_name,layout_enum", [ + ("bshd", tex.NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD), + ("sbhd", tex.NVTE_QKV_Layout.NVTE_SBHD_SBHD_SBHD), + ("thd", tex.NVTE_QKV_Layout.NVTE_THD_THD_THD), + ]) + def test_fused_attn_fwd_shapes(self, device, layout_name, layout_enum): + """fused_attn_fwd should produce correct output shapes for each layout.""" + q, k, v, cu = self._make_qkv(device, layout_name) + result = tex.fused_attn_fwd( + self.S, self.S, True, 1.0 / 8.0, 0.0, True, + layout_enum, + tex.NVTE_Bias_Type.NVTE_NO_BIAS, + tex.NVTE_Mask_Type.NVTE_CAUSAL_MASK, + tex.NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, + (-1, 0), cu, cu, q, k, v, self.DTYPE, + ) + assert isinstance(result, list), "fused_attn_fwd should return a list" + assert len(result) >= 1, "Result must contain at least the output tensor" + assert result[0].shape == q.shape, ( + f"Output shape {result[0].shape} != Q shape {q.shape}" + ) + + @pytest.mark.parametrize("layout_name,layout_enum", [ + ("bshd", tex.NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD), + ("sbhd", tex.NVTE_QKV_Layout.NVTE_SBHD_SBHD_SBHD), + ("thd", tex.NVTE_QKV_Layout.NVTE_THD_THD_THD), + ]) + def test_fused_attn_bwd_shapes(self, device, layout_name, layout_enum): + """fused_attn_bwd should produce correct gradient shapes.""" + q, k, v, cu = self._make_qkv(device, layout_name) + result = tex.fused_attn_fwd( + self.S, self.S, True, 1.0 / 8.0, 0.0, True, + layout_enum, + tex.NVTE_Bias_Type.NVTE_NO_BIAS, + tex.NVTE_Mask_Type.NVTE_CAUSAL_MASK, + tex.NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, + (-1, 0), cu, cu, q, k, v, self.DTYPE, + ) + out = result[0] + aux = result[1:] + d_o = torch.randn_like(out) + bwd = tex.fused_attn_bwd( + self.S, self.S, 1.0 / 8.0, 0.0, True, + layout_enum, + tex.NVTE_Bias_Type.NVTE_NO_BIAS, + tex.NVTE_Mask_Type.NVTE_CAUSAL_MASK, + tex.NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, + (-1, 0), False, + cu, cu, q, k, v, out, d_o, self.DTYPE, None, aux, + ) + assert len(bwd) == 5, "fused_attn_bwd should return [dQ, dK, dV, dBias, dSoftmaxOffset]" + assert bwd[0].shape == q.shape, f"dQ shape {bwd[0].shape} != Q shape {q.shape}" + assert bwd[1].shape == k.shape, f"dK shape {bwd[1].shape} != K shape {k.shape}" + assert bwd[2].shape == v.shape, f"dV shape {bwd[2].shape} != V shape {v.shape}" + + def test_fused_attn_aux_ctx_tensors(self, device): + """Forward should return [out, softmax_lse, rng_state] for training.""" + q, k, v, cu = self._make_qkv(device, "bshd") + result = tex.fused_attn_fwd( + self.S, self.S, True, 1.0 / 8.0, 0.0, True, + tex.NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD, + tex.NVTE_Bias_Type.NVTE_NO_BIAS, + tex.NVTE_Mask_Type.NVTE_CAUSAL_MASK, + tex.NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, + (-1, 0), cu, cu, q, k, v, self.DTYPE, + ) + assert len(result) >= 3, ( + f"Training fwd should return [out, softmax_lse, rng_state], got {len(result)} tensors" + ) + softmax_lse = result[1] + rng_state = result[2] + assert softmax_lse.dtype == torch.float32, "softmax_lse should be float32" + assert rng_state.shape[0] == 2, "rng_state should have 2 elements [seed, offset]" + + @pytest.mark.parametrize("mask_type", [ + tex.NVTE_Mask_Type.NVTE_NO_MASK, + tex.NVTE_Mask_Type.NVTE_CAUSAL_MASK, + tex.NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK, + ]) + def test_fused_attn_mask_types(self, device, mask_type): + """fused_attn_fwd should handle different mask types.""" + q, k, v, cu = self._make_qkv(device, "bshd") + result = tex.fused_attn_fwd( + self.S, self.S, False, 1.0 / 8.0, 0.0, True, + tex.NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD, + tex.NVTE_Bias_Type.NVTE_NO_BIAS, + mask_type, + tex.NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, + (-1, -1), cu, cu, q, k, v, self.DTYPE, + ) + assert result[0].shape == q.shape + + # -- GQA (grouped query attention) ------------------------------------ + + def test_fused_attn_gqa(self, device): + """Attention should work with fewer KV heads than Q heads (GQA).""" + q, k, v, cu = self._make_qkv(device, "bshd", num_kv_heads=4) + result = tex.fused_attn_fwd( + self.S, self.S, False, 1.0 / 8.0, 0.0, True, + tex.NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD, + tex.NVTE_Bias_Type.NVTE_NO_BIAS, + tex.NVTE_Mask_Type.NVTE_CAUSAL_MASK, + tex.NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, + (-1, 0), cu, cu, q, k, v, self.DTYPE, + ) + assert result[0].shape == q.shape + + # -- Variable-length sequences (thd) ---------------------------------- + + def test_fused_attn_varlen(self, device): + """Attention should handle variable-length sequences in THD format.""" + H, D = self.H, self.D + # 3 sequences of length 20, 30, 14 + cu = torch.tensor([0, 20, 50, 64], device=device, dtype=torch.int32) + total = 64 + q = torch.randn(total, H, D, device=device, dtype=self.DTYPE) + k = torch.randn(total, H, D, device=device, dtype=self.DTYPE) + v = torch.randn(total, H, D, device=device, dtype=self.DTYPE) + result = tex.fused_attn_fwd( + 30, 30, True, 1.0 / 8.0, 0.0, True, + tex.NVTE_QKV_Layout.NVTE_THD_THD_THD, + tex.NVTE_Bias_Type.NVTE_NO_BIAS, + tex.NVTE_Mask_Type.NVTE_CAUSAL_MASK, + tex.NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, + (-1, 0), cu, cu, q, k, v, self.DTYPE, + ) + assert result[0].shape == (total, H, D) + + # -- Numerical: AITER vs SDPA ----------------------------------------- + + def test_aiter_vs_sdpa_numerical(self, device): + """AITER CK and SDPA fallback should produce similar results.""" + from transformer_engine.pytorch._lite import attention as attn_mod + + torch.manual_seed(42) + q, k, v, cu = self._make_qkv(device, "bshd") + layout = tex.NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD + mask = tex.NVTE_Mask_Type.NVTE_CAUSAL_MASK + args = ( + self.S, self.S, False, 1.0 / 8.0, 0.0, True, + layout, tex.NVTE_Bias_Type.NVTE_NO_BIAS, mask, + tex.NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, + (-1, 0), cu, cu, q, k, v, self.DTYPE, + ) + + # AITER path + result_aiter = attn_mod.fused_attn_fwd(*args) + out_aiter = result_aiter[0] + + # SDPA path (force by temporarily disabling AITER) + saved_fwd = attn_mod._aiter_fwd + saved_varlen = attn_mod._aiter_varlen_fwd + attn_mod._aiter_fwd = None + attn_mod._aiter_varlen_fwd = None + try: + result_sdpa = attn_mod.fused_attn_fwd(*args) + out_sdpa = result_sdpa[0] + finally: + attn_mod._aiter_fwd = saved_fwd + attn_mod._aiter_varlen_fwd = saved_varlen + + max_diff = (out_aiter - out_sdpa).abs().max().item() + assert max_diff < 1e-2, ( + f"AITER vs SDPA max diff {max_diff:.4e} exceeds tolerance" + ) + + # -- Helper functions ------------------------------------------------- + + def test_fa_prepare_fwd(self, device): + """fa_prepare_fwd: [s, b, n, 3*h] -> [3, b, s, n, h].""" + qkvi = torch.randn(32, 2, 16, 192, device=device, dtype=self.DTYPE) + out = tex.fa_prepare_fwd(qkvi) + assert out.shape == (3, 2, 32, 16, 64) + + def test_fa_prepare_bwd(self, device): + """fa_prepare_bwd: 3 x [s, b, n, h] -> [b, s, n, 3*h].""" + q = torch.randn(32, 2, 16, 64, device=device, dtype=self.DTYPE) + k = torch.randn(32, 2, 16, 64, device=device, dtype=self.DTYPE) + v = torch.randn(32, 2, 16, 64, device=device, dtype=self.DTYPE) + out = tex.fa_prepare_bwd(q, k, v) + assert out.shape == (2, 32, 16, 192) + + def test_convert_thd_bshd_roundtrip(self, device): + """THD -> BSHD -> THD should preserve data.""" + cu = torch.tensor([0, 10, 25, 32], device=device, dtype=torch.int32) + thd = torch.randn(32, 16, 64, device=device, dtype=self.DTYPE) + bshd = tex.convert_thd_to_bshd(thd, cu, 3, 15) + assert bshd.shape == (3, 15, 16, 64) + thd2 = tex.convert_bshd_to_thd(bshd, cu, 32) + assert thd2.shape == (32, 16, 64) + # Data for each sequence should survive the roundtrip + assert torch.allclose(thd[:10], thd2[:10]) + assert torch.allclose(thd[10:25], thd2[10:25]) + assert torch.allclose(thd[25:32], thd2[25:32]) + + # -- DotProductAttention module (end-to-end) -------------------------- + + def test_dot_product_attention_fwd(self, device): + """DotProductAttention forward should work in lite mode.""" + dpa = te.DotProductAttention(16, 64, 16, attn_mask_type="causal").to( + dtype=self.DTYPE, device=device, + ) + q = torch.randn(self.B, self.S, self.H, self.D, device=device, dtype=self.DTYPE) + k = torch.randn(self.B, self.S, self.H, self.D, device=device, dtype=self.DTYPE) + v = torch.randn(self.B, self.S, self.H, self.D, device=device, dtype=self.DTYPE) + with torch.amp.autocast("cuda", dtype=self.DTYPE): + out = dpa(q, k, v) + # DotProductAttention returns (B, S, H*D) after head projection + assert out.shape == (self.B, self.S, self.H * self.D) + + def test_multihead_attention_fwd(self, device): + """MultiheadAttention forward should work in lite mode.""" + hidden = self.H * self.D # 1024 + mha = te.MultiheadAttention(hidden, self.H, attn_mask_type="causal").to( + dtype=self.DTYPE, device=device, + ) + x = torch.randn(self.B, self.S, hidden, device=device, dtype=self.DTYPE) + with torch.amp.autocast("cuda", dtype=self.DTYPE): + out = mha(x) + assert out.shape == x.shape + + +# --------------------------------------------------------------------------- +# MoE router tests +# --------------------------------------------------------------------------- + +class TestMoERouter: + """Test MoE router operations in lite mode.""" + + SEED = 42 + + def _seed(self): + torch.manual_seed(self.SEED) + torch.cuda.manual_seed(self.SEED) + + # -- fused_topk_with_score_function forward ---------------------------- + + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + @pytest.mark.parametrize("topk", [1, 2]) + def test_topk_fwd_shapes(self, device, score_function, topk): + """Forward returns (probs, routing_map, intermediate_output) with correct shapes.""" + self._seed() + N, E = 32, 8 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + expert_bias = torch.randn(E, device=device) if score_function == "sigmoid" else None + + probs, routing_map, intermediate = tex.fused_topk_with_score_function_fwd( + logits, topk, False, None, None, 1.0, score_function, expert_bias, + ) + assert probs.shape == (N, E) + assert routing_map.shape == (N, E) + assert intermediate.shape == (N, E) + assert routing_map.dtype == torch.bool + # Exactly topk experts selected per token + assert (routing_map.sum(dim=-1) == topk).all() + + def test_softmax_pre_softmax(self, device): + """Pre-softmax mode: softmax applied before topk.""" + self._seed() + N, E, topk = 16, 8, 2 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + probs, routing_map, intermediate = tex.fused_topk_with_score_function_fwd( + logits, topk, True, None, None, 1.0, "softmax", None, + ) + # intermediate should be softmax output (sums to 1 per row) + torch.testing.assert_close( + intermediate.sum(dim=-1), + torch.ones(N, device=device), + atol=1e-5, rtol=1e-5, + ) + + def test_softmax_post_softmax(self, device): + """Post-softmax mode: softmax applied after topk over selected experts.""" + self._seed() + N, E, topk = 16, 8, 2 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + probs, routing_map, _ = tex.fused_topk_with_score_function_fwd( + logits, topk, False, None, None, 1.0, "softmax", None, + ) + # Selected probs should sum to 1 per token (post-softmax over topk) + selected_probs = probs[routing_map].reshape(N, topk) + torch.testing.assert_close( + selected_probs.sum(dim=-1), + torch.ones(N, device=device), + atol=1e-5, rtol=1e-5, + ) + + def test_sigmoid_values(self, device): + """Sigmoid scores are in (0, 1) range.""" + self._seed() + N, E, topk = 16, 8, 1 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + _, _, intermediate = tex.fused_topk_with_score_function_fwd( + logits, topk, False, None, None, 1.0, "sigmoid", None, + ) + # intermediate holds sigmoid output + ref_sigmoid = torch.sigmoid(logits) + torch.testing.assert_close(intermediate, ref_sigmoid, atol=1e-6, rtol=1e-6) + + def test_sigmoid_expert_bias(self, device): + """Expert bias affects expert selection but is removed from final scores.""" + self._seed() + N, E, topk = 32, 8, 2 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + bias = torch.zeros(E, device=device) + bias[0] = 100.0 # strongly bias towards expert 0 + + probs, routing_map, _ = tex.fused_topk_with_score_function_fwd( + logits, topk, False, None, None, 1.0, "sigmoid", bias, + ) + # Expert 0 should be selected for (almost) every token + assert routing_map[:, 0].sum() >= N * 0.9 + + def test_sigmoid_normalization_topk_gt1(self, device): + """With topk > 1, sigmoid scores are normalized to sum to ~1.""" + self._seed() + N, E, topk = 16, 8, 3 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + probs, routing_map, _ = tex.fused_topk_with_score_function_fwd( + logits, topk, False, None, None, 1.0, "sigmoid", None, + ) + selected = probs[routing_map].reshape(N, topk) + # Should be approximately normalized (sum ≈ 1) + torch.testing.assert_close( + selected.sum(dim=-1), + torch.ones(N, device=device), + atol=1e-5, rtol=1e-5, + ) + + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + def test_group_topk(self, device, score_function): + """Group top-k selects experts only from winning groups.""" + self._seed() + N, E, topk = 32, 16, 4 + num_groups, group_topk = 4, 2 # 4 groups of 4 experts, pick 2 groups + logits = torch.randn(N, E, device=device, dtype=torch.float32) + expert_bias = torch.randn(E, device=device) if score_function == "sigmoid" else None + + probs, routing_map, intermediate = tex.fused_topk_with_score_function_fwd( + logits, topk, False, num_groups, group_topk, 1.0, + score_function, expert_bias, + ) + assert probs.shape == (N, E) + assert routing_map.dtype == torch.bool + # Exactly topk experts selected per token + assert (routing_map.sum(dim=-1) == topk).all() + + # Selected experts should come from at most group_topk groups + group_size = E // num_groups + for i in range(N): + selected_experts = routing_map[i].nonzero(as_tuple=True)[0] + groups_used = set((idx.item() // group_size) for idx in selected_experts) + assert len(groups_used) <= group_topk, ( + f"Token {i}: experts from {len(groups_used)} groups, expected <= {group_topk}" + ) + + def test_scaling_factor(self, device): + """Scaling factor multiplies the output probs.""" + self._seed() + N, E, topk = 16, 8, 2 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + scale = 2.5 + probs_s1, _, _ = tex.fused_topk_with_score_function_fwd( + logits, topk, True, None, None, 1.0, "softmax", None, + ) + probs_s2, _, _ = tex.fused_topk_with_score_function_fwd( + logits, topk, True, None, None, scale, "softmax", None, + ) + torch.testing.assert_close(probs_s2, probs_s1 * scale, atol=1e-5, rtol=1e-5) + + # -- fused_topk_with_score_function backward --------------------------- + + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + @pytest.mark.parametrize("topk", [1, 2]) + def test_topk_bwd_shapes(self, device, score_function, topk): + """Backward returns grad_logits with correct shape.""" + self._seed() + N, E = 32, 8 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + expert_bias = torch.randn(E, device=device) if score_function == "sigmoid" else None + + probs, routing_map, intermediate = tex.fused_topk_with_score_function_fwd( + logits, topk, False, None, None, 1.0, score_function, expert_bias, + ) + grad_probs = torch.randn_like(probs) + grad_logits = tex.fused_topk_with_score_function_bwd( + N, E, routing_map, intermediate, grad_probs, topk, + False, 1.0, score_function, + ) + assert grad_logits.shape == (N, E) + + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + def test_topk_bwd_unselected_zero(self, device, score_function): + """Gradients at unselected expert positions should be zero (softmax post) or + propagated through score function (softmax pre / sigmoid).""" + self._seed() + N, E, topk = 16, 8, 2 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + + probs, routing_map, intermediate = tex.fused_topk_with_score_function_fwd( + logits, topk, False, None, None, 1.0, score_function, None, + ) + grad_probs = torch.randn_like(probs) + grad_logits = tex.fused_topk_with_score_function_bwd( + N, E, routing_map, intermediate, grad_probs, topk, + False, 1.0, score_function, + ) + # grad_logits should not be all zeros + assert not torch.all(grad_logits == 0) + + # -- fused_score_for_moe_aux_loss -------------------------------------- + + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + def test_aux_loss_score_fwd(self, device, score_function): + """fused_score_for_moe_aux_loss_fwd returns 3 tensors.""" + self._seed() + N, E, topk = 32, 8, 2 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + + scores, routing_map, intermediate = tex.fused_score_for_moe_aux_loss_fwd( + logits, topk, score_function, + ) + assert scores.shape == (N, E) + assert routing_map.shape == (N, E) + assert intermediate.shape == (N, E) + + if score_function == "sigmoid": + ref = torch.sigmoid(logits) + else: + ref = F.softmax(logits, dim=-1) + torch.testing.assert_close(scores, ref, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + def test_aux_loss_score_bwd(self, device, score_function): + """fused_score_for_moe_aux_loss_bwd returns correct gradient shape.""" + self._seed() + N, E, topk = 32, 8, 2 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + scores, _, intermediate = tex.fused_score_for_moe_aux_loss_fwd( + logits, topk, score_function, + ) + grad_scores = torch.randn_like(scores) + grad_logits = tex.fused_score_for_moe_aux_loss_bwd( + N, E, intermediate, grad_scores, topk, score_function, + ) + assert grad_logits.shape == (N, E) + + # -- Aux loss fwd/bwd return + autograd --------------------------------- + + def test_aux_loss_fwd_returns_tuple(self, device): + """fused_moe_aux_loss_fwd returns (loss, Const_buf) matching C++ interface.""" + self._seed() + N, E, topk, coeff = 32, 8, 2, 0.01 + probs = torch.rand(N, E, device=device, dtype=torch.float32) + tpe = torch.randint(1, N, (E,), device=device, dtype=torch.int32) + total = int(tpe.sum().item()) + + result = tex.fused_moe_aux_loss_fwd( + probs, tpe, total, E, N, E, topk, coeff, + ) + assert isinstance(result, tuple), f"Expected tuple, got {type(result)}" + loss, const_buf = result + assert loss.shape == (), f"loss should be scalar, got {loss.shape}" + assert const_buf.shape == (), f"Const_buf should be scalar, got {const_buf.shape}" + assert const_buf.dtype == torch.float32 + + # Verify Const_buf value + expected_c = (E * coeff) / topk / (total * total) + torch.testing.assert_close( + const_buf, torch.tensor(expected_c, dtype=torch.float32, device=device), + atol=1e-6, rtol=1e-6, + ) + + def test_aux_loss_bwd_uses_const_buf(self, device): + """fused_moe_aux_loss_bwd produces correct gradient using Const_buf.""" + self._seed() + N, E, topk, coeff = 16, 8, 2, 0.05 + probs = torch.rand(N, E, device=device, dtype=torch.float32) + tpe = torch.randint(1, N, (E,), device=device, dtype=torch.int32) + total = int(tpe.sum().item()) + + loss, const_buf = tex.fused_moe_aux_loss_fwd( + probs, tpe, total, E, N, E, topk, coeff, + ) + grad_aux = torch.tensor(1.0, device=device, dtype=torch.float32) + grad_probs = tex.fused_moe_aux_loss_bwd(const_buf, tpe, N, E, grad_aux) + + assert grad_probs.shape == (N, E) + # grad_probs[j, i] = C_coeff * tokens_per_expert[i] * grad_aux_loss + for i in range(E): + expected = const_buf.item() * tpe[i].item() * grad_aux.item() + torch.testing.assert_close( + grad_probs[:, i], + torch.full((N,), expected, device=device), + atol=1e-5, rtol=1e-5, + ) + + def test_autograd_aux_loss(self, device): + """High-level fused_moe_aux_loss propagates gradients end-to-end.""" + from transformer_engine.pytorch.router import fused_moe_aux_loss + self._seed() + N, E, topk, coeff = 16, 8, 2, 0.01 + probs = torch.rand(N, E, device=device, dtype=torch.float32, requires_grad=True) + tpe = torch.randint(1, N, (E,), device=device, dtype=torch.int32) + total = int(tpe.sum().item()) + + loss = fused_moe_aux_loss(probs, tpe, total, E, topk, coeff) + loss.backward() + assert probs.grad is not None + assert probs.grad.shape == (N, E) + assert not torch.all(probs.grad == 0) + + # -- High-level autograd integration ----------------------------------- + + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + def test_autograd_topk(self, device, score_function): + """High-level fused_topk_with_score_function propagates gradients.""" + from transformer_engine.pytorch.router import fused_topk_with_score_function + self._seed() + N, E, topk = 16, 8, 2 + logits = torch.randn(N, E, device=device, dtype=torch.float32, requires_grad=True) + expert_bias = torch.randn(E, device=device) if score_function == "sigmoid" else None + + probs, routing_map = fused_topk_with_score_function( + logits, topk, False, None, None, 1.0, score_function, expert_bias, + ) + probs.sum().backward() + assert logits.grad is not None + assert logits.grad.shape == (N, E) + assert not torch.all(logits.grad == 0) + + # -- Numerical gradient verification ------------------------------------ + + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + @pytest.mark.parametrize("topk", [1, 2, 3]) + def test_gradient_vs_finite_diff(self, device, score_function, topk): + """Verify backward matches finite-difference approximation. + + Uses random grad_out (uniform grad is degenerate for normalization) + and tolerances appropriate for float32 finite differences. + """ + self._seed() + N, E = 4, 8 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + eps = 1e-4 # larger eps for float32 stability + + # Forward + backward + probs, rmap, inter = tex.fused_topk_with_score_function_fwd( + logits, topk, False, None, None, 1.0, score_function, None, + ) + grad_out = torch.randn_like(probs) + grad_ana = tex.fused_topk_with_score_function_bwd( + N, E, rmap, inter, grad_out, topk, False, 1.0, score_function, + ) + + # Finite-difference per element, only where routing is stable + max_err = 0.0 + n_checked = 0 + for i in range(N): + for j in range(E): + logits_p = logits.clone() + logits_p[i, j] += eps + probs_p, rmap_p, _ = tex.fused_topk_with_score_function_fwd( + logits_p, topk, False, None, None, 1.0, score_function, None, + ) + logits_m = logits.clone() + logits_m[i, j] -= eps + probs_m, rmap_m, _ = tex.fused_topk_with_score_function_fwd( + logits_m, topk, False, None, None, 1.0, score_function, None, + ) + if not (torch.equal(rmap_p, rmap) and torch.equal(rmap_m, rmap)): + continue # skip topk discontinuity + fd = ((probs_p - probs_m) * grad_out).sum() / (2 * eps) + err = abs(grad_ana[i, j].item() - fd.item()) + max_err = max(max_err, err) + n_checked += 1 + + assert n_checked > 0, "No stable routing points found for finite-diff check" + assert max_err < 0.01, ( + f"Max gradient error {max_err:.4e} for {score_function} topk={topk} " + f"({n_checked} points checked)" + ) + + def test_sigmoid_topk1_no_normalization(self, device): + """Sigmoid with topk=1 skips normalization — score is raw sigmoid * scale.""" + self._seed() + N, E = 16, 8 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + probs, rmap, inter = tex.fused_topk_with_score_function_fwd( + logits, 1, False, None, None, 1.0, "sigmoid", None, + ) + # For each token, the selected prob should equal sigmoid(logit) + selected_probs = probs[rmap] + selected_sigmoid = inter[rmap] + torch.testing.assert_close(selected_probs, selected_sigmoid, atol=1e-6, rtol=1e-6) + + def test_pre_softmax_backward(self, device): + """Pre-softmax backward: gradient flows through all experts via softmax Jacobian.""" + from transformer_engine.pytorch.router import fused_topk_with_score_function + self._seed() + N, E, topk = 8, 8, 2 + logits = torch.randn(N, E, device=device, dtype=torch.float32, requires_grad=True) + + probs, _ = fused_topk_with_score_function( + logits, topk, True, None, None, 1.0, "softmax", None, + ) + probs.sum().backward() + # Pre-softmax: gradient should be non-zero at ALL expert positions + # (softmax couples all inputs), not just selected ones + assert logits.grad is not None + non_zero_per_token = (logits.grad.abs() > 1e-7).sum(dim=-1) + assert (non_zero_per_token > topk).all(), ( + "Pre-softmax backward should produce non-zero gradients beyond selected experts" + ) + + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + def test_triton_vs_pytorch_fallback(self, device, score_function): + """Triton-fused path matches the PyTorch-native fallback.""" + from transformer_engine.pytorch._lite.router import ( + _fused_topk_fwd_pytorch, _fused_topk_bwd_pytorch, + ) + from transformer_engine.pytorch.triton.fused_router import ( + fused_topk_with_score_function_fwd as triton_fwd, + fused_topk_with_score_function_bwd as triton_bwd, + ) + self._seed() + N, E, topk = 32, 8, 2 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + expert_bias = torch.randn(E, device=device) if score_function == "sigmoid" else None + + # Forward + p_tri, r_tri, i_tri = triton_fwd( + logits, topk, False, 1.0, score_function, expert_bias, + ) + p_pt, r_pt, i_pt = _fused_topk_fwd_pytorch( + logits, topk, False, None, None, 1.0, score_function, expert_bias, + ) + torch.testing.assert_close(r_tri, r_pt, msg="routing_map mismatch") + torch.testing.assert_close(p_tri, p_pt, atol=1e-5, rtol=1e-5, msg="probs fwd mismatch") + torch.testing.assert_close(i_tri, i_pt, atol=1e-5, rtol=1e-5, msg="intermediate mismatch") + + # Backward + grad_out = torch.randn_like(p_tri) + g_tri = triton_bwd(N, E, r_tri, i_tri, grad_out, topk, False, 1.0, score_function) + g_pt = _fused_topk_bwd_pytorch( + N, E, r_pt, i_pt, grad_out, topk, False, 1.0, score_function, + ) + torch.testing.assert_close(g_tri, g_pt, atol=1e-5, rtol=1e-5, msg="grad_logits mismatch") + + def test_sigmoid_scaling_factor(self, device): + """Scaling factor works correctly with sigmoid scoring.""" + self._seed() + N, E, topk = 16, 8, 2 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + scale = 3.0 + p1, r1, _ = tex.fused_topk_with_score_function_fwd( + logits, topk, False, None, None, 1.0, "sigmoid", None, + ) + p2, r2, _ = tex.fused_topk_with_score_function_fwd( + logits, topk, False, None, None, scale, "sigmoid", None, + ) + # Same routing + torch.testing.assert_close(r1, r2) + # Probs scaled + torch.testing.assert_close(p2, p1 * scale, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + def test_autograd_group_topk(self, device, score_function): + """Autograd works end-to-end with group top-k.""" + from transformer_engine.pytorch.router import fused_topk_with_score_function + self._seed() + N, E, topk = 16, 16, 4 + num_groups, group_topk = 4, 2 + logits = torch.randn(N, E, device=device, dtype=torch.float32, requires_grad=True) + expert_bias = torch.randn(E, device=device) if score_function == "sigmoid" else None + + probs, routing_map = fused_topk_with_score_function( + logits, topk, False, num_groups, group_topk, 1.0, + score_function, expert_bias, + ) + probs.sum().backward() + assert logits.grad is not None + assert not torch.all(logits.grad == 0) + + @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + def test_triton_vs_pytorch_group_topk(self, device, score_function): + """Triton group top-k matches PyTorch fallback.""" + from transformer_engine.pytorch._lite.router import _fused_topk_fwd_pytorch + from transformer_engine.pytorch.triton.fused_router import ( + fused_topk_with_score_function_fwd as triton_fwd, + ) + self._seed() + N, E, topk = 32, 16, 4 + num_groups, group_topk = 4, 2 + logits = torch.randn(N, E, device=device, dtype=torch.float32) + expert_bias = torch.randn(E, device=device) if score_function == "sigmoid" else None + + p_tri, r_tri, i_tri = triton_fwd( + logits, topk, False, 1.0, score_function, expert_bias, + num_groups, group_topk, + ) + p_pt, r_pt, i_pt = _fused_topk_fwd_pytorch( + logits, topk, False, num_groups, group_topk, 1.0, + score_function, expert_bias, + ) + torch.testing.assert_close(r_tri, r_pt, msg="group topk routing_map mismatch") + torch.testing.assert_close( + p_tri, p_pt, atol=1e-5, rtol=1e-5, msg="group topk probs mismatch", + ) + + +# --------------------------------------------------------------------------- +# MoE permutation tests +# --------------------------------------------------------------------------- + +def _pytorch_permute_index_map(tokens, indices, num_out_tokens=None): + """Reference implementation for index-map permutation.""" + topk = indices.size(1) if indices.dim() > 1 else 1 + flat = indices.view(-1) + sorted_indices = torch.argsort(flat, stable=True) + n_out = num_out_tokens if num_out_tokens is not None else flat.size(0) + return tokens.index_select(0, sorted_indices[:n_out] // topk), sorted_indices + + +def _pytorch_unpermute_index_map(permuted, sorted_indices, probs=None): + """Reference implementation for index-map unpermutation.""" + if probs is not None: + n_unp = probs.numel() + topk = probs.size(1) + else: + n_unp = sorted_indices.size(0) + topk = 1 + out = torch.zeros(n_unp, permuted.shape[-1], + dtype=permuted.dtype, device=permuted.device) + out.index_copy_(0, sorted_indices[:permuted.size(0)], permuted) + out = out.reshape(-1, topk, permuted.size(-1)) + if probs is not None: + out = out * probs.unsqueeze(-1) + return out.sum(dim=1) + + +class TestMoEPermutation: + """Test MoE permutation operations in lite mode.""" + + DTYPE = torch.bfloat16 + SEED = 1234 + + def _seed(self): + torch.manual_seed(self.SEED) + torch.cuda.manual_seed(self.SEED) + + # -- Low-level tex interface tests ----------------------------------- + + def test_permute_fwd_shapes(self, device): + """moe_permute_fwd returns correct shapes.""" + self._seed() + N, H, topK, E = 32, 64, 2, 8 + inp = torch.randn(N, H, device=device, dtype=self.DTYPE) + indices = torch.stack( + [torch.randperm(E, device=device)[:topK] for _ in range(N)] + ).to(torch.int32) + + permuted, row_id_map, ws = tex.moe_permute_fwd( + inp, tex.DType.kBFloat16, indices, -1, [], N * topK, + ) + assert permuted.shape == (N * topK, H) + assert row_id_map.shape == (N * topK,) + assert row_id_map.dtype == torch.int32 + + def test_permute_fwd_with_num_out_tokens(self, device): + """moe_permute_fwd respects num_out_tokens truncation.""" + self._seed() + N, H, topK, E = 32, 64, 2, 8 + num_out = 48 + inp = torch.randn(N, H, device=device, dtype=self.DTYPE) + indices = torch.stack( + [torch.randperm(E, device=device)[:topK] for _ in range(N)] + ).to(torch.int32) + + permuted, row_id_map, _ = tex.moe_permute_fwd( + inp, tex.DType.kBFloat16, indices, num_out, [], N * topK, + ) + assert permuted.shape == (num_out, H) + + def test_roundtrip_identity(self, device): + """Permute then unpermute with uniform probs recovers the input.""" + self._seed() + N, H, topK, E = 64, 128, 2, 8 + inp = torch.randn(N, H, device=device, dtype=self.DTYPE) + indices = torch.stack( + [torch.randperm(E, device=device)[:topK] for _ in range(N)] + ).to(torch.int32) + + permuted, row_id_map, _ = tex.moe_permute_fwd( + inp, tex.DType.kBFloat16, indices, -1, [], N * topK, + ) + probs = torch.ones(N, topK, device=device, dtype=torch.float32) + unpermuted = tex.moe_unpermute_fwd( + permuted, tex.DType.kBFloat16, row_id_map, probs, N, topK, + ) + # Each token is gathered topK times and summed with prob=1.0 + torch.testing.assert_close(unpermuted, inp.float() * topK, atol=1e-5, rtol=1e-5) + + def test_permute_bwd_equals_unpermute_fwd(self, device): + """moe_permute_bwd delegates to moe_unpermute_fwd.""" + self._seed() + N, H, topK, E = 32, 64, 2, 8 + inp = torch.randn(N, H, device=device, dtype=self.DTYPE) + indices = torch.stack( + [torch.randperm(E, device=device)[:topK] for _ in range(N)] + ).to(torch.int32) + + permuted, row_id_map, _ = tex.moe_permute_fwd( + inp, tex.DType.kBFloat16, indices, -1, [], N * topK, + ) + probs = torch.rand(N, topK, device=device, dtype=torch.float32) + + via_bwd = tex.moe_permute_bwd( + permuted, tex.DType.kBFloat16, row_id_map, probs, N, topK, + ) + via_fwd = tex.moe_unpermute_fwd( + permuted, tex.DType.kBFloat16, row_id_map, probs, N, topK, + ) + torch.testing.assert_close(via_bwd, via_fwd) + + def test_unpermute_bwd_shapes(self, device): + """moe_unpermute_bwd returns (act_grad, prob_grad) with correct shapes.""" + self._seed() + N, H, topK, E = 32, 64, 2, 8 + inp = torch.randn(N, H, device=device, dtype=self.DTYPE) + indices = torch.stack( + [torch.randperm(E, device=device)[:topK] for _ in range(N)] + ).to(torch.int32) + + permuted, row_id_map, _ = tex.moe_permute_fwd( + inp, tex.DType.kBFloat16, indices, -1, [], N * topK, + ) + probs = torch.rand(N, topK, device=device, dtype=torch.float32) + grad_out = torch.randn(N, H, device=device, dtype=self.DTYPE) + + act_grad, prob_grad = tex.moe_unpermute_bwd( + grad_out, permuted, tex.DType.kBFloat16, row_id_map, probs, + ) + assert act_grad.shape == (N * topK, H) + assert prob_grad.shape == (N, topK) + assert prob_grad.dtype == torch.float32 + + def test_unpermute_bwd_no_probs(self, device): + """moe_unpermute_bwd works without probs.""" + self._seed() + N, H, topK, E = 32, 64, 1, 8 + inp = torch.randn(N, H, device=device, dtype=self.DTYPE) + indices = torch.stack( + [torch.randperm(E, device=device)[:topK] for _ in range(N)] + ).to(torch.int32) + + permuted, row_id_map, _ = tex.moe_permute_fwd( + inp, tex.DType.kBFloat16, indices, -1, [], N * topK, + ) + grad_out = torch.randn(N, H, device=device, dtype=self.DTYPE) + empty_prob = torch.empty(0, device=device) + + act_grad, prob_grad = tex.moe_unpermute_bwd( + grad_out, permuted, tex.DType.kBFloat16, row_id_map, empty_prob, + ) + assert act_grad.shape == (N * topK, H) + assert prob_grad.numel() == 0 + + # -- High-level API: forward / backward numerical tests --------------- + + @pytest.mark.parametrize("topK", [1, 2]) + @pytest.mark.parametrize("with_probs", [True, False]) + def test_index_map_vs_reference(self, device, topK, with_probs): + """High-level moe_permute/unpermute matches PyTorch reference (index map).""" + if not with_probs and topK > 1: + pytest.skip("topK>1 without probs not supported for index-map") + self._seed() + N, H, E = 64, 128, 8 + from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute + + inp = torch.randn(N, H, device=device, dtype=self.DTYPE, requires_grad=True) + indices = torch.stack( + [torch.randperm(E, device=device)[:topK] for _ in range(N)] + ).to(torch.int32) + + # Reference + ref_perm, ref_sorted = _pytorch_permute_index_map(inp.detach(), indices) + probs = None + if with_probs: + probs = torch.rand(N, topK, device=device).softmax(dim=-1) + + ref_unperm = _pytorch_unpermute_index_map(ref_perm, ref_sorted, probs) + + # TE lite + te_inp = inp.detach().clone().requires_grad_(True) + te_perm, row_id_map = moe_permute(te_inp, indices, map_type="index") + + te_probs = probs.clone().requires_grad_(True) if probs is not None else None + te_unperm = moe_unpermute( + te_perm.detach().clone().requires_grad_(True), + row_id_map, te_probs, map_type="index", + ) + + # Forward check + torch.testing.assert_close( + ref_perm.float(), te_perm.float(), + msg="permute fwd mismatch", + ) + tols = dict(rtol=2.5e-2, atol=1e-5) + torch.testing.assert_close( + ref_unperm.float(), te_unperm.float(), + msg="unpermute fwd mismatch", **tols, + ) + + # Backward check + grad = torch.randn(N, H, device=device, dtype=self.DTYPE) + te_unperm.backward(grad) + + @pytest.mark.parametrize("topK", [1, 2, 3]) + def test_index_map_empty_input(self, device, topK): + """Empty tensor should pass through without error.""" + from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute + inp = torch.empty(0, 64, device=device, dtype=self.DTYPE) + indices = torch.empty(0, topK, device=device, dtype=torch.int32) + perm, rid = moe_permute(inp, indices, map_type="index") + assert perm.numel() == 0 + + # -- Triton kernel integration ---------------------------------------- + + def test_triton_sort_used_in_permute(self, device): + """Verify the Triton sort_chunks_by_map kernel is loaded for permute.""" + from transformer_engine.pytorch._lite.permutation import _try_load_triton_sort + fn = _try_load_triton_sort() + assert fn is not None, "Triton sort_chunks_by_map should be loadable" + + def test_triton_gather_matches_pytorch(self, device): + """Triton sort_chunks_by_map (gather mode) matches PyTorch indexing.""" + from transformer_engine.pytorch.triton.permutation import sort_chunks_by_map + self._seed() + N, H = 128, 256 + inp = torch.randn(N, H, device=device, dtype=self.DTYPE) + ids = torch.randperm(N, device=device, dtype=torch.int32) + triton_out, _ = sort_chunks_by_map(inp, ids, None, N, H, is_forward=False) + pytorch_out = inp[ids.long()] + torch.testing.assert_close(triton_out, pytorch_out) + + # -- Mask-map path tests ------------------------------------------------ + + @staticmethod + def _make_routing_map(N, E, topK, device): + """Create a mask-format routing map: [N, E] int32 with topK ones per row.""" + routing_map = torch.zeros(N, E, dtype=torch.int32, device=device) + for i in range(N): + experts = torch.randperm(E, device=device)[:topK] + routing_map[i, experts] = 1 + return routing_map + + @pytest.mark.parametrize("topK", [1, 2, 3]) + def test_mask_map_permute_shapes(self, device, topK): + """moe_permute with mask map returns correct shapes.""" + from transformer_engine.pytorch.permutation import moe_permute + self._seed() + N, H, E = 64, 128, 8 + inp = torch.randn(N, H, device=device, dtype=self.DTYPE) + routing_map = self._make_routing_map(N, E, topK, device) + num_out = int(routing_map.sum().item()) + + perm, row_id_map = moe_permute(inp, routing_map, num_out_tokens=num_out, map_type="mask") + assert perm.shape == (num_out, H) + assert row_id_map.shape[0] == N + + @pytest.mark.parametrize("topK", [1, 2]) + def test_mask_map_roundtrip(self, device, topK): + """Mask-map permute then unpermute with merging_probs recovers weighted sum.""" + from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute + self._seed() + N, H, E = 32, 64, 8 + inp = torch.randn(N, H, device=device, dtype=self.DTYPE) + routing_map = self._make_routing_map(N, E, topK, device) + num_out = int(routing_map.sum().item()) + + perm, row_id_map = moe_permute(inp, routing_map, num_out_tokens=num_out, map_type="mask") + + # Create merging probs matching the routing map + probs_full = torch.rand(N, E, device=device, dtype=torch.float32) * routing_map.float() + # Normalize per token + probs_full = probs_full / probs_full.sum(dim=-1, keepdim=True).clamp(min=1e-6) + + unperm = moe_unpermute( + perm, row_id_map, merging_probs=probs_full, + restore_shape=torch.Size([N, H]), map_type="mask", + ) + assert unperm.shape == (N, H) + + @pytest.mark.parametrize("topK", [1, 2]) + def test_mask_map_backward(self, device, topK): + """Mask-map path propagates gradients through permute and unpermute.""" + from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute + self._seed() + N, H, E = 32, 64, 8 + inp = torch.randn(N, H, device=device, dtype=torch.float32, requires_grad=True) + routing_map = self._make_routing_map(N, E, topK, device) + num_out = int(routing_map.sum().item()) + + perm, row_id_map = moe_permute(inp, routing_map, num_out_tokens=num_out, map_type="mask") + assert perm.requires_grad + + # Backward through permute + perm.sum().backward() + assert inp.grad is not None + assert inp.grad.shape == inp.shape + + @pytest.mark.parametrize("topK", [1, 2]) + def test_mask_map_unpermute_backward_with_probs(self, device, topK): + """Unpermute backward propagates gradients to both act and probs.""" + from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute + self._seed() + N, H, E = 32, 64, 8 + inp = torch.randn(N, H, device=device, dtype=torch.float32) + routing_map = self._make_routing_map(N, E, topK, device) + num_out = int(routing_map.sum().item()) + + perm, row_id_map = moe_permute(inp, routing_map, num_out_tokens=num_out, map_type="mask") + perm_detached = perm.detach().clone().requires_grad_(True) + + probs_full = (torch.rand(N, E, device=device, dtype=torch.float32) + * routing_map.float()).requires_grad_(True) + + unperm = moe_unpermute( + perm_detached, row_id_map, merging_probs=probs_full, + restore_shape=torch.Size([N, H]), map_type="mask", + ) + grad_out = torch.randn_like(unperm) + unperm.backward(grad_out) + + assert perm_detached.grad is not None, "act_grad not propagated" + assert perm_detached.grad.shape == perm_detached.shape + assert probs_full.grad is not None, "probs_grad not propagated" + assert probs_full.grad.shape == probs_full.shape + + # -- moe_permute_with_probs tests --------------------------------------- + + @pytest.mark.parametrize("topK", [1, 2]) + def test_permute_with_probs_forward(self, device, topK): + """moe_permute_with_probs permutes both tokens and probs.""" + from transformer_engine.pytorch.permutation import moe_permute_with_probs + self._seed() + N, H, E = 32, 64, 8 + inp = torch.randn(N, H, device=device, dtype=self.DTYPE) + routing_map = self._make_routing_map(N, E, topK, device) + probs = torch.rand(N, E, device=device, dtype=torch.float32) * routing_map.float() + num_out = int(routing_map.sum().item()) + + perm_out, perm_probs, row_id_map = moe_permute_with_probs( + inp, probs, routing_map, num_out_tokens=num_out, + ) + assert perm_out.shape == (num_out, H) + assert perm_probs.shape == (num_out,) + + @pytest.mark.parametrize("topK", [1, 2]) + def test_permute_with_probs_backward(self, device, topK): + """moe_permute_with_probs backward propagates gradients to probs.""" + from transformer_engine.pytorch.permutation import moe_permute_with_probs + self._seed() + N, H, E = 32, 64, 8 + inp = torch.randn(N, H, device=device, dtype=torch.float32, requires_grad=True) + routing_map = self._make_routing_map(N, E, topK, device) + probs = (torch.rand(N, E, device=device, dtype=torch.float32) + * routing_map.float()).requires_grad_(True) + num_out = int(routing_map.sum().item()) + + perm_out, perm_probs, row_id_map = moe_permute_with_probs( + inp, probs, routing_map, num_out_tokens=num_out, + ) + # Backward through probs + perm_probs.sum().backward() + assert probs.grad is not None + assert probs.grad.shape == probs.shape + + # -- Chunk-sort tests --------------------------------------------------- + + def test_sort_chunks_by_index(self, device): + """moe_sort_chunks_by_index reorders chunks correctly.""" + from transformer_engine.pytorch.permutation import moe_sort_chunks_by_index + self._seed() + H = 64 + split_sizes = torch.tensor([10, 20, 15, 5], device=device, dtype=torch.int32) + N = int(split_sizes.sum().item()) + inp = torch.randn(N, H, device=device, dtype=self.DTYPE, requires_grad=True) + sorted_indices = torch.tensor([2, 0, 3, 1], device=device, dtype=torch.int32) + + output = moe_sort_chunks_by_index(inp, split_sizes, sorted_indices) + assert output.shape == (N, H) + + # Verify chunks are reordered: chunk at sorted_indices[i] moves to position i + ref_chunks = torch.split(inp.detach(), split_sizes.tolist(), dim=0) + ref_output = torch.cat([ref_chunks[idx] for idx in sorted_indices.tolist()], dim=0) + torch.testing.assert_close(output.detach(), ref_output) + + def test_sort_chunks_by_index_backward(self, device): + """moe_sort_chunks_by_index backward propagates gradients.""" + from transformer_engine.pytorch.permutation import moe_sort_chunks_by_index + self._seed() + H = 64 + split_sizes = torch.tensor([8, 12, 6], device=device, dtype=torch.int32) + N = int(split_sizes.sum().item()) + inp = torch.randn(N, H, device=device, dtype=torch.float32, requires_grad=True) + sorted_indices = torch.tensor([1, 2, 0], device=device, dtype=torch.int32) + + output = moe_sort_chunks_by_index(inp, split_sizes, sorted_indices) + output.sum().backward() + assert inp.grad is not None + assert inp.grad.shape == inp.shape + + def test_sort_chunks_by_index_with_probs(self, device): + """moe_sort_chunks_by_index_with_probs reorders both tokens and probs.""" + from transformer_engine.pytorch.permutation import moe_sort_chunks_by_index_with_probs + self._seed() + H = 64 + split_sizes = torch.tensor([10, 20, 15], device=device, dtype=torch.int32) + N = int(split_sizes.sum().item()) + inp = torch.randn(N, H, device=device, dtype=self.DTYPE) + probs = torch.rand(N, device=device, dtype=torch.float32) + sorted_indices = torch.tensor([2, 0, 1], device=device, dtype=torch.int32) + + output, perm_probs = moe_sort_chunks_by_index_with_probs( + inp, probs, split_sizes, sorted_indices, + ) + assert output.shape == (N, H) + assert perm_probs.shape == (N,) + + # Verify probs are reordered consistently with tokens + ref_prob_chunks = torch.split(probs, split_sizes.tolist(), dim=0) + ref_probs = torch.cat([ref_prob_chunks[idx] for idx in sorted_indices.tolist()], dim=0) + torch.testing.assert_close(perm_probs, ref_probs) + + # -- Numerical gradient verification for index-map ---------------------- + + @pytest.mark.parametrize("topK", [1, 2]) + def test_index_map_gradient_numerical(self, device, topK): + """Verify index-map permute/unpermute gradients numerically.""" + from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute + self._seed() + N, H, E = 16, 32, 8 + inp = torch.randn(N, H, device=device, dtype=torch.float32, requires_grad=True) + indices = torch.stack( + [torch.randperm(E, device=device)[:topK] for _ in range(N)] + ).to(torch.int32) + probs = torch.rand(N, topK, device=device, dtype=torch.float32).requires_grad_(True) + + # Forward + perm, row_id_map = moe_permute(inp, indices, map_type="index") + perm_detached = perm.detach().clone().requires_grad_(True) + unperm = moe_unpermute(perm_detached, row_id_map, probs, map_type="index") + + # Backward + grad_out = torch.randn(N, H, device=device, dtype=torch.float32) + unperm.backward(grad_out) + + # Verify act_grad: manual computation + # unpermute gathers from permuted[row_id_map], applies probs, sums over topK + # So d(unperm)/d(perm[j]) = probs_flat[inv_map[j]] * (grad_out broadcast) + assert perm_detached.grad is not None + assert not torch.all(perm_detached.grad == 0), "act_grad is all zeros" + + # Verify prob_grad: d(loss)/d(prob[i,k]) = sum_h(perm[row_id_map[i*topK+k], h] * grad_out[i, h]) + assert probs.grad is not None + assert probs.grad.shape == (N, topK) + assert not torch.all(probs.grad == 0), "prob_grad is all zeros" + + +# --------------------------------------------------------------------------- +# MoE padding tests +# --------------------------------------------------------------------------- + +class TestMoEPadding: + """Test multi-row padding / unpadding in lite mode.""" + + DTYPE = torch.bfloat16 + + def test_padding_basic(self, device): + """Rows are copied and extra rows are zero-padded.""" + src_splits = [3, 5, 2] + dst_splits = [4, 8, 4] + features = 64 + inp = torch.randn(sum(src_splits), features, device=device, dtype=self.DTYPE) + out = torch.full( + (sum(dst_splits), features), float("nan"), device=device, dtype=self.DTYPE, + ) + tex.fused_multi_row_padding(inp, out, src_splits, dst_splits) + + in_off, out_off = 0, 0 + for src, dst in zip(src_splits, dst_splits): + # Copied region matches + torch.testing.assert_close( + out[out_off:out_off + src], inp[in_off:in_off + src], + ) + # Padding region is zero + if dst > src: + assert (out[out_off + src:out_off + dst] == 0).all() + in_off += src + out_off += dst + + def test_unpadding_basic(self, device): + """Unpadding extracts the correct rows.""" + src_splits = [4, 8, 4] + dst_splits = [3, 5, 2] + features = 64 + inp = torch.randn(sum(src_splits), features, device=device, dtype=self.DTYPE) + out = torch.empty(sum(dst_splits), features, device=device, dtype=self.DTYPE) + tex.fused_multi_row_unpadding(inp, out, src_splits, dst_splits) + + in_off, out_off = 0, 0 + for src, dst in zip(src_splits, dst_splits): + torch.testing.assert_close( + out[out_off:out_off + dst], inp[in_off:in_off + dst], + ) + in_off += src + out_off += dst + + def test_roundtrip(self, device): + """Padding then unpadding recovers the original tensor.""" + src_splits = [7, 3, 11, 1] + dst_splits = [8, 8, 16, 8] + features = 128 + inp = torch.randn(sum(src_splits), features, device=device, dtype=self.DTYPE) + padded = torch.empty( + sum(dst_splits), features, device=device, dtype=self.DTYPE, + ) + tex.fused_multi_row_padding(inp, padded, src_splits, dst_splits) + + recovered = torch.empty_like(inp) + tex.fused_multi_row_unpadding(padded, recovered, dst_splits, src_splits) + torch.testing.assert_close(recovered, inp) + + def test_no_padding_needed(self, device): + """When splits are equal, data is just copied.""" + splits = [4, 4, 4] + features = 32 + inp = torch.randn(sum(splits), features, device=device, dtype=self.DTYPE) + out = torch.empty_like(inp) + tex.fused_multi_row_padding(inp, out, splits, splits) + torch.testing.assert_close(out, inp) + + def test_single_group(self, device): + """Works with a single group.""" + inp = torch.randn(5, 16, device=device, dtype=self.DTYPE) + out = torch.empty(8, 16, device=device, dtype=self.DTYPE) + tex.fused_multi_row_padding(inp, out, [5], [8]) + torch.testing.assert_close(out[:5], inp) + assert (out[5:] == 0).all() + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_dtype_preservation(self, device, dtype): + """Padding works across dtypes.""" + inp = torch.randn(10, 32, device=device, dtype=dtype) + out = torch.empty(16, 32, device=device, dtype=dtype) + tex.fused_multi_row_padding(inp, out, [4, 6], [8, 8]) + assert out.dtype == dtype + torch.testing.assert_close(out[:4], inp[:4]) + torch.testing.assert_close(out[8:14], inp[4:10]) + + +# --------------------------------------------------------------------------- +# Fused LayerNormLinear / LayerNormMLP (lite-native modules) +# --------------------------------------------------------------------------- + +class TestLiteLayerNormLinear: + """Tests for the lite-native LayerNormLinear module.""" + + DTYPE = torch.bfloat16 + + @pytest.mark.parametrize("normalization", ["LayerNorm", "RMSNorm"]) + def test_forward_shape(self, device, normalization): + mod = te.LayerNormLinear( + 256, 128, bias=True, normalization=normalization, + ).to(dtype=self.DTYPE, device=device) + x = torch.randn(4, 256, device=device, dtype=self.DTYPE) + y = mod(x) + assert y.shape == (4, 128) + + def test_forward_3d_input(self, device): + mod = te.LayerNormLinear(256, 128, bias=True).to( + dtype=self.DTYPE, device=device + ) + x = torch.randn(2, 8, 256, device=device, dtype=self.DTYPE) + y = mod(x) + assert y.shape == (2, 8, 128) + + def test_forward_no_bias(self, device): + mod = te.LayerNormLinear(256, 128, bias=False).to( + dtype=self.DTYPE, device=device + ) + x = torch.randn(4, 256, device=device, dtype=self.DTYPE) + y = mod(x) + assert y.shape == (4, 128) + + def test_return_layernorm_output(self, device): + mod = te.LayerNormLinear( + 256, 128, bias=True, return_layernorm_output=True, + ).to(dtype=self.DTYPE, device=device) + x = torch.randn(4, 256, device=device, dtype=self.DTYPE) + out = mod(x) + assert isinstance(out, tuple) and len(out) == 2 + y, ln_out = out + assert y.shape == (4, 128) + assert ln_out.shape == (4, 256) + + @pytest.mark.parametrize("normalization", ["LayerNorm", "RMSNorm"]) + def test_backward_all_grads(self, device, normalization): + mod = te.LayerNormLinear( + 256, 128, bias=True, normalization=normalization, + ).to(dtype=self.DTYPE, device=device) + x = torch.randn(4, 256, device=device, dtype=self.DTYPE, requires_grad=True) + y = mod(x) + y.sum().backward() + assert x.grad is not None and x.grad.shape == x.shape + assert mod.weight.grad is not None + assert mod.layer_norm_weight.grad is not None + if normalization == "LayerNorm": + assert mod.layer_norm_bias.grad is not None + + def test_backward_no_bias(self, device): + mod = te.LayerNormLinear(256, 128, bias=False).to( + dtype=self.DTYPE, device=device + ) + x = torch.randn(4, 256, device=device, dtype=self.DTYPE, requires_grad=True) + y = mod(x) + y.sum().backward() + assert x.grad is not None + + def test_numerical_vs_manual(self, device): + """Verify output matches manual norm→linear composition.""" + mod = te.LayerNormLinear(128, 64, bias=True, normalization="RMSNorm").to( + dtype=self.DTYPE, device=device + ) + x = torch.randn(4, 128, device=device, dtype=self.DTYPE) + + # Manual reference + w = mod.layer_norm_weight.data + rms = x.float().pow(2).mean(dim=-1, keepdim=True).add(1e-5).rsqrt() + normed = (x.float() * rms * w.float()).to(self.DTYPE) + expected = torch.nn.functional.linear(normed, mod.weight.data, mod.bias.data) + + y = mod(x) + diff = (y - expected).abs().max().item() + assert diff < 0.1, f"Max diff {diff:.4f} too large" + + +class TestLiteLayerNormMLP: + """Tests for the lite-native LayerNormMLP module.""" + + DTYPE = torch.bfloat16 + + @pytest.mark.parametrize("activation", ["gelu", "silu", "relu"]) + def test_forward_non_gated(self, device, activation): + mod = te.LayerNormMLP( + 256, 512, bias=True, activation=activation, + ).to(dtype=self.DTYPE, device=device) + x = torch.randn(4, 256, device=device, dtype=self.DTYPE) + y = mod(x) + assert y.shape == (4, 256) + + @pytest.mark.parametrize("activation", ["swiglu", "geglu", "reglu"]) + def test_forward_gated(self, device, activation): + mod = te.LayerNormMLP( + 256, 512, bias=True, activation=activation, + ).to(dtype=self.DTYPE, device=device) + x = torch.randn(4, 256, device=device, dtype=self.DTYPE) + y = mod(x) + assert y.shape == (4, 256) + # Gated activations should have 2x fc1_weight first dim + assert mod.fc1_weight.shape[0] == 1024 + + def test_forward_3d_input(self, device): + mod = te.LayerNormMLP(256, 512).to(dtype=self.DTYPE, device=device) + x = torch.randn(2, 8, 256, device=device, dtype=self.DTYPE) + y = mod(x) + assert y.shape == (2, 8, 256) + + def test_forward_no_bias(self, device): + mod = te.LayerNormMLP(256, 512, bias=False).to( + dtype=self.DTYPE, device=device + ) + x = torch.randn(4, 256, device=device, dtype=self.DTYPE) + y = mod(x) + assert y.shape == (4, 256) + + @pytest.mark.parametrize("normalization", ["LayerNorm", "RMSNorm"]) + def test_forward_norm_variants(self, device, normalization): + mod = te.LayerNormMLP( + 256, 512, normalization=normalization, + ).to(dtype=self.DTYPE, device=device) + x = torch.randn(4, 256, device=device, dtype=self.DTYPE) + y = mod(x) + assert y.shape == (4, 256) + + def test_return_layernorm_output(self, device): + mod = te.LayerNormMLP( + 256, 512, return_layernorm_output=True, + ).to(dtype=self.DTYPE, device=device) + x = torch.randn(4, 256, device=device, dtype=self.DTYPE) + out = mod(x) + assert isinstance(out, tuple) and len(out) == 2 + y, ln_out = out + assert y.shape == (4, 256) + assert ln_out.shape == (4, 256) + + @pytest.mark.parametrize("activation", ["gelu", "silu", "relu", "swiglu"]) + def test_backward_all_grads(self, device, activation): + mod = te.LayerNormMLP( + 256, 512, bias=True, activation=activation, + ).to(dtype=self.DTYPE, device=device) + x = torch.randn(4, 256, device=device, dtype=self.DTYPE, requires_grad=True) + y = mod(x) + y.sum().backward() + assert x.grad is not None and x.grad.shape == x.shape + assert mod.fc1_weight.grad is not None + assert mod.fc2_weight.grad is not None + assert mod.layer_norm_weight.grad is not None + + def test_backward_no_bias(self, device): + mod = te.LayerNormMLP(256, 512, bias=False, activation="gelu").to( + dtype=self.DTYPE, device=device + ) + x = torch.randn(4, 256, device=device, dtype=self.DTYPE, requires_grad=True) + y = mod(x) + y.sum().backward() + assert x.grad is not None + assert mod.fc1_weight.grad is not None + assert mod.fc2_weight.grad is not None + + def test_numerical_vs_manual(self, device): + """Verify output matches manual norm→fc1→gelu→fc2 composition.""" + mod = te.LayerNormMLP( + 128, 256, bias=True, normalization="RMSNorm", activation="gelu", + ).to(dtype=self.DTYPE, device=device) + x = torch.randn(4, 128, device=device, dtype=self.DTYPE) + + # Manual reference + w = mod.layer_norm_weight.data + rms = x.float().pow(2).mean(dim=-1, keepdim=True).add(1e-5).rsqrt() + normed = (x.float() * rms * w.float()).to(self.DTYPE) + fc1_out = torch.nn.functional.linear(normed, mod.fc1_weight.data, mod.fc1_bias.data) + act_out = torch.nn.functional.gelu(fc1_out, approximate="tanh") + expected = torch.nn.functional.linear(act_out, mod.fc2_weight.data, mod.fc2_bias.data) + + y = mod(x) + diff = (y - expected).abs().max().item() + assert diff < 0.5, f"Max diff {diff:.4f} too large" + + +# --------------------------------------------------------------------------- +# Fused gated activation + FP8 quantize (AITER kernel) +# --------------------------------------------------------------------------- + +class TestFusedGatedActQuant: + """Tests for AITER fused gated activation + block FP8 quantize.""" + + DTYPE = torch.bfloat16 + + @staticmethod + def _has_fused_kernel(): + """Check if the AITER fused act+quant kernel is available.""" + try: + from aiter.ops.triton.activation import act_mul_and_fp8_group_quant # noqa: F401 + return True + except ImportError: + return False + + @staticmethod + def _has_float8_block(): + """Check if Float8BlockQuantizer is available.""" + try: + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + ) + return True + except ImportError: + return False + + @pytest.mark.parametrize("activation,aiter_act", [ + ("swiglu", "silu"), + ("geglu", "gelu_tanh"), + ("reglu", "relu"), + ]) + def test_fused_path_matches_separate(self, device, activation, aiter_act): + """Fused act+quant output should dequantize close to separate act then quant.""" + if not self._has_fused_kernel() or not self._has_float8_block(): + pytest.skip("AITER fused act+quant or Float8BlockQuantizer not available") + + from aiter.ops.triton.activation import act_mul_and_fp8_group_quant + from transformer_engine.pytorch._lite.activations import ( + _aiter_fused_gated_act_quant, + ) + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + ) + + hidden = 256 + x = torch.randn(8, 2 * hidden, device=device, dtype=self.DTYPE) + + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=False, + ) + + # Fused path + fused_out = _aiter_fused_gated_act_quant(x, activation, quantizer) + assert fused_out is not None, "Fused path should fire for Float8BlockQuantizer" + + # Separate path: manual act then quantize + act_fn = {"swiglu": F.silu, "geglu": lambda t: F.gelu(t, approximate="tanh"), "reglu": F.relu} + chunks = x.chunk(2, dim=-1) + ref_bf16 = act_fn[activation](chunks[0]) * chunks[1] + + # Dequantize fused output + fp8_data = fused_out._rowwise_data.view(torch.float8_e4m3fnuz).float() + scale_inv = fused_out._rowwise_scale_inv + # Expand scales to match data: each scale covers block_len elements + block_len = quantizer.block_len + num_blocks = fp8_data.shape[-1] // block_len + scale_expanded = scale_inv.repeat_interleave(block_len, dim=-1) + if scale_expanded.shape[-1] > fp8_data.shape[-1]: + scale_expanded = scale_expanded[..., :fp8_data.shape[-1]] + dequant = fp8_data * scale_expanded + + diff = (dequant - ref_bf16.float()).abs().max().item() + # FP8 quantization error should be small relative to the values + ref_max = ref_bf16.float().abs().max().item() + rel_err = diff / max(ref_max, 1e-6) + assert rel_err < 0.15, f"Fused act+quant relative error {rel_err:.4f} too large" + + @pytest.mark.parametrize("activation", ["swiglu", "geglu", "reglu"]) + def test_fused_path_not_taken_without_block_quantizer(self, device, activation): + """Fused path should return None when quantizer is not Float8BlockQuantizer.""" + from transformer_engine.pytorch._lite.activations import ( + _aiter_fused_gated_act_quant, + ) + + x = torch.randn(4, 256, device=device, dtype=self.DTYPE) + # None quantizer → no fused path + result = _aiter_fused_gated_act_quant(x, activation, None) + assert result is None + + def test_fused_path_output_shape(self, device): + """Fused output should have half the last dim (gated activation).""" + if not self._has_fused_kernel() or not self._has_float8_block(): + pytest.skip("AITER fused act+quant or Float8BlockQuantizer not available") + + from transformer_engine.pytorch._lite.activations import ( + _aiter_fused_gated_act_quant, + ) + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + ) + + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=False, + ) + + x = torch.randn(4, 8, 512, device=device, dtype=self.DTYPE) + result = _aiter_fused_gated_act_quant(x, "swiglu", quantizer) + assert result is not None + # Gated activation halves last dim: 512 → 256 + assert result._rowwise_data.shape[-1] == 256 + # Total elements should match batch * 256 + assert result._rowwise_data.numel() == 4 * 8 * 256 + + +class TestFusedGatedActCurrentScaling: + """Tests for AITER fused gated activation + per-row FP8 quantize (CurrentScaling).""" + + DTYPE = torch.bfloat16 + + @staticmethod + def _has_fused_kernel(): + try: + from aiter.ops.triton.activation import act_mul_and_fp8_group_quant # noqa: F401 + return True + except ImportError: + return False + + @staticmethod + def _has_current_scaling(): + try: + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, # noqa: F401 + ) + return True + except ImportError: + return False + + @pytest.mark.parametrize("activation,aiter_act", [ + ("swiglu", "silu"), + ("geglu", "gelu_tanh"), + ("reglu", "relu"), + ]) + def test_current_scaling_fused_matches_separate(self, device, activation, aiter_act): + """Fused act+quant with CurrentScaling should dequantize close to separate act then quant.""" + if not self._has_fused_kernel() or not self._has_current_scaling(): + pytest.skip("AITER fused kernel or Float8CurrentScalingQuantizer not available") + + from transformer_engine.pytorch._lite.activations import ( + _aiter_fused_gated_act_current_scaling, + ) + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, + Float8Tensor, + ) + + hidden = 256 + x = torch.randn(8, 2 * hidden, device=device, dtype=self.DTYPE) + + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=x.device, + rowwise=True, + columnwise=False, + ) + + # Fused path + fused_out = _aiter_fused_gated_act_current_scaling(x, activation, quantizer) + assert fused_out is not None, "Fused path should fire for Float8CurrentScalingQuantizer" + assert isinstance(fused_out, Float8Tensor) + + # Per-row scale_inv: shape (M,) + assert fused_out._scale_inv.shape == (8,), ( + f"Expected per-row scale_inv shape (8,), got {fused_out._scale_inv.shape}" + ) + + # Separate path: manual act then per-row quant reference + act_fn = {"swiglu": F.silu, "geglu": lambda t: F.gelu(t, approximate="tanh"), "reglu": F.relu} + chunks = x.chunk(2, dim=-1) + ref_bf16 = act_fn[activation](chunks[0]) * chunks[1] + + # Dequantize fused output + fp8_data = fused_out._data.view(torch.float8_e4m3fnuz).float() + scale_inv = fused_out._scale_inv # shape (M,) + # Expand per-row scales: (M,) → (M, 1) for broadcast + dequant = fp8_data * scale_inv.unsqueeze(-1) + + diff = (dequant - ref_bf16.float()).abs().max().item() + ref_max = ref_bf16.float().abs().max().item() + rel_err = diff / max(ref_max, 1e-6) + assert rel_err < 0.15, f"Fused act+quant (CurrentScaling) relative error {rel_err:.4f} too large" + + @pytest.mark.parametrize("activation", ["swiglu", "geglu", "reglu"]) + def test_current_scaling_not_taken_for_block_quantizer(self, device, activation): + """CurrentScaling fused path should return None for Float8BlockQuantizer.""" + from transformer_engine.pytorch._lite.activations import ( + _aiter_fused_gated_act_current_scaling, + ) + + x = torch.randn(4, 256, device=device, dtype=self.DTYPE) + # None quantizer → no fused path + result = _aiter_fused_gated_act_current_scaling(x, activation, None) + assert result is None + + def test_current_scaling_output_shape_3d(self, device): + """Fused output should handle 3D input and have half the last dim.""" + if not self._has_fused_kernel() or not self._has_current_scaling(): + pytest.skip("AITER fused kernel or Float8CurrentScalingQuantizer not available") + + from transformer_engine.pytorch._lite.activations import ( + _aiter_fused_gated_act_current_scaling, + ) + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, + Float8Tensor, + ) + + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=device, + rowwise=True, + columnwise=False, + ) + + x = torch.randn(4, 8, 512, device=device, dtype=self.DTYPE) + result = _aiter_fused_gated_act_current_scaling(x, "swiglu", quantizer) + assert result is not None + assert isinstance(result, Float8Tensor) + # Gated activation halves last dim: 512 → 256 + assert result._data.shape == (4, 8, 256) + # Per-row scales: M = 4*8 = 32 rows + assert result._scale_inv.shape == (32,) + + def test_current_scaling_per_row_scales_vary(self, device): + """Per-row scales should differ across rows (not degenerate per-tensor).""" + if not self._has_fused_kernel() or not self._has_current_scaling(): + pytest.skip("AITER fused kernel or Float8CurrentScalingQuantizer not available") + + from transformer_engine.pytorch._lite.activations import ( + _aiter_fused_gated_act_current_scaling, + ) + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, + ) + + # Use input with deliberately different magnitudes per row + hidden = 128 + x = torch.randn(16, 2 * hidden, device=device, dtype=self.DTYPE) + # Scale each row differently so per-row scales must differ + row_scales = torch.logspace(-2, 2, 16, device=device, dtype=self.DTYPE).unsqueeze(-1) + x = x * row_scales + + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=device, + rowwise=True, + columnwise=False, + ) + + result = _aiter_fused_gated_act_current_scaling(x, "swiglu", quantizer) + assert result is not None + scales = result._scale_inv + # With 4 orders of magnitude in row scales, per-row scales should vary + assert scales.max() / scales.min() > 2.0, ( + f"Per-row scales should vary, got max/min ratio {scales.max() / scales.min():.2f}" + ) + + +# --------------------------------------------------------------------------- +# Recipe-level FP8 integration tests +# --------------------------------------------------------------------------- + +def _available_recipes(): + """Build list of recipe instances for what this hardware supports.""" + recipes = [] + avail, _ = te.is_fp8_available(return_reason=True) + if avail: + recipes.append(recipe.DelayedScaling()) + recipes.append(recipe.Float8CurrentScaling()) + block_avail, _ = te.is_fp8_block_scaling_available(return_reason=True) + if block_avail: + recipes.append(recipe.Float8BlockScaling()) + mx_avail, _ = te.is_mxfp8_available(return_reason=True) + if mx_avail: + recipes.append(recipe.MXFP8BlockScaling()) + return recipes + + +def _recipe_id(val): + """Short string ID for parametrize labels.""" + if isinstance(val, pytest.param): + return None # let pytest use the id from pytest.param + return type(val).__name__ + + +def _mark_recipes(recipes): + """Wrap recipes with xfail markers for any known lite-mode bugs.""" + return [pytest.param(r, id=type(r).__name__) for r in recipes] + + +_RECIPES = _available_recipes() +_RECIPES_FWD = _mark_recipes(_RECIPES) +_RECIPES_FWD_BWD = _mark_recipes(_RECIPES) + + +class TestRecipeIntegration: + """Recipe-level FP8 integration through te.autocast for core TE modules. + + Tests the full path: recipe object -> autocast context -> RecipeState -> + quantizer construction -> module forward/backward. + """ + + DTYPE = torch.bfloat16 + HIDDEN = 256 + FFN_HIDDEN = 1024 + BATCH = 8 # Must be divisible by 8 for FP8 alignment + SEQ = 8 + + @pytest.fixture(autouse=True) + def _reset_fp8_state(self): + """Reset global FP8 state between tests.""" + yield + FP8GlobalStateManager.reset() + + def _skip_if_no_recipes(self): + if not _RECIPES: + pytest.skip("No FP8 recipes available on this hardware") + + # --------------------------------------------------------------- + # Linear — simplest module, isolates GEMM + quantize path + # --------------------------------------------------------------- + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD_BWD) + def test_linear_fwd_bwd(self, device, fp8_recipe): + """Linear forward+backward completes without error under each recipe.""" + mod = te.Linear(self.HIDDEN, self.HIDDEN, bias=True, + params_dtype=self.DTYPE).to(device) + x = torch.randn(self.BATCH, self.HIDDEN, device=device, + dtype=self.DTYPE, requires_grad=True) + with te.autocast(enabled=True, recipe=fp8_recipe): + y = mod(x) + y.sum().backward() + torch.cuda.synchronize() + assert y.shape == (self.BATCH, self.HIDDEN) + assert x.grad is not None + assert x.grad.shape == x.shape + for p in mod.parameters(): + if p.requires_grad: + assert p.grad is not None, f"Missing grad for param shape {p.shape}" + + # --------------------------------------------------------------- + # LayerNormLinear — tests fused norm+quant -> GEMM path + # --------------------------------------------------------------- + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD_BWD) + @pytest.mark.parametrize("normalization", ["LayerNorm", "RMSNorm"]) + def test_layernorm_linear_fwd_bwd(self, device, fp8_recipe, normalization): + """LayerNormLinear forward+backward under each recipe and norm variant.""" + mod = te.LayerNormLinear( + self.HIDDEN, self.HIDDEN, bias=True, + normalization=normalization, + params_dtype=self.DTYPE, + ).to(device) + x = torch.randn(self.BATCH, self.HIDDEN, device=device, + dtype=self.DTYPE, requires_grad=True) + with te.autocast(enabled=True, recipe=fp8_recipe): + y = mod(x) + y.sum().backward() + torch.cuda.synchronize() + assert y.shape == (self.BATCH, self.HIDDEN) + assert x.grad is not None + + # --------------------------------------------------------------- + # LayerNormMLP — norm+quant -> GEMM -> act+quant -> GEMM + # --------------------------------------------------------------- + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD_BWD) + @pytest.mark.parametrize("activation", ["gelu", "swiglu", "geglu"]) + def test_layernorm_mlp_fwd_bwd(self, device, fp8_recipe, activation): + """LayerNormMLP forward+backward — exercises fused act+quant dispatch.""" + mod = te.LayerNormMLP( + self.HIDDEN, self.FFN_HIDDEN, + activation=activation, + params_dtype=self.DTYPE, + ).to(device) + x = torch.randn(self.BATCH, self.HIDDEN, device=device, + dtype=self.DTYPE, requires_grad=True) + with te.autocast(enabled=True, recipe=fp8_recipe): + y = mod(x) + y.sum().backward() + torch.cuda.synchronize() + assert y.shape == (self.BATCH, self.HIDDEN) + assert x.grad is not None + + # --------------------------------------------------------------- + # Multi-step — catches stale scale / amax history bugs + # --------------------------------------------------------------- + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD_BWD) + def test_linear_multi_step(self, device, fp8_recipe): + """3-step forward+backward — verifies amax history and scale updates.""" + mod = te.Linear(self.HIDDEN, self.HIDDEN, bias=True, + params_dtype=self.DTYPE).to(device) + for step in range(3): + x = torch.randn(self.BATCH, self.HIDDEN, device=device, + dtype=self.DTYPE, requires_grad=True) + with te.autocast(enabled=True, recipe=fp8_recipe): + y = mod(x) + y.sum().backward() + torch.cuda.synchronize() + assert y.shape == (self.BATCH, self.HIDDEN), f"Failed at step {step}" + assert x.grad is not None, f"No input grad at step {step}" + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD_BWD) + def test_layernorm_mlp_multi_step(self, device, fp8_recipe): + """3-step LayerNormMLP — full pipeline with scale state evolution.""" + mod = te.LayerNormMLP( + self.HIDDEN, self.FFN_HIDDEN, + activation="swiglu", + params_dtype=self.DTYPE, + ).to(device) + for step in range(3): + x = torch.randn(self.BATCH, self.HIDDEN, device=device, + dtype=self.DTYPE, requires_grad=True) + with te.autocast(enabled=True, recipe=fp8_recipe): + y = mod(x) + y.sum().backward() + torch.cuda.synchronize() + assert y.shape == (self.BATCH, self.HIDDEN), f"Failed at step {step}" + + # --------------------------------------------------------------- + # Output sanity — FP8 shouldn't produce garbage + # --------------------------------------------------------------- + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD) + def test_linear_output_finite(self, device, fp8_recipe): + """FP8 output should be finite and not all-zeros.""" + mod = te.Linear(self.HIDDEN, self.HIDDEN, bias=True, + params_dtype=self.DTYPE).to(device) + x = torch.randn(self.BATCH, self.HIDDEN, device=device, + dtype=self.DTYPE) + with te.autocast(enabled=True, recipe=fp8_recipe): + y = mod(x) + assert torch.isfinite(y).all(), "Output contains NaN/Inf" + assert y.abs().max() > 0, "Output is all zeros" + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD) + def test_fp8_vs_bf16_correlation(self, device, fp8_recipe): + """FP8 output should correlate with bf16 output (same weights, same input).""" + mod = te.Linear(self.HIDDEN, self.HIDDEN, bias=True, + params_dtype=self.DTYPE).to(device) + x = torch.randn(self.BATCH, self.HIDDEN, device=device, + dtype=self.DTYPE) + + # bf16 reference (no FP8) + with torch.no_grad(): + ref = mod(x) + + # FP8 path + with torch.no_grad(): + with te.autocast(enabled=True, recipe=fp8_recipe): + fp8_out = mod(x) + + cos_sim = F.cosine_similarity(ref.flatten().float(), + fp8_out.flatten().float(), dim=0) + assert cos_sim > 0.9, ( + f"FP8 output too far from bf16: cosine_similarity={cos_sim:.4f}" + ) + + # --------------------------------------------------------------- + # TransformerLayer — full attention + MLP stack with FP8 + # --------------------------------------------------------------- + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD_BWD) + def test_transformer_layer_fwd_bwd(self, device, fp8_recipe): + """Full TransformerLayer forward+backward under FP8 recipe.""" + mod = te.TransformerLayer( + self.HIDDEN, self.FFN_HIDDEN, num_attention_heads=4, + params_dtype=self.DTYPE, + ).to(device) + # TransformerLayer expects (seq, batch, hidden) + x = torch.randn( + self.SEQ, 2, self.HIDDEN, device=device, + dtype=self.DTYPE, requires_grad=True, + ) + with torch.amp.autocast("cuda", dtype=self.DTYPE): + with te.autocast(enabled=True, recipe=fp8_recipe): + y = mod(x) + y.sum().backward() + torch.cuda.synchronize() + assert y.shape == x.shape + assert torch.isfinite(y).all() + assert x.grad is not None + assert torch.isfinite(x.grad).all() + + # --------------------------------------------------------------- + # FP8 vs bf16 correlation for fused modules — catches silent + # wrong-dispatch, scale broadcast bugs, and per-row axis misalignment + # --------------------------------------------------------------- + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD) + @pytest.mark.parametrize("normalization", ["LayerNorm", "RMSNorm"]) + def test_layernorm_linear_correlation(self, device, fp8_recipe, normalization): + """LayerNormLinear FP8 output should correlate with bf16 (same weights).""" + mod = te.LayerNormLinear( + self.HIDDEN, self.HIDDEN, bias=True, + normalization=normalization, params_dtype=self.DTYPE, + ).to(device) + x = torch.randn(self.BATCH, self.HIDDEN, device=device, dtype=self.DTYPE) + with torch.no_grad(): + ref = mod(x) + with te.autocast(enabled=True, recipe=fp8_recipe): + fp8_out = mod(x) + cos = F.cosine_similarity(ref.flatten().float(), + fp8_out.flatten().float(), dim=0).item() + assert cos > 0.9, f"cos_sim={cos:.4f}" + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD) + @pytest.mark.parametrize("activation", ["gelu", "swiglu"]) + def test_layernorm_mlp_correlation(self, device, fp8_recipe, activation): + """LayerNormMLP FP8 output should correlate with bf16 (same weights).""" + mod = te.LayerNormMLP( + self.HIDDEN, self.FFN_HIDDEN, + activation=activation, params_dtype=self.DTYPE, + ).to(device) + x = torch.randn(self.BATCH, self.HIDDEN, device=device, dtype=self.DTYPE) + with torch.no_grad(): + ref = mod(x) + with te.autocast(enabled=True, recipe=fp8_recipe): + fp8_out = mod(x) + cos = F.cosine_similarity(ref.flatten().float(), + fp8_out.flatten().float(), dim=0).item() + assert cos > 0.9, f"cos_sim={cos:.4f}" + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD) + def test_transformer_layer_correlation(self, device, fp8_recipe): + """TransformerLayer FP8 output should correlate with bf16 (same weights).""" + mod = te.TransformerLayer( + self.HIDDEN, self.FFN_HIDDEN, num_attention_heads=4, + params_dtype=self.DTYPE, + ).to(device) + x = torch.randn(self.SEQ, 2, self.HIDDEN, device=device, dtype=self.DTYPE) + with torch.no_grad(): + with torch.amp.autocast("cuda", dtype=self.DTYPE): + ref = mod(x) + with te.autocast(enabled=True, recipe=fp8_recipe): + fp8_out = mod(x) + cos = F.cosine_similarity(ref.flatten().float(), + fp8_out.flatten().float(), dim=0).item() + # TransformerLayer has more accumulated error than a single Linear, + # so use a looser tolerance (still much better than random ~0). + assert cos > 0.75, f"cos_sim={cos:.4f}" + + +# --------------------------------------------------------------------------- +# Per-row FP8 (CurrentScaling) end-to-end integration +# +# Exercises the per-row dispatch chain a full Linear/LayerNormLinear under +# Float8CurrentScaling: +# Fwd: fused RMSNorm + dynamic_per_token_quant -> gemm_a8w8_per_token_scale +# Bwd: dynamic_per_token_quant_fp8_i8(dY) -> gemm_a8w8_per_token_scale (dgrad) +# gemm_a8w8 per-tensor (wgrad fallback) +# +# Asserts that the fused/per-row AITER kernels actually fire (catches silent +# fallback to per-tensor) and that fwd+bwd numerics stay close to a BF16 +# reference at FP8-appropriate tolerances. +# --------------------------------------------------------------------------- + +def _aiter_per_row_kernels_available(): + """All three AITER kernels needed by the per-row CurrentScaling path.""" + try: + from aiter.ops.triton.rmsnorm import rmsnorm2d_fwd_with_dynamicquant # noqa: F401 + from aiter.ops.triton.quant import dynamic_per_token_quant_fp8_i8 # noqa: F401 + from aiter.ops.triton.gemm_a8w8_per_token_scale import ( # noqa: F401 + gemm_a8w8_per_token_scale, + ) + return True + except (ImportError, AttributeError): + return False + + +class TestLitePerRowFP8: + """End-to-end tests for the per-row FP8 (CurrentScaling) training path.""" + + DTYPE = torch.bfloat16 + M = 64 # tokens (rows) — must be divisible by 8 for FP8 alignment + K = 256 # in_features + N = 128 # out_features + + @pytest.fixture(autouse=True) + def _reset_fp8_state(self): + yield + FP8GlobalStateManager.reset() + + @pytest.fixture + def per_row_counts(self, monkeypatch): + """Wrap the two AITER per-row kernel module-attrs with counters. + + Triggers the lazy loaders first so the originals are present, then + monkeypatches the module globals. The dispatch sites read these as + module-level names, so wrapping the attr is enough to intercept. + """ + if not _aiter_per_row_kernels_available(): + pytest.skip("AITER per-row kernels not available") + + # `_lite.quantize` is shadowed in the package namespace by the + # re-exported `quantize` function, so even `import a.b.c as x` + # resolves to the function via attribute lookup. Pull from sys.modules + # to get the actual submodule object. + import sys + import transformer_engine.pytorch._lite.norms # noqa: F401 + import transformer_engine.pytorch._lite.quantize # noqa: F401 + _ln = sys.modules["transformer_engine.pytorch._lite.norms"] + _qz = sys.modules["transformer_engine.pytorch._lite.quantize"] + + _ln._try_load_aiter_norms() + _qz._try_load_aiter_quant() + + if _ln._aiter_fused_rms_dynamic_quant is None: + pytest.skip("rmsnorm2d_fwd_with_dynamicquant not loaded") + if _qz._aiter_dynamic_per_token_quant is None: + pytest.skip("dynamic_per_token_quant_fp8_i8 not loaded") + + counts = {"fused_rms_quant": 0, "dyn_per_token_quant": 0} + orig_rms = _ln._aiter_fused_rms_dynamic_quant + orig_dyn = _qz._aiter_dynamic_per_token_quant + + def rms_wrap(*args, **kwargs): + counts["fused_rms_quant"] += 1 + return orig_rms(*args, **kwargs) + + def dyn_wrap(*args, **kwargs): + counts["dyn_per_token_quant"] += 1 + return orig_dyn(*args, **kwargs) + + monkeypatch.setattr(_ln, "_aiter_fused_rms_dynamic_quant", rms_wrap) + monkeypatch.setattr(_qz, "_aiter_dynamic_per_token_quant", dyn_wrap) + return counts + + def _make_module(self, device): + return te.LayerNormLinear( + self.K, self.N, bias=True, normalization="RMSNorm", + ).to(dtype=self.DTYPE, device=device) + + def _recipe(self): + return recipe.Float8CurrentScaling() + + def test_fwd_dispatches_per_row(self, device, per_row_counts): + """Fwd through LayerNormLinear under CurrentScaling routes the input + through the fused RMSNorm + per-row dynamic FP8 quantize kernel.""" + torch.manual_seed(0) + mod = self._make_module(device) + x = torch.randn(self.M, self.K, device=device, dtype=self.DTYPE) + + with te.autocast(enabled=True, recipe=self._recipe()): + y = mod(x) + + assert y.shape == (self.M, self.N) + assert per_row_counts["fused_rms_quant"] >= 1, ( + "fused RMSNorm + per-row quant kernel did not fire — " + "dispatch fell back to a non-fused path. " + f"counts={per_row_counts}" + ) + + def test_bwd_dispatches_per_row(self, device, per_row_counts): + """Bwd quantizes dY via the per-row dynamic quant kernel.""" + torch.manual_seed(0) + mod = self._make_module(device) + x = torch.randn( + self.M, self.K, device=device, dtype=self.DTYPE, requires_grad=True, + ) + + with te.autocast(enabled=True, recipe=self._recipe()): + y = mod(x) + y.sum().backward() + + assert x.grad is not None and x.grad.shape == x.shape + assert mod.weight.grad is not None + assert mod.layer_norm_weight.grad is not None + assert mod.bias.grad is not None + assert per_row_counts["dyn_per_token_quant"] >= 1, ( + "per-row dY quant kernel did not fire on backward; " + f"counts={per_row_counts}" + ) + + def test_numerics_vs_bf16(self, device, per_row_counts): + """Per-row FP8 fwd+bwd stay within FP8-appropriate tolerance of the + same module run in BF16. Tolerances are relative-to-RMS so they don't + spuriously pass on near-zero outputs.""" + torch.manual_seed(0) + mod_fp8 = self._make_module(device) + mod_bf16 = self._make_module(device) + with torch.no_grad(): + mod_bf16.weight.copy_(mod_fp8.weight) + mod_bf16.layer_norm_weight.copy_(mod_fp8.layer_norm_weight) + if (hasattr(mod_bf16, "layer_norm_bias") + and mod_bf16.layer_norm_bias is not None): + mod_bf16.layer_norm_bias.copy_(mod_fp8.layer_norm_bias) + mod_bf16.bias.copy_(mod_fp8.bias) + + x = torch.randn(self.M, self.K, device=device, dtype=self.DTYPE) + x_fp8 = x.clone().requires_grad_(True) + x_bf16 = x.clone().requires_grad_(True) + + with te.autocast(enabled=True, recipe=self._recipe()): + y_fp8 = mod_fp8(x_fp8) + y_fp8.sum().backward() + + y_bf16 = mod_bf16(x_bf16) + y_bf16.sum().backward() + + def _cos(out, ref): + o = out.float().flatten() + r = ref.float().flatten() + return F.cosine_similarity(o, r, dim=0).item() + + def _rel_rms(out, ref): + ref_rms = ref.float().pow(2).mean().sqrt().item() + 1e-8 + err_rms = (out.float() - ref.float()).pow(2).mean().sqrt().item() + return err_rms / ref_rms + + # Direction (cosine) catches drift; rel-RMS catches scale drift. + # Max-abs would be dominated by individual E4M3 outliers (~6% + # quantization step) that don't reflect tensor-level correctness. + for name, out, ref in ( + ("fwd", y_fp8, y_bf16), + ("x.grad", x_fp8.grad, x_bf16.grad), + ("weight.grad", mod_fp8.weight.grad, mod_bf16.weight.grad), + ): + cos = _cos(out, ref) + err = _rel_rms(out, ref) + # 0.99 cosine, 5% rel-RMS — within FP8/per-row budget for a single + # LayerNormLinear at K=256. weight.grad uses per-tensor fallback + # (per-row scales lie on the reduction axis for dW) so its budget + # is looser. + # Per-tensor wgrad fallback gets the loosest budget; x.grad + # compounds dY quant + weight quant so it's looser than fwd. + if name == "weight.grad": + cos_min, err_max = 0.95, 0.15 + elif name == "x.grad": + cos_min, err_max = 0.99, 0.08 + else: + cos_min, err_max = 0.99, 0.05 + assert cos > cos_min, f"{name}: cos {cos:.4f} < {cos_min}" + assert err < err_max, f"{name}: rel-RMS {err:.4f} > {err_max}" + + # Confirm both per-row paths actually fired (not silent fallback). + assert per_row_counts["fused_rms_quant"] >= 1 + assert per_row_counts["dyn_per_token_quant"] >= 1 + + +# --------------------------------------------------------------------------- +# API contract tests — verify lite exposes the symbols the full TE does, +# with compatible constructor signatures. Catches cases like "accepted kwarg +# but silently dropped it" (the return_bias bug we just fixed). +# --------------------------------------------------------------------------- + +class TestLiteAPI: + """Verify the lite backend exposes the API contract of full TE.""" + + # Top-level symbols that test_sanity.py imports from transformer_engine.pytorch + _TE_PUBLIC_SYMBOLS = [ + "Linear", "LayerNormLinear", "LayerNormMLP", "TransformerLayer", + "GroupedLinear", "RMSNorm", "LayerNorm", + "autocast", "fp8_autocast", + "Float8Tensor", "Float8Quantizer", "Float8CurrentScalingQuantizer", + "QuantizedTensor", + "is_fp8_available", "is_mxfp8_available", "is_fp8_block_scaling_available", + "is_bf16_available", + ] + + # Critical tex (transformer_engine_torch) functions used across the codebase + _TEX_FUNCTIONS = [ + "quantize", "dequantize", "bgrad_quantize", "split_quantize", + "generic_gemm", "fp8_transpose", + "layernorm_fwd", "layernorm_bwd", "rmsnorm_fwd", "rmsnorm_bwd", + "fused_amax_and_scale_update_after_reduction", "compute_amax", + "swiglu", "geglu", "reglu", "gelu", "silu", "relu", "srelu", "qgelu", + "dswiglu", "dgeglu", "dreglu", "dgelu", "dsilu", "drelu", + "dbias_dgelu", "dbias_dsilu", "dbias_drelu", + ] + + # Expected DType enum values + _TEX_DTYPES = [ + "kFloat32", "kFloat16", "kBFloat16", + "kFloat8E4M3", "kFloat8E5M2", + ] + + def test_te_public_symbols_exist(self): + """Every symbol test_sanity.py imports from te must exist in lite.""" + missing = [s for s in self._TE_PUBLIC_SYMBOLS if not hasattr(te, s)] + assert not missing, f"Missing from te: {missing}" + + def test_tex_functions_exist(self): + """Every tex function called across the TE codebase must exist in lite.""" + missing = [s for s in self._TEX_FUNCTIONS if not hasattr(tex, s)] + assert not missing, f"Missing from tex: {missing}" + + def test_tex_functions_callable(self): + """tex functions must be callable, not just exist as sentinel values.""" + non_callable = [ + s for s in self._TEX_FUNCTIONS + if hasattr(tex, s) and not callable(getattr(tex, s)) + ] + assert not non_callable, f"Not callable: {non_callable}" + + def test_tex_dtype_enum(self): + """DType enum must expose the standard values.""" + assert hasattr(tex, "DType"), "tex.DType missing" + missing = [v for v in self._TEX_DTYPES if not hasattr(tex.DType, v)] + assert not missing, f"Missing DType values: {missing}" + + @pytest.mark.parametrize("cls_name,required_kwargs", [ + ("Linear", ["bias", "params_dtype"]), + ("LayerNormLinear", ["bias", "params_dtype", "normalization", + "return_bias", "return_layernorm_output", + "zero_centered_gamma"]), + ("LayerNormMLP", ["bias", "params_dtype", "normalization", "activation", + "return_bias", "return_layernorm_output", + "zero_centered_gamma"]), + ("TransformerLayer", ["num_attention_heads", "params_dtype"]), + ("LayerNorm", ["zero_centered_gamma"]), + ("RMSNorm", ["zero_centered_gamma"]), + ]) + def test_module_accepts_expected_kwargs(self, cls_name, required_kwargs): + """Module constructors must accept the documented kwargs, not silently drop them.""" + import inspect + cls = getattr(te, cls_name) + sig = inspect.signature(cls.__init__) + params = sig.parameters + # Either the kwarg is explicitly listed, or **kwargs catches it. + has_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + missing = [k for k in required_kwargs if k not in params and not has_kwargs] + assert not missing, f"{cls_name} missing kwargs: {missing}" + + def test_layernorm_mlp_return_bias_returns_tuple(self, device): + """Regression test: LayerNormMLP(return_bias=True) must return (out, bias).""" + mod = te.LayerNormMLP( + 256, 1024, bias=True, return_bias=True, params_dtype=torch.bfloat16, + ).to(device) + x = torch.randn(8, 256, device=device, dtype=torch.bfloat16) + out = mod(x) + assert isinstance(out, tuple), f"Expected tuple, got {type(out).__name__}" + assert len(out) == 2, f"Expected 2-tuple, got {len(out)}-tuple" + main_out, bias = out + assert main_out.shape == (8, 256) + # bias must be 1D with last-dim size — either a real bias or a placeholder + assert bias.ndim == 1 and bias.shape[0] == 256 + + def test_lite_mode_flag(self): + """The tex module's __name__ reliably identifies lite vs full backend.""" + assert tex.__name__ == "transformer_engine.pytorch._lite" + + def test_recipes_available(self): + """Recipe classes must be importable via the common API, regardless of hw support.""" + from transformer_engine.common import recipe as r + assert hasattr(r, "DelayedScaling") + assert hasattr(r, "Float8CurrentScaling") + assert hasattr(r, "Float8BlockScaling") + assert hasattr(r, "MXFP8BlockScaling") + + +# --------------------------------------------------------------------------- +# End-to-end training loop tests — optimizer.step() drives weight updates, +# verifying that FP8 recipes actually converge and the fp8 weight cache +# invalidates correctly between steps. +# --------------------------------------------------------------------------- + +class TestFP8Training: + """Verify FP8 recipes support real training (optimizer.step).""" + + DTYPE = torch.bfloat16 + HIDDEN = 256 + FFN_HIDDEN = 1024 + BATCH = 8 + + @pytest.fixture(autouse=True) + def _reset_fp8_state(self): + yield + FP8GlobalStateManager.reset() + + def _overfit_and_check(self, mod, x, target, fp8_recipe, + steps=50, lr=1e-3, use_amp=False, + loss_drop_ratio=0.8): + """Run N training steps with Adam; assert loss drops by at least + (1-loss_drop_ratio) of the initial value. Adam is used (not SGD) + to avoid per-module LR tuning — it adapts to gradient magnitudes.""" + opt = torch.optim.Adam(mod.parameters(), lr=lr) + losses = [] + for _ in range(steps): + opt.zero_grad() + if use_amp: + with torch.amp.autocast("cuda", dtype=self.DTYPE): + with te.autocast(enabled=True, recipe=fp8_recipe): + y = mod(x) + else: + with te.autocast(enabled=True, recipe=fp8_recipe): + y = mod(x) + loss = (y.float() - target.float()).pow(2).mean() + loss.backward() + opt.step() + losses.append(loss.item()) + torch.cuda.synchronize() + # No NaN/Inf during the trajectory + assert all(torch.isfinite(torch.tensor(L)) for L in losses), ( + f"Non-finite loss during training: trajectory={losses[::5]}" + ) + # Substantial learning (final < loss_drop_ratio * initial) + assert losses[-1] < losses[0] * loss_drop_ratio, ( + f"Loss didn't drop enough: {losses[0]:.4f} -> {losses[-1]:.4f} " + f"(trajectory every 10: {losses[::10]})" + ) + return losses + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD_BWD) + def test_linear_overfits_batch(self, device, fp8_recipe): + """Linear should overfit a fixed batch under each FP8 recipe.""" + torch.manual_seed(0) + mod = te.Linear(self.HIDDEN, self.HIDDEN, bias=True, + params_dtype=self.DTYPE).to(device) + x = torch.randn(self.BATCH, self.HIDDEN, device=device, dtype=self.DTYPE) + target = torch.randn(self.BATCH, self.HIDDEN, device=device, dtype=self.DTYPE) + self._overfit_and_check(mod, x, target, fp8_recipe) + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD_BWD) + def test_layernorm_mlp_overfits_batch(self, device, fp8_recipe): + """LayerNormMLP should overfit a fixed batch.""" + torch.manual_seed(0) + mod = te.LayerNormMLP(self.HIDDEN, self.FFN_HIDDEN, + activation="swiglu", + params_dtype=self.DTYPE).to(device) + x = torch.randn(self.BATCH, self.HIDDEN, device=device, dtype=self.DTYPE) + target = torch.randn(self.BATCH, self.HIDDEN, device=device, dtype=self.DTYPE) + self._overfit_and_check(mod, x, target, fp8_recipe) + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD_BWD) + def test_transformer_layer_overfits_batch(self, device, fp8_recipe): + """TransformerLayer should overfit a fixed batch.""" + torch.manual_seed(0) + mod = te.TransformerLayer( + self.HIDDEN, self.FFN_HIDDEN, num_attention_heads=4, + params_dtype=self.DTYPE, + ).to(device) + x = torch.randn(self.BATCH, 2, self.HIDDEN, device=device, dtype=self.DTYPE) + target = torch.randn(self.BATCH, 2, self.HIDDEN, device=device, dtype=self.DTYPE) + self._overfit_and_check(mod, x, target, fp8_recipe, use_amp=True) + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD_BWD) + def test_weights_change_after_step(self, device, fp8_recipe): + """Sanity: optimizer.step() must actually update the weights.""" + torch.manual_seed(0) + mod = te.Linear(self.HIDDEN, self.HIDDEN, bias=True, + params_dtype=self.DTYPE).to(device) + opt = torch.optim.SGD(mod.parameters(), lr=0.1) + + x = torch.randn(self.BATCH, self.HIDDEN, device=device, dtype=self.DTYPE) + w_before = mod.weight.detach().clone() + + with te.autocast(enabled=True, recipe=fp8_recipe): + y = mod(x) + y.sum().backward() + opt.step() + torch.cuda.synchronize() + + assert not torch.equal(w_before, mod.weight), ( + "Weights unchanged after optimizer.step()" + ) + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD_BWD) + def test_fp8_training_tracks_bf16(self, device, fp8_recipe): + """After N training steps, FP8 weights should still correlate with bf16 weights. + Catches FP8 weight-cache staleness — if fp8 cache isn't invalidated after + optimizer.step, the two trajectories diverge rapidly.""" + torch.manual_seed(0) + mod_fp8 = te.Linear(self.HIDDEN, self.HIDDEN, bias=True, + params_dtype=self.DTYPE).to(device) + torch.manual_seed(0) + mod_bf = te.Linear(self.HIDDEN, self.HIDDEN, bias=True, + params_dtype=self.DTYPE).to(device) + + opt_fp8 = torch.optim.SGD(mod_fp8.parameters(), lr=0.01) + opt_bf = torch.optim.SGD(mod_bf.parameters(), lr=0.01) + + for step in range(10): + torch.manual_seed(step + 1) + x = torch.randn(self.BATCH, self.HIDDEN, device=device, dtype=self.DTYPE) + target = torch.randn(self.BATCH, self.HIDDEN, device=device, dtype=self.DTYPE) + + opt_fp8.zero_grad() + with te.autocast(enabled=True, recipe=fp8_recipe): + y = mod_fp8(x) + ((y.float() - target.float()).pow(2).mean()).backward() + opt_fp8.step() + + opt_bf.zero_grad() + y_bf = mod_bf(x) + ((y_bf.float() - target.float()).pow(2).mean()).backward() + opt_bf.step() + + torch.cuda.synchronize() + cos = F.cosine_similarity( + mod_fp8.weight.detach().float().flatten(), + mod_bf.weight.detach().float().flatten(), + dim=0, + ).item() + assert cos > 0.95, f"FP8 weights diverged from bf16: cos={cos:.4f}" + + +# --------------------------------------------------------------------------- +# FP8 attention flags — lite does not implement FP8 attention kernels. +# These tests document the behavior: fp8_dpa=False/fp8_mha=False works, +# setting either to True raises a clear NotImplementedError (not a cryptic +# AttributeError from a missing enum value). +# --------------------------------------------------------------------------- + +class TestFP8AttentionFlags: + """Lite accepts fp8_dpa/fp8_mha recipe flags but rejects them cleanly.""" + + DTYPE = torch.bfloat16 + HIDDEN = 256 + FFN_HIDDEN = 1024 + + @pytest.fixture(autouse=True) + def _reset_fp8_state(self): + yield + FP8GlobalStateManager.reset() + + def _make_model_and_input(self, device): + mod = te.TransformerLayer( + self.HIDDEN, self.FFN_HIDDEN, num_attention_heads=4, + params_dtype=self.DTYPE, + ).to(device) + x = torch.randn(8, 2, self.HIDDEN, device=device, dtype=self.DTYPE) + return mod, x + + def test_fp8_dpa_raises_not_implemented(self, device): + """fp8_dpa=True must raise a clear NotImplementedError, not AttributeError.""" + if not _RECIPES: + pytest.skip("No FP8 recipes available on this hardware") + mod, x = self._make_model_and_input(device) + r = recipe.DelayedScaling(fp8_dpa=True) + with pytest.raises(NotImplementedError, match="FP8 attention"): + with torch.amp.autocast("cuda", dtype=self.DTYPE): + with te.autocast(enabled=True, recipe=r): + mod(x) + + def test_fp8_mha_raises_not_implemented(self, device): + """fp8_mha=True must raise NotImplementedError (q_type becomes FP8 upstream).""" + if not _RECIPES: + pytest.skip("No FP8 recipes available on this hardware") + mod, x = self._make_model_and_input(device) + r = recipe.DelayedScaling(fp8_mha=True) + with pytest.raises(NotImplementedError, match="FP8 attention"): + with torch.amp.autocast("cuda", dtype=self.DTYPE): + with te.autocast(enabled=True, recipe=r): + mod(x) + + def test_fp8_dpa_and_mha_raises(self, device): + """Both flags set together still raises cleanly.""" + if not _RECIPES: + pytest.skip("No FP8 recipes available on this hardware") + mod, x = self._make_model_and_input(device) + r = recipe.DelayedScaling(fp8_dpa=True, fp8_mha=True) + with pytest.raises(NotImplementedError, match="FP8 attention"): + with torch.amp.autocast("cuda", dtype=self.DTYPE): + with te.autocast(enabled=True, recipe=r): + mod(x) + + def test_default_flags_work(self, device): + """Recipe without the FP8 attention flags (default) must work end-to-end.""" + if not _RECIPES: + pytest.skip("No FP8 recipes available on this hardware") + mod, x = self._make_model_and_input(device) + # Explicitly False to pin the contract + r = recipe.DelayedScaling(fp8_dpa=False, fp8_mha=False) + with torch.amp.autocast("cuda", dtype=self.DTYPE): + with te.autocast(enabled=True, recipe=r): + y = mod(x) + assert y.shape == x.shape + assert torch.isfinite(y).all() + + def test_current_scaling_fp8_dpa_raises(self, device): + """CurrentScaling recipe with fp8_dpa=True also rejected cleanly.""" + if not _RECIPES: + pytest.skip("No FP8 recipes available on this hardware") + mod, x = self._make_model_and_input(device) + r = recipe.Float8CurrentScaling(fp8_dpa=True) + with pytest.raises(NotImplementedError, match="FP8 attention"): + with torch.amp.autocast("cuda", dtype=self.DTYPE): + with te.autocast(enabled=True, recipe=r): + mod(x) + + def test_enum_has_nvte_fp8(self): + """NVTE_FP8 enum value must exist for API compatibility.""" + assert hasattr(tex.NVTE_Fused_Attn_Backend, "NVTE_FP8"), ( + "NVTE_FP8 must exist in the enum for framework compat, even if " + "lite doesn't implement the corresponding backend." + ) + + +# --------------------------------------------------------------------------- +# GroupedLinear tests — MoE-style expert parallelism via a single module +# that dispatches num_gemms weights against an input split by m_splits. +# --------------------------------------------------------------------------- + +class TestGroupedLinear: + """Verify GroupedLinear works end-to-end in lite mode.""" + + DTYPE = torch.bfloat16 + NUM_GEMMS = 4 + IN_FEATURES = 256 + OUT_FEATURES = 256 + M_SPLITS = [8, 8, 8, 8] # total 32 tokens + + @pytest.fixture(autouse=True) + def _reset_fp8_state(self): + yield + FP8GlobalStateManager.reset() + + def _make_module(self, device, bias=True): + return te.GroupedLinear( + num_gemms=self.NUM_GEMMS, + in_features=self.IN_FEATURES, + out_features=self.OUT_FEATURES, + bias=bias, + params_dtype=self.DTYPE, + parallel_mode=None, + ).to(device) + + @pytest.mark.parametrize("bias", [True, False]) + def test_forward_shape(self, device, bias): + """Output shape is (total_tokens, out_features).""" + mod = self._make_module(device, bias=bias) + total = sum(self.M_SPLITS) + x = torch.randn(total, self.IN_FEATURES, device=device, dtype=self.DTYPE) + y = mod(x, self.M_SPLITS) + assert y.shape == (total, self.OUT_FEATURES) + assert torch.isfinite(y).all() + + @pytest.mark.parametrize("bias", [True, False]) + def test_forward_matches_manual(self, device, bias): + """Output matches F.linear per chunk (reference implementation).""" + torch.manual_seed(42) + mod = self._make_module(device, bias=bias) + total = sum(self.M_SPLITS) + x = torch.randn(total, self.IN_FEATURES, device=device, dtype=self.DTYPE) + y = mod(x, self.M_SPLITS) + + # Manual reference: split input, run each chunk through its expert + chunks = torch.split(x, self.M_SPLITS, dim=0) + y_ref_parts = [] + for i, chunk in enumerate(chunks): + w = getattr(mod, f"weight{i}") + b = getattr(mod, f"bias{i}") if bias else None + y_ref_parts.append(F.linear(chunk, w, b)) + y_ref = torch.cat(y_ref_parts, dim=0) + assert torch.allclose(y, y_ref, atol=1e-3, rtol=1e-3), ( + f"GroupedLinear output differs from manual: max_diff=" + f"{(y - y_ref).abs().max().item()}" + ) + + @pytest.mark.parametrize("bias", [True, False]) + def test_backward_grads_finite(self, device, bias): + """Input gradient and all per-expert weight gradients must be finite.""" + mod = self._make_module(device, bias=bias) + total = sum(self.M_SPLITS) + x = torch.randn(total, self.IN_FEATURES, device=device, + dtype=self.DTYPE, requires_grad=True) + y = mod(x, self.M_SPLITS) + y.sum().backward() + torch.cuda.synchronize() + assert x.grad is not None + assert torch.isfinite(x.grad).all() + for i in range(self.NUM_GEMMS): + w_grad = getattr(mod, f"weight{i}").grad + assert w_grad is not None, f"weight{i}.grad is None" + assert torch.isfinite(w_grad).all(), f"weight{i}.grad has NaN/Inf" + if bias: + b_grad = getattr(mod, f"bias{i}").grad + assert b_grad is not None, f"bias{i}.grad is None" + assert torch.isfinite(b_grad).all(), f"bias{i}.grad has NaN/Inf" + + def test_uneven_splits(self, device): + """Non-uniform m_splits should also work (MoE often has imbalanced routing).""" + mod = self._make_module(device, bias=True) + m_splits = [4, 12, 8, 8] # total 32 + total = sum(m_splits) + x = torch.randn(total, self.IN_FEATURES, device=device, + dtype=self.DTYPE, requires_grad=True) + y = mod(x, m_splits) + assert y.shape == (total, self.OUT_FEATURES) + y.sum().backward() + torch.cuda.synchronize() + assert torch.isfinite(x.grad).all() + + @pytest.mark.parametrize("fp8_recipe", _RECIPES_FWD_BWD) + @pytest.mark.xfail( + strict=True, + reason="FP8 GroupedLinear hits dtype mismatch in Triton wrapper " + "(lhs=fp32 vs bias=bf16) — pre-existing issue in " + "triton_kernels/gmm/gmm_common.py, out of scope for lite adapter.", + ) + def test_fp8_forward(self, device, fp8_recipe): + """FP8 GroupedLinear — currently blocked on a Triton GMM bug.""" + mod = self._make_module(device, bias=True) + total = sum(self.M_SPLITS) + x = torch.randn(total, self.IN_FEATURES, device=device, + dtype=self.DTYPE, requires_grad=True) + with te.autocast(enabled=True, recipe=fp8_recipe): + y = mod(x, self.M_SPLITS) + y.sum().backward() + torch.cuda.synchronize() + assert torch.isfinite(y).all() + assert torch.isfinite(x.grad).all() + + +# --------------------------------------------------------------------------- +# Grouped-GEMM dispatcher (lite) — Phase 1 covers BF16 only via +# `general_grouped_gemm_triton`. FP8 operands must fail loudly until the +# Phase 2 fused-MoE path lands; otherwise they would silently misroute +# through the BF16 kernel and either crash or miscompute. +# --------------------------------------------------------------------------- + +class TestGroupedGemmDispatch: + """Verify lite's te_general_grouped_gemm dispatcher gating.""" + + NUM_GEMMS = 2 + IN_FEATURES = 64 + OUT_FEATURES = 64 + M_SPLITS = [4, 4] + + def test_fp8_operands_raise_not_implemented(self, device): + total = sum(self.M_SPLITS) + A = [ + torch.empty(self.OUT_FEATURES, self.IN_FEATURES, + device=device, dtype=torch.float8_e4m3fn) + for _ in range(self.NUM_GEMMS) + ] + B = [ + torch.empty(m, self.IN_FEATURES, + device=device, dtype=torch.float8_e4m3fn) + for m in self.M_SPLITS + ] + out = [ + torch.empty(m, self.OUT_FEATURES, device=device, dtype=torch.bfloat16) + for m in self.M_SPLITS + ] + ws = [torch.empty(1024, device=device, dtype=torch.uint8)] + with pytest.raises(NotImplementedError, match="FP8 grouped GEMM"): + tex.te_general_grouped_gemm( + A, True, B, False, out, torch.bfloat16, self.M_SPLITS, + [], None, False, [], False, ws, ws[0].shape[0], + False, False, 0, + ) + + def test_empty_tokens_short_circuit_forward(self, device): + """MoE routing can produce zero local tokens in early training; the + underlying AITER gmm asserts M > 0, so the lite wrapper must + short-circuit instead of forwarding.""" + m_splits = [0, 0] + A = [ + torch.randn(self.OUT_FEATURES, self.IN_FEATURES, + device=device, dtype=torch.bfloat16) + for _ in range(self.NUM_GEMMS) + ] + # Empty input: (0, in_features) + B = [torch.empty(0, self.IN_FEATURES, device=device, dtype=torch.bfloat16)] + out = [torch.empty(0, self.OUT_FEATURES, device=device, dtype=torch.bfloat16)] + ws = [torch.empty(1024, device=device, dtype=torch.uint8)] + # Forward layout (transa=True, transb=False), grad=False. + # Should return without raising, with the empty out tensor untouched. + tex.te_general_grouped_gemm( + A, True, B, False, out, torch.bfloat16, m_splits, + [], None, True, [], False, ws, ws[0].shape[0], + False, False, 0, + ) + assert out[0].shape == (0, self.OUT_FEATURES) + + def test_empty_tokens_short_circuit_wgrad_zeros_out(self, device): + """Wgrad with zero tokens must zero its (G, K, N) output when + accumulate=False so the caller sees the correct zero contribution.""" + m_splits = [0, 0] + A = [torch.empty(0, self.IN_FEATURES, device=device, dtype=torch.bfloat16)] + B = [torch.empty(0, self.OUT_FEATURES, device=device, dtype=torch.bfloat16)] + # Pre-fill out with garbage so we can verify the zero_() actually fired. + out = [ + torch.full((self.OUT_FEATURES, self.IN_FEATURES), 7.0, + device=device, dtype=torch.bfloat16) + for _ in range(self.NUM_GEMMS) + ] + ws = [torch.empty(1024, device=device, dtype=torch.uint8)] + # Wgrad layout: transa=True (wgrad layout NT in the upstream wrapper + # corresponds to transa=False, transb=True from the C++ binding's view + # — but our lite wrapper checks `transa and not transb and grad`, so + # invoke with (transa=True, transb=False, grad=True) to hit it). + tex.te_general_grouped_gemm( + A, True, B, False, out, torch.bfloat16, m_splits, + [], None, True, [], True, ws, ws[0].shape[0], + False, False, 0, + ) + for o in out: + assert torch.all(o == 0), "wgrad output not zeroed under M=0" + + +# --------------------------------------------------------------------------- +# FSDP2 weight-wrap tests — lite's compound modules must wrap FP8 weights in +# FSDPAGTensor when use_fsdp2=True so FSDP2's all-gather calls +# fsdp_pre_all_gather to quantize at gather time, not at parameter init. +# --------------------------------------------------------------------------- + +class TestFSDP2WeightWrap: + """Verify lite compound modules emit FSDPAGTensor under use_fsdp2=True.""" + + DTYPE = torch.bfloat16 + HIDDEN = 256 + FFN_HIDDEN = 1024 + + @staticmethod + def _has_fsdpag(): + try: + from transformer_engine.pytorch.tensor.fsdp2_allgather_tensor import FSDPAGTensor # noqa + return True + except ImportError: + return False + + def _fsdpag_cls(self): + from transformer_engine.pytorch.tensor.fsdp2_allgather_tensor import FSDPAGTensor + return FSDPAGTensor + + def test_layernorm_linear_no_wrap_by_default(self, device): + """Default (use_fsdp2=False): weight is a plain Parameter, not FSDPAGTensor.""" + if not self._has_fsdpag(): + pytest.skip("FSDPAGTensor unavailable") + mod = te.LayerNormLinear( + self.HIDDEN, self.HIDDEN, params_dtype=self.DTYPE, device=device, + ) + assert not isinstance(mod.weight, self._fsdpag_cls()) + assert isinstance(mod.weight, torch.nn.Parameter) + + def test_layernorm_linear_wraps_with_use_fsdp2(self, device): + """use_fsdp2=True: weight is wrapped in FSDPAGTensor for FSDP2 all-gather.""" + if not self._has_fsdpag(): + pytest.skip("FSDPAGTensor unavailable") + mod = te.LayerNormLinear( + self.HIDDEN, self.HIDDEN, use_fsdp2=True, + params_dtype=self.DTYPE, device=device, + ) + assert isinstance(mod.weight, self._fsdpag_cls()) + # Must still be a Parameter so autograd works + assert isinstance(mod.weight, torch.nn.Parameter) + + def test_layernorm_mlp_wraps_both_weights(self, device): + """use_fsdp2=True: both fc1_weight and fc2_weight are wrapped.""" + if not self._has_fsdpag(): + pytest.skip("FSDPAGTensor unavailable") + mod = te.LayerNormMLP( + self.HIDDEN, self.FFN_HIDDEN, use_fsdp2=True, + params_dtype=self.DTYPE, device=device, + ) + assert isinstance(mod.fc1_weight, self._fsdpag_cls()) + assert isinstance(mod.fc2_weight, self._fsdpag_cls()) + + def test_layernorm_mlp_no_wrap_by_default(self, device): + """Default (use_fsdp2=False): no FSDPAGTensor wrapping.""" + if not self._has_fsdpag(): + pytest.skip("FSDPAGTensor unavailable") + mod = te.LayerNormMLP( + self.HIDDEN, self.FFN_HIDDEN, params_dtype=self.DTYPE, device=device, + ) + assert not isinstance(mod.fc1_weight, self._fsdpag_cls()) + assert not isinstance(mod.fc2_weight, self._fsdpag_cls()) + + def test_forward_bf16_with_wrapped_weights(self, device): + """bf16 forward+backward works with FSDPAGTensor-wrapped weights (the + wrapper's __torch_dispatch__ unwraps for ordinary ops). + FP8 + use_fsdp2 requires an actual FSDP2 wrap to gather properly — + that path is not tested here because it needs ≥2 GPUs + fully_shard. + """ + if not self._has_fsdpag(): + pytest.skip("FSDPAGTensor unavailable") + mod = te.LayerNormLinear( + self.HIDDEN, self.HIDDEN, use_fsdp2=True, + params_dtype=self.DTYPE, device=device, + ) + x = torch.randn(8, self.HIDDEN, device=device, + dtype=self.DTYPE, requires_grad=True) + y = mod(x) + y.sum().backward() + torch.cuda.synchronize() + assert y.shape == x.shape + assert torch.isfinite(y).all() + assert x.grad is not None + assert torch.isfinite(x.grad).all() + + def test_forward_bf16_layernorm_mlp_with_wrap(self, device): + """Same bf16 smoke test for LayerNormMLP with use_fsdp2=True.""" + if not self._has_fsdpag(): + pytest.skip("FSDPAGTensor unavailable") + mod = te.LayerNormMLP( + self.HIDDEN, self.FFN_HIDDEN, activation="swiglu", + use_fsdp2=True, params_dtype=self.DTYPE, device=device, + ) + x = torch.randn(8, self.HIDDEN, device=device, + dtype=self.DTYPE, requires_grad=True) + y = mod(x) + y.sum().backward() + torch.cuda.synchronize() + assert y.shape == x.shape + assert torch.isfinite(y).all() + assert torch.isfinite(x.grad).all() + + def test_non_hip_silently_ignores_flag(self, device): + """On non-ROCm builds the flag is forced to False (matches full build). + On ROCm this is always True; we just verify the attribute is accessible.""" + mod = te.LayerNormLinear( + self.HIDDEN, self.HIDDEN, use_fsdp2=True, + params_dtype=self.DTYPE, device=device, + ) + assert hasattr(mod, "use_fsdp2") + # On ROCm (lite's target platform), use_fsdp2 should pass through. + from torch.utils.cpp_extension import IS_HIP_EXTENSION + if IS_HIP_EXTENSION: + assert mod.use_fsdp2 is True + else: + assert mod.use_fsdp2 is False + + def test_fsdpag_wraps_parameter_preserve_grad(self, device): + """After wrap, gradients still flow to the underlying _data tensor.""" + if not self._has_fsdpag(): + pytest.skip("FSDPAGTensor unavailable") + mod = te.LayerNormLinear( + self.HIDDEN, self.HIDDEN, use_fsdp2=True, + params_dtype=self.DTYPE, device=device, + ) + assert mod.weight.requires_grad + x = torch.randn(8, self.HIDDEN, device=device, + dtype=self.DTYPE) + mod(x).sum().backward() + torch.cuda.synchronize() + assert mod.weight.grad is not None + assert torch.isfinite(mod.weight.grad).all() + + +# --------------------------------------------------------------------------- +# Multi-tensor kernels +# --------------------------------------------------------------------------- + +class TestMultiTensor: + """Cover the lite replacements for transformer_engine_torch multi-tensor ops. + + The reference semantics come from common/multi_tensor/{scale,l2norm,adam,sgd}.cu + and csrc/extensions/multi_tensor/*.cpp. + """ + + CHUNK = 2048 * 32 + + @staticmethod + def _overflow_buf(device): + return torch.zeros(1, dtype=torch.int32, device=device) + + # --- multi_tensor_scale ------------------------------------------------- + + @pytest.mark.parametrize("in_dtype,out_dtype", [ + (torch.float32, torch.float32), + (torch.float32, torch.bfloat16), + (torch.float16, torch.float32), + (torch.bfloat16, torch.bfloat16), + ]) + def test_scale_writes_out_list(self, device, in_dtype, out_dtype): + overflow = self._overflow_buf(device) + a = torch.full([777], 4.0, dtype=in_dtype, device=device) + b = torch.full([555], 4.0, dtype=in_dtype, device=device) + in_list = [a.clone(), b.clone()] + out_list = [torch.empty_like(t, dtype=out_dtype) for t in in_list] + + tex.multi_tensor_scale(self.CHUNK, overflow, [in_list, out_list], 0.25) + + expected = torch.full_like(out_list[0], 1.0) + torch.testing.assert_close(out_list[0], expected) + torch.testing.assert_close(out_list[1], torch.full_like(out_list[1], 1.0)) + # in_list should not be modified + torch.testing.assert_close(in_list[0], torch.full_like(in_list[0], 4.0)) + assert overflow.item() == 0 + + def test_scale_sets_overflow_on_nan(self, device): + overflow = self._overflow_buf(device) + a = torch.full([64], 4.0, dtype=torch.float32, device=device) + a[3] = float("nan") + out = torch.empty_like(a) + tex.multi_tensor_scale(self.CHUNK, overflow, [[a], [out]], 0.5) + assert overflow.item() == 1 + + def test_scale_sets_overflow_on_inf(self, device): + overflow = self._overflow_buf(device) + a = torch.full([64], 4.0, dtype=torch.float32, device=device) + a[7] = float("inf") + out = torch.empty_like(a) + tex.multi_tensor_scale(self.CHUNK, overflow, [[a], [out]], 0.5) + assert overflow.item() == 1 + + # --- multi_tensor_l2norm ------------------------------------------------ + + def test_l2norm_returns_2_tuple_when_per_tensor_false(self, device): + """Megatron's clip_grad_norm unpacks (norm, _) unconditionally.""" + overflow = self._overflow_buf(device) + a = torch.full([100], 3.0, dtype=torch.float32, device=device) + b = torch.full([100], 4.0, dtype=torch.float32, device=device) + result = tex.multi_tensor_l2norm(self.CHUNK, overflow, [[a, b]], False) + assert isinstance(result, tuple) and len(result) == 2 + norm, per_tensor = result + # sqrt(100*9 + 100*16) = sqrt(2500) = 50 + torch.testing.assert_close(norm, torch.tensor([50.0], device=device)) + assert per_tensor.numel() == 0 + + def test_l2norm_per_tensor_values(self, device): + overflow = self._overflow_buf(device) + a = torch.full([100], 3.0, dtype=torch.float32, device=device) + b = torch.full([100], 4.0, dtype=torch.float32, device=device) + norm, per_tensor = tex.multi_tensor_l2norm( + self.CHUNK, overflow, [[a, b]], True + ) + # norms: sqrt(100*9)=30, sqrt(100*16)=40; total=50 + torch.testing.assert_close(norm, torch.tensor([50.0], device=device)) + torch.testing.assert_close(per_tensor, torch.tensor([30.0, 40.0], device=device)) + + def test_l2norm_mixed_dtypes(self, device): + """l2norm must tolerate fp16/bf16 inputs (upcasts to fp32 internally).""" + overflow = self._overflow_buf(device) + a = torch.full([64], 3.0, dtype=torch.bfloat16, device=device) + b = torch.full([64], 4.0, dtype=torch.float16, device=device) + norm, _ = tex.multi_tensor_l2norm(self.CHUNK, overflow, [[a, b]], False) + expected = (64 * 9 + 64 * 16) ** 0.5 # sqrt(1600) = 40 + torch.testing.assert_close(norm, torch.tensor([expected], device=device), + atol=1e-4, rtol=1e-4) + + # --- multi_tensor_unscale_l2norm --------------------------------------- + + def test_unscale_l2norm_returns_2_tuple(self, device): + overflow = self._overflow_buf(device) + a = torch.full([100], 6.0, dtype=torch.float32, device=device) + b = torch.full([100], 8.0, dtype=torch.float32, device=device) + inv_scale = torch.tensor([2.0], device=device) # scale = 0.5 + result = tex.multi_tensor_unscale_l2norm( + self.CHUNK, overflow, [[a, b]], inv_scale, False + ) + assert isinstance(result, tuple) and len(result) == 2 + norm, per_tensor = result + # After unscaling (multiply by 0.5): norms 3 and 4; total 50 over 100 each + # sqrt(100*9 + 100*16) = 50 + torch.testing.assert_close(norm, torch.tensor([50.0], device=device)) + assert per_tensor.numel() == 0 + + def test_unscale_l2norm_per_tensor(self, device): + overflow = self._overflow_buf(device) + a = torch.full([100], 6.0, dtype=torch.float32, device=device) + b = torch.full([100], 8.0, dtype=torch.float32, device=device) + inv_scale = torch.tensor([2.0], device=device) + norm, per_tensor = tex.multi_tensor_unscale_l2norm( + self.CHUNK, overflow, [[a, b]], inv_scale, True + ) + torch.testing.assert_close(norm, torch.tensor([50.0], device=device)) + torch.testing.assert_close(per_tensor, torch.tensor([30.0, 40.0], device=device)) + + # --- multi_tensor_adam -------------------------------------------------- + + @staticmethod + def _adam_reference(p, g, m, v, lr, b1, b2, eps, step, wd, adam_w): + """Reference implementation mirroring common/multi_tensor/adam.cu.""" + p_f = p.float() + g_f = g.float() + if not adam_w and wd != 0.0: + g_f = g_f + wd * p_f + m_new = b1 * m + (1 - b1) * g_f + v_new = b2 * v + (1 - b2) * g_f * g_f + bc1 = 1 - b1 ** step + bc2 = 1 - b2 ** step + denom = (v_new / bc2).sqrt() + eps + update = (m_new / bc1) / denom + if adam_w and wd != 0.0: + update = update + wd * p_f + p_new = p_f - lr * update + return p_new, m_new, v_new + + @pytest.mark.parametrize("adam_w_mode", [True, False]) + @pytest.mark.parametrize("weight_decay", [0.0, 0.01]) + def test_adam_4list_no_master(self, device, adam_w_mode, weight_decay): + overflow = self._overflow_buf(device) + torch.manual_seed(0) + p = torch.randn(256, device=device, dtype=torch.float32) + g = torch.randn(256, device=device, dtype=torch.float32) * 0.01 + m = torch.zeros_like(p) + v = torch.zeros_like(p) + p0, g0 = p.clone(), g.clone() + + tex.multi_tensor_adam( + self.CHUNK, overflow, [[g], [p], [m], [v]], + 1e-3, 0.9, 0.999, 1e-8, 1, adam_w_mode, True, weight_decay, + ) + + p_ref, m_ref, v_ref = self._adam_reference( + p0, g0, torch.zeros_like(p0), torch.zeros_like(p0), + 1e-3, 0.9, 0.999, 1e-8, 1, weight_decay, adam_w_mode, + ) + torch.testing.assert_close(p, p_ref, atol=1e-6, rtol=1e-5) + torch.testing.assert_close(m, m_ref, atol=1e-6, rtol=1e-5) + torch.testing.assert_close(v, v_ref, atol=1e-6, rtol=1e-5) + + def test_adam_5list_master_bf16(self, device): + """Megatron's master-weights path: bf16 params, fp32 master + m + v.""" + overflow = self._overflow_buf(device) + torch.manual_seed(0) + pm = torch.randn(256, device=device, dtype=torch.float32) + p = pm.to(torch.bfloat16) + g = (torch.randn(256, device=device) * 0.01).to(torch.bfloat16) + m = torch.zeros(256, device=device, dtype=torch.float32) + v = torch.zeros(256, device=device, dtype=torch.float32) + pm0 = pm.clone() + + tex.multi_tensor_adam( + self.CHUNK, overflow, [[g], [p], [m], [v], [pm]], + 1e-3, 0.9, 0.999, 1e-8, 1, True, True, 0.0, + ) + + p_ref, _, _ = self._adam_reference( + pm0, g.float(), torch.zeros_like(pm0), torch.zeros_like(pm0), + 1e-3, 0.9, 0.999, 1e-8, 1, 0.0, True, + ) + # Master is kept in fp32 + torch.testing.assert_close(pm, p_ref, atol=1e-6, rtol=1e-5) + # bf16 shadow is a downcast of master + torch.testing.assert_close(p, p_ref.to(torch.bfloat16), atol=1e-3, rtol=1e-2) + + def test_adam_no_bias_correction(self, device): + overflow = self._overflow_buf(device) + torch.manual_seed(0) + p = torch.randn(64, device=device) + g = torch.randn(64, device=device) * 0.01 + m = torch.zeros_like(p) + v = torch.zeros_like(p) + p0, g0 = p.clone(), g.clone() + + tex.multi_tensor_adam( + self.CHUNK, overflow, [[g], [p], [m], [v]], + 1e-3, 0.9, 0.999, 1e-8, 5, True, False, 0.0, + ) + + # bias_correction=False → bc1=bc2=1 + m_ref = 0.9 * torch.zeros_like(p0) + 0.1 * g0 + v_ref = 0.999 * torch.zeros_like(p0) + 0.001 * g0 * g0 + denom = v_ref.sqrt() + 1e-8 + p_ref = p0 - 1e-3 * (m_ref / denom) + torch.testing.assert_close(p, p_ref, atol=1e-6, rtol=1e-5) + + def test_adam_oversized_state_tensors(self, device): + """Megatron's distributed optimizer passes m/v sized for the full + vocab while the gradient is only this rank's TP shard. The C++ kernel + uses `g.numel()` as the work size, and lite must match that. Only the + first g.numel() elements of m/v/p should be modified. + """ + overflow = self._overflow_buf(device) + torch.manual_seed(0) + # g/p are the TP shard; m/v are the full parameter (8x larger). + shard_rows, full_rows, cols = 64, 512, 32 + g = torch.randn(shard_rows, cols, device=device, dtype=torch.bfloat16) * 0.01 + p = torch.randn(shard_rows, cols, device=device, dtype=torch.bfloat16) + m = torch.randn(full_rows, cols, device=device, dtype=torch.float32) * 0.001 + v = torch.randn(full_rows, cols, device=device, dtype=torch.float32).abs() * 0.001 + + # Snapshot the out-of-shard region to confirm it's untouched. + m_tail = m[shard_rows:].clone() + v_tail = v[shard_rows:].clone() + p0 = p.clone() + m_head0 = m[:shard_rows].clone() + v_head0 = v[:shard_rows].clone() + + tex.multi_tensor_adam( + self.CHUNK, overflow, [[g], [p], [m], [v]], + 1e-3, 0.9, 0.999, 1e-8, 1, True, True, 0.0, + ) + + # Out-of-shard region untouched. + torch.testing.assert_close(m[shard_rows:], m_tail) + torch.testing.assert_close(v[shard_rows:], v_tail) + + # In-shard region updated per Adam math on fp32-upcast grads. + g_f = g.float() + m_ref = 0.9 * m_head0 + 0.1 * g_f + v_ref = 0.999 * v_head0 + 0.001 * g_f * g_f + torch.testing.assert_close(m[:shard_rows], m_ref, atol=1e-5, rtol=1e-4) + torch.testing.assert_close(v[:shard_rows], v_ref, atol=1e-5, rtol=1e-4) + + def test_adam_param_remainder_not_implemented(self, device): + with pytest.raises(NotImplementedError): + tex.multi_tensor_adam_param_remainder() + + def test_adam_fp8_not_implemented(self, device): + with pytest.raises(NotImplementedError): + tex.multi_tensor_adam_fp8() + + def test_adam_capturable_not_implemented(self, device): + with pytest.raises(NotImplementedError): + tex.multi_tensor_adam_capturable() + + # --- multi_tensor_sgd --------------------------------------------------- + + def test_sgd_no_momentum(self, device): + overflow = self._overflow_buf(device) + torch.manual_seed(0) + w = torch.randn(64, device=device, dtype=torch.float32) + g = torch.randn(64, device=device, dtype=torch.float32) * 0.01 + mom = torch.zeros_like(w) + w0, g0 = w.clone(), g.clone() + + tex.multi_tensor_sgd( + self.CHUNK, overflow, [[g], [w], [mom]], + 1e-2, # lr + 0.0, # momentum + 0.0, # dampening + 0.0, # weight_decay + False, # nesterov + True, # first_run + False, # wd_after_momentum + 1.0, # scale + ) + torch.testing.assert_close(w, w0 - 1e-2 * g0, atol=1e-6, rtol=1e-5) + + def test_sgd_with_momentum_first_run(self, device): + overflow = self._overflow_buf(device) + torch.manual_seed(0) + w = torch.randn(64, device=device, dtype=torch.float32) + g = torch.randn(64, device=device, dtype=torch.float32) * 0.01 + mom = torch.zeros_like(w) + w0, g0 = w.clone(), g.clone() + + tex.multi_tensor_sgd( + self.CHUNK, overflow, [[g], [w], [mom]], + 1e-2, 0.9, 0.0, 0.0, False, True, False, 1.0, + ) + # first_run: mom = g; g_eff = mom = g; w -= lr*g + torch.testing.assert_close(mom, g0, atol=1e-6, rtol=1e-5) + torch.testing.assert_close(w, w0 - 1e-2 * g0, atol=1e-6, rtol=1e-5) + + def test_sgd_weight_decay_before_momentum(self, device): + overflow = self._overflow_buf(device) + torch.manual_seed(0) + w = torch.randn(32, device=device, dtype=torch.float32) + g = torch.randn(32, device=device, dtype=torch.float32) * 0.01 + mom = torch.zeros_like(w) + w0, g0 = w.clone(), g.clone() + + tex.multi_tensor_sgd( + self.CHUNK, overflow, [[g], [w], [mom]], + 1e-2, 0.0, 0.0, 0.1, False, True, False, 1.0, + ) + # wd before momentum: g_eff = g + 0.1*w; w -= lr*g_eff + g_eff = g0 + 0.1 * w0 + torch.testing.assert_close(w, w0 - 1e-2 * g_eff, atol=1e-6, rtol=1e-5) + + def test_sgd_4list_writes_fp16_copy(self, device): + overflow = self._overflow_buf(device) + torch.manual_seed(0) + w = torch.randn(32, device=device, dtype=torch.float32) + g = torch.randn(32, device=device, dtype=torch.float32) * 0.01 + mom = torch.zeros_like(w) + w_fp16 = torch.zeros(32, device=device, dtype=torch.float16) + w0, g0 = w.clone(), g.clone() + + tex.multi_tensor_sgd( + self.CHUNK, overflow, [[g], [w], [mom], [w_fp16]], + 1e-2, 0.0, 0.0, 0.0, False, True, False, 1.0, + ) + expected = w0 - 1e-2 * g0 + torch.testing.assert_close(w, expected, atol=1e-6, rtol=1e-5) + torch.testing.assert_close(w_fp16, expected.to(torch.float16), + atol=1e-3, rtol=1e-2) diff --git a/tests/pytorch/test_lite_mori_ep.py b/tests/pytorch/test_lite_mori_ep.py new file mode 100644 index 000000000..d91bdd528 --- /dev/null +++ b/tests/pytorch/test_lite_mori_ep.py @@ -0,0 +1,1472 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for MORI Expert Parallelism integration in tealite. + +This test suite validates the MORI EP integration at two levels: + +1. **Unit tests** (no GPU / no MORI required): Test config validation, API + surface, initialization guards, and error handling using mocks. + +2. **Multi-GPU integration tests** (require MORI + multiple AMD GPUs): Test + actual dispatch/combine with real MORI kernels. Skipped automatically when + MORI is not installed or insufficient GPUs are available. + +Run unit tests (always works): + pytest tests/pytorch/test_lite_mori_ep.py -v -k "unit" + +Run integration tests (requires MORI + multi-GPU): + pytest tests/pytorch/test_lite_mori_ep.py -v -k "integration" +""" + +import os +import sys +import pytest +from unittest import mock +from unittest.mock import MagicMock, patch + +import torch + +os.environ["NVTE_LITE"] = "1" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _mori_installed(): + try: + import mori # noqa: F401 + return True + except ImportError: + return False + + +def _gpu_count(): + if not torch.cuda.is_available(): + return 0 + return torch.cuda.device_count() + + +skip_no_mori = pytest.mark.skipif( + not _mori_installed(), reason="MORI not installed" +) +skip_insufficient_gpus = pytest.mark.skipif( + _gpu_count() < 2, reason=f"Need >=2 GPUs, found {_gpu_count()}" +) + + +# --------------------------------------------------------------------------- +# Unit Tests -- no MORI or GPU required +# --------------------------------------------------------------------------- + +class TestMoriEPAvailability: + """Test availability detection and import guards.""" + + def test_mori_ep_available_returns_bool(self): + from transformer_engine.pytorch._lite.mori_ep import mori_ep_available + result = mori_ep_available() + assert isinstance(result, bool) + + def test_mori_ep_available_reflects_import(self): + """mori_ep_available() should return True iff mori is importable.""" + from transformer_engine.pytorch._lite.mori_ep import mori_ep_available + expected = _mori_installed() + assert mori_ep_available() == expected + + +class TestMoriEPInitGuards: + """Test that init_mori_ep enforces prerequisites.""" + + def test_init_without_mori_raises(self): + """init_mori_ep raises RuntimeError when MORI is not installed.""" + import transformer_engine.pytorch._lite.mori_ep as mod + + with mock.patch.object(mod, "_mori_available", False), \ + mock.patch.object(mod, "_mori", None): + with pytest.raises(RuntimeError, match="MORI is not installed"): + # Reset cached state so it re-checks + orig = mod._mori_available + mod._mori_available = False + try: + mod.init_mori_ep() + finally: + mod._mori_available = orig + + def test_init_without_dist_raises(self): + """init_mori_ep raises RuntimeError if torch.distributed not initialized.""" + import transformer_engine.pytorch._lite.mori_ep as mod + + mock_mori = MagicMock() + with mock.patch.object(mod, "_mori_available", True), \ + mock.patch.object(mod, "_mori", mock_mori), \ + mock.patch.object(mod, "_mori_shmem_initialized", False), \ + mock.patch("torch.distributed.is_initialized", return_value=False): + with pytest.raises(RuntimeError, match="torch.distributed must be initialized"): + mod.init_mori_ep() + + def test_init_idempotent(self): + """Calling init_mori_ep when already initialized is a no-op.""" + import transformer_engine.pytorch._lite.mori_ep as mod + + with mock.patch.object(mod, "_mori_shmem_initialized", True): + # Should not raise or call anything + mod.init_mori_ep() + + def test_finalize_when_not_initialized_is_noop(self): + """finalize_mori_ep when not initialized should be a no-op.""" + import transformer_engine.pytorch._lite.mori_ep as mod + + with mock.patch.object(mod, "_mori_shmem_initialized", False): + mod.finalize_mori_ep() # should not raise + + +class TestMoriExpertParallelConfig: + """Test MoriExpertParallel config validation (mocked MORI).""" + + def _make_ep(self, **kwargs): + """Create a MoriExpertParallel with mocked MORI backend.""" + import transformer_engine.pytorch._lite.mori_ep as mod + + mock_mori = MagicMock() + # Mock the kernel type enum + mock_kt = MagicMock() + for name in ["IntraNode", "InterNode", "InterNodeV1", "InterNodeV1LL", "AsyncLL"]: + setattr(mock_kt, name, name) + mock_mori.ops.EpDispatchCombineKernelType = mock_kt + mock_mori.ops.EpDispatchCombineConfig = MagicMock() + mock_mori.ops.EpDispatchCombineOp = MagicMock() + + defaults = dict( + rank=0, + world_size=8, + hidden_dim=7168, + num_experts_per_rank=32, + num_experts_per_token=8, + max_num_inp_token_per_rank=4096, + ) + defaults.update(kwargs) + + with mock.patch.object(mod, "_mori_available", True), \ + mock.patch.object(mod, "_mori", mock_mori), \ + mock.patch.object(mod, "_mori_shmem_initialized", True): + ep = mod.MoriExpertParallel(**defaults) + + return ep, mock_mori + + def test_default_construction(self): + ep, mock_mori = self._make_ep() + assert ep.rank == 0 + assert ep.world_size == 8 + assert ep.hidden_dim == 7168 + assert ep.num_experts_per_rank == 32 + assert ep.num_experts_per_token == 8 + assert ep.num_experts == 256 # 32 * 8 + + # Verify MORI config was created with correct params + mock_mori.ops.EpDispatchCombineConfig.assert_called_once() + call_kwargs = mock_mori.ops.EpDispatchCombineConfig.call_args + assert call_kwargs.kwargs["rank"] == 0 + assert call_kwargs.kwargs["world_size"] == 8 + assert call_kwargs.kwargs["hidden_dim"] == 7168 + + def test_invalid_kernel_type_raises(self): + import transformer_engine.pytorch._lite.mori_ep as mod + + mock_mori = MagicMock() + with mock.patch.object(mod, "_mori_available", True), \ + mock.patch.object(mod, "_mori", mock_mori), \ + mock.patch.object(mod, "_mori_shmem_initialized", True): + with pytest.raises(ValueError, match="Unknown kernel_type"): + mod.MoriExpertParallel( + rank=0, world_size=8, hidden_dim=128, + num_experts_per_rank=4, num_experts_per_token=2, + max_num_inp_token_per_rank=64, + kernel_type="nonexistent", + ) + + @pytest.mark.parametrize("kernel_type,expected", [ + ("intra_node", "IntraNode"), + ("inter_node", "InterNode"), + ("inter_node_v1", "InterNodeV1"), + ("inter_node_v1_ll", "InterNodeV1LL"), + ("async_ll", "AsyncLL"), + ]) + def test_kernel_type_mapping(self, kernel_type, expected): + ep, mock_mori = self._make_ep(kernel_type=kernel_type) + call_kwargs = mock_mori.ops.EpDispatchCombineConfig.call_args.kwargs + assert call_kwargs["kernel_type"] == expected + + def test_not_initialized_raises(self): + """Creating MoriExpertParallel without init raises RuntimeError.""" + import transformer_engine.pytorch._lite.mori_ep as mod + + mock_mori = MagicMock() + with mock.patch.object(mod, "_mori_available", True), \ + mock.patch.object(mod, "_mori", mock_mori), \ + mock.patch.object(mod, "_mori_shmem_initialized", False): + with pytest.raises(RuntimeError, match="MORI shmem not initialized"): + mod.MoriExpertParallel( + rank=0, world_size=8, hidden_dim=128, + num_experts_per_rank=4, num_experts_per_token=2, + max_num_inp_token_per_rank=64, + ) + + def test_fp8_direct_cast_sets_external_buf(self): + ep, mock_mori = self._make_ep(quant_type="fp8_direct_cast") + call_kwargs = mock_mori.ops.EpDispatchCombineConfig.call_args.kwargs + assert call_kwargs["use_external_inp_buf"] is True + + def test_no_quant_unsets_external_buf(self): + ep, mock_mori = self._make_ep(quant_type="none") + call_kwargs = mock_mori.ops.EpDispatchCombineConfig.call_args.kwargs + assert call_kwargs["use_external_inp_buf"] is False + + +class TestMoriExpertParallelDispatchCombine: + """Test dispatch/combine API surface with mocked MORI backend.""" + + def _make_ep_with_mock_op(self): + """Create EP with a fully mocked dispatch/combine operator.""" + import transformer_engine.pytorch._lite.mori_ep as mod + + mock_mori = MagicMock() + mock_kt = MagicMock() + mock_kt.IntraNode = "IntraNode" + mock_mori.ops.EpDispatchCombineKernelType = mock_kt + mock_mori.ops.EpDispatchCombineConfig = MagicMock() + + # Mock the op returned by EpDispatchCombineOp() + mock_op = MagicMock() + mock_mori.ops.EpDispatchCombineOp.return_value = mock_op + + with mock.patch.object(mod, "_mori_available", True), \ + mock.patch.object(mod, "_mori", mock_mori), \ + mock.patch.object(mod, "_mori_shmem_initialized", True): + ep = mod.MoriExpertParallel( + rank=0, world_size=2, + hidden_dim=128, + num_experts_per_rank=4, + num_experts_per_token=2, + max_num_inp_token_per_rank=32, + ) + + return ep, mock_op + + @patch("torch.cuda.synchronize") + def test_dispatch_calls_mori_op(self, mock_sync): + ep, mock_op = self._make_ep_with_mock_op() + + # Set up mock return for dispatch + num_recv = 10 + hidden_dim = 128 + topk = 2 + mock_op.dispatch.return_value = ( + torch.zeros(num_recv, hidden_dim), # out tokens + torch.zeros(num_recv, topk), # out weights + None, # out scales + torch.zeros(num_recv, topk, dtype=torch.int32), # out indices + torch.tensor([num_recv], dtype=torch.int32), # total_recv + ) + + input_t = torch.randn(8, hidden_dim) + weights = torch.rand(8, topk) + indices = torch.randint(0, 8, (8, topk), dtype=torch.int32) + + recv_tokens, recv_weights, recv_indices, n_recv = ep.dispatch( + input_t, weights, indices, + ) + + mock_op.dispatch.assert_called_once() + assert n_recv == num_recv + assert recv_tokens.shape == (num_recv, hidden_dim) + mock_sync.assert_called() + + @patch("torch.cuda.synchronize") + def test_combine_calls_mori_op(self, mock_sync): + ep, mock_op = self._make_ep_with_mock_op() + + hidden_dim = 128 + topk = 2 + max_tokens = 32 + mock_op.combine.return_value = ( + torch.zeros(max_tokens, hidden_dim), # output + torch.zeros(max_tokens, topk), # output_weights + ) + + expert_out = torch.randn(10, hidden_dim) + weights = torch.rand(10, topk) + indices = torch.randint(0, 8, (10, topk), dtype=torch.int32) + + output, output_weights = ep.combine(expert_out, weights, indices) + + mock_op.combine.assert_called_once() + assert output.shape == (max_tokens, hidden_dim) + mock_sync.assert_called() + + def test_reset_calls_mori_op(self): + ep, mock_op = self._make_ep_with_mock_op() + ep.reset() + mock_op.reset.assert_called_once() + + @patch("torch.cuda.synchronize") + def test_dispatch_casts_int64_indices_to_int32(self, mock_sync): + ep, mock_op = self._make_ep_with_mock_op() + + hidden_dim = 128 + topk = 2 + mock_op.dispatch.return_value = ( + torch.zeros(5, hidden_dim), + torch.zeros(5, topk), + None, + torch.zeros(5, topk, dtype=torch.int32), + torch.tensor([5], dtype=torch.int32), + ) + + input_t = torch.randn(8, hidden_dim) + weights = torch.rand(8, topk) + indices_int64 = torch.randint(0, 8, (8, topk), dtype=torch.int64) + + ep.dispatch(input_t, weights, indices_int64) + + # Check that the indices passed to MORI are int32 + call_args = mock_op.dispatch.call_args + actual_indices = call_args.args[3] # 4th positional arg + assert actual_indices.dtype == torch.int32 + + @patch("torch.cuda.synchronize") + def test_dispatch_and_combine_full_cycle(self, mock_sync): + """Test the convenience dispatch_and_combine method.""" + ep, mock_op = self._make_ep_with_mock_op() + + hidden_dim = 128 + topk = 2 + num_recv = 6 + + mock_op.dispatch.return_value = ( + torch.randn(num_recv, hidden_dim), + torch.rand(num_recv, topk), + None, + torch.randint(0, 8, (num_recv, topk), dtype=torch.int32), + torch.tensor([num_recv], dtype=torch.int32), + ) + mock_op.combine.return_value = ( + torch.randn(32, hidden_dim), + torch.rand(32, topk), + ) + + # Expert function: identity + def expert_fn(tokens, indices, n): + return tokens + + input_t = torch.randn(8, hidden_dim) + weights = torch.rand(8, topk) + indices = torch.randint(0, 8, (8, topk), dtype=torch.int32) + + output, output_weights = ep.dispatch_and_combine( + input_t, weights, indices, expert_fn, + ) + + mock_op.dispatch.assert_called_once() + mock_op.combine.assert_called_once() + mock_op.reset.assert_called_once() + assert output.shape[1] == hidden_dim + + +class TestMaskToIndex: + """Test mask-map ↔ index-map routing conversion.""" + + def test_basic_conversion(self): + from transformer_engine.pytorch._lite.mori_ep import mask_to_index + + # token 0 → experts 1,3; token 1 → experts 0,2 + mask = torch.tensor([[0, 1, 0, 1], + [1, 0, 1, 0]], dtype=torch.int32) + probs = torch.tensor([[0.0, 0.3, 0.0, 0.7], + [0.5, 0.0, 0.5, 0.0]], dtype=torch.float32) + + indices, weights = mask_to_index(mask, probs) + + assert indices.shape == (2, 2) + assert indices.dtype == torch.int32 + assert weights.shape == (2, 2) + + # nonzero produces sorted columns, so expert order is ascending + assert indices[0].tolist() == [1, 3] + assert indices[1].tolist() == [0, 2] + assert torch.allclose(weights[0], torch.tensor([0.3, 0.7])) + assert torch.allclose(weights[1], torch.tensor([0.5, 0.5])) + + def test_no_probs_gives_uniform_weights(self): + from transformer_engine.pytorch._lite.mori_ep import mask_to_index + + mask = torch.tensor([[1, 0, 1], + [0, 1, 1]], dtype=torch.int32) + indices, weights = mask_to_index(mask, probs=None) + + assert indices[0].tolist() == [0, 2] + assert indices[1].tolist() == [1, 2] + assert torch.all(weights == 1.0) + + def test_single_expert_per_token(self): + from transformer_engine.pytorch._lite.mori_ep import mask_to_index + + mask = torch.tensor([[0, 0, 1], + [1, 0, 0], + [0, 1, 0]], dtype=torch.int32) + indices, weights = mask_to_index(mask) + + assert indices.shape == (3, 1) + assert indices.flatten().tolist() == [2, 0, 1] + + def test_empty_input(self): + from transformer_engine.pytorch._lite.mori_ep import mask_to_index + + mask = torch.zeros(0, 4, dtype=torch.int32) + indices, weights = mask_to_index(mask) + assert indices.shape[0] == 0 + assert weights.shape[0] == 0 + + @pytest.mark.parametrize("topk", [1, 2, 4, 8]) + def test_various_topk(self, topk): + from transformer_engine.pytorch._lite.mori_ep import mask_to_index + + num_tokens, num_experts = 16, 32 + mask = torch.zeros(num_tokens, num_experts, dtype=torch.int32) + for i in range(num_tokens): + chosen = torch.randperm(num_experts)[:topk].sort().values + mask[i, chosen] = 1 + + indices, weights = mask_to_index(mask) + assert indices.shape == (num_tokens, topk) + # Each row's experts should match the mask + for i in range(num_tokens): + expected = mask[i].nonzero(as_tuple=False).flatten().to(torch.int32) + assert torch.equal(indices[i], expected) + + +class TestIndexToMask: + """Test index-map → mask-map conversion.""" + + def test_basic_conversion(self): + from transformer_engine.pytorch._lite.mori_ep import index_to_mask + + indices = torch.tensor([[1, 3], [0, 2]], dtype=torch.int32) + weights = torch.tensor([[0.3, 0.7], [0.5, 0.5]], dtype=torch.float32) + + mask, probs = index_to_mask(indices, num_experts=4, weights=weights) + + assert mask.shape == (2, 4) + assert mask.dtype == torch.int32 + assert mask[0].tolist() == [0, 1, 0, 1] + assert mask[1].tolist() == [1, 0, 1, 0] + assert probs is not None + assert torch.allclose(probs[0], torch.tensor([0.0, 0.3, 0.0, 0.7])) + assert torch.allclose(probs[1], torch.tensor([0.5, 0.0, 0.5, 0.0])) + + def test_no_weights(self): + from transformer_engine.pytorch._lite.mori_ep import index_to_mask + + indices = torch.tensor([[2], [0]], dtype=torch.int32) + mask, probs = index_to_mask(indices, num_experts=3) + + assert mask[0].tolist() == [0, 0, 1] + assert mask[1].tolist() == [1, 0, 0] + assert probs is None + + +class TestRoundtrip: + """Test mask→index→mask and index→mask→index round-trips.""" + + def test_mask_index_mask(self): + from transformer_engine.pytorch._lite.mori_ep import mask_to_index, index_to_mask + + num_tokens, num_experts, topk = 8, 16, 3 + mask_orig = torch.zeros(num_tokens, num_experts, dtype=torch.int32) + probs_orig = torch.zeros(num_tokens, num_experts, dtype=torch.float32) + for i in range(num_tokens): + chosen = torch.randperm(num_experts)[:topk].sort().values + mask_orig[i, chosen] = 1 + probs_orig[i, chosen] = torch.rand(topk) + + indices, weights = mask_to_index(mask_orig, probs_orig) + mask_rt, probs_rt = index_to_mask(indices, num_experts, weights) + + assert torch.equal(mask_orig, mask_rt) + assert torch.allclose(probs_orig, probs_rt) + + def test_index_mask_index(self): + from transformer_engine.pytorch._lite.mori_ep import mask_to_index, index_to_mask + + num_tokens, num_experts, topk = 8, 16, 2 + # Generate sorted indices (mask_to_index produces sorted output) + indices_orig = torch.stack([ + torch.randperm(num_experts)[:topk].sort().values + for _ in range(num_tokens) + ]).to(torch.int32) + weights_orig = torch.rand(num_tokens, topk) + + mask, probs = index_to_mask(indices_orig, num_experts, weights_orig) + indices_rt, weights_rt = mask_to_index(mask, probs) + + assert torch.equal(indices_orig, indices_rt) + assert torch.allclose(weights_orig, weights_rt) + + +class TestExportedSymbols: + """Verify symbols are exported from _lite/__init__.py.""" + + def test_mori_ep_symbols_exported(self): + from transformer_engine.pytorch._lite import ( + mori_ep_available, + init_mori_ep, + finalize_mori_ep, + is_mori_ep_initialized, + MoriExpertParallel, + ) + assert callable(mori_ep_available) + assert callable(init_mori_ep) + assert callable(finalize_mori_ep) + assert callable(is_mori_ep_initialized) + assert callable(MoriExpertParallel) + + def test_autograd_symbols_exported(self): + from transformer_engine.pytorch._lite import ( + MoriEPDispatch, + MoriEPCombine, + ) + assert callable(MoriEPDispatch.apply) + assert callable(MoriEPCombine.apply) + + def test_routing_conversion_symbols_exported(self): + from transformer_engine.pytorch._lite import mask_to_index, index_to_mask + assert callable(mask_to_index) + assert callable(index_to_mask) + + def test_std_moe_symbols_exported(self): + from transformer_engine.pytorch._lite import ( + MoriEPDispatchStdMoE, + MoriEPCombineStdMoE, + ) + assert callable(MoriEPDispatchStdMoE.apply) + assert callable(MoriEPCombineStdMoE.apply) + + +# --------------------------------------------------------------------------- +# Autograd Unit Tests +# --------------------------------------------------------------------------- + +class TestMoriEPCycleState: + """Test the shared cycle state object.""" + + def _make_ep_and_state(self): + import transformer_engine.pytorch._lite.mori_ep as mod + + mock_mori = MagicMock() + mock_kt = MagicMock() + mock_kt.IntraNode = "IntraNode" + mock_mori.ops.EpDispatchCombineKernelType = mock_kt + mock_mori.ops.EpDispatchCombineConfig = MagicMock() + mock_mori.ops.EpDispatchCombineOp.return_value = MagicMock() + + with mock.patch.object(mod, "_mori_available", True), \ + mock.patch.object(mod, "_mori", mock_mori), \ + mock.patch.object(mod, "_mori_shmem_initialized", True): + ep = mod.MoriExpertParallel( + rank=0, world_size=2, hidden_dim=64, + num_experts_per_rank=4, num_experts_per_token=2, + max_num_inp_token_per_rank=16, + ) + state = ep.new_cycle() + + return ep, state + + def test_new_cycle_returns_state(self): + from transformer_engine.pytorch._lite.mori_ep import _MoriEPCycleState + ep, state = self._make_ep_and_state() + assert isinstance(state, _MoriEPCycleState) + assert state.ep is ep + + def test_initial_state_is_empty(self): + ep, state = self._make_ep_and_state() + assert state.fwd_weights is None + assert state.fwd_indices is None + assert state.fwd_num_input == 0 + assert state.fwd_num_recv == 0 + assert state.bwd_recv_weights is None + assert state.bwd_recv_indices is None + assert state.bwd_num_recv == 0 + + +class TestMoriEPDispatchAutograd: + """Test MoriEPDispatch autograd function with mocked MORI backend.""" + + def _make_ep_and_mock(self): + import transformer_engine.pytorch._lite.mori_ep as mod + + mock_mori = MagicMock() + mock_kt = MagicMock() + mock_kt.IntraNode = "IntraNode" + mock_mori.ops.EpDispatchCombineKernelType = mock_kt + mock_mori.ops.EpDispatchCombineConfig = MagicMock() + + mock_op = MagicMock() + mock_mori.ops.EpDispatchCombineOp.return_value = mock_op + + with mock.patch.object(mod, "_mori_available", True), \ + mock.patch.object(mod, "_mori", mock_mori), \ + mock.patch.object(mod, "_mori_shmem_initialized", True): + ep = mod.MoriExpertParallel( + rank=0, world_size=2, hidden_dim=64, + num_experts_per_rank=4, num_experts_per_token=2, + max_num_inp_token_per_rank=16, + ) + + return ep, mock_op + + @patch("torch.cuda.synchronize") + def test_dispatch_forward_saves_state(self, mock_sync): + from transformer_engine.pytorch._lite.mori_ep import MoriEPDispatch + + ep, mock_op = self._make_ep_and_mock() + hidden_dim, topk, num_recv = 64, 2, 5 + + mock_op.dispatch.return_value = ( + torch.randn(num_recv, hidden_dim), + torch.rand(num_recv, topk), + None, + torch.randint(0, 8, (num_recv, topk), dtype=torch.int32), + torch.tensor([num_recv], dtype=torch.int32), + ) + + state = ep.new_cycle() + input_t = torch.randn(4, hidden_dim, requires_grad=True) + weights = torch.rand(4, topk) + indices = torch.randint(0, 8, (4, topk), dtype=torch.int32) + + recv, recv_w, recv_idx = MoriEPDispatch.apply( + input_t, weights, indices, state, + ) + + # Check state was populated + assert state.fwd_num_input == 4 + assert state.fwd_num_recv == num_recv + assert state.fwd_weights is not None + assert state.fwd_indices is not None + assert recv.shape == (num_recv, hidden_dim) + assert recv_w.shape == (num_recv, topk) + + @patch("torch.cuda.synchronize") + def test_dispatch_backward_calls_combine(self, mock_sync): + """Dispatch backward should call MORI combine to reverse the communication.""" + from transformer_engine.pytorch._lite.mori_ep import MoriEPDispatch + + ep, mock_op = self._make_ep_and_mock() + hidden_dim, topk = 64, 2 + num_recv = 5 + num_input = 4 + + mock_op.dispatch.return_value = ( + torch.randn(num_recv, hidden_dim), + torch.rand(num_recv, topk), + None, + torch.randint(0, 8, (num_recv, topk), dtype=torch.int32), + torch.tensor([num_recv], dtype=torch.int32), + ) + mock_op.combine.return_value = ( + torch.randn(16, hidden_dim), # max_tokens=16 + torch.rand(16, topk), + ) + + state = ep.new_cycle() + input_t = torch.randn(num_input, hidden_dim, requires_grad=True) + weights = torch.rand(num_input, topk) + indices = torch.randint(0, 8, (num_input, topk), dtype=torch.int32) + + recv, recv_w, recv_idx = MoriEPDispatch.apply( + input_t, weights, indices, state, + ) + + # Simulate backward: set bwd state as if combine.backward ran first + state.bwd_recv_weights = torch.rand(num_recv, topk) + state.bwd_recv_indices = torch.randint(0, 8, (num_recv, topk), dtype=torch.int32) + state.bwd_num_recv = num_recv + + # Trigger backward + loss = recv.sum() + loss.backward() + + # Verify combine was called in backward + mock_op.combine.assert_called_once() + mock_op.reset.assert_called_once() + assert input_t.grad is not None + assert input_t.grad.shape == (num_input, hidden_dim) + + +class TestMoriEPCombineAutograd: + """Test MoriEPCombine autograd function with mocked MORI backend.""" + + def _make_ep_and_mock(self): + import transformer_engine.pytorch._lite.mori_ep as mod + + mock_mori = MagicMock() + mock_kt = MagicMock() + mock_kt.IntraNode = "IntraNode" + mock_mori.ops.EpDispatchCombineKernelType = mock_kt + mock_mori.ops.EpDispatchCombineConfig = MagicMock() + + mock_op = MagicMock() + mock_mori.ops.EpDispatchCombineOp.return_value = mock_op + + with mock.patch.object(mod, "_mori_available", True), \ + mock.patch.object(mod, "_mori", mock_mori), \ + mock.patch.object(mod, "_mori_shmem_initialized", True): + ep = mod.MoriExpertParallel( + rank=0, world_size=2, hidden_dim=64, + num_experts_per_rank=4, num_experts_per_token=2, + max_num_inp_token_per_rank=16, + ) + + return ep, mock_op + + @patch("torch.cuda.synchronize") + def test_combine_forward_resets_op(self, mock_sync): + """Combine forward should reset the MORI op after combining.""" + from transformer_engine.pytorch._lite.mori_ep import MoriEPCombine + + ep, mock_op = self._make_ep_and_mock() + hidden_dim, topk = 64, 2 + num_input = 4 + + mock_op.combine.return_value = ( + torch.randn(16, hidden_dim), + torch.rand(16, topk), + ) + + state = ep.new_cycle() + state.fwd_num_input = num_input # as if dispatch already ran + + expert_out = torch.randn(5, hidden_dim, requires_grad=True) + recv_w = torch.rand(5, topk) + recv_idx = torch.randint(0, 8, (5, topk), dtype=torch.int32) + + output, output_w = MoriEPCombine.apply(expert_out, recv_w, recv_idx, state) + + mock_op.combine.assert_called_once() + mock_op.reset.assert_called_once() + assert output.shape == (num_input, hidden_dim) + + @patch("torch.cuda.synchronize") + def test_combine_backward_dispatches_gradients(self, mock_sync): + """Combine backward should dispatch gradients to expert ranks.""" + from transformer_engine.pytorch._lite.mori_ep import MoriEPCombine + + ep, mock_op = self._make_ep_and_mock() + hidden_dim, topk = 64, 2 + num_input = 4 + num_expert_tokens = 5 # tokens this rank received from dispatch + + mock_op.combine.return_value = ( + torch.randn(16, hidden_dim), + torch.rand(16, topk), + ) + # Mock the backward dispatch call -- in a real scenario, the backward + # dispatch uses the same routing as forward, so the number of received + # gradient tokens matches the forward expert_output count. + mock_op.dispatch.return_value = ( + torch.randn(num_expert_tokens, hidden_dim), + torch.rand(num_expert_tokens, topk), + None, + torch.randint(0, 8, (num_expert_tokens, topk), dtype=torch.int32), + torch.tensor([num_expert_tokens], dtype=torch.int32), + ) + + state = ep.new_cycle() + state.fwd_num_input = num_input + state.fwd_weights = torch.rand(num_input, topk) + state.fwd_indices = torch.randint(0, 8, (num_input, topk), dtype=torch.int32) + + expert_out = torch.randn(num_expert_tokens, hidden_dim, requires_grad=True) + recv_w = torch.rand(num_expert_tokens, topk) + recv_idx = torch.randint(0, 8, (num_expert_tokens, topk), dtype=torch.int32) + + output, output_w = MoriEPCombine.apply(expert_out, recv_w, recv_idx, state) + + # Reset mock counters (combine was called in forward, reset was called) + mock_op.reset.reset_mock() + + loss = output.sum() + loss.backward() + + # Verify dispatch was called in backward for gradient communication + assert mock_op.dispatch.call_count == 1 + assert expert_out.grad is not None + assert expert_out.grad.shape == (num_expert_tokens, hidden_dim) + # Verify backward state was saved + assert state.bwd_num_recv == num_expert_tokens + + +class TestMoriEPFullAutogradCycle: + """Test a full forward+backward cycle through dispatch → expert → combine.""" + + def _make_ep_and_mock(self): + import transformer_engine.pytorch._lite.mori_ep as mod + + mock_mori = MagicMock() + mock_kt = MagicMock() + mock_kt.IntraNode = "IntraNode" + mock_mori.ops.EpDispatchCombineKernelType = mock_kt + mock_mori.ops.EpDispatchCombineConfig = MagicMock() + + mock_op = MagicMock() + mock_mori.ops.EpDispatchCombineOp.return_value = mock_op + + with mock.patch.object(mod, "_mori_available", True), \ + mock.patch.object(mod, "_mori", mock_mori), \ + mock.patch.object(mod, "_mori_shmem_initialized", True): + ep = mod.MoriExpertParallel( + rank=0, world_size=2, hidden_dim=64, + num_experts_per_rank=4, num_experts_per_token=2, + max_num_inp_token_per_rank=16, + ) + + return ep, mock_op + + @patch("torch.cuda.synchronize") + def test_full_forward_backward_cycle(self, mock_sync): + """Test complete dispatch → expert → combine → backward flow.""" + from transformer_engine.pytorch._lite.mori_ep import ( + MoriEPDispatch, MoriEPCombine, + ) + + ep, mock_op = self._make_ep_and_mock() + hidden_dim, topk = 64, 2 + num_input = 4 + num_recv = 6 + + # Forward dispatch mock + fwd_recv_tokens = torch.randn(num_recv, hidden_dim) + fwd_recv_weights = torch.rand(num_recv, topk) + fwd_recv_indices = torch.randint(0, 8, (num_recv, topk), dtype=torch.int32) + + mock_op.dispatch.return_value = ( + fwd_recv_tokens, + fwd_recv_weights, + None, + fwd_recv_indices, + torch.tensor([num_recv], dtype=torch.int32), + ) + + # Forward combine mock + mock_op.combine.return_value = ( + torch.randn(16, hidden_dim), + torch.rand(16, topk), + ) + + # --- Forward --- + state = ep.new_cycle() + input_t = torch.randn(num_input, hidden_dim, requires_grad=True) + weights = torch.rand(num_input, topk) + indices = torch.randint(0, 8, (num_input, topk), dtype=torch.int32) + + # Step 1: Dispatch + recv, recv_w, recv_idx = MoriEPDispatch.apply( + input_t, weights, indices, state, + ) + + # Step 2: Expert computation (simple linear, differentiable) + expert_weight = torch.randn(hidden_dim, hidden_dim, requires_grad=True) + expert_out = recv @ expert_weight + + # Step 3: Combine + output, output_w = MoriEPCombine.apply( + expert_out, recv_w, recv_idx, state, + ) + + # Verify forward calls + assert mock_op.dispatch.call_count == 1 # forward dispatch + assert mock_op.combine.call_count == 1 # forward combine + assert mock_op.reset.call_count == 1 # forward reset + + # --- Backward --- + # Set up backward mocks (dispatch in combine.bwd, combine in dispatch.bwd) + bwd_recv = torch.randn(num_recv, hidden_dim) + mock_op.dispatch.return_value = ( + bwd_recv, + torch.rand(num_recv, topk), + None, + torch.randint(0, 8, (num_recv, topk), dtype=torch.int32), + torch.tensor([num_recv], dtype=torch.int32), + ) + mock_op.combine.return_value = ( + torch.randn(16, hidden_dim), + torch.rand(16, topk), + ) + + loss = output.sum() + loss.backward() + + # Verify backward calls + assert mock_op.dispatch.call_count == 2 # fwd dispatch + bwd dispatch (in combine.bwd) + assert mock_op.combine.call_count == 2 # fwd combine + bwd combine (in dispatch.bwd) + assert mock_op.reset.call_count == 2 # fwd reset + bwd reset + + # Verify gradients exist + assert input_t.grad is not None + assert input_t.grad.shape == (num_input, hidden_dim) + assert expert_weight.grad is not None + assert expert_weight.grad.shape == (hidden_dim, hidden_dim) + + +# --------------------------------------------------------------------------- +# Multi-GPU Integration Tests -- require MORI + AMD GPUs +# --------------------------------------------------------------------------- + +def _run_ep_worker(rank, world_size, hidden_dim, num_experts_per_rank, + num_experts_per_token, max_tokens, results_dict): + """Worker function for multi-GPU dispatch/combine test.""" + import torch + import torch.distributed as dist + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + os.environ.setdefault("MORI_SHMEM_HEAP_SIZE", "1G") + + torch.cuda.set_device(rank) + device = torch.device("cuda", rank) + + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + rank=rank, + world_size=world_size, + device_id=device, + ) + # Register world group for MORI + world_group = dist.group.WORLD + torch._C._distributed_c10d._register_process_group("default", world_group) + + from transformer_engine.pytorch._lite.mori_ep import ( + init_mori_ep, finalize_mori_ep, MoriExpertParallel, + ) + + try: + init_mori_ep() + + ep = MoriExpertParallel( + rank=rank, + world_size=world_size, + hidden_dim=hidden_dim, + num_experts_per_rank=num_experts_per_rank, + num_experts_per_token=num_experts_per_token, + max_num_inp_token_per_rank=max_tokens, + dtype=torch.bfloat16, + kernel_type="intra_node", + ) + + # Generate test data + torch.manual_seed(42 + rank) + num_tokens = max_tokens // 2 # use half capacity + total_experts = num_experts_per_rank * world_size + input_tokens = torch.randn( + num_tokens, hidden_dim, dtype=torch.bfloat16, device=device, + ) + weights = torch.rand( + num_tokens, num_experts_per_token, dtype=torch.float32, device=device, + ) + indices = torch.stack([ + torch.randperm(total_experts, device=device)[:num_experts_per_token] + for _ in range(num_tokens) + ]).to(torch.int32) + + # Dispatch + recv_tokens, recv_weights, recv_indices, num_recv = ep.dispatch( + input_tokens, weights, indices, + ) + + # Simple expert function: identity (pass through) + expert_output = recv_tokens[:num_recv].clone().to(torch.bfloat16) + + # Combine + output, output_weights = ep.combine( + expert_output, recv_weights, recv_indices, + ) + ep.reset() + + # Verify basic properties + results_dict[rank] = { + "num_recv": num_recv, + "output_shape": tuple(output.shape), + "num_input_tokens": num_tokens, + "success": True, + } + + except Exception as e: + results_dict[rank] = { + "success": False, + "error": str(e), + } + + finally: + finalize_mori_ep() + dist.destroy_process_group() + + +# --------------------------------------------------------------------------- +# Standard MoE layout tests +# --------------------------------------------------------------------------- + +class TestStdMoEDispatchCombine: + """Test standard MoE per-expert layout dispatch/combine with mocked MORI.""" + + def _make_ep_and_mock(self): + import transformer_engine.pytorch._lite.mori_ep as mod + + mock_mori = MagicMock() + mock_kt = MagicMock() + mock_kt.IntraNode = "IntraNode" + mock_mori.ops.EpDispatchCombineKernelType = mock_kt + mock_mori.ops.EpDispatchCombineConfig = MagicMock() + + mock_op = MagicMock() + mock_mori.ops.EpDispatchCombineOp.return_value = mock_op + + with mock.patch.object(mod, "_mori_available", True), \ + mock.patch.object(mod, "_mori", mock_mori), \ + mock.patch.object(mod, "_mori_shmem_initialized", True): + ep = mod.MoriExpertParallel( + rank=0, world_size=2, hidden_dim=64, + num_experts_per_rank=4, num_experts_per_token=2, + max_num_inp_token_per_rank=16, + kernel_type="intra_node", + ) + + return ep, mock_op + + @patch("torch.cuda.synchronize") + def test_dispatch_standard_moe_returns_per_expert_layout(self, mock_sync): + ep, mock_op = self._make_ep_and_mock() + hidden_dim = 64 + num_experts = 4 + max_tpe = 32 # max_tokens_per_expert = world_size * max_tokens + + mock_op.dispatch_standard_moe.return_value = ( + torch.randn(num_experts, max_tpe, hidden_dim), # packed_tokens + torch.tensor([5, 3, 7, 2], dtype=torch.int32), # recv_count + torch.zeros(num_experts, max_tpe, dtype=torch.int32), # src_info + torch.empty(0), # layout_range + ) + + input_t = torch.randn(8, hidden_dim) + weights = torch.rand(8, 2) + indices = torch.randint(0, 8, (8, 2), dtype=torch.int32) + + packed, recv_count, src_info = ep.dispatch_standard_moe( + input_t, weights, indices, + ) + + mock_op.dispatch_standard_moe.assert_called_once() + assert packed.shape == (num_experts, max_tpe, hidden_dim) + assert recv_count.shape == (num_experts,) + assert recv_count.tolist() == [5, 3, 7, 2] + + @patch("torch.cuda.synchronize") + def test_combine_standard_moe(self, mock_sync): + ep, mock_op = self._make_ep_and_mock() + hidden_dim = 64 + num_experts = 4 + max_tpe = 32 + + mock_op.combine_standard_moe.return_value = ( + torch.randn(16, hidden_dim), # output + None, + ) + + expert_out = torch.randn(num_experts, max_tpe, hidden_dim) + weights = torch.rand(8, 2) + indices = torch.randint(0, 8, (8, 2), dtype=torch.int32) + + output, output_w = ep.combine_standard_moe( + expert_out, weights, indices, + ) + + mock_op.combine_standard_moe.assert_called_once() + assert output.shape[1] == hidden_dim + assert output_w is None + + @patch("torch.cuda.synchronize") + def test_convert_dispatch_to_standard(self, mock_sync): + ep, mock_op = self._make_ep_and_mock() + hidden_dim = 64 + num_experts = 4 + max_tpe = 32 + + mock_op.convert_dispatch_output.return_value = ( + torch.randn(num_experts, max_tpe, hidden_dim), + torch.tensor([3, 4, 2, 1], dtype=torch.int32), + torch.zeros(num_experts, max_tpe, dtype=torch.int32), + torch.empty(0), + ) + + dispatch_tokens = torch.randn(20, hidden_dim) + dispatch_indices = torch.randint(0, 8, (20, 2), dtype=torch.int32) + + packed, recv_count, src_info = ep.convert_dispatch_to_standard( + dispatch_tokens, dispatch_indices, + ) + + mock_op.convert_dispatch_output.assert_called_once() + assert packed.shape == (num_experts, max_tpe, hidden_dim) + + @patch("torch.cuda.synchronize") + def test_convert_standard_to_combine_input(self, mock_sync): + ep, mock_op = self._make_ep_and_mock() + hidden_dim = 64 + num_experts = 4 + max_tpe = 32 + max_recv = 20 + + mock_op.convert_combine_input.return_value = torch.randn(max_recv, hidden_dim) + + packed = torch.randn(num_experts, max_tpe, hidden_dim) + src_info = torch.zeros(num_experts, max_tpe, dtype=torch.int32) + + flat = ep.convert_standard_to_combine_input(packed, src_info) + + mock_op.convert_combine_input.assert_called_once() + assert flat.shape == (max_recv, hidden_dim) + + +class TestStdMoECycleState: + """Test standard MoE cycle state creation.""" + + def _make_ep(self): + import transformer_engine.pytorch._lite.mori_ep as mod + + mock_mori = MagicMock() + mock_kt = MagicMock() + mock_kt.IntraNode = "IntraNode" + mock_mori.ops.EpDispatchCombineKernelType = mock_kt + mock_mori.ops.EpDispatchCombineConfig = MagicMock() + mock_mori.ops.EpDispatchCombineOp.return_value = MagicMock() + + with mock.patch.object(mod, "_mori_available", True), \ + mock.patch.object(mod, "_mori", mock_mori), \ + mock.patch.object(mod, "_mori_shmem_initialized", True): + ep = mod.MoriExpertParallel( + rank=0, world_size=2, hidden_dim=64, + num_experts_per_rank=4, num_experts_per_token=2, + max_num_inp_token_per_rank=16, + ) + return ep + + def test_new_std_moe_cycle(self): + from transformer_engine.pytorch._lite.mori_ep import _MoriStdMoECycleState + ep = self._make_ep() + state = ep.new_std_moe_cycle() + assert isinstance(state, _MoriStdMoECycleState) + assert state.ep is ep + assert state.fwd_recv_count is None + assert state.fwd_src_info is None + + +class TestStdMoEAutograd: + """Test standard MoE autograd dispatch/combine cycle.""" + + def _make_ep_and_mock(self): + import transformer_engine.pytorch._lite.mori_ep as mod + + mock_mori = MagicMock() + mock_kt = MagicMock() + mock_kt.IntraNode = "IntraNode" + mock_mori.ops.EpDispatchCombineKernelType = mock_kt + mock_mori.ops.EpDispatchCombineConfig = MagicMock() + + mock_op = MagicMock() + mock_mori.ops.EpDispatchCombineOp.return_value = mock_op + + with mock.patch.object(mod, "_mori_available", True), \ + mock.patch.object(mod, "_mori", mock_mori), \ + mock.patch.object(mod, "_mori_shmem_initialized", True): + ep = mod.MoriExpertParallel( + rank=0, world_size=2, hidden_dim=64, + num_experts_per_rank=4, num_experts_per_token=2, + max_num_inp_token_per_rank=16, + ) + + return ep, mock_op + + @patch("torch.cuda.synchronize") + def test_full_std_moe_forward_backward(self, mock_sync): + """Test dispatch_standard_moe → expert → combine_standard_moe → backward.""" + from transformer_engine.pytorch._lite.mori_ep import ( + MoriEPDispatchStdMoE, MoriEPCombineStdMoE, + ) + + ep, mock_op = self._make_ep_and_mock() + hidden_dim = 64 + topk = 2 + num_input = 4 + num_experts = 4 + max_tpe = 32 + + # Forward dispatch mock + mock_op.dispatch_standard_moe.return_value = ( + torch.randn(num_experts, max_tpe, hidden_dim), + torch.tensor([3, 2, 4, 1], dtype=torch.int32), + torch.zeros(num_experts, max_tpe, dtype=torch.int32), + torch.empty(0), + ) + + # Forward combine mock + mock_op.combine_standard_moe.return_value = ( + torch.randn(16, hidden_dim), + None, + ) + + # --- Forward --- + state = ep.new_std_moe_cycle() + input_t = torch.randn(num_input, hidden_dim, requires_grad=True) + weights = torch.rand(num_input, topk) + indices = torch.randint(0, 8, (num_input, topk), dtype=torch.int32) + + packed, recv_count, src_info = MoriEPDispatchStdMoE.apply( + input_t, weights, indices, state, + ) + + assert packed.shape == (num_experts, max_tpe, hidden_dim) + assert state.fwd_recv_count is not None + + # Expert computation: simple per-expert linear + expert_weight = torch.randn(hidden_dim, hidden_dim, requires_grad=True) + expert_out = torch.einsum("eth,hd->etd", packed.float(), expert_weight.float()) + expert_out = expert_out.to(packed.dtype) + + output, _ = MoriEPCombineStdMoE.apply( + expert_out, weights, indices, state, + ) + + # Verify forward calls + assert mock_op.dispatch_standard_moe.call_count == 1 + assert mock_op.combine_standard_moe.call_count == 1 + assert mock_op.reset.call_count == 1 + + # --- Backward --- + # Backward mocks: combine.bwd calls dispatch_standard_moe, + # dispatch.bwd calls combine_standard_moe + mock_op.dispatch_standard_moe.return_value = ( + torch.randn(num_experts, max_tpe, hidden_dim), + torch.tensor([3, 2, 4, 1], dtype=torch.int32), + torch.zeros(num_experts, max_tpe, dtype=torch.int32), + torch.empty(0), + ) + mock_op.combine_standard_moe.return_value = ( + torch.randn(16, hidden_dim), + None, + ) + + loss = output.sum() + loss.backward() + + # Verify backward calls + assert mock_op.dispatch_standard_moe.call_count == 2 # fwd + bwd + assert mock_op.combine_standard_moe.call_count == 2 # fwd + bwd + assert mock_op.reset.call_count == 2 # fwd + bwd + + assert input_t.grad is not None + assert input_t.grad.shape == (num_input, hidden_dim) + assert expert_weight.grad is not None + + @patch("torch.cuda.synchronize") + def test_dispatch_std_moe_saves_state(self, mock_sync): + from transformer_engine.pytorch._lite.mori_ep import MoriEPDispatchStdMoE + + ep, mock_op = self._make_ep_and_mock() + hidden_dim, topk, num_experts, max_tpe = 64, 2, 4, 32 + + mock_op.dispatch_standard_moe.return_value = ( + torch.randn(num_experts, max_tpe, hidden_dim), + torch.tensor([2, 3, 1, 4], dtype=torch.int32), + torch.ones(num_experts, max_tpe, dtype=torch.int32), + torch.empty(0), + ) + + state = ep.new_std_moe_cycle() + input_t = torch.randn(6, hidden_dim, requires_grad=True) + weights = torch.rand(6, topk) + indices = torch.randint(0, 8, (6, topk), dtype=torch.int32) + + MoriEPDispatchStdMoE.apply(input_t, weights, indices, state) + + assert state.fwd_num_input == 6 + assert state.fwd_weights is not None + assert state.fwd_indices is not None + assert state.fwd_recv_count is not None + assert state.fwd_src_info is not None + + +@skip_no_mori +@skip_insufficient_gpus +class TestMoriEPIntegration: + """Integration tests requiring MORI and multiple GPUs.""" + + def test_dispatch_combine_basic(self): + """Test basic dispatch/combine round-trip across GPUs.""" + import torch.multiprocessing as mp + + world_size = min(_gpu_count(), 8) + hidden_dim = 256 + num_experts_per_rank = 4 + num_experts_per_token = 2 + max_tokens = 64 + + manager = mp.Manager() + results = manager.dict() + + mp.spawn( + _run_ep_worker, + args=(world_size, hidden_dim, num_experts_per_rank, + num_experts_per_token, max_tokens, results), + nprocs=world_size, + join=True, + ) + + for rank in range(world_size): + assert rank in results, f"Rank {rank} did not report results" + r = results[rank] + assert r["success"], f"Rank {rank} failed: {r.get('error', 'unknown')}" + assert r["num_recv"] >= 0, f"Rank {rank}: invalid num_recv" + assert r["output_shape"][1] == hidden_dim + + +def _run_ep_roundtrip_verify(rank, world_size, hidden_dim, num_experts_per_rank, + num_experts_per_token, max_tokens, results_dict): + """Worker that verifies dispatch/combine preserves token data.""" + import torch + import torch.distributed as dist + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29501" + os.environ.setdefault("MORI_SHMEM_HEAP_SIZE", "1G") + + torch.cuda.set_device(rank) + device = torch.device("cuda", rank) + + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + rank=rank, + world_size=world_size, + device_id=device, + ) + world_group = dist.group.WORLD + torch._C._distributed_c10d._register_process_group("default", world_group) + + from transformer_engine.pytorch._lite.mori_ep import ( + init_mori_ep, finalize_mori_ep, MoriExpertParallel, + ) + + try: + init_mori_ep() + + ep = MoriExpertParallel( + rank=rank, + world_size=world_size, + hidden_dim=hidden_dim, + num_experts_per_rank=num_experts_per_rank, + num_experts_per_token=num_experts_per_token, + max_num_inp_token_per_rank=max_tokens, + dtype=torch.bfloat16, + kernel_type="intra_node", + ) + + # Fixed seed for reproducibility across runs + torch.manual_seed(123 + rank) + num_tokens = 16 + total_experts = num_experts_per_rank * world_size + + input_tokens = torch.randn( + num_tokens, hidden_dim, dtype=torch.bfloat16, device=device, + ) + weights = torch.ones( + num_tokens, num_experts_per_token, dtype=torch.float32, device=device, + ) + indices = torch.stack([ + torch.randperm(total_experts, device=device)[:num_experts_per_token] + for _ in range(num_tokens) + ]).to(torch.int32) + + # Dispatch + recv_tokens, recv_weights, recv_indices, num_recv = ep.dispatch( + input_tokens, weights, indices, + ) + + # Identity expert: just pass tokens through + expert_output = recv_tokens[:num_recv].clone().to(torch.bfloat16) + + # Combine + output, _ = ep.combine(expert_output, recv_weights, recv_indices) + + # After identity expert with weights=1.0, each token's output should + # equal input * (number of unique PEs the token was sent to). + # This matches the MORI test pattern. + for i in range(num_tokens): + pes = set() + for idx in indices[i].cpu().tolist(): + pes.add(idx // num_experts_per_rank) + unique_pes = len(pes) + + expected = (input_tokens[i].float() * unique_pes).to(torch.bfloat16) + got = output[i] + match = torch.allclose(got.float(), expected.float(), atol=1e-2, rtol=1e-2) + if not match: + results_dict[rank] = { + "success": False, + "error": ( + f"Token {i}: expected scale {unique_pes}, " + f"max_diff={torch.abs(got.float() - expected.float()).max().item()}" + ), + } + ep.reset() + return + + ep.reset() + results_dict[rank] = {"success": True, "num_tokens_verified": num_tokens} + + except Exception as e: + results_dict[rank] = {"success": False, "error": str(e)} + + finally: + finalize_mori_ep() + dist.destroy_process_group() + + +@skip_no_mori +@skip_insufficient_gpus +class TestMoriEPRoundtrip: + """Verify data integrity through dispatch/combine round-trip.""" + + def test_identity_expert_roundtrip(self): + """With identity expert and uniform weights, output = input * num_unique_pes.""" + import torch.multiprocessing as mp + + world_size = min(_gpu_count(), 8) + hidden_dim = 128 + num_experts_per_rank = 4 + num_experts_per_token = 2 + max_tokens = 64 + + manager = mp.Manager() + results = manager.dict() + + mp.spawn( + _run_ep_roundtrip_verify, + args=(world_size, hidden_dim, num_experts_per_rank, + num_experts_per_token, max_tokens, results), + nprocs=world_size, + join=True, + ) + + for rank in range(world_size): + assert rank in results, f"Rank {rank} did not report results" + r = results[rank] + assert r["success"], f"Rank {rank} failed: {r.get('error', 'unknown')}" diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index b2b175892..25d4646b9 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -86,6 +86,19 @@ try: __version__ = str(metadata.version("transformer_engine")) except metadata.PackageNotFoundError: - _te_core_installed, _, __version__ = transformer_engine.common.get_te_core_package_info(True) - if not _te_core_installed: - raise + if transformer_engine.common._nvte_lite_mode: + # Lite-only wheels are installed under the `tealite` distribution name, + # so `metadata.version("transformer_engine")` raises. Prefer that, then + # fall back to reading VERSION.txt via build_tools, then a sentinel. + try: + __version__ = str(metadata.version("tealite")) + except metadata.PackageNotFoundError: + try: + from transformer_engine.build_tools.te_version import te_version + __version__ = te_version() + "+lite" + except Exception: + __version__ = "0.0.0+lite" + else: + _te_core_installed, _, __version__ = transformer_engine.common.get_te_core_package_info(True) + if not _te_core_installed: + raise diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 9dbf998e5..df3971f82 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -158,6 +158,15 @@ def load_framework_extension(framework: str) -> None: # Supported frameworks. assert framework in ("jax", "torch"), f"Unsupported framework {framework}" + # Lite mode: use pure-Python replacement instead of compiled C++ extension. + if framework == "torch" and os.environ.get("NVTE_LITE", "0") == "1": + from transformer_engine.pytorch._lite import \ + __name__ as _lite_name # noqa: F401 -- trigger init + import transformer_engine.pytorch._lite as _lite_module + module_name = f"transformer_engine_{framework}" + sys.modules[module_name] = _lite_module + return + # Name of the framework extension library. module_name = f"transformer_engine_{framework}" @@ -402,6 +411,9 @@ def _load_cuda_library(lib_name: str): @functools.cache def is_fp8_fnuz(): + if _TE_LIB_CTYPES is None: + # Lite mode: assume FNUZ based on ROCm convention + return True if te_rocm_build: _TE_LIB_CTYPES.nvte_uses_fp8_fnuz.restype = ctypes.c_bool return _TE_LIB_CTYPES.nvte_uses_fp8_fnuz() @@ -413,7 +425,17 @@ def _load_core_library(): return ctypes.CDLL(_get_shared_object_file("core"), mode=ctypes.RTLD_GLOBAL) -if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): +# Detect lite mode: explicit env var or LITE_BUILD marker file (lite-only wheel) +_lite_marker = Path(__file__).parent.parent / "LITE_BUILD" +_nvte_lite_mode = os.environ.get("NVTE_LITE", "0") == "1" or _lite_marker.exists() +if _nvte_lite_mode: + os.environ["NVTE_LITE"] = "1" + +if _nvte_lite_mode: + # In lite mode, skip loading compiled C++ libraries entirely. + _TE_LIB_CTYPES = None + te_rocm_build = True # Lite mode targets ROCm +elif "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): sanity_checks_for_pypi_installation() if not te_rocm_build: try: diff --git a/transformer_engine/common/triton/fused_router.py b/transformer_engine/common/triton/fused_router.py new file mode 100644 index 000000000..eafbca764 --- /dev/null +++ b/transformer_engine/common/triton/fused_router.py @@ -0,0 +1,338 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Triton JIT kernels for fused MoE router operations. + +Fuses score-function (sigmoid/softmax) + top-k + post-processing into a +single kernel launch, matching the C++ ``fused_topk_with_score_function`` +kernel behaviour. +""" + +import triton +import triton.language as tl + + +# --------------------------------------------------------------------------- # +# Helpers +# --------------------------------------------------------------------------- # + +@triton.jit +def _iterative_topk( + scores, + valid, + offs, + selected, + topk_vals, + TOPK: tl.constexpr, + BLOCK_E: tl.constexpr, +): + """Select top-TOPK from *scores* at *valid* & unselected positions. + + Updates *selected* (0/1 int32 mask) and *topk_vals* in-place and returns + the updated pair. + """ + for _k in tl.static_range(TOPK): + avail = valid & (selected == 0) + masked = tl.where(avail, scores, float("-inf")) + best = tl.max(masked, axis=0) + is_best = (masked == best) & avail + candidate = tl.where(is_best, offs, BLOCK_E) + winner = tl.min(candidate, axis=0) + newly = (offs == winner) + selected = tl.where(newly, 1, selected) + topk_vals = tl.where(newly, scores, topk_vals) + return selected, topk_vals + + +# --------------------------------------------------------------------------- # +# Forward kernel +# --------------------------------------------------------------------------- # + +@triton.jit +def _fused_topk_score_fwd_kernel( + logits_ptr, + probs_ptr, + routing_map_ptr, + intermediate_ptr, + bias_ptr, + num_tokens, + scaling_factor, + # compile-time constants ------------------------------------------------ + NUM_EXPERTS: tl.constexpr, + TOPK: tl.constexpr, + SCORE_FN: tl.constexpr, # 0 = sigmoid, 1 = softmax + USE_PRE_SOFTMAX: tl.constexpr, + HAS_BIAS: tl.constexpr, + USE_GROUP_TOPK: tl.constexpr, + NUM_GROUPS: tl.constexpr, # ignored when USE_GROUP_TOPK == False + GROUP_TOPK: tl.constexpr, # ignored when USE_GROUP_TOPK == False + GROUP_SIZE: tl.constexpr, # NUM_EXPERTS // NUM_GROUPS + BLOCK_E: tl.constexpr, # next_power_of_2(NUM_EXPERTS) + BLOCK_G: tl.constexpr, # next_power_of_2(NUM_GROUPS) +): + """One program-instance per token.""" + pid = tl.program_id(0) + if pid >= num_tokens: + return + + offs = tl.arange(0, BLOCK_E) + valid = offs < NUM_EXPERTS + base = pid * NUM_EXPERTS + + # -- load logits -------------------------------------------------------- + logits = tl.load(logits_ptr + base + offs, mask=valid, other=0.0).to(tl.float32) + + # -- Step 1: score function --------------------------------------------- + if SCORE_FN == 0: # sigmoid + scores = tl.sigmoid(logits) + intermediate = scores # saved for backward + if HAS_BIAS: + bias = tl.load(bias_ptr + offs, mask=valid, other=0.0).to(tl.float32) + routing_scores = scores + bias + else: + routing_scores = scores + else: # softmax + if USE_PRE_SOFTMAX: + max_val = tl.max(tl.where(valid, logits, float("-inf")), axis=0) + exp_l = tl.exp(logits - max_val) + exp_l = tl.where(valid, exp_l, 0.0) + sum_exp = tl.sum(exp_l, axis=0) + scores = exp_l / sum_exp + intermediate = scores + routing_scores = scores + else: + routing_scores = logits + scores = logits + intermediate = tl.zeros([BLOCK_E], dtype=tl.float32) + + # -- Step 2: top-k (with optional group selection) ---------------------- + if USE_GROUP_TOPK: + # 2a. Score each group: sum of top-(TOPK//GROUP_TOPK) within group + g_offs = tl.arange(0, BLOCK_G) + g_valid = g_offs < NUM_GROUPS + group_scores = tl.zeros([BLOCK_G], dtype=tl.float32) + + for g in tl.static_range(NUM_GROUPS): + in_group = ((offs // GROUP_SIZE) == g) & valid + g_sel = tl.zeros([BLOCK_E], dtype=tl.int32) + g_sum = tl.zeros([1], dtype=tl.float32) # accumulator + for _ik in tl.static_range(TOPK // GROUP_TOPK): + g_avail = in_group & (g_sel == 0) + g_masked = tl.where(g_avail, routing_scores, float("-inf")) + g_best = tl.max(g_masked, axis=0) + g_is_best = (g_masked == g_best) & g_avail + g_cand = tl.where(g_is_best, offs, BLOCK_E) + g_win = tl.min(g_cand, axis=0) + g_sel = tl.where(offs == g_win, 1, g_sel) + g_sum += g_best + group_scores = tl.where(g_offs == g, g_sum, group_scores) + + # 2b. Select top GROUP_TOPK groups + selected_groups = tl.zeros([BLOCK_G], dtype=tl.int32) + for _gk in tl.static_range(GROUP_TOPK): + ga = g_valid & (selected_groups == 0) + gm = tl.where(ga, group_scores, float("-inf")) + gb = tl.max(gm, axis=0) + gib = (gm == gb) & ga + gc = tl.where(gib, g_offs, BLOCK_G) + gw = tl.min(gc, axis=0) + selected_groups = tl.where(g_offs == gw, 1, selected_groups) + + # 2c. Build expert mask from selected groups + expert_mask = tl.zeros([BLOCK_E], dtype=tl.int32) + for g in tl.static_range(NUM_GROUPS): + g_is_sel = tl.sum(tl.where(g_offs == g, selected_groups, 0), axis=0) + in_g = ((offs // GROUP_SIZE) == g) & valid + expert_mask = tl.where(in_g & (g_is_sel > 0), 1, expert_mask) + + # 2d. Top-k over experts in selected groups + topk_scores_for_sel = tl.where(expert_mask != 0, routing_scores, float("-inf")) + else: + topk_scores_for_sel = routing_scores + + selected = tl.zeros([BLOCK_E], dtype=tl.int32) + topk_vals = tl.zeros([BLOCK_E], dtype=tl.float32) + selected, topk_vals = _iterative_topk( + topk_scores_for_sel, valid, offs, selected, topk_vals, + TOPK=TOPK, BLOCK_E=BLOCK_E, + ) + sel_bool = selected != 0 + + # -- Step 3: post-processing -------------------------------------------- + if SCORE_FN == 0: # sigmoid + if HAS_BIAS: + topk_vals = tl.where(sel_bool, topk_vals - bias, topk_vals) + if TOPK > 1: + s = tl.sum(tl.where(sel_bool, topk_vals, 0.0), axis=0) + 1e-9 + topk_vals = tl.where(sel_bool, topk_vals / s, 0.0) + else: # softmax + if not USE_PRE_SOFTMAX: + sel_logits = tl.where(sel_bool, topk_vals, float("-inf")) + mx = tl.max(sel_logits, axis=0) + e = tl.exp(sel_logits - mx) + e = tl.where(sel_bool, e, 0.0) + topk_vals = e / tl.sum(e, axis=0) + intermediate = tl.where(sel_bool, topk_vals, intermediate) + + # scaling + topk_vals = tl.where(sel_bool, topk_vals * scaling_factor, 0.0) + + # -- Step 4: store outputs ---------------------------------------------- + tl.store(probs_ptr + base + offs, + tl.where(sel_bool & valid, topk_vals, 0.0).to(logits.dtype), + mask=valid) + tl.store(routing_map_ptr + base + offs, + selected.to(tl.int8), + mask=valid) + tl.store(intermediate_ptr + base + offs, + tl.where(valid, intermediate, 0.0).to(logits.dtype), + mask=valid) + + +# --------------------------------------------------------------------------- # +# Backward kernel (unchanged — group topk doesn't affect backward) +# --------------------------------------------------------------------------- # + +@triton.jit +def _fused_topk_score_bwd_kernel( + routing_map_ptr, + intermediate_ptr, + grad_probs_ptr, + grad_logits_ptr, + num_tokens, + scaling_factor, + # compile-time constants ------------------------------------------------ + NUM_EXPERTS: tl.constexpr, + TOPK: tl.constexpr, + SCORE_FN: tl.constexpr, + USE_PRE_SOFTMAX: tl.constexpr, + BLOCK_E: tl.constexpr, +): + """One program-instance per token.""" + pid = tl.program_id(0) + if pid >= num_tokens: + return + + offs = tl.arange(0, BLOCK_E) + valid = offs < NUM_EXPERTS + base = pid * NUM_EXPERTS + + grad = tl.load(grad_probs_ptr + base + offs, mask=valid, other=0.0).to(tl.float32) + sel_i8 = tl.load(routing_map_ptr + base + offs, mask=valid, other=0) + sel = sel_i8 != 0 + fwd_out = tl.load(intermediate_ptr + base + offs, mask=valid, other=0.0).to(tl.float32) + + # scale selected grads + grad = tl.where(sel, grad * scaling_factor, grad) + + if SCORE_FN == 0: # sigmoid + if TOPK > 1: + fwd_sel = tl.where(sel, fwd_out, 0.0) + s = tl.sum(fwd_sel, axis=0) + 1e-9 + og = tl.sum(tl.where(sel, fwd_sel * grad, 0.0), axis=0) + grad = tl.where(sel, grad / s - og / (s * s), 0.0) + # mask unselected + grad = tl.where(sel, grad, 0.0) + # sigmoid derivative + grad = grad * fwd_out * (1.0 - fwd_out) + else: # softmax + if not USE_PRE_SOFTMAX: + og = tl.sum(tl.where(sel, fwd_out * grad, 0.0), axis=0) + grad = tl.where(sel, fwd_out * (grad - og), 0.0) + else: + grad = tl.where(sel, grad, 0.0) + dot = tl.sum(tl.where(valid, fwd_out * grad, 0.0), axis=0) + grad = fwd_out * (grad - dot) + + tl.store(grad_logits_ptr + base + offs, + tl.where(valid, grad, 0.0).to(fwd_out.dtype), + mask=valid) + + +# --------------------------------------------------------------------------- # +# Aux-loss score forward kernel +# --------------------------------------------------------------------------- # + +@triton.jit +def _fused_score_aux_loss_fwd_kernel( + logits_ptr, + scores_ptr, + routing_map_ptr, + intermediate_ptr, + num_tokens, + NUM_EXPERTS: tl.constexpr, + TOPK: tl.constexpr, + SCORE_FN: tl.constexpr, + BLOCK_E: tl.constexpr, +): + """Score computation for auxiliary loss — one program per token.""" + pid = tl.program_id(0) + if pid >= num_tokens: + return + + offs = tl.arange(0, BLOCK_E) + valid = offs < NUM_EXPERTS + base = pid * NUM_EXPERTS + + logits = tl.load(logits_ptr + base + offs, mask=valid, other=0.0).to(tl.float32) + + if SCORE_FN == 0: # sigmoid + scores = tl.sigmoid(logits) + else: # softmax + mx = tl.max(tl.where(valid, logits, float("-inf")), axis=0) + e = tl.exp(logits - mx) + e = tl.where(valid, e, 0.0) + scores = e / tl.sum(e, axis=0) + + # top-k for routing map + selected = tl.zeros([BLOCK_E], dtype=tl.int32) + topk_vals = tl.zeros([BLOCK_E], dtype=tl.float32) + selected, topk_vals = _iterative_topk( + scores, valid, offs, selected, topk_vals, + TOPK=TOPK, BLOCK_E=BLOCK_E, + ) + + tl.store(scores_ptr + base + offs, + tl.where(valid, scores, 0.0).to(logits.dtype), mask=valid) + tl.store(routing_map_ptr + base + offs, + selected.to(tl.int8), mask=valid) + tl.store(intermediate_ptr + base + offs, + tl.where(valid, scores, 0.0).to(logits.dtype), mask=valid) + + +# --------------------------------------------------------------------------- # +# Aux-loss score backward kernel +# --------------------------------------------------------------------------- # + +@triton.jit +def _fused_score_aux_loss_bwd_kernel( + intermediate_ptr, + grad_scores_ptr, + grad_logits_ptr, + num_tokens, + NUM_EXPERTS: tl.constexpr, + SCORE_FN: tl.constexpr, + BLOCK_E: tl.constexpr, +): + """Score backward for auxiliary loss — one program per token.""" + pid = tl.program_id(0) + if pid >= num_tokens: + return + + offs = tl.arange(0, BLOCK_E) + valid = offs < NUM_EXPERTS + base = pid * NUM_EXPERTS + + fwd_out = tl.load(intermediate_ptr + base + offs, mask=valid, other=0.0).to(tl.float32) + g = tl.load(grad_scores_ptr + base + offs, mask=valid, other=0.0).to(tl.float32) + + if SCORE_FN == 0: # sigmoid + grad = g * fwd_out * (1.0 - fwd_out) + else: # softmax + dot = tl.sum(tl.where(valid, fwd_out * g, 0.0), axis=0) + grad = fwd_out * (g - dot) + + tl.store(grad_logits_ptr + base + offs, + tl.where(valid, grad, 0.0).to(fwd_out.dtype), mask=valid) diff --git a/transformer_engine/pytorch/_lite/README.md b/transformer_engine/pytorch/_lite/README.md new file mode 100644 index 000000000..d4f2897ff --- /dev/null +++ b/transformer_engine/pytorch/_lite/README.md @@ -0,0 +1,585 @@ +# Transformer Engine Lite (`tealite`) + +``` + ) + ( + ) _ _ _ _ + ___|___ | |_ ___ __ _ | (_) |_ ___ + [_______] | __/ _ \/ _` || | | __/ _ \ + | || __/ (_| || | | || __/ + \__\___|\__,_||_|_|\__\___| + + TransformerEngine, by candlelight +``` + +A pure-Python drop-in replacement for the `transformer_engine_torch` C++ extension +module. It targets **ROCm / AMD GPUs** and eliminates the need for C++ compilation +by dispatching to [AITER](https://github.com/ROCm/aiter) kernels (CK / Triton), +standalone Triton kernels, or PyTorch-native ops. + +## Motivation + +The full Transformer Engine build compiles hundreds of C++ / CUDA / HIP sources +via CMake. This takes significant time, couples tightly to toolchain versions, and +makes rapid iteration difficult on ROCm. The lite module sidesteps all of that: + +- **No C++ compilation** -- build a wheel in seconds, not minutes. +- **No git submodules** -- AITER is an optional pip dependency (`pip install amd-aiter`). +- **Transparent activation** -- the module registers itself as + `transformer_engine_torch` via `sys.modules`, so all existing TE code works + without changes. + +## Building the `tealite` Wheel + +```bash +# From the repo root: +NVTE_LITE_ONLY=1 pip install . + +# Or build the wheel without installing: +NVTE_LITE_ONLY=1 python setup.py bdist_wheel +``` + +This produces a wheel named **`tealite`** containing only Python and Triton +sources. A `LITE_BUILD` marker file is embedded in the package so that lite mode +activates automatically at import time -- no environment variable needed. + +### Using lite mode with a full build + +If you have a full Transformer Engine build installed, you can activate lite mode +at runtime instead: + +```bash +NVTE_LITE=1 python train.py +``` + +## Runtime Backend Selection + +Most subsystems follow a tiered fallback: + +1. **AITER** (CK or Triton kernels from `amd-aiter`) -- best performance on MI300X +2. **Triton kernels** (bundled in `transformer_engine/pytorch/triton_kernels/`) +3. **PyTorch-native ops** -- always available, no extra dependencies + +GEMM backend can be forced via `NVTE_LITE_GEMM_BACKEND={pytorch,triton,ck}` +(default `pytorch`, which prefers `torch._scaled_mm` and falls back to AITER). + +## Environment Variables + +| Variable | Scope | Values | Default | Purpose | +|----------|-------|--------|---------|---------| +| `NVTE_LITE_ONLY` | build-time | `0` / `1` | `0` | When `1`, `setup.py` produces the `tealite` wheel (Python + Triton only, no C++ extensions) and embeds a `LITE_BUILD` marker so lite mode activates automatically at import. | +| `NVTE_LITE` | runtime | `0` / `1` | `0` | When `1`, forces lite dispatch at import time on a full build — `transformer_engine.pytorch` registers `_lite` as `transformer_engine_torch` in `sys.modules` instead of loading the C++ extension. Set automatically by `tealite` wheels via the `LITE_BUILD` marker. | +| `NVTE_LITE_GEMM_BACKEND` | runtime | `pytorch`, `triton`, `ck` | `pytorch` | Forces the GEMM backend in `_lite/gemm.py`. `pytorch` prefers `torch._scaled_mm` (hipBLASLt-backed on ROCm) for FP8 and falls back to AITER for FP8 shapes `_scaled_mm` can't serve (per-row scale on the reduction axis, block-scaled, unsupported dtype combos), with dequantize + `torch.matmul` as last resort. `triton` and `ck` route directly to AITER's Triton or CK kernels respectively. Read once at module import. | +| `NVTE_LITE_AMAX_FUSED` | runtime | `0` / `1` | `1` | When `1` (default), `fused_amax_and_scale_update_after_reduction` dispatches to a single Triton multi-tensor-apply kernel that mirrors `delayed_scaling.cu`'s `kernel_bulk`. Set to `0` to fall back to the per-group Python loop (e.g. for debugging or on a system where the Triton kernel fails to load). | +| `NVTE_LITE_SKIP_FP8_DGRAD_FOR_NORM` | runtime | `0` / `1` | `0` | Opt-in optimization for `LayerNormLinear` / `LayerNormMLP` fused modules: when set, the dgrad GEMM emits BF16 instead of FP8 if the only downstream consumer is the norm backward (which would dequantize anyway). Eliminates a BF16 → FP8 → BF16 round-trip; DelayedScaling amax bookkeeping is preserved via a standalone reduction (`amax_utils.update_amax_from_bf16`). Scoped to `Float8Quantizer` and `Float8CurrentScalingQuantizer`; MXFP8 is skipped because per-block scales can't be reconstructed from amax alone. | +| `NVTE_LITE_DIAG` | runtime | `0` / `1` | `0` | Enables one-shot diagnostic prints from `_lite/{gemm,norms,attention,quantize}.py` (and `module/base.py`) capturing shapes, dtypes, scale layout, scaled-mm rejection reasons, etc. Off by default; intended for triaging numerical or dispatch issues. | + +## Module Structure + +``` +_lite/ + __init__.py # Public API -- mirrors transformer_engine_torch exports + enums.py # Pure-Python re-declarations of C++ enum types + aiter_utils.py # Shared AITER availability detection (lru_cache) + amax_utils.py # BF16 amax-update helper for skip-FP8-dgrad path + + # Compute kernels + gemm.py # GEMM dispatch (torch._scaled_mm, AITER CK/Triton, PyTorch matmul) + grouped_gemm.py # Grouped GEMM for MoE (AITER Triton GMM, BF16/FP16; FP8 NYI) + attention.py # Fused attention (AITER CK, flash-attn stub, SDPA) + norms.py # LayerNorm / RMSNorm (Triton, PyTorch) + activations.py # Activation functions (AITER fused gated, PyTorch) + rope.py # Rotary position embeddings (AITER, PyTorch) + quantize.py # FP8/MXFP8/MXFP4 quantization (Triton cast, PyTorch) + softmax.py # Scaled/masked softmax variants (PyTorch) + dropout.py # Dropout (PyTorch) + transpose.py # FP8 transpose ops + + # Compound modules (pure-Python autograd Functions, lazy-loaded to avoid + # circular import with tex registration; see __init__.py) + fused_layernorm_linear.py # LayerNorm+Linear fused autograd Function + fused_layernorm_mlp.py # LayerNorm+MLP fused autograd Function + + # Structured / MOE + permutation.py # MOE token permutation (Triton sort, PyTorch gather) + router.py # MOE router ops -- topk, aux loss (PyTorch) + padding.py # Multi-row padding / unpadding + + # Distributed + comm.py # Comm-overlap stubs (not available; use torch.distributed) + context_parallel.py # THD <-> BSHD conversion helpers + mori_ep.py # Expert parallelism via MORI (dispatch/combine, autograd) + + # Optimizer + multi_tensor.py # Multi-tensor Adam, SGD, scale, L2 norm (PyTorch) + + # Misc + misc.py # Utility stubs +``` + +--- + +## Feature Status + +Each section below compares the lite module against the full C++ build. + +### GEMM + +| Feature | Lite | Full Build | +|---------|------|------------| +| BF16 / FP16 / FP32 GEMM | AITER Triton or `torch.matmul` | cuBLAS / hipBLAS | +| Per-tensor FP8 x FP8 | AITER CK (`gemm_a8w8`) | cuBLAS | +| Per-row FP8 x FP8 | AITER Triton (`gemm_a8w8_per_token_scale`) | N/A | +| Block-scaled FP8 x FP8 | AITER CK/Triton (`gemm_a8w8_blockscale`) | cuBLAS | +| Mixed precision (FP16 x FP8) | AITER CK (`gemm_a16w8`) | cuBLAS | +| MXFP4 x MXFP4 | AITER CK/Triton (`gemm_a4w4`) | cuBLAS | +| Grouped GEMM | AITER Triton GMM | cuBLAS grouped | +| Bias epilogue | Yes | Yes | +| GELU epilogue | Yes | Yes | +| Accumulation epilogue | Yes | Yes | +| Multi-stream cuBLAS | No | Yes | + +**Gaps:** No multi-stream execution. Performance depends on AITER kernel +maturity for each precision/shape combination. The default `pytorch` backend +routes FP8×FP8 through `torch._scaled_mm` (hipBLASLt-backed on ROCm), which +keeps FP8 memory bandwidth — only when `_scaled_mm` rejects the combo +(per-row scale on the reduction axis, certain block-scaled or unsupported +dtype combos) does the GEMM fall through to dequantize + `torch.matmul`, +which loses the FP8 bandwidth advantage. + +--- + +### Attention + +| Feature | Lite | Full Build | +|---------|------|------------| +| Dense attention (BSHD, SBHD) | AITER CK / SDPA | CK / cuDNN / flash-attn | +| Variable-length (THD) | AITER CK varlen | CK / cuDNN | +| Causal masking | Yes | Yes | +| Padding masking | Yes | Yes | +| Sliding window | AITER only | Yes | +| ALiBi / bias types | AITER only | Yes | +| GQA (grouped query) | Yes (head expansion) | Yes | +| Dropout | Yes | Yes | +| KV cache copy | Yes | Yes | +| cuDNN backend | No | Yes | +| flash-attn package | Stub (NotImplementedError) | Yes | +| Softmax stats (LSE) | AITER only; SDPA returns dummy | Yes | + +**Gaps:** No cuDNN attention backend. Flash-attn integration is stubbed out. +SDPA fallback does not return real softmax statistics (LSE), which can affect +numerics in some training configurations. Sliding window and bias types require +AITER -- no PyTorch fallback for those features. + +**FP8 attention flags (`fp8_dpa`, `fp8_mha`):** Not supported. AITER, PyTorch +SDPA, and the stubbed flash-attn path all operate on bf16/fp16 — there is no +FP8 attention kernel in lite. Setting either flag to `True` on the recipe +raises a clear `NotImplementedError` from `get_fused_attn_backend` pointing +back at `fp8_dpa=False / fp8_mha=False`. The default recipe (both flags +`False`, which is the default) continues to work — attention runs bf16 while +GEMMs use FP8. See `TestFP8AttentionFlags` for the contract. + +--- + +### Activations + +| Feature | Lite | Full Build | +|---------|------|------------| +| Non-gated (GeLU, ReLU, SiLU, QGeLU, SReLU) | PyTorch ops | CUDA kernels | +| Gated (GeGLU, SwiGLU, ReGLU, QGeGLU, SReGLU) | AITER fused or PyTorch | CUDA kernels | +| ClampedSwiGLU | PyTorch | CUDA kernel | +| All backward variants | Yes | Yes | +| Fused dbias + dact (non-gated) | Yes | Yes | +| Fused dbias + dact (gated) | No | Yes | +| Fused activation + FP8 quantization (gated) | AITER (`act_mul_and_fp8_group_quant`) | Yes (FULLY_FUSED, FUSED_AMAX_FP8, NVFP4) | +| Fused activation + FP8 quantization (non-gated) | No (quantize post-compute) | Yes | + +**Gaps:** Gated activations (SwiGLU, GeGLU, ReGLU) use AITER's +`act_mul_and_fp8_group_quant` to fuse activation + gate multiply + FP8 quantize +into a single kernel, eliminating the BF16 intermediate round-trip. This covers +both `Float8BlockQuantizer` (per-block scales, `group_size=block_len`) and +`Float8CurrentScalingQuantizer` (per-row scales, `group_size = output_hidden_dim` +so each row gets one scale). Non-gated activations (GeLU, ReLU, SiLU, etc.) +still run as separate ops with post-compute quantize. Gated dbias fusions are +missing. Activations outside the 6 with explicit paths (swiglu/geglu/reglu for +fused + gelu/silu/relu for basic) fall back to unfused PyTorch ops. + +--- + +### LayerNorm / RMSNorm + +| Feature | Lite | Full Build | +|---------|------|------------| +| LayerNorm forward / backward | AITER Triton > TE Triton > PyTorch | CUDA tuned kernels | +| RMSNorm forward / backward | AITER Triton > TE Triton > PyTorch | CUDA tuned kernels | +| RMSNorm backward + add | Yes | Yes | +| Zero-centered gamma | Yes | Yes | +| Fused RMSNorm + FP8 quant (delayed) | AITER (`fused_rms_fp8_per_tensor_static_quant`) | CUDA kernel | +| Fused RMSNorm + FP8 quant (current, per-row) | AITER (`rmsnorm2d_fwd_with_dynamicquant`) | N/A | +| Fused RMSNorm + MXFP8 quant (block) | AITER (`fused_rms_fp8_group_quant`) — partial | CUDA kernel | +| Output quantization (generic) | Yes | Yes | +| cuDNN backend | No | Yes (optional) | +| Pre-tuned hidden sizes (28 sizes) | No (auto-tune) | Yes | +| Fused LayerNormLinear | Yes (pure-Python autograd Function) | Yes (CUDA) | +| Fused LayerNormMLP | Yes (pure-Python autograd Function) | Yes (CUDA) | +| SM margin (backward) | Ignored | Full per-stage control | + +**Gaps:** No cuDNN backend or pre-tuned CUDA kernels. SM margin control is +ignored in the backward pass. Distributed-parallelism status (TP/SP/FSDP2) +for the fused compound modules is documented in the +[Communication / Distributed](#communication--distributed) section. + +`LayerNormLinear` and `LayerNormMLP` are implemented as pure-Python +`torch.autograd.Function` subclasses in `_lite/fused_layernorm_linear.py` and +`_lite/fused_layernorm_mlp.py`. They reuse the AITER fused norm+quant path +when FP8 is active, then chain to the Linear/MLP GEMMs. This is not the same +thing as the full build's single CUDA kernel, but functionally covers the same +API surface — including `return_bias`, `return_layernorm_output`, and all +supported activations. + +The core norm operations are the strongest lite subsystem. AITER Triton kernels +are the primary backend with TE Triton and PyTorch fallbacks. The fused +RMSNorm+FP8 quantize path for CurrentScaling is a lite-only feature -- it fuses +norm and per-row quantize into a single kernel, which is not available in the +full C++ build (see [FP8 Training](#fp8-training) below). + +--- + +### RoPE (Rotary Position Embeddings) + +| Feature | Lite | Full Build | +|---------|------|------------| +| Basic RoPE (forward / backward) | AITER or PyTorch | CUDA kernel | +| QKV fused RoPE | Yes | Yes | +| Tensor formats (sbhd / bshd / thd) | None (single assumed layout) | All three | +| Interleaved cos/sin layout | No | Yes | +| Partial RoPE (`rotary_percent`) | No | Yes | +| `start_positions` (KV-cache inference) | No | Yes | +| `cu_seqlens` (THD ragged packing) | No | Yes | +| Context parallelism (`cp_size` / `cp_rank`) | Yes | Yes | +| Position interpolation (NTK-like) | No | Yes | +| `RotaryPositionEmbedding` module | No | Yes | + +**Gaps:** Only the simplest layout works -- apply rotation to a dense tensor with +a single assumed layout. Missing `start_positions` blocks KV-cache inference, +missing `cu_seqlens` blocks variable-length batching. Context parallelism is +supported for multi-GPU training with `cp_size` / `cp_rank` parameters. + +--- + +### Quantization + +| Feature | Lite | Full Build | +|---------|------|------------| +| Per-tensor Float8 (e4m3 / e5m2) | Triton cast kernel | CUDA kernel | +| Per-row dynamic Float8 (CurrentScaling) | AITER (`dynamic_per_token_quant_fp8_i8`) | N/A | +| MXFP8 (block-scaled) | Triton cast kernel | CUDA kernel | +| MXFP4 | Triton cast kernel | CUDA kernel | +| Dequantize | Yes | Yes | +| Bias-grad + quantize | Yes | Yes | +| Multi-tensor quantize | Yes | Yes | +| Amax compute / update | Yes | Yes | +| Block-scaling partial amax / cast | Yes | Yes | +| Fused cast + transpose | Triton (noop variant) | CUDA kernel | +| FP8 recipe management (`fp8_autocast`, recipes) | Yes (pure Python, shared) | Yes | + +**Gaps:** Minimal. The Triton cast kernels cover all major quantization formats. +When a `Float8CurrentScalingQuantizer` is used and AITER is available, all +quantize calls automatically use per-row dynamic scaling instead of per-tensor -- +this is strictly better (higher precision, single kernel) and happens +transparently. Performance difference vs CUDA kernels varies by shape and dtype. + +--- + +### FP8 Training + +The lite module supports three FP8 scaling recipes, each with different +trade-offs. The CurrentScaling per-row path is a **lite-only optimization** that +is not available in the full C++ build. + +| Recipe | Scaling granularity | Lite backend | Full Build | +|--------|-------------------|--------------|------------| +| `DelayedScaling` | Per-tensor (history window) | AITER fused norm+quant / Triton cast | CUDA kernels | +| `Float8CurrentScaling` | **Per-row dynamic** | AITER fused norm+quant / per-token GEMM | Per-tensor CUDA kernels | +| `MXFP8BlockScaling` | Per-block (128×128 or 1×128) | Triton cast / AITER block GEMM | CUDA kernels | +| `Float8BlockScaling` | Per-block (128×128) | Triton cast / AITER block GEMM | CUDA kernels | + +#### CurrentScaling per-row fusion (lite-only) + +The full C++ build implements `Float8CurrentScaling` as **per-tensor** current +scaling: scan the entire tensor for `amax`, compute one scale, then quantize. +This requires three kernel launches and three full memory passes: + +``` +Kernel 1: RMSNorm(input) → BF16 output [read input, write BF16 to HBM] +Kernel 2: amax = max(abs(BF16 output)) [read BF16 from HBM, write scalar] +Kernel 3: FP8 = BF16 output × scale [read BF16 from HBM, write FP8] +``` + +The lite module replaces this with AITER's `rmsnorm2d_fwd_with_dynamicquant` +which fuses norm + quantize into a **single kernel** with **per-row** scaling. +Each row computes its own `max(abs(...))` in registers and quantizes before the +data leaves SRAM. The BF16 intermediate never touches HBM: + +``` +Kernel 1: RMSNorm+Quant(input) → FP8 output + yscale(M,) [1 read, 1 write] +``` + +This works because per-row scaling removes the global data dependency: row 0's +scale doesn't depend on row 1's data, so the fused kernel can process each row +independently. + +**Forward path:** + +| Step | Kernel | Input → Output | +|------|--------|---------------| +| Fused norm+quant | `rmsnorm2d_fwd_with_dynamicquant` | BF16 → FP8 + scale(M,) | +| GEMM | `gemm_a8w8_per_token_scale` | FP8 × FP8 → BF16 | + +**Backward path (dgrad):** + +| Step | Kernel | Input → Output | +|------|--------|---------------| +| Per-row quant dY | `dynamic_per_token_quant_fp8_i8` | BF16 → FP8 + scale(M,) | +| dgrad GEMM | `gemm_a8w8_per_token_scale` | FP8 × FP8 → BF16 | + +**Backward path (wgrad):** + +Per-row scales are along the reduction axis (M) — incompatible with per-token +GEMM. Falls back to per-tensor `gemm_a8w8_CK`, which is acceptable since the +reduction across tokens averages out outliers. + +**Precision advantage:** Per-row scaling gives each token its own optimal scale +factor. A batch with high-magnitude outlier tokens no longer forces every token +to share a single scale driven by the outlier's magnitude. This is especially +beneficial for long-context training where activation magnitudes vary widely +across sequence positions. + +**Detection is automatic:** When `Float8CurrentScaling` recipe is used in lite +mode and AITER is available, the per-row path activates transparently — no +configuration needed. The `Float8Tensor._scale_inv` field carries shape `(M,)` +instead of `(1,)`, and the GEMM dispatch detects this and routes to +`gemm_a8w8_per_token_scale`. + +--- + +### MOE (Mixture of Experts) + +| Feature | Lite | Full Build | +|---------|------|------------| +| Token permutation (forward / backward) | Triton sort + PyTorch gather | CUDA kernel | +| Token unpermutation | PyTorch gather + scatter | CUDA kernel | +| Top-k routing | Fused Triton kernel | CUDA fused kernel | +| Auxiliary load-balancing loss | Fused Triton kernel | CUDA fused kernel | +| Score functions (softmax, sigmoid) | Fused Triton kernel | CUDA fused kernel | +| Grouped GEMM — BF16 / FP16 (fwd / dgrad / wgrad) | AITER Triton GMM (`gmm` / `ptgmm`) | cuBLAS grouped | +| Grouped GEMM — FP8 | **Not yet supported** (NYI) | cuBLAS grouped | +| `te.GroupedLinear` / `GroupedMLP` (BF16) | Yes — `tex.te_general_grouped_gemm` hot-swap | Yes | + +**Gaps:** Router and permutation ops are functionally complete (fused Triton +kernel for topk/scoring/aux-loss in a single pass; full build uses fused CUDA +kernels). Performance difference is most visible at high expert counts. + +The expert compute path for `GroupedLinear` / `GroupedMLP` is served by +`_lite/grouped_gemm.py`, which adapts AITER's Triton GMM kernels (`gmm`, +`ptgmm`) to the C++ `te_general_grouped_gemm` signature — no `_lite/` +GroupedLinear module is needed; the tex hot-swap is sufficient. **FP8 grouped +GEMM is not yet supported**: AITER's generic GMM family is BF16/FP16 only +(the `p`/`np` prefix is persistent vs non-persistent kernel, not per-tensor +scaling), and FP8 expert compute lives in AITER as a separate fused-MoE op +(`aiter.fused_moe`, `moe_op_gemm_a8w8_blockscale`) with a different API +shape. Run with `TE_FP8=0` for MoE training in lite mode until the Phase 2 +dispatcher lands. See also the `TestGroupedLinear::test_fp8_forward` xfail +under [Known xfails](#known-xfails). + +--- + +### Communication / Distributed + +| Feature | Lite | Full Build | +|---------|------|------------| +| Comm-overlap (AG/RS + GEMM) | **Not available** (stubs raise error) | Full support | +| NVSHMEM integration | **Not available** | Full support | +| Expert parallelism (EP) | MORI dispatch/combine | NCCL / NVSHMEM | +| `torch.distributed` | Works normally | Works normally | +| FSDP2 integration | Yes — `use_fsdp2=True` wraps weights in `FSDPAGTensor` (1D mesh; HSDP / 2D mesh not yet plumbed) | Yes | +| Tensor parallelism | No built-in support (compound modules accept `tp_size`/`tp_group`/`parallel_mode` for API compat but ignore them; hardcoded `tp_size=1`) | Integrated in modules | +| Sequence parallelism (Megatron-style) | No built-in support (requires TP) | Integrated in modules | +| Context parallelism | RoPE + attention CP supported; THD <-> BSHD conversion helpers | Full support | + +**Gaps:** Comm-overlap APIs remain stubs. Tensor parallelism and Megatron-style +sequence parallelism (which is a TP optimization) have no built-in support in +lite's fused compound modules — `LayerNormLinear` / `LayerNormMLP` accept the +related kwargs for API compatibility but hardcode `tp_size=1`. The multi-node +story for lite is therefore FSDP/HSDP-shaped, not TP-shaped. FSDP2 with a 1D +mesh is supported and tested; HSDP (2D mesh) requires `device_mesh` plumbing +through the compound modules and is not yet wired. + +Expert parallelism is supported via the MORI integration (see below), which +bridges the most significant distributed gap for MoE workloads. Context +parallelism (sequence sharding without TP, e.g. Ulysses-style) is supported +in attention and RoPE. + +--- + +### Expert Parallelism (MORI) + +The `mori_ep` module integrates AMD's [MORI](https://github.com/ROCm/mori) +(Modular RDMA Interface) library to provide high-performance distributed expert +parallelism for MoE pipelines. MORI handles token dispatch/combine across GPUs +using XGMI (intra-node) and RDMA (inter-node) without requiring C++ extensions. + +**Requirements:** `pip install mori` (or build from source with ROCm 6.4+). +MORI shmem must be initialized after `torch.distributed.init_process_group()`. + +| Feature | Lite (MORI) | Full Build | +|---------|-------------|------------| +| Token dispatch (flat layout) | `MoriExpertParallel.dispatch` | NCCL all-to-all | +| Token combine (flat layout) | `MoriExpertParallel.combine` | NCCL all-to-all | +| Per-expert layout (grouped GEMM) | `dispatch_standard_moe` / `combine_standard_moe` | Custom kernels | +| Layout conversion (flat <-> per-expert) | `convert_dispatch_to_standard` / `convert_standard_to_combine_input` | N/A | +| Autograd (flat layout training) | `MoriEPDispatch` / `MoriEPCombine` | Integrated in MoE module | +| Autograd (per-expert layout training) | `MoriEPDispatchStdMoE` / `MoriEPCombineStdMoE` | Integrated in MoE module | +| Routing map conversion (mask <-> index) | `mask_to_index` / `index_to_mask` | N/A (native format) | +| Intra-node transport (XGMI) | Yes | N/A (NCCL) | +| Inter-node transport (RDMA) | Yes (multiple kernel types) | N/A (NCCL) | +| FP8 quantized dispatch | `fp8_direct_cast` mode | Yes | +| Convenience dispatch+combine cycle | `dispatch_and_combine` | N/A | + +**Kernel types:** `intra_node` (default), `inter_node`, `inter_node_v1`, +`inter_node_v1_ll`, `async_ll`. Standard MoE layout requires `intra_node` or +`inter_node_v1_ll`. + +**EP gaps vs full build:** + +- **No integration with TE's `MoE` module layer** -- MORI EP is a standalone + primitive. The full build's `MoE` module handles EP dispatch/combine + transparently within its forward pass; with lite, you must call the MORI APIs + explicitly in your training loop. +- **No comm-overlap with expert GEMM** -- dispatch and GEMM run sequentially. + The full build can overlap EP communication with expert computation. +- **No pipeline-parallel EP** -- only data-parallel expert parallelism is + supported. No integration with pipeline stages or interleaved scheduling. +- **No heterogeneous expert placement** -- assumes uniform + `num_experts_per_rank` across all ranks. The full build supports uneven expert + distribution. +- **Standard MoE layout limited to two kernel types** -- `dispatch_standard_moe` + / `combine_standard_moe` require MORI built with `ENABLE_STANDARD_MOE_ADAPT=ON` + and only work with `intra_node` or `inter_node_v1_ll` kernels. + +--- + +### Multi-Tensor Optimizer Ops + +| Feature | Lite | Full Build | +|---------|------|------------| +| Multi-tensor Adam | PyTorch | C++ fused kernel | +| Multi-tensor SGD | PyTorch | C++ fused kernel | +| Multi-tensor scale | PyTorch | C++ fused kernel | +| Multi-tensor L2 norm | PyTorch | C++ fused kernel | +| FP8 Adam | PyTorch | C++ fused kernel | +| Capturable Adam | PyTorch | C++ fused kernel + CUDA graphs | + +**Gaps:** All functionally correct but use per-tensor PyTorch loops instead of +fused multi-tensor C++ kernels. Optimizer step overhead is higher but typically +not the training bottleneck. + +--- + +## Running Tests + +The lite module has a dedicated test suite at `tests/pytorch/test_lite.py`. +All tests run entirely in lite mode (the file sets `NVTE_LITE=1` before +importing TE, so the C++ extension is never loaded). + +```bash +# Full lite test suite +pytest tests/pytorch/test_lite.py -v + +# One test class +pytest tests/pytorch/test_lite.py::TestRecipeIntegration -v + +# Tests filtered by name +pytest tests/pytorch/test_lite.py -k "current_scaling" -v +``` + +### Test Coverage + +| Class | What it covers | +|-------|----------------| +| `TestImport` | Module loads, key symbols exist | +| `TestForward` | bf16 forward for Linear / LayerNormLinear / LayerNormMLP / LayerNorm / RMSNorm / TransformerLayer | +| `TestBackward` | Same modules with `loss.backward()` | +| `TestNumerical` | Lite output vs `torch.nn` reference (FP32 exact / BF16 close) | +| `TestTritonNorms` | Triton + AITER norm kernels, fused norm+quant for Float8Quantizer / Float8CurrentScalingQuantizer / MXFP8Quantizer | +| `TestQuantize` | FP8 quantize/dequantize (no-recursion), `bgrad_quantize`, CurrentScaling per-row | +| `TestMXFP8` | MXFP8 tensor detection, E8M0 scale conversion, roundtrip error | +| `TestGemm` | `generic_gemm` with all transpose combinations, bias epilogue, bias-gradient epilogue | +| `TestAttention` | Fused attention: BSHD / SBHD / THD layouts, causal / padding masks, GQA, bias | +| `TestMoERouter` | MoE router top-k, softmax / sigmoid score functions, aux-loss | +| `TestMoEPermutation` | Token permute / unpermute / roundtrip, gradient shapes | +| `TestMoEPadding` | Multi-row pad / unpad / roundtrip across dtypes | +| `TestLiteLayerNormLinear` | LayerNormLinear bf16 forward+backward, LayerNorm/RMSNorm variants, `return_layernorm_output` | +| `TestLiteLayerNormMLP` | LayerNormMLP bf16 forward+backward, non-gated + gated activations | +| `TestFusedGatedActQuant` | AITER fused gated act + block FP8 quantize (swiglu/geglu/reglu × Float8BlockQuantizer) | +| `TestFusedGatedActCurrentScaling` | AITER fused gated act + per-row FP8 quantize (swiglu/geglu/reglu × Float8CurrentScalingQuantizer, `group_size = N/2`) | +| `TestRecipeIntegration` | Full `te.autocast(recipe=...)` path for Linear / LayerNormLinear / LayerNormMLP / TransformerLayer × DelayedScaling / Float8CurrentScaling; multi-step loops; FP8 vs bf16 correlation | +| `TestLiteAPI` | Public symbol presence, tex function signatures, DType enum, module constructor kwargs, regression tests | +| `TestFP8Training` | `optimizer.step()`-driven training — overfit-a-batch (loss must drop), FP8 vs bf16 weight tracking, cache-invalidation | +| `TestFP8AttentionFlags` | `fp8_dpa=True` / `fp8_mha=True` raise clean `NotImplementedError`; default flags work | +| `TestGroupedLinear` | GroupedLinear forward+backward, output matches manual F.linear per chunk, uneven m_splits | + +Total: **~285 tests** covering forward, backward, FP8 recipes (DelayedScaling / +Float8CurrentScaling end-to-end), API contracts, training loops, and MoE ops. +The suite is the primary gate against regressions in the lite build. + +### Known xfails + +- `TestGroupedLinear::test_fp8_forward` — FP8 GroupedLinear hits a dtype + mismatch in the Triton GMM wrapper (`lhs=fp32` vs `bias=bf16`). This is a + pre-existing issue in `triton_kernels/gmm/gmm_common.py`; out of scope for + the lite adapter. The marker is `strict=True`, so if the Triton fix lands + upstream the test will fail-loud (XPASS → FAIL) to force a deliberate flip. + +### Adding new tests + +- Any new kernel or dispatch path added to `_lite/` should get a regression + test in `test_lite.py`. Prefer the test class closest to the feature + (e.g. a new GEMM kernel → `TestGemm`, a new recipe-level feature → + `TestRecipeIntegration`). +- Tests that exercise FP8 recipes should use the `_RECIPES_FWD_BWD` or + `_RECIPES_FWD` helpers to parametrize across whatever recipes the hardware + supports, so tests skip cleanly on unsupported hardware. +- FP8-vs-bf16 correlation (cosine similarity ≥ 0.9 for single modules, + ≥ 0.75 for TransformerLayer) is the standard numerical check — catches + silent wrong-dispatch and scale-broadcast bugs. + +--- + +## Summary + +| Subsystem | Functional Coverage | Performance | Key Backend | +|-----------|-------------------|-------------|-------------| +| GEMM | Full (incl. per-row FP8) | Good | `torch._scaled_mm` (hipBLASLt) / AITER CK / Triton | +| Attention | Full | Good (AITER) | AITER CK / SDPA | +| Norms | Full + fused norm+quant | Good (AITER) | AITER Triton / TE Triton | +| FP8 Training | Full (3 recipes) | **Best** (fused per-row) | AITER fused kernels | +| Activations | Full | Moderate | AITER (2 ops) / PyTorch | +| Quantization | Full + per-row dynamic | Good (AITER/Triton) | AITER / Triton cast | +| RoPE | Basic + CP | Moderate | AITER / PyTorch | +| MOE | BF16 full; FP8 grouped GEMM NYI | Good (Triton) | Triton fused router + AITER Triton GMM | +| Expert parallelism | Full (standalone) | Good (MORI) | MORI XGMI/RDMA | +| Comm-overlap | **None** | N/A | Stubs | +| Multi-tensor ops | Full | Lower | PyTorch loops | + +The lite module provides **functional correctness** across all major compute +paths. Performance is competitive for GEMM, attention, norms, and quantization +where AITER or Triton kernels are available. The **FP8 CurrentScaling per-row +fusion** is a lite-only optimization that outperforms the full build's per-tensor +path by eliminating two HBM round-trips per norm+quantize operation. Expert +parallelism is available via MORI for distributed MoE workloads. The remaining +primary gaps are **comm-overlap** (not available), **tensor/sequence +parallelism** (no built-in support in lite's compound modules), **FP8 grouped +GEMM** (BF16/FP16 only — blocks FP8 MoE training, see the MOE section), and a +handful of FP8 attention paths (`fp8_dpa` / `fp8_mha` — see the Attention +section). diff --git a/transformer_engine/pytorch/_lite/SKILLS.md b/transformer_engine/pytorch/_lite/SKILLS.md new file mode 100644 index 000000000..73228b0ab --- /dev/null +++ b/transformer_engine/pytorch/_lite/SKILLS.md @@ -0,0 +1,235 @@ +# Transformer Engine Lite — Working Notes + +Operational knowledge for engineers (and agents) modifying `tealite`. The +[README](README.md) is the feature/coverage reference; this file is the +"what we have learned" complement — invariants, gotchas, dead ends, and +measurement protocol that aren't visible from the code alone. + +## Mental model + +- **Lite is a `sys.modules` swap, not a fork.** `transformer_engine.pytorch` + registers `_lite` as `transformer_engine_torch` at import time when + `NVTE_LITE=1` (or the `LITE_BUILD` marker is present). Every `tex.` call + in the rest of the codebase resolves into `_lite/`. **Implication:** you + almost never need to change call sites in `module/`, `cpp_extensions/`, etc. + — implementing the function in `_lite/` is enough. +- **Tiered fallback per subsystem:** AITER → bundled Triton → PyTorch-native. + If you add a new path, preserve this order; PyTorch fallback must stay + reachable when AITER is missing or rejects the inputs. +- **Quantizer drives the kernel, not the recipe enum.** The dispatch decisions + in `gemm.py`, `quantize.py`, and `norms.py` branch on the *Quantizer class* + (`Float8Quantizer`, `Float8CurrentScalingQuantizer`, `Float8BlockQuantizer`, + `MXFP8Quantizer`) and on `_scale_inv.numel()` / `_scale_inv.shape` — never + on a recipe string. When adding a new path, key off the same. + +## Where to add code + +| Need to | Do | +|---|---| +| Replace a `tex` C++ function in lite | Implement it in `_lite/.py`, export from `_lite/__init__.py` | +| Wrap a new AITER kernel | Add a thin dispatcher in the relevant `_lite/.py`; gate it on `_aiter_available()` from `aiter_utils.py` (lru_cached) | +| Cover a new compound module that already calls `tex.*` | **Do nothing in `_lite/`.** The tex hot-swap is sufficient. Verify with a smoke test. (Validated for `te.GroupedLinear` BF16 — 2026-04-28.) | +| Change a Quantizer's behavior in lite | Edit the dispatch in `_lite/quantize.py` and `_lite/norms.py`; do **not** modify the shared `Float8Tensor` class — the same class is used by full TE | +| Add a perf-sensitive elementwise path | Write or reuse a Triton kernel before adding PyTorch ops — fragmented PyTorch elementwise launches are the dominant remaining penalty (see "Performance baselines") | + +The fused `LayerNormLinear` / `LayerNormMLP` are pure-Python `autograd.Function` +subclasses (`fused_layernorm_linear.py`, `fused_layernorm_mlp.py`) loaded +**lazily** from `__init__.py` to avoid circular import with the tex +registration. If you touch the `__init__.py` import order, run `TestImport` +and `TestLiteLayerNormLinear` to verify lazy-load still works. + +## Numerical & dispatch hazards + +1. **Scale shape selects the kernel family — not just a numeric value.** + `torch._scaled_mm` routes per-tensor scalars to the F8NBS/F8B8NBS kernel + family (covers same-dtype *and* mixed-dtype FP8); broadcasting the same + scalar to `(M,1)` / `(1,N)` forces the rowwise family, which has no + mixed-dtype coverage on current ROCm. Per-tensor → 0-dim scalar; per-row → + `(M,1)` / `(1,N)`. Never broadcast scalar to rowwise shape. (Fixed in + `5a660e9c`.) +2. **AITER `gemm_a8w8_CK` rejects mixed FP8 dtypes.** Both operands must + share the same FP8 dtype. Standard FP8 training uses E4M3 × E5M2 for + dgrad/wgrad — ~48% of backward GEMMs hit this. Route mixed-dtype to + `torch._scaled_mm` (the default `pytorch` backend already does this). +3. **AITER fused RMSNorm+FP8 kernels write only `_data`, not the columnwise + transpose buffer.** If `make_empty(columnwise_usage=True)` allocated it, + the buffer is uninitialized — set `_transpose_invalid = True` after + filling `_data`, or downstream `update_usage(columnwise_usage=True)` + trusts stale bytes. +4. **Per-row scales on the reduction axis can't use per-token GEMM.** wgrad + under `Float8CurrentScaling` falls back to per-tensor `gemm_a8w8_CK`. This + is correct (reduction across tokens averages out outliers); don't try to + "fix" it without changing the operand layout. +5. **`Float8CurrentScalingQuantizer` per-row is strictly better than per-tensor** + in lite — the AITER `dynamic_per_token_quant_fp8_i8` path intercepts + *before* per-tensor quantize. Don't restore the per-tensor branch as a + "default" — it loses precision and adds 2 HBM round-trips (see README § + FP8 Training). +6. **AITER `aiter/ops/triton/fused_fp8_quant.py` line ~83 has a bug:** + `out1_col_stride = out2.stride(1)` should be `out1.stride(1)` — crashes + when `output_unquantized_inp1=True` and `inp2=None`. Reported upstream; + may already be fixed when you read this. +7. **`gemm_a8w8_CK` falls back to default config for untuned shapes.** The + warning *"shape ... not found tuned config in a8w8_tuned_gemm.csv, will + use default config!"* means CK is running with `splitK=0` — no + exception, just slow. Non-round M (e.g. 8184 = 2046×4) is the usual + miss. Either run AITER's a8w8 tuner against your shape set or prefer the + `pytorch` GEMM backend (default). + +## Multi-node planning + +- **TE-lite has no TP/SP** — `fused_layernorm_linear.py:456` and + `fused_layernorm_mlp.py:505` hardcode `self.tp_size = 1`. Kwargs are + accepted for API compatibility; setting `tp_size > 1` blows up downstream + in Megatron's QKV reshape. Multi-node plans for tealite must be + FSDP/HSDP-shaped. +- **Comm-overlap (AG/RS + GEMM) is unimplemented;** `_lite/comm.py` raises. +- **Expert parallelism works via `_lite/mori_ep.py`** but is a standalone + primitive — call its dispatch/combine APIs explicitly; there is no + integration into a TE `MoE` module. +- **MoE BF16 GroupedLinear works for free** through the tex hot-swap → + `_lite/grouped_gemm.py` (AITER Triton GMM). FP8 grouped GEMM is NYI: + AITER's generic GMM is BF16/FP16 only, and FP8 expert compute lives + separately in `aiter.fused_moe`. Run BF16-only for MoE until that path + lands. (`TestGroupedLinear::test_fp8_forward` is xfail-strict.) + +## Performance baselines (LLaMA-3-8B, 8×MI300X, seq=2048, RECOMPUTE=0) + +As of 2026-05-01, lite ≈ full at the same TE commit: + +| Mode | ms/iter | tok/GPU/s | +|---|---:|---:| +| full @ same commit | 1712.5 | 9567 | +| lite | 1712.0 | 9570 | + +**The earlier "lite is 5–10% slower" gap was a stale-build artifact** — +the supposedly-faster full was an older commit (`f141f34b`, 1599 ms). There +is a confirmed ~7% regression between `f141f34b` and HEAD that is *not* a +lite issue; bisect harness in `/root/bisect/`. + +The dominant remaining lite-vs-full kernel-time penalty (when one exists) is +**Triton-fragmented elementwise/copy ops** — top offenders are +`multi_tensor_apply_kernel`, fused SwiGLU+bias kernels, FSDP shape-shuffle +copies. GEMM, FMHA, RCCL, and fused norm+quant are all at parity or better. + +### How to re-profile + +```bash +NVTE_LITE_GEMM_BACKEND=pytorch NVTE_LITE_DIAG=1 +``` + +Sanity-check the diag counters in stdout: + +- `pytorch_scaled_mm_ok` ≈ 2/3 of FP8 GEMMs (fwd + dgrad) +- `pytorch_aiter_fallback_ok` ≈ 1/3 (wgrad — K=8184 hits `k_not_div16`) +- `pytorch_dequant_matmul` should be **zero**. Any hits = both `_scaled_mm` + and AITER rejected; that's a 100–1000× slowdown. + +For Megatron runs, pass `--attention-backend fused` to force the AITER AOT +`fmha_v3_fwd/bwd` path. Without it, Megatron defaults to `auto`, sets +`NVTE_FLASH_ATTN=1`, and routes through ROCm `flash_attn` 2.8.3 which +bypasses `_lite/attention.py` entirely. Don't try to set `NVTE_FLASH_ATTN=0` +directly — Megatron asserts on it; use the CLI flag. + +### Apples-apples discipline + +- Always check **loss-AR async fix** symmetry between full and lite + containers before quoting a %; one side missing the fix is worth ~70 ms + (3.5%) and can flip the apparent winner. +- Always compare at the **same TE commit** — there is a real ~7% regression + in TE itself between Jan and May 2026. +- `RECOMPUTE=0`, `seq_len=2048` is the current standard config for new + measurements (CK earlier needed `seq_len=4098` to hit a tuned config; that + workaround is irrelevant on the `pytorch` backend). + +## Debug tooling + +| Flag | What it does | +|---|---| +| `NVTE_LITE_DIAG=1` | One-shot prints from `_lite/{gemm,norms,attention,quantize}.py` and `module/base.py`; per-bucket counters (`[LITE-GEMM]`, `[LITE-NORM]`, `[LITE-ATTN]`, `[LITE-QUANT]`, `[LITE-NONCONTIG]`, `[LITE-SCALED-MM-FAIL]`, `[LITE-GEMM-CK-FAIL]`). Zero overhead when off. | +| `NVTE_LITE_AMAX_FUSED=0` | Falls back from the Triton multi-tensor-apply amax/scale kernel to the per-group Python loop (`_lite/quantize.py`, ~14 kernel launches × N groups). For A/B against the fused path. | +| `NVTE_LITE_SKIP_FP8_DGRAD_FOR_NORM=1` | Opt-in: skip the BF16→FP8 cast on dgrad output when only the norm backward consumes it. **Shelved**: mechanically perfect (1484 casts eliminated) but wall-time within noise — kept env-gated. See open-question note in dgrad-skip memory if reviving. | +| `NVTE_CONTIG_DIAG=1` `NVTE_CONTIG_DIAG_DUMP_STEP=N` | Counts and times every `prepare_forward` `.contiguous()` materialize per `(module, shape, stride, caller)`. Diff full-vs-lite stdout to see where lite has extra materializes. Phase 1 result (2026-05-01): same single site, same 64 calls/step, **lite 3× faster** at the materialize itself — drops the AITER 3D-strided-input kernel patch from the queue. | + +## Test conventions + +`tests/pytorch/test_lite.py` sets `NVTE_LITE=1` **before** importing TE — so +the C++ extension never loads, even on a full build. Don't reorder those +lines. + +- New kernels / dispatch paths should land with a regression test in + `test_lite.py`, in the class closest to the feature (new GEMM kernel → + `TestGemm`; new recipe-level feature → `TestRecipeIntegration`). +- FP8-recipe tests should parametrize via `_RECIPES_FWD_BWD` / + `_RECIPES_FWD` so they skip cleanly on unsupported hardware. +- The standard numerical check is **FP8-vs-bf16 cosine similarity** ≥ 0.9 + for single modules, ≥ 0.75 for `TransformerLayer`. This catches silent + wrong-dispatch and scale-broadcast bugs that exact-tolerance checks miss. + +### Monkeypatch gotcha + +`_lite.quantize` is shadowed in the package namespace by the `quantize` +function re-exported from `_lite/__init__.py`. To monkeypatch a module-level +kernel attr (e.g. `_aiter_dynamic_per_token_quant`) you must reach the +*module*, not the function — use: + +```python +import sys +mod = sys.modules["transformer_engine.pytorch._lite.quantize"] +monkeypatch.setattr(mod, "_aiter_dynamic_per_token_quant", spy) +``` + +`import transformer_engine.pytorch._lite.quantize as q` resolves `q` to the +function via attribute lookup, **not** the module — patching `q.foo` won't +affect dispatch. `_lite.norms` is not shadowed; `import as` works there. + +## Discipline + +- **Wait for profile data before optimizing.** Code-inspection guesses at + hotspots miss the real bottleneck often enough that this is the safer + default. When something is reported slow, ask for / wait on the top-N + kernels with self-CUDA-time and call counts before writing code. +- **Verify "genuine cost" claims with measurement, not deduction.** The + `prepare_forward` materialize was assumed to be a lite penalty for two + weeks; the contig-diag harness showed lite is 3× *faster* at it. Build the + diff harness before writing the patch. +- **A/B every speculative perf change.** Multiple bypass attempts at + `prepare_forward` (`NVTE_LITE_SKIP_NONCONTIG`, `_to_bshd` strided view, + `ROPE_FUSION=0`) all looked like wins on paper and all regressed wall + time. Net-positive must be observed, not predicted. + +## Dead ends — don't retry + +- **Pad-M to next power of 2** (reverted `eac04dd8` → `ccb1f30b`) — inflated + weight N dims by 12.5%. Current div-by-16 pad (`3ed9d8ae`, only 8184→8192) + is the correct version. +- **Tuning AITER CK CSV for forward shapes** — irrelevant since fwd is on + `_scaled_mm` (hipBLASLt) under the default `pytorch` backend. +- **Dequant + matmul as fallback for FP8** — catastrophically slow (~206 + s/iter). Always fall through to AITER before dequant+matmul (`e8272800`). +- **Broadcasting per-tensor scalar scales to rowwise shapes** — see hazard + #1 above. +- **`NVTE_LITE_SKIP_NONCONTIG`-style env-var bypass of `prepare_forward`'s + `.contiguous()`** — downstream FP8 quantize + GEMM 3D→2D reshape paths + re-materialize anyway, more expensively. Reverted in `1bc68c3f`. +- **`ROPE_FUSION=0`** to avoid the BSHD round-trip — `_to_bshd` then makes 3 + input copies (q+k+v at SBHD→BSHD) instead of 1 output copy. +87 ms. +- **`_to_bshd` strided-view (`e4a05c50`)** — unreachable with current + Megatron BSHD path. Reverted in `c62e9771`. +- **`NVTE_LITE_SKIP_FP8_DGRAD_FOR_NORM=1` as default** — works mechanically, + but the standalone amax-only reduction added to preserve DelayedScaling + amax history is memory-bound over the same BF16 tensor and roughly cancels + the savings. Keep env-gated; don't flip default. Open question: an + unexplained +8 `pytorch_scaled_mm_ok` calls under skip=1 — worth chasing + if reviving. + +## Untracked patches & TODOs (as of 2026-05-06) + +- AITER-side defensive K-innermost asserts in + `/root/WORK/aiter/aiter/ops/triton/gemm/basic/gemm_a8w8{,_per_token_scale,_blockscale}.py` + — drafted, **uncommitted**. Send upstream when convenient. +- `[LITE-*]` diag counter print sites — Jason wants stripped before merging + to `dev`. All gated behind `NVTE_LITE_DIAG=1` so they're harmless in + production, but noise in the source. +- AITER `fused_fp8_quant.py:83` upstream bug (`out2.stride(1)` → + `out1.stride(1)`) — report or send a one-line PR. diff --git a/transformer_engine/pytorch/_lite/__init__.py b/transformer_engine/pytorch/_lite/__init__.py new file mode 100644 index 000000000..0cb15161b --- /dev/null +++ b/transformer_engine/pytorch/_lite/__init__.py @@ -0,0 +1,124 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Transformer Engine Lite -- Pure-Python drop-in replacement for transformer_engine_torch. + +This module provides the same API surface as the compiled C++ extension but uses +Triton kernels, AITER, and PyTorch-native operations instead. + +Activate by setting NVTE_LITE=1 before importing transformer_engine. +""" + +# Re-export all enums and types +from .enums import ( + DType, + NVTE_Bias_Type, + NVTE_Mask_Type, + NVTE_Softmax_Type, + NVTE_QKV_Format, + NVTE_QKV_Layout, + NVTE_Fused_Attn_Backend, + Float8BlockScaleTensorFormat, + FP8FwdTensors, + FP8BwdTensors, + CommOverlapType, + CommOverlapAlgo, + FP8TensorMeta, + CommOverlapCore, +) + +# Re-export operation implementations +from .activations import ( + gelu, geglu, qgelu, qgeglu, + relu, reglu, srelu, sreglu, + silu, swiglu, clamped_swiglu, + dgelu, dgeglu, dqgelu, dqgeglu, + drelu, dreglu, dsrelu, dsreglu, + dsilu, dswiglu, clamped_dswiglu, + dbias_dgelu, dbias_dsilu, dbias_drelu, dbias_dqgelu, dbias_dsrelu, +) +from .norms import ( + layernorm_fwd, layernorm_bwd, + rmsnorm_fwd, rmsnorm_bwd, rmsnorm_bwd_add, +) +from .quantize import ( + quantize, dequantize, bgrad_quantize, + multi_tensor_quantize, split_quantize, + compute_amax, fused_amax_and_scale_update_after_reduction, + fp8_block_scaling_compute_partial_amax, fp8_block_scaling_partial_cast, +) +from .gemm import generic_gemm +from .grouped_gemm import te_general_grouped_gemm +from .softmax import ( + scaled_softmax_forward, scaled_softmax_backward, + scaled_masked_softmax_forward, scaled_masked_softmax_backward, + scaled_upper_triang_masked_softmax_forward, scaled_upper_triang_masked_softmax_backward, + scaled_aligned_causal_masked_softmax_forward, scaled_aligned_causal_masked_softmax_backward, +) +from .attention import ( + get_fused_attn_backend, + fused_attn_fwd, fused_attn_bwd, + fa_prepare_fwd, fa_prepare_bwd, + copy_to_kv_cache, + convert_thd_to_bshd, convert_bshd_to_thd, +) +from .rope import ( + fused_rope_forward, fused_rope_backward, + fused_qkv_rope_forward, fused_qkv_rope_backward, +) +from .dropout import dropout_fwd, dropout_bwd +from .transpose import fp8_transpose, swap_first_dims +from .permutation import ( + moe_permute_fwd, moe_permute_bwd, + moe_unpermute_fwd, moe_unpermute_bwd, +) +from .multi_tensor import ( + multi_tensor_scale, multi_tensor_l2norm, multi_tensor_unscale_l2norm, + multi_tensor_adam, multi_tensor_adam_param_remainder, + multi_tensor_adam_fp8, + multi_tensor_adam_capturable, multi_tensor_adam_capturable_master, + multi_tensor_sgd, + multi_tensor_compute_scale_and_scale_inv, + multi_tensor_compute_scale_inv_e8m0, +) +from .router import ( + fused_topk_with_score_function_fwd, fused_topk_with_score_function_bwd, + fused_score_for_moe_aux_loss_fwd, fused_score_for_moe_aux_loss_bwd, + fused_moe_aux_loss_fwd, fused_moe_aux_loss_bwd, +) +from .comm import ( + CommOverlapHelper, CommOverlap, CommOverlapP2P, + CommOverlapBase, CommOverlapP2PBase, + bulk_overlap_ag_with_external_gemm, + init_nvshmem_backend, create_nvshmem_tensor, + nvshmem_send_on_current_stream, nvshmem_wait_on_current_stream, nvshmem_finalize, + device_supports_multicast, get_stream_priority_range, ubuf_built_with_mpi, +) +from .misc import get_num_cublas_streams +from .context_parallel import ( + thd_read_half_tensor, thd_second_half_lse_correction, + thd_read_second_half_lse, thd_out_correction, + thd_grad_correction, thd_get_partitioned_indices, +) +from .mori_ep import ( + mori_ep_available, + init_mori_ep, + finalize_mori_ep, + is_mori_ep_initialized, + mask_to_index, + index_to_mask, + MoriExpertParallel, + MoriEPDispatch, + MoriEPCombine, + MoriEPDispatchStdMoE, + MoriEPCombineStdMoE, +) +from .padding import fused_multi_row_padding, fused_multi_row_unpadding + +# Note: fused_layernorm_linear and fused_layernorm_mlp are NOT imported here +# because they import `transformer_engine_torch as tex` which resolves to this +# module, creating a circular import. They are accessed via +# transformer_engine.pytorch.module.__init__ when NVTE_LITE=1. diff --git a/transformer_engine/pytorch/_lite/activations.py b/transformer_engine/pytorch/_lite/activations.py new file mode 100644 index 000000000..530c3d095 --- /dev/null +++ b/transformer_engine/pytorch/_lite/activations.py @@ -0,0 +1,475 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Activation functions -- AITER fused gated activations with PyTorch fallback. + +When AITER is available, gated activations (swiglu, geglu) use AITER's +fused kernels (silu_and_mul, gelu_tanh_and_mul) which combine +chunk + activation + gate multiply in a single kernel. For block-scaled +FP8 quantization, AITER's act_mul_and_fp8_group_quant fuses activation + +gate multiply + FP8 cast in a single kernel, eliminating the intermediate +bf16 round-trip. +""" + +import torch +import torch.nn.functional as F +import math + +from .aiter_utils import is_aiter_available, get_aiter + + +# Lazy-loaded references to avoid circular imports +_Float8BlockQuantizer = None +_Float8BlockwiseQTensorStorage = None +_Float8BlockScaleTensorFormat = None +_Float8CurrentScalingQuantizer = None +_Float8Tensor = None +_aiter_act_mul_fp8_group_quant = None +_fused_act_quant_loaded = False + + +def _try_load_fused_act_quant(): + """Lazy-load Float8Block/CurrentScaling types and AITER fused act+quant kernel.""" + global _Float8BlockQuantizer, _Float8BlockwiseQTensorStorage + global _Float8BlockScaleTensorFormat + global _Float8CurrentScalingQuantizer, _Float8Tensor + global _aiter_act_mul_fp8_group_quant + global _fused_act_quant_loaded + + if _fused_act_quant_loaded: + return + _fused_act_quant_loaded = True + + try: + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + ) + from transformer_engine.pytorch.tensor.storage.float8_blockwise_tensor_storage import ( + Float8BlockwiseQTensorStorage, + ) + from .enums import Float8BlockScaleTensorFormat + _Float8BlockQuantizer = Float8BlockQuantizer + _Float8BlockwiseQTensorStorage = Float8BlockwiseQTensorStorage + _Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat + except ImportError: + pass + + try: + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, + Float8Tensor, + ) + _Float8CurrentScalingQuantizer = Float8CurrentScalingQuantizer + _Float8Tensor = Float8Tensor + except ImportError: + pass + + if is_aiter_available(): + try: + from aiter.ops.triton.activation import act_mul_and_fp8_group_quant + _aiter_act_mul_fp8_group_quant = act_mul_and_fp8_group_quant + except ImportError: + pass + + +# Map TE activation names → AITER activation strings for fused act+quant +_AITER_ACT_QUANT_MAP = { + "swiglu": "silu", + "geglu": "gelu_tanh", + "reglu": "relu", +} + + +def _aiter_fused_gated_act_quant(input, activation, quantizer): + """Try AITER fused gated activation + block FP8 quantize. + + Returns the quantized tensor, or None if the fused path isn't available + (wrong quantizer type, AITER not installed, etc.). + """ + _try_load_fused_act_quant() + + if _aiter_act_mul_fp8_group_quant is None or _Float8BlockQuantizer is None: + return None + if not isinstance(quantizer, _Float8BlockQuantizer): + return None + + aiter_act = _AITER_ACT_QUANT_MAP.get(activation) + if aiter_act is None: + return None + + # Flatten to 2D for AITER kernel + orig_shape = input.shape + input_2d = input.reshape(-1, input.shape[-1]) + + try: + fp8_data, scale_inv = _aiter_act_mul_fp8_group_quant( + input_2d, aiter_act, group_size=quantizer.block_len, + ) + except (RuntimeError, TypeError): + return None + + # Reshape fp8_data back to original leading dims + half_size = input.shape[-1] // 2 + out_shape = orig_shape[:-1] + (half_size,) + fp8_data = fp8_data.reshape(out_shape) if fp8_data.shape != out_shape else fp8_data + + # Wrap in Float8BlockwiseQTensorStorage + result = _Float8BlockwiseQTensorStorage( + rowwise_data=fp8_data.view(torch.uint8), + rowwise_scale_inv=scale_inv, + columnwise_data=None, + columnwise_scale_inv=None, + fp8_dtype=quantizer.dtype, + quantizer=quantizer, + is_2D_scaled=False, + data_format=_Float8BlockScaleTensorFormat.COMPACT, + ) + return result + + +def _aiter_fused_gated_act_current_scaling(input, activation, quantizer): + """Try AITER fused gated activation + per-row FP8 quantize for CurrentScaling. + + Uses act_mul_and_fp8_group_quant with group_size = output_hidden_dim (N/2), + so each row gets exactly one scale — equivalent to per-row dynamic scaling. + Returns a Float8Tensor with _scale_inv shape (M,), or None if unavailable. + """ + _try_load_fused_act_quant() + + if _aiter_act_mul_fp8_group_quant is None or _Float8CurrentScalingQuantizer is None: + return None + if not isinstance(quantizer, _Float8CurrentScalingQuantizer): + return None + + aiter_act = _AITER_ACT_QUANT_MAP.get(activation) + if aiter_act is None: + return None + + # Flatten to 2D for AITER kernel + orig_shape = input.shape + input_2d = input.reshape(-1, input.shape[-1]) + half_size = input.shape[-1] // 2 + + try: + # group_size = half_size means one group per row → per-row scaling + fp8_data, scale_inv = _aiter_act_mul_fp8_group_quant( + input_2d, aiter_act, group_size=half_size, + ) + except (RuntimeError, TypeError): + return None + + # scale_inv shape: (M, 1) — squeeze to (M,) for per-row convention + M = input_2d.shape[0] + scale_inv = scale_inv.reshape(M) + + # Reshape fp8_data back to original leading dims + out_shape = orig_shape[:-1] + (half_size,) + fp8_data = fp8_data.reshape(out_shape) if fp8_data.shape != out_shape else fp8_data + + # Wrap in Float8Tensor with per-row scale_inv + result = _Float8Tensor( + shape=out_shape, + dtype=input.dtype, + data=fp8_data.view(torch.uint8), + fp8_scale_inv=scale_inv, + fp8_dtype=quantizer.dtype, + data_transpose=None, + quantizer=quantizer, + ) + return result + + +def _apply_quantizer(output, quantizer): + """Apply quantizer if provided, otherwise return as-is.""" + if quantizer is not None and hasattr(quantizer, 'quantize'): + return quantizer.quantize(output) + return output + + +def _aiter_gated_act(input, aiter_fn_name): + """Try AITER fused gated activation. Returns None if unsupported. + + AITER gated activation API: fn(out, input) -> None + Input is (*, 2*H), output is (*, H). Fuses chunk + act + gate multiply. + """ + aiter = get_aiter() + if aiter is None: + return None + fn = getattr(aiter, aiter_fn_name, None) + if fn is None: + return None + try: + half_size = input.shape[-1] // 2 + out_shape = input.shape[:-1] + (half_size,) + out = torch.empty(out_shape, dtype=input.dtype, device=input.device) + fn(out, input) + return out + except (RuntimeError, TypeError): + return None + + +# --------------------------------------------------------------------------- # +# Forward activations +# --------------------------------------------------------------------------- # + +def gelu(input, quantizer): + """GeLU activation (tanh approximation).""" + out = F.gelu(input, approximate='tanh') + return _apply_quantizer(out, quantizer) + + +def geglu(input, quantizer): + """GeGLU: split input in half, apply GELU to first, multiply by second.""" + # Try fused gated act + per-row FP8 quantize (CurrentScaling) + fused = _aiter_fused_gated_act_current_scaling(input, "geglu", quantizer) + if fused is not None: + return fused + # Try fused gated act + block FP8 quantize (single kernel) + fused = _aiter_fused_gated_act_quant(input, "geglu", quantizer) + if fused is not None: + return fused + if is_aiter_available(): + result = _aiter_gated_act(input, 'gelu_tanh_and_mul') + if result is not None: + return _apply_quantizer(result, quantizer) + chunks = input.chunk(2, dim=-1) + out = F.gelu(chunks[0], approximate='tanh') * chunks[1] + return _apply_quantizer(out, quantizer) + + +def qgelu(input, quantizer): + """QuickGELU: x * sigmoid(1.702 * x).""" + out = input * torch.sigmoid(1.702 * input) + return _apply_quantizer(out, quantizer) + + +def qgeglu(input, quantizer): + """Quick GeGLU: gated variant of QuickGELU.""" + chunks = input.chunk(2, dim=-1) + out = (chunks[0] * torch.sigmoid(1.702 * chunks[0])) * chunks[1] + return _apply_quantizer(out, quantizer) + + +def relu(input, quantizer): + """ReLU activation.""" + out = F.relu(input) + return _apply_quantizer(out, quantizer) + + +def reglu(input, quantizer): + """ReGLU: gated variant of ReLU.""" + fused = _aiter_fused_gated_act_current_scaling(input, "reglu", quantizer) + if fused is not None: + return fused + fused = _aiter_fused_gated_act_quant(input, "reglu", quantizer) + if fused is not None: + return fused + chunks = input.chunk(2, dim=-1) + out = F.relu(chunks[0]) * chunks[1] + return _apply_quantizer(out, quantizer) + + +def srelu(input, quantizer): + """Squared ReLU: relu(x)^2.""" + out = F.relu(input).square() + return _apply_quantizer(out, quantizer) + + +def sreglu(input, quantizer): + """Squared ReGLU: gated variant of squared ReLU.""" + chunks = input.chunk(2, dim=-1) + out = F.relu(chunks[0]).square() * chunks[1] + return _apply_quantizer(out, quantizer) + + +def silu(input, quantizer): + """SiLU (Swish) activation.""" + out = F.silu(input) + return _apply_quantizer(out, quantizer) + + +def swiglu(input, quantizer): + """SwiGLU: gated variant of SiLU.""" + fused = _aiter_fused_gated_act_current_scaling(input, "swiglu", quantizer) + if fused is not None: + return fused + fused = _aiter_fused_gated_act_quant(input, "swiglu", quantizer) + if fused is not None: + return fused + if is_aiter_available(): + result = _aiter_gated_act(input, 'silu_and_mul') + if result is not None: + return _apply_quantizer(result, quantizer) + chunks = input.chunk(2, dim=-1) + out = F.silu(chunks[0]) * chunks[1] + return _apply_quantizer(out, quantizer) + + +def clamped_swiglu(input, quantizer, limit=7.0, alpha=1.702): + """SwiGLU with clamping (GPT OSS variant).""" + chunks = input.chunk(2, dim=-1) + out = F.silu(chunks[0]) * chunks[1] + out = out.clamp(min=-limit, max=limit) + return _apply_quantizer(out, quantizer) + + +# --------------------------------------------------------------------------- # +# Backward activations +# --------------------------------------------------------------------------- # + +def _gelu_backward(grad, x): + """Backward of tanh-approximated GELU.""" + kBeta = math.sqrt(2.0 / math.pi) + kKappa = 0.044715 + x_cube = x * x * x + inner = kBeta * (x + kKappa * x_cube) + tanh_inner = torch.tanh(inner) + dtanh = 1.0 - tanh_inner * tanh_inner + d_inner = kBeta * (1.0 + 3.0 * kKappa * x * x) + return grad * 0.5 * (1.0 + tanh_inner + x * dtanh * d_inner) + + +def dgelu(grad, fwd_input, quantizer): + """Backward of GeLU.""" + out = _gelu_backward(grad, fwd_input) + return _apply_quantizer(out, quantizer) + + +def dgeglu(grad, fwd_input, quantizer): + """Backward of GeGLU.""" + chunks = fwd_input.chunk(2, dim=-1) + x, gate = chunks[0], chunks[1] + gelu_x = F.gelu(x, approximate='tanh') + dgelu_x = _gelu_backward(grad * gate, x) + dgate = grad * gelu_x + out = torch.cat([dgelu_x, dgate], dim=-1) + return _apply_quantizer(out, quantizer) + + +def dqgelu(grad, fwd_input, quantizer): + """Backward of QuickGELU.""" + sig = torch.sigmoid(1.702 * fwd_input) + out = grad * sig * (1.0 + 1.702 * fwd_input * (1.0 - sig)) + return _apply_quantizer(out, quantizer) + + +def dqgeglu(grad, fwd_input, quantizer): + """Backward of Quick GeGLU.""" + chunks = fwd_input.chunk(2, dim=-1) + x, gate = chunks[0], chunks[1] + sig = torch.sigmoid(1.702 * x) + qgelu_x = x * sig + dqgelu_x = grad * gate * sig * (1.0 + 1.702 * x * (1.0 - sig)) + dgate = grad * qgelu_x + out = torch.cat([dqgelu_x, dgate], dim=-1) + return _apply_quantizer(out, quantizer) + + +def drelu(grad, fwd_input, quantizer): + """Backward of ReLU.""" + out = grad * (fwd_input > 0).to(grad.dtype) + return _apply_quantizer(out, quantizer) + + +def dreglu(grad, fwd_input, quantizer): + """Backward of ReGLU.""" + chunks = fwd_input.chunk(2, dim=-1) + x, gate = chunks[0], chunks[1] + dx = grad * gate * (x > 0).to(grad.dtype) + dgate = grad * F.relu(x) + out = torch.cat([dx, dgate], dim=-1) + return _apply_quantizer(out, quantizer) + + +def dsrelu(grad, fwd_input, quantizer): + """Backward of Squared ReLU.""" + out = grad * 2.0 * F.relu(fwd_input) + return _apply_quantizer(out, quantizer) + + +def dsreglu(grad, fwd_input, quantizer): + """Backward of Squared ReGLU.""" + chunks = fwd_input.chunk(2, dim=-1) + x, gate = chunks[0], chunks[1] + dx = grad * gate * 2.0 * F.relu(x) + dgate = grad * F.relu(x).square() + out = torch.cat([dx, dgate], dim=-1) + return _apply_quantizer(out, quantizer) + + +def dsilu(grad, fwd_input, quantizer): + """Backward of SiLU.""" + sig = torch.sigmoid(fwd_input) + out = grad * sig * (1.0 + fwd_input * (1.0 - sig)) + return _apply_quantizer(out, quantizer) + + +def dswiglu(grad, fwd_input, quantizer): + """Backward of SwiGLU.""" + chunks = fwd_input.chunk(2, dim=-1) + x, gate = chunks[0], chunks[1] + sig = torch.sigmoid(x) + silu_x = x * sig + dx = grad * gate * sig * (1.0 + x * (1.0 - sig)) + dgate = grad * silu_x + out = torch.cat([dx, dgate], dim=-1) + return _apply_quantizer(out, quantizer) + + +def clamped_dswiglu(grad, fwd_input, quantizer, limit=7.0, alpha=1.702): + """Backward of clamped SwiGLU.""" + chunks = fwd_input.chunk(2, dim=-1) + x, gate = chunks[0], chunks[1] + sig = torch.sigmoid(x) + silu_x = x * sig + fwd_out = silu_x * gate + # Zero out gradient where clamped + mask = (fwd_out >= -limit) & (fwd_out <= limit) + grad = grad * mask.to(grad.dtype) + dx = grad * gate * sig * (1.0 + x * (1.0 - sig)) + dgate = grad * silu_x + out = torch.cat([dx, dgate], dim=-1) + return _apply_quantizer(out, quantizer) + + +# --------------------------------------------------------------------------- # +# DBias + DAct fusions +# --------------------------------------------------------------------------- # + +def dbias_dgelu(grad, fwd_input, quantizer): + """Fused DGeLU + DBias: returns (dact, dbias).""" + dact = _gelu_backward(grad, fwd_input) + dbias = dact.sum(dim=tuple(range(dact.ndim - 1))) + return _apply_quantizer(dact, quantizer), dbias + + +def dbias_dsilu(grad, fwd_input, quantizer): + """Fused DSiLU + DBias: returns (dact, dbias).""" + sig = torch.sigmoid(fwd_input) + dact = grad * sig * (1.0 + fwd_input * (1.0 - sig)) + dbias = dact.sum(dim=tuple(range(dact.ndim - 1))) + return _apply_quantizer(dact, quantizer), dbias + + +def dbias_drelu(grad, fwd_input, quantizer): + """Fused DReLU + DBias: returns (dact, dbias).""" + dact = grad * (fwd_input > 0).to(grad.dtype) + dbias = dact.sum(dim=tuple(range(dact.ndim - 1))) + return _apply_quantizer(dact, quantizer), dbias + + +def dbias_dqgelu(grad, fwd_input, quantizer): + """Fused DQGeLU + DBias: returns (dact, dbias).""" + sig = torch.sigmoid(1.702 * fwd_input) + dact = grad * sig * (1.0 + 1.702 * fwd_input * (1.0 - sig)) + dbias = dact.sum(dim=tuple(range(dact.ndim - 1))) + return _apply_quantizer(dact, quantizer), dbias + + +def dbias_dsrelu(grad, fwd_input, quantizer): + """Fused DSquaredReLU + DBias: returns (dact, dbias).""" + dact = grad * 2.0 * F.relu(fwd_input) + dbias = dact.sum(dim=tuple(range(dact.ndim - 1))) + return _apply_quantizer(dact, quantizer), dbias diff --git a/transformer_engine/pytorch/_lite/aiter_utils.py b/transformer_engine/pytorch/_lite/aiter_utils.py new file mode 100644 index 000000000..10c506188 --- /dev/null +++ b/transformer_engine/pytorch/_lite/aiter_utils.py @@ -0,0 +1,41 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""AITER availability detection and common utilities. + +AITER is an optional pip dependency providing CK/Triton kernels for AMD GPUs. +All _lite modules should use these functions instead of per-file import checks. +""" + +import functools + + +@functools.lru_cache(maxsize=1) +def is_aiter_available(): + """Check if AITER is installed and importable.""" + try: + import aiter # noqa: F401 + return True + except ImportError: + return False + + +def get_aiter(): + """Return the aiter module, or None if not installed.""" + if not is_aiter_available(): + return None + import aiter + return aiter + + +def get_aiter_rope(): + """Return aiter.ops.rope module, or None if not available.""" + if not is_aiter_available(): + return None + try: + from aiter.ops import rope + return rope + except (ImportError, AttributeError): + return None diff --git a/transformer_engine/pytorch/_lite/amax_utils.py b/transformer_engine/pytorch/_lite/amax_utils.py new file mode 100644 index 000000000..1d7a4772c --- /dev/null +++ b/transformer_engine/pytorch/_lite/amax_utils.py @@ -0,0 +1,38 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Amax bookkeeping helpers for lite backend. + +These helpers let call sites update a quantizer's amax tensor without +materializing an FP8 tensor. Used by the "skip FP8 dgrad round-trip" +optimization on the norm backward path: when the only downstream consumer +of an FP8 tensor is going to dequantize it immediately, we can skip the +cast entirely and still preserve DelayedScaling amax history by running +a standalone reduction on the BF16 source. +""" + +import torch + + +def update_amax_from_bf16(quantizer, bf16_tensor): + """Update quantizer.amax from a BF16 tensor's abs-max. + + Only meaningful for DelayedScaling (Float8Quantizer), which uses + amax history to pick next step's scalar scale. CurrentScaling computes + per-row scales in-kernel on each use (no cross-step bookkeeping), and + MXFP8 uses per-block scales computed at quantize time — for both, this + call is a no-op. + + Matches the amax update the Triton cast-transpose kernel performs as a + side effect of the FP8 cast (see quantize.py: amax_out=q.amax), so the + stored amax value is identical whether we take the cast path or skip it. + """ + if quantizer is None or not hasattr(quantizer, "amax"): + return + if type(quantizer).__name__ != "Float8Quantizer": + return + if bf16_tensor is None or bf16_tensor.numel() == 0: + return + quantizer.amax.copy_(bf16_tensor.abs().amax()) diff --git a/transformer_engine/pytorch/_lite/attention.py b/transformer_engine/pytorch/_lite/attention.py new file mode 100644 index 000000000..7d099ce1a --- /dev/null +++ b/transformer_engine/pytorch/_lite/attention.py @@ -0,0 +1,828 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Attention operations -- multi-backend: AITER, flash-attn (stub), PyTorch SDPA. + +Backend priority: AITER CK kernels > flash-attn (stubbed) > PyTorch SDPA fallback. +""" + +import math +import os +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F + +from .aiter_utils import is_aiter_available, get_aiter +from .enums import ( + NVTE_Fused_Attn_Backend, NVTE_Mask_Type, NVTE_Bias_Type, + NVTE_QKV_Layout, NVTE_QKV_Format, +) + +# --- Debug dispatch counter (matches _lite/gemm.py probe style) --- +_LITE_DIAG = os.environ.get("NVTE_LITE_DIAG", "0") != "0" + +from collections import Counter as _AttnCounter +_ATTN_CALLS = _AttnCounter() +_FWD_ARGS_PRINTED = False + +def _attn_bump(tag): + if not _LITE_DIAG: + return + _ATTN_CALLS[tag] += 1 + if sum(_ATTN_CALLS.values()) % 500 == 0: + print(f"[LITE-ATTN] {dict(_ATTN_CALLS)}", flush=True) + + +def _attn_probe_fwd_args(q_fmt, q, k, v, attn_bias, causal, wl, wr, dropout_p): + """One-shot dump of fwd arg shape/flags (NVTE_LITE_DIAG=1 only). + + Mirrors the conditions aiter uses to gate its AOT fmha_v3_fwd vs the + slower JIT ck_tile mha_fwd path (see can_impl_fmha_v3_fwd in aiter + mha.py). Helpful when an attention backend regression appears. + """ + global _FWD_ARGS_PRINTED + if not _LITE_DIAG or _FWD_ARGS_PRINTED: + return + _FWD_ARGS_PRINTED = True + print( + f"[LITE-ATTN-FWD] fmt={q_fmt} dtype={q.dtype} " + f"q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)} " + f"causal={causal} window=({wl},{wr}) dropout_p={dropout_p} " + f"bias={attn_bias is not None}", + flush=True, + ) + +# --------------------------------------------------------------------------- +# AITER raw kernel imports (lazy) +# --------------------------------------------------------------------------- +_aiter_fwd = None +_aiter_bwd = None +_aiter_varlen_fwd = None +_aiter_varlen_bwd = None +_aiter_import_attempted = False + + +def _try_load_aiter_attn(): + """Lazy-import AITER raw MHA kernels. Called once, result cached.""" + global _aiter_fwd, _aiter_bwd, _aiter_varlen_fwd, _aiter_varlen_bwd + global _aiter_import_attempted + if _aiter_import_attempted: + return + _aiter_import_attempted = True + if not is_aiter_available(): + return + try: + from aiter.ops.mha import ( + _flash_attn_forward, + _flash_attn_backward, + _flash_attn_varlen_forward, + _flash_attn_varlen_backward, + ) + _aiter_fwd = _flash_attn_forward + _aiter_bwd = _flash_attn_backward + _aiter_varlen_fwd = _flash_attn_varlen_forward + _aiter_varlen_bwd = _flash_attn_varlen_backward + except (ImportError, AttributeError): + pass + + +# --------------------------------------------------------------------------- +# Flash-attention (stubbed -- placeholder for future integration) +# --------------------------------------------------------------------------- +_flash_attn_available = False +# Uncomment when ready to integrate: +# try: +# from flash_attn.flash_attn_interface import ( +# _flash_attn_forward as _fa_fwd, +# _flash_attn_backward as _fa_bwd, +# _flash_attn_varlen_forward as _fa_varlen_fwd, +# _flash_attn_varlen_backward as _fa_varlen_bwd, +# ) +# _flash_attn_available = True +# except ImportError: +# pass + + +# --------------------------------------------------------------------------- +# QKV layout helpers +# --------------------------------------------------------------------------- + +# Map NVTE_QKV_Layout enum values -> (q_format, kv_format) +_LAYOUT_TO_FMT = { + NVTE_QKV_Layout.NVTE_SB3HD: ("sbhd", "sbhd"), + NVTE_QKV_Layout.NVTE_SBH3D: ("sbhd", "sbhd"), + NVTE_QKV_Layout.NVTE_SBHD_SB2HD: ("sbhd", "sbhd"), + NVTE_QKV_Layout.NVTE_SBHD_SBH2D: ("sbhd", "sbhd"), + NVTE_QKV_Layout.NVTE_SBHD_SBHD_SBHD: ("sbhd", "sbhd"), + NVTE_QKV_Layout.NVTE_BS3HD: ("bshd", "bshd"), + NVTE_QKV_Layout.NVTE_BSH3D: ("bshd", "bshd"), + NVTE_QKV_Layout.NVTE_BSHD_BS2HD: ("bshd", "bshd"), + NVTE_QKV_Layout.NVTE_BSHD_BSH2D: ("bshd", "bshd"), + NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: ("bshd", "bshd"), + NVTE_QKV_Layout.NVTE_T3HD: ("thd", "thd"), + NVTE_QKV_Layout.NVTE_TH3D: ("thd", "thd"), + NVTE_QKV_Layout.NVTE_THD_T2HD: ("thd", "thd"), + NVTE_QKV_Layout.NVTE_THD_TH2D: ("thd", "thd"), + NVTE_QKV_Layout.NVTE_THD_THD_THD: ("thd", "thd"), + NVTE_QKV_Layout.NVTE_SBHD_BSHD_BSHD: ("sbhd", "bshd"), + NVTE_QKV_Layout.NVTE_BSHD_SBHD_SBHD: ("bshd", "sbhd"), + NVTE_QKV_Layout.NVTE_THD_BSHD_BSHD: ("thd", "bshd"), + NVTE_QKV_Layout.NVTE_THD_SBHD_SBHD: ("thd", "sbhd"), +} + +# Mask types that imply causal attention +_CAUSAL_MASKS = { + NVTE_Mask_Type.NVTE_CAUSAL_MASK, + NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK, + NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK, + NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, +} + +# Bias types that don't pass a bias tensor to the kernel +_NO_BIAS_TENSOR = {NVTE_Bias_Type.NVTE_NO_BIAS, NVTE_Bias_Type.NVTE_ALIBI} + + +def _get_qkv_format(qkv_layout) -> Tuple[str, str]: + """Extract per-tensor format from a TE qkv_layout (enum int or string). + + Returns (q_format, kv_format) where each is one of 'bshd', 'sbhd', 'thd'. + """ + if isinstance(qkv_layout, int): + return _LAYOUT_TO_FMT[qkv_layout] + # String fallback (used by direct tests) + canon = qkv_layout.replace("3", "").replace("2", "") + parts = canon.split("_") + # Filter out "paged", "kv" prefixes + parts = [p for p in parts if p not in ("paged", "kv")] + q_fmt = parts[0] + kv_fmt = parts[-1] if len(parts) > 1 else q_fmt + return q_fmt, kv_fmt + + +def _to_bshd(t: torch.Tensor, fmt: str) -> torch.Tensor: + """Convert tensor from *fmt* to BSHD layout. Returns a contiguous tensor.""" + if fmt == "bshd": + return t + if fmt == "sbhd": + return t.transpose(0, 1).contiguous() + raise ValueError(f"_to_bshd does not handle format '{fmt}' (use varlen path for thd)") + + +def _from_bshd(t: torch.Tensor, fmt: str) -> torch.Tensor: + """Convert tensor from BSHD back to *fmt*.""" + if fmt == "bshd": + return t + if fmt == "sbhd": + return t.transpose(0, 1).contiguous() + raise ValueError(f"_from_bshd does not handle format '{fmt}'") + + +def _is_causal(attn_mask_type) -> bool: + """Check if mask type implies causal attention. Accepts enum int or string.""" + if isinstance(attn_mask_type, int): + return attn_mask_type in _CAUSAL_MASKS + return "causal" in attn_mask_type + + +def _has_bias_tensor(bias_type) -> bool: + """Check if bias type carries an actual bias tensor.""" + if isinstance(bias_type, int): + return bias_type not in _NO_BIAS_TENSOR + return bias_type not in ("no_bias", "alibi") + + +# --------------------------------------------------------------------------- +# Backend selection +# --------------------------------------------------------------------------- + +_FP8_TE_DTYPES = None + + +def _is_fp8_dtype(te_dtype): + """Detect FP8 TE DType values (kFloat8E4M3, kFloat8E5M2).""" + global _FP8_TE_DTYPES + if _FP8_TE_DTYPES is None: + from .enums import DType as TE_DType + _FP8_TE_DTYPES = {TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2} + return te_dtype in _FP8_TE_DTYPES + + +def get_fused_attn_backend( + is_training, + q_type, + kv_type, + qkv_layout, + bias_type, + mask_type, + softmax_type, + dropout, + num_heads, + num_gqa_groups, + max_seqlen_q, + max_seqlen_kv, + head_dim_qk, + head_dim_v, + window_size_left=-1, + window_size_right=-1, + return_max_logit=False, + cuda_graph=False, +): + """Select the best available attention backend for lite mode. + + Priority: AITER CK > (flash-attn, stubbed) > PyTorch SDPA. + + FP8 attention (fp8_dpa=True / fp8_mha=True in the recipe) is not + implemented in lite — there is no FP8 attention kernel available + through AITER or PyTorch SDPA on ROCm. Raises NotImplementedError + with a clear message if FP8 input dtypes reach this call. + """ + if _is_fp8_dtype(q_type) or _is_fp8_dtype(kv_type): + raise NotImplementedError( + "FP8 attention (fp8_dpa=True or fp8_mha=True) is not supported in " + "NVTE_LITE mode — no FP8 attention kernel is available through " + "AITER or PyTorch SDPA on ROCm. Set fp8_dpa=False and fp8_mha=False " + "on the recipe; attention will run in bf16 while GEMMs use FP8." + ) + + _try_load_aiter_attn() + + # AITER available -- covers causal, padding, sliding window, GQA, bias + if _aiter_varlen_fwd is not None: + return NVTE_Fused_Attn_Backend.NVTE_CK + + # Flash-attention (currently stubbed) + if _flash_attn_available: + return NVTE_Fused_Attn_Backend.NVTE_Flash + + # PyTorch SDPA fallback -- always available (PyTorch >= 2.0) + return NVTE_Fused_Attn_Backend.NVTE_SDPA + + +# --------------------------------------------------------------------------- +# AITER forward / backward +# --------------------------------------------------------------------------- + +def _aiter_attn_fwd( + q, k, v, + cu_seqlens_q, cu_seqlens_kv, + max_seqlen_q, max_seqlen_kv, + attn_scale, dropout, is_training, causal, + window_size, attn_bias, qkv_layout, + cu_seqlens_q_padded=None, cu_seqlens_kv_padded=None, +): + """AITER CK attention forward via raw _flash_attn_*_forward.""" + q_fmt, kv_fmt = _get_qkv_format(qkv_layout) + wl, wr = window_size + _drop = dropout if is_training else 0.0 + + if q_fmt == "thd": + _attn_probe_fwd_args(q_fmt, q, k, v, attn_bias, causal, wl, wr, _drop) + # Q/K/V already in (total, heads, dim) -- use varlen API + out, softmax_lse, _, rng_state = _aiter_varlen_fwd( + q, k, v, + cu_seqlens_q, cu_seqlens_kv, + cu_seqlens_q_padded, cu_seqlens_kv_padded, + max_seqlen_q, max_seqlen_kv, + 0, # min_seqlen_q + _drop, + attn_scale, + causal=causal, + window_size_left=wl, + window_size_right=wr, + bias=attn_bias, + return_lse=True, + ) + else: + # bshd or sbhd -- convert to bshd, use non-varlen API + q_bshd = _to_bshd(q, q_fmt) + k_bshd = _to_bshd(k, kv_fmt) + v_bshd = _to_bshd(v, kv_fmt) + _attn_probe_fwd_args(q_fmt, q_bshd, k_bshd, v_bshd, attn_bias, causal, wl, wr, _drop) + # Pass via keyword to stay resilient to aiter API drift — newer + # aiter releases inserted positional args (sink_size, *_descale) + # between window_size_right and return_lse. + out, softmax_lse, _, rng_state = _aiter_fwd( + q_bshd, k_bshd, v_bshd, + dropout_p=_drop, + softmax_scale=attn_scale, + causal=causal, + window_size_left=wl, + window_size_right=wr, + sink_size=0, + bias=attn_bias, + alibi_slopes=None, + q_descale=None, + k_descale=None, + v_descale=None, + return_lse=True, + return_softmax=False, + how_v3_bf16_cvt=1, + # bshd/sbhd is fixed-length per batch; passing non-None cu_seqlens + # here forces aiter off its AOT v3 fwd kernel onto the slower JIT + # ck_tile mha_fwd path (see can_impl_fmha_v3_fwd in aiter mha.py). + cu_seqlens_q=None, + cu_seqlens_kv=None, + ) + out = _from_bshd(out, q_fmt) + + aux_ctx_tensors = [softmax_lse, rng_state] + return out, aux_ctx_tensors + + +def _aiter_attn_bwd( + d_o, q, k, v, o, softmax_lse, rng_state, + cu_seqlens_q, cu_seqlens_kv, + max_seqlen_q, max_seqlen_kv, + attn_scale, dropout, causal, + window_size, qkv_layout, deterministic, + cu_seqlens_q_padded=None, cu_seqlens_kv_padded=None, +): + """AITER CK attention backward via raw _flash_attn_*_backward.""" + q_fmt, kv_fmt = _get_qkv_format(qkv_layout) + wl, wr = window_size + + if q_fmt == "thd": + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + _aiter_varlen_bwd( + d_o, q, k, v, o, softmax_lse, + dq, dk, dv, + cu_seqlens_q, cu_seqlens_kv, + max_seqlen_q, max_seqlen_kv, + dropout, attn_scale, causal, + wl, wr, + None, # alibi_slopes + deterministic, + rng_state=rng_state, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_k_padded=cu_seqlens_kv_padded, + ) + else: + q_bshd = _to_bshd(q, q_fmt) + k_bshd = _to_bshd(k, kv_fmt) + v_bshd = _to_bshd(v, kv_fmt) + o_bshd = _to_bshd(o, q_fmt) + d_o_bshd = _to_bshd(d_o, q_fmt) + dq = torch.empty_like(q_bshd) + dk = torch.empty_like(k_bshd) + dv = torch.empty_like(v_bshd) + _aiter_bwd( + d_o_bshd, q_bshd, k_bshd, v_bshd, o_bshd, softmax_lse, + dq, dk, dv, + None, # dbias + dropout, attn_scale, causal, + wl, wr, + None, # bias (not needed for grad computation) + None, # alibi_slopes + deterministic, + rng_state, + True, # is_v3_atomic_fp32 + 1, # how_v3_bf16_cvt + ) + dq = _from_bshd(dq, q_fmt) + dk = _from_bshd(dk, kv_fmt) + dv = _from_bshd(dv, kv_fmt) + + return dq, dk, dv + + +# --------------------------------------------------------------------------- +# PyTorch SDPA forward / backward +# --------------------------------------------------------------------------- + +def _sdpa_attn_fwd( + q, k, v, + cu_seqlens_q, cu_seqlens_kv, + max_seqlen_q, max_seqlen_kv, + attn_scale, dropout, is_training, causal, + window_size, attn_bias, qkv_layout, +): + """PyTorch SDPA attention forward. + + SDPA expects (batch, heads, seq, dim). We convert from TE formats. + For thd (variable-length packed), we unpack to bshd first. + """ + q_fmt, kv_fmt = _get_qkv_format(qkv_layout) + + if q_fmt == "thd": + # Unpack thd -> bshd for SDPA + q_bshd = convert_thd_to_bshd(q, cu_seqlens_q, None, max_seqlen_q) + k_bshd = convert_thd_to_bshd(k, cu_seqlens_kv, None, max_seqlen_kv) + v_bshd = convert_thd_to_bshd(v, cu_seqlens_kv, None, max_seqlen_kv) + else: + q_bshd = _to_bshd(q, q_fmt) + k_bshd = _to_bshd(k, kv_fmt) + v_bshd = _to_bshd(v, kv_fmt) + + # SDPA expects (B, H, S, D) + q_sdpa = q_bshd.transpose(1, 2) + k_sdpa = k_bshd.transpose(1, 2) + v_sdpa = v_bshd.transpose(1, 2) + + # GQA: expand K/V heads to match Q heads + num_heads_q = q_sdpa.shape[1] + num_heads_kv = k_sdpa.shape[1] + if num_heads_kv < num_heads_q: + repeat = num_heads_q // num_heads_kv + k_sdpa = k_sdpa.repeat_interleave(repeat, dim=1) + v_sdpa = v_sdpa.repeat_interleave(repeat, dim=1) + + # Build attention mask from attn_bias if provided + sdpa_attn_mask = None + if attn_bias is not None: + sdpa_attn_mask = attn_bias + + # SDPA handles dropout and causal natively + with torch.nn.attention.sdpa_kernel( + [torch.nn.attention.SDPBackend.FLASH_ATTENTION, + torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + torch.nn.attention.SDPBackend.MATH] + ): + out_sdpa = F.scaled_dot_product_attention( + q_sdpa, k_sdpa, v_sdpa, + attn_mask=sdpa_attn_mask, + dropout_p=dropout if is_training else 0.0, + is_causal=causal and sdpa_attn_mask is None, + scale=attn_scale, + ) + + # Convert back: (B, H, S, D) -> bshd -> original format + out_bshd = out_sdpa.transpose(1, 2).contiguous() + + if q_fmt == "thd": + batch_size = cu_seqlens_q.shape[0] - 1 + out = convert_bshd_to_thd(out_bshd, cu_seqlens_q, q.shape[0]) + else: + out = _from_bshd(out_bshd, q_fmt) + + # SDPA doesn't expose softmax stats, but backends.py always accesses + # aux_ctx_tensors[0] (contiguity check) and saves them for backward. + # Provide dummy tensors that pass through safely. + batch_size = cu_seqlens_q.shape[0] - 1 + num_heads = q_bshd.shape[2] + dummy_lse = torch.zeros( + batch_size, num_heads, max_seqlen_q, + dtype=torch.float32, device=q_bshd.device, + ) + dummy_rng = torch.zeros(2, dtype=torch.int64, device=q_bshd.device) + aux_ctx_tensors = [dummy_lse, dummy_rng] + return out, aux_ctx_tensors + + +def _sdpa_attn_bwd( + d_o, q, k, v, o, + cu_seqlens_q, cu_seqlens_kv, + max_seqlen_q, max_seqlen_kv, + attn_scale, dropout, causal, + window_size, attn_bias, qkv_layout, +): + """PyTorch SDPA backward via autograd recomputation.""" + q_fmt, kv_fmt = _get_qkv_format(qkv_layout) + + if q_fmt == "thd": + q_in = convert_thd_to_bshd(q, cu_seqlens_q, None, max_seqlen_q) + k_in = convert_thd_to_bshd(k, cu_seqlens_kv, None, max_seqlen_kv) + v_in = convert_thd_to_bshd(v, cu_seqlens_kv, None, max_seqlen_kv) + d_o_in = convert_thd_to_bshd(d_o, cu_seqlens_q, None, max_seqlen_q) + else: + q_in = _to_bshd(q, q_fmt) + k_in = _to_bshd(k, kv_fmt) + v_in = _to_bshd(v, kv_fmt) + d_o_in = _to_bshd(d_o, q_fmt) + + # Re-run forward with autograd to compute gradients + q_g = q_in.detach().requires_grad_(True) + k_g = k_in.detach().requires_grad_(True) + v_g = v_in.detach().requires_grad_(True) + + # (B, S, H, D) -> (B, H, S, D) + q_sdpa = q_g.transpose(1, 2) + k_sdpa = k_g.transpose(1, 2) + v_sdpa = v_g.transpose(1, 2) + + num_heads_q = q_sdpa.shape[1] + num_heads_kv = k_sdpa.shape[1] + if num_heads_kv < num_heads_q: + repeat = num_heads_q // num_heads_kv + k_sdpa = k_sdpa.repeat_interleave(repeat, dim=1) + v_sdpa = v_sdpa.repeat_interleave(repeat, dim=1) + + sdpa_attn_mask = attn_bias if attn_bias is not None else None + + out_sdpa = F.scaled_dot_product_attention( + q_sdpa, k_sdpa, v_sdpa, + attn_mask=sdpa_attn_mask, + dropout_p=0.0, # no dropout in backward recompute for determinism + is_causal=causal and sdpa_attn_mask is None, + scale=attn_scale, + ) + + out_bshd = out_sdpa.transpose(1, 2).contiguous() + d_o_bshd = d_o_in + + out_bshd.backward(d_o_bshd) + + dq_bshd = q_g.grad + dk_bshd = k_g.grad + dv_bshd = v_g.grad + + if q_fmt == "thd": + batch_size = cu_seqlens_q.shape[0] - 1 + dq = convert_bshd_to_thd(dq_bshd, cu_seqlens_q, q.shape[0]) + dk = convert_bshd_to_thd(dk_bshd, cu_seqlens_kv, k.shape[0]) + dv = convert_bshd_to_thd(dv_bshd, cu_seqlens_kv, v.shape[0]) + else: + dq = _from_bshd(dq_bshd, q_fmt) + dk = _from_bshd(dk_bshd, kv_fmt) + dv = _from_bshd(dv_bshd, kv_fmt) + + return dq, dk, dv + + +# --------------------------------------------------------------------------- +# Public API: fused_attn_fwd / fused_attn_bwd +# --------------------------------------------------------------------------- + +def fused_attn_fwd( + max_seqlen_q, + max_seqlen_kv, + is_training, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + cu_seqlens_q, + cu_seqlens_kv, + q, + k, + v, + fake_dtype, + cu_seqlens_q_padded=None, + cu_seqlens_kv_padded=None, + page_table_k=None, + page_table_v=None, + s_quantizer=None, + o_quantizer=None, + attn_bias=None, + softmax_offset=None, + rng_gen=None, + rng_elts_per_thread=0, + return_max_logit=False, + cuda_graph=False, +): + """Fused attention forward -- lite multi-backend dispatcher. + + Signature matches the C++ tex.fused_attn_fwd binding (positional arg order). + Called from transformer_engine.pytorch.cpp_extensions.fused_attn.fused_attn_fwd. + + Returns a list of tensors: [output, *aux_ctx_tensors]. + """ + _try_load_aiter_attn() + + causal = _is_causal(attn_mask_type) + bias_tensor = attn_bias if _has_bias_tensor(bias_type) else None + + # Select backend if not already determined + if _aiter_varlen_fwd is not None: + backend = NVTE_Fused_Attn_Backend.NVTE_CK + elif _flash_attn_available: + backend = NVTE_Fused_Attn_Backend.NVTE_Flash + else: + backend = NVTE_Fused_Attn_Backend.NVTE_SDPA + + if backend == NVTE_Fused_Attn_Backend.NVTE_CK: + _attn_bump("fwd_aiter_ck") + out, aux_ctx_tensors = _aiter_attn_fwd( + q, k, v, cu_seqlens_q, cu_seqlens_kv, + max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, is_training, causal, + window_size, bias_tensor, qkv_layout, + cu_seqlens_q_padded, cu_seqlens_kv_padded, + ) + elif backend == NVTE_Fused_Attn_Backend.NVTE_Flash: + _attn_bump("fwd_flash_stub") + raise NotImplementedError( + "Flash-attention backend is stubbed in lite mode. " + "Install AITER or use the SDPA fallback." + ) + else: + _attn_bump("fwd_sdpa") + out, aux_ctx_tensors = _sdpa_attn_fwd( + q, k, v, cu_seqlens_q, cu_seqlens_kv, + max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, is_training, causal, + window_size, bias_tensor, qkv_layout, + ) + + # Return format must match C++ extension: list of [output, *aux] + # The Python wrapper (cpp_extensions/fused_attn.py) does: + # return output_tensors[0], output_tensors[1:] + result = [out] + aux_ctx_tensors + return result + + +def fused_attn_bwd( + max_seqlen_q, + max_seqlen_kv, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + deterministic, + cu_seqlens_q, + cu_seqlens_kv, + q, + k, + v, + o, + d_o, + fake_dtype, + dqkv_dtype, + aux_ctx_tensors, + cu_seqlens_q_padded=None, + cu_seqlens_kv_padded=None, + s_quantizer=None, + dp_quantizer=None, + dqkv_quantizer=None, + cuda_graph=False, +): + """Fused attention backward -- lite multi-backend dispatcher. + + Signature matches the C++ tex.fused_attn_bwd binding (positional arg order). + Called from transformer_engine.pytorch.cpp_extensions.fused_attn.fused_attn_bwd. + + Returns [dQ, dK, dV, dBias, dSoftmaxOffset]. + """ + _try_load_aiter_attn() + + causal = _is_causal(attn_mask_type) + + if _aiter_varlen_fwd is not None: + backend = NVTE_Fused_Attn_Backend.NVTE_CK + elif _flash_attn_available: + backend = NVTE_Fused_Attn_Backend.NVTE_Flash + else: + backend = NVTE_Fused_Attn_Backend.NVTE_SDPA + + if backend == NVTE_Fused_Attn_Backend.NVTE_CK: + _attn_bump("bwd_aiter_ck") + softmax_lse = aux_ctx_tensors[0] if aux_ctx_tensors else None + rng_state = aux_ctx_tensors[1] if len(aux_ctx_tensors) > 1 else None + dq, dk, dv = _aiter_attn_bwd( + d_o, q, k, v, o, softmax_lse, rng_state, + cu_seqlens_q, cu_seqlens_kv, + max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, causal, + window_size, qkv_layout, deterministic, + cu_seqlens_q_padded, cu_seqlens_kv_padded, + ) + elif backend == NVTE_Fused_Attn_Backend.NVTE_Flash: + _attn_bump("bwd_flash_stub") + raise NotImplementedError( + "Flash-attention backward is stubbed in lite mode." + ) + else: + _attn_bump("bwd_sdpa") + dq, dk, dv = _sdpa_attn_bwd( + d_o, q, k, v, o, + cu_seqlens_q, cu_seqlens_kv, + max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, causal, + window_size, None, qkv_layout, + ) + + # Return format matches C++ extension: [dQ, dK, dV, dBias, dSoftmaxOffset] + return [dq, dk, dv, None, None] + + +# --------------------------------------------------------------------------- +# QKV preparation (flash-attn interleaved format conversions) +# --------------------------------------------------------------------------- + +def fa_prepare_fwd(qkvi: torch.Tensor) -> torch.Tensor: + """Convert interleaved QKV from [s, b, n, 3*h] to [3, b, s, n, h]. + + Pure PyTorch replacement for the C++ nvte_prepare_flash_attn_fwd kernel. + """ + s, b, n, three_h = qkvi.shape + h = three_h // 3 + # Reshape to [s, b, n, 3, h] then permute to [3, b, s, n, h] + return qkvi.view(s, b, n, 3, h).permute(3, 1, 0, 2, 4).contiguous() + + +def fa_prepare_bwd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """Convert 3 x [s, b, n, h] to [b, s, n, 3*h]. + + Pure PyTorch replacement for the C++ nvte_prepare_flash_attn_bwd kernel. + """ + s, b, n, h = q.shape + # Stack on new dim -> [3, s, b, n, h], then permute to [b, s, n, 3, h], reshape + stacked = torch.stack([q, k, v], dim=0) # [3, s, b, n, h] + transposed = stacked.permute(2, 1, 3, 0, 4) # [b, s, n, 3, h] + return transposed.reshape(b, s, n, 3 * h).contiguous() + + +# --------------------------------------------------------------------------- +# KV cache operations +# --------------------------------------------------------------------------- + +def copy_to_kv_cache( + new_k: torch.Tensor, + new_v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + page_table: torch.Tensor, + cu_new_lens: torch.Tensor, + cu_cached_lens: torch.Tensor, + qkv_format: int, + batch_size: int, + max_ctx_len: int, + max_seq_len: int, + max_pages_per_seq: int, + is_non_paged: bool, +) -> None: + """Copy new KV tokens into a KV cache. + + Pure PyTorch replacement for the C++ nvte_copy_to_kv_cache kernel. + Supports non-paged caches in BSHD and SBHD formats. + """ + new_lens = cu_new_lens[1:] - cu_new_lens[:-1] + cached_lens = cu_cached_lens[1:] - cu_cached_lens[:-1] + + # Determine format from enum value + is_sbhd = (qkv_format == 1) # NVTE_QKV_Format.NVTE_SBHD + + for b in range(batch_size): + nl = int(new_lens[b].item()) + cl = int(cached_lens[b].item()) + if nl == 0: + continue + + if is_sbhd: + # new_k/v: [seq, batch, heads, dim], cache: [seq, batch, heads, dim] + k_cache[cl:cl + nl, b] = new_k[:nl, b] + v_cache[cl:cl + nl, b] = new_v[:nl, b] + else: + # BSHD: new_k/v: [batch, seq, heads, dim], cache: same + k_cache[b, cl:cl + nl] = new_k[b, :nl] + v_cache[b, cl:cl + nl] = new_v[b, :nl] + + +# --------------------------------------------------------------------------- +# THD <-> BSHD format conversion +# --------------------------------------------------------------------------- + +def convert_thd_to_bshd( + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + batch_size: Optional[int], + max_seq_len: int, +) -> torch.Tensor: + """Convert tensor from THD [total, heads, dim] to BSHD [batch, seq, heads, dim]. + + Pure PyTorch replacement for the C++ nvte_convert_thd_to_bshd kernel. + Sequences shorter than max_seq_len are zero-padded. + """ + if batch_size is None: + batch_size = cu_seqlens.shape[0] - 1 + h, d = tensor.shape[1], tensor.shape[2] + out = tensor.new_zeros(batch_size, max_seq_len, h, d) + for b in range(batch_size): + start = int(cu_seqlens[b].item()) + end = int(cu_seqlens[b + 1].item()) + length = end - start + out[b, :length] = tensor[start:end] + return out + + +def convert_bshd_to_thd( + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + total: int, +) -> torch.Tensor: + """Convert tensor from BSHD [batch, seq, heads, dim] to THD [total, heads, dim]. + + Pure PyTorch replacement for the C++ nvte_convert_bshd_to_thd kernel. + Strips padding based on cu_seqlens. + """ + batch_size = tensor.shape[0] + h, d = tensor.shape[2], tensor.shape[3] + out = tensor.new_empty(total, h, d) + for b in range(batch_size): + start = int(cu_seqlens[b].item()) + end = int(cu_seqlens[b + 1].item()) + length = end - start + out[start:end] = tensor[b, :length] + return out diff --git a/transformer_engine/pytorch/_lite/comm.py b/transformer_engine/pytorch/_lite/comm.py new file mode 100644 index 000000000..24d740280 --- /dev/null +++ b/transformer_engine/pytorch/_lite/comm.py @@ -0,0 +1,110 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Communication overlap stubs. + +In lite mode, comm-overlap is not available. These stubs provide the API +surface so that imports succeed, but raise NotImplementedError if actually used. +Use torch.distributed for communication instead. +""" + +from .enums import CommOverlapCore + + +class CommOverlapBase(CommOverlapCore): + """Stub for CommOverlapBase.""" + pass + + +class CommOverlapP2PBase(CommOverlapCore): + """Stub for CommOverlapP2PBase.""" + pass + + +class CommOverlapHelper: + """Stub for CommOverlapHelper.""" + def __init__(self, *args, **kwargs): + pass + + +class CommOverlap(CommOverlapBase): + """Stub for CommOverlap.""" + def __init__(self, *args, **kwargs): + raise NotImplementedError( + "CommOverlap is not available in lite mode. " + "Use torch.distributed for communication." + ) + + def copy_into_buffer(self, *args, **kwargs): + raise NotImplementedError + + def get_buffer(self, *args, **kwargs): + raise NotImplementedError + + def get_communication_stream(self, *args, **kwargs): + raise NotImplementedError + + +class CommOverlapP2P(CommOverlapP2PBase): + """Stub for CommOverlapP2P.""" + def __init__(self, *args, **kwargs): + raise NotImplementedError( + "CommOverlapP2P is not available in lite mode. " + "Use torch.distributed for communication." + ) + + def copy_into_buffer(self, *args, **kwargs): + raise NotImplementedError + + def get_buffer(self, *args, **kwargs): + raise NotImplementedError + + def get_communication_stream(self, *args, **kwargs): + raise NotImplementedError + + +def bulk_overlap_ag_with_external_gemm(*args, **kwargs): + """Stub: Bulk overlap AG with external GEMM.""" + raise NotImplementedError("Communication overlap not available in lite mode.") + + +def init_nvshmem_backend(*args, **kwargs): + """Stub: Initialize NVSHMEM/ROCSHMEM backend.""" + raise NotImplementedError("NVSHMEM/ROCSHMEM not available in lite mode.") + + +def create_nvshmem_tensor(*args, **kwargs): + """Stub: Create NVSHMEM/ROCSHMEM tensor.""" + raise NotImplementedError("NVSHMEM/ROCSHMEM not available in lite mode.") + + +def nvshmem_send_on_current_stream(*args, **kwargs): + """Stub: NVSHMEM send.""" + raise NotImplementedError("NVSHMEM/ROCSHMEM not available in lite mode.") + + +def nvshmem_wait_on_current_stream(*args, **kwargs): + """Stub: NVSHMEM wait.""" + raise NotImplementedError("NVSHMEM/ROCSHMEM not available in lite mode.") + + +def nvshmem_finalize(*args, **kwargs): + """Stub: NVSHMEM finalize.""" + raise NotImplementedError("NVSHMEM/ROCSHMEM not available in lite mode.") + + +def device_supports_multicast(device_id=-1): + """Stub: Check multicast support.""" + return False + + +def get_stream_priority_range(device_id=-1): + """Stub: Get stream priority range.""" + return (0, 0) + + +def ubuf_built_with_mpi(): + """Stub: Check if userbuffers built with MPI.""" + return False diff --git a/transformer_engine/pytorch/_lite/context_parallel.py b/transformer_engine/pytorch/_lite/context_parallel.py new file mode 100644 index 000000000..ff4e5f013 --- /dev/null +++ b/transformer_engine/pytorch/_lite/context_parallel.py @@ -0,0 +1,39 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Context parallel THD format helpers -- PyTorch-native tensor ops. + +TODO Phase 3: Implement as PyTorch tensor slicing operations. +""" + + +def thd_read_half_tensor(*args, **kwargs): + """Read first or second half of each sequence in a THD tensor.""" + raise NotImplementedError("thd_read_half_tensor not yet implemented in lite mode.") + + +def thd_second_half_lse_correction(*args, **kwargs): + """Correct the second half of softmax_lse.""" + raise NotImplementedError("thd_second_half_lse_correction not yet implemented in lite mode.") + + +def thd_read_second_half_lse(*args, **kwargs): + """Read the second half of softmax_lse.""" + raise NotImplementedError("thd_read_second_half_lse not yet implemented in lite mode.") + + +def thd_out_correction(*args, **kwargs): + """Correct THD format output of context parallelism in forward pass.""" + raise NotImplementedError("thd_out_correction not yet implemented in lite mode.") + + +def thd_grad_correction(*args, **kwargs): + """Correct THD format gradients of context parallelism in backward pass.""" + raise NotImplementedError("thd_grad_correction not yet implemented in lite mode.") + + +def thd_get_partitioned_indices(*args, **kwargs): + """Generate partitioned indices for inputs in THD format.""" + raise NotImplementedError("thd_get_partitioned_indices not yet implemented in lite mode.") diff --git a/transformer_engine/pytorch/_lite/dropout.py b/transformer_engine/pytorch/_lite/dropout.py new file mode 100644 index 000000000..148c7fd20 --- /dev/null +++ b/transformer_engine/pytorch/_lite/dropout.py @@ -0,0 +1,49 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Dropout -- PyTorch-native implementation.""" + +import torch +import torch.nn.functional as F + + +def dropout_fwd(input, dropout_probability, out=None): + """Dropout forward. + + Returns (output, mask) tuple. The C++ version uses 8-bit RNG masks; + here we use standard PyTorch boolean masks. + """ + if dropout_probability == 0.0: + mask = torch.ones_like(input, dtype=torch.uint8) + if out is not None: + out.copy_(input) + return out, mask + return input.clone(), mask + + keep_prob = 1.0 - dropout_probability + mask = (torch.rand_like(input) < keep_prob).to(torch.uint8) + output = input * mask.to(input.dtype) / keep_prob + + if out is not None: + out.copy_(output) + return out, mask + return output, mask + + +def dropout_bwd(grad_output, mask, dropout_probability, grad_input=None): + """Dropout backward.""" + if dropout_probability == 0.0: + if grad_input is not None: + grad_input.copy_(grad_output) + return grad_input + return grad_output.clone() + + keep_prob = 1.0 - dropout_probability + output = grad_output * mask.to(grad_output.dtype) / keep_prob + + if grad_input is not None: + grad_input.copy_(output) + return grad_input + return output diff --git a/transformer_engine/pytorch/_lite/enums.py b/transformer_engine/pytorch/_lite/enums.py new file mode 100644 index 000000000..6c05a5bc4 --- /dev/null +++ b/transformer_engine/pytorch/_lite/enums.py @@ -0,0 +1,167 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Pure-Python re-declarations of C++ enum types from transformer_engine_torch.""" + +import enum + + +class DType(enum.IntEnum): + """Data type enum matching transformer_engine::DType.""" + kByte = 0 + kInt32 = 1 + kFloat32 = 2 + kFloat16 = 3 + kBFloat16 = 4 + kFloat8E4M3 = 5 + kFloat8E5M2 = 6 + kFloat4E2M1 = 7 + + +class NVTE_Bias_Type(enum.IntEnum): + """Bias type for fused attention.""" + NVTE_NO_BIAS = 0 + NVTE_PRE_SCALE_BIAS = 1 + NVTE_POST_SCALE_BIAS = 2 + NVTE_ALIBI = 3 + + +class NVTE_Mask_Type(enum.IntEnum): + """Mask type for fused attention.""" + NVTE_NO_MASK = 0 + NVTE_PADDING_MASK = 1 + NVTE_CAUSAL_MASK = 2 + NVTE_PADDING_CAUSAL_MASK = 3 + NVTE_CAUSAL_BOTTOM_RIGHT_MASK = 4 + NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5 + + +class NVTE_Softmax_Type(enum.IntEnum): + """Softmax type for fused attention.""" + NVTE_VANILLA_SOFTMAX = 0 + NVTE_OFF_BY_ONE_SOFTMAX = 1 + NVTE_LEARNABLE_SOFTMAX = 2 + + +class NVTE_QKV_Format(enum.IntEnum): + """QKV tensor format for fused attention.""" + NVTE_BSHD = 0 + NVTE_SBHD = 1 + NVTE_THD = 2 + NVTE_SBHD_2BSHD = 3 + NVTE_BSHD_2SBHD = 4 + NVTE_THD_2BSHD = 5 + NVTE_THD_2SBHD = 6 + + +class NVTE_QKV_Layout(enum.IntEnum): + """QKV layout for fused attention.""" + NVTE_SB3HD = 0 + NVTE_SBH3D = 1 + NVTE_SBHD_SB2HD = 2 + NVTE_SBHD_SBH2D = 3 + NVTE_SBHD_SBHD_SBHD = 4 + NVTE_BS3HD = 5 + NVTE_BSH3D = 6 + NVTE_BSHD_BS2HD = 7 + NVTE_BSHD_BSH2D = 8 + NVTE_BSHD_BSHD_BSHD = 9 + NVTE_T3HD = 10 + NVTE_TH3D = 11 + NVTE_THD_T2HD = 12 + NVTE_THD_TH2D = 13 + NVTE_THD_THD_THD = 14 + NVTE_SBHD_BSHD_BSHD = 15 + NVTE_BSHD_SBHD_SBHD = 16 + NVTE_THD_BSHD_BSHD = 17 + NVTE_THD_SBHD_SBHD = 18 + NVTE_Paged_KV_BSHD_BSHD_BSHD = 19 + NVTE_Paged_KV_BSHD_SBHD_SBHD = 20 + NVTE_Paged_KV_SBHD_BSHD_BSHD = 21 + NVTE_Paged_KV_SBHD_SBHD_SBHD = 22 + NVTE_Paged_KV_THD_BSHD_BSHD = 23 + NVTE_Paged_KV_THD_SBHD_SBHD = 24 + + +class NVTE_Fused_Attn_Backend(enum.IntEnum): + """Fused attention backend selection (ROCm values).""" + NVTE_AOTriton = 0 + NVTE_CK = 1 + NVTE_No_Backend = 2 + # Lite-mode additions + NVTE_SDPA = 100 + NVTE_Flash = 101 + # Included for API parity with the full build. Lite does not actually + # implement an FP8 attention kernel — get_fused_attn_backend raises + # NotImplementedError when FP8 inputs are requested (fp8_dpa=True). + NVTE_FP8 = 200 + + +class Float8BlockScaleTensorFormat(enum.IntEnum): + """Block scale tensor format.""" + GEMM_READY = 0 + COMPACT = 1 + + +class FP8FwdTensors(enum.IntEnum): + """FP8 forward tensor indices.""" + GEMM1_INPUT = 0 + GEMM1_WEIGHT = 1 + GEMM1_OUTPUT = 2 + GEMM2_INPUT = 3 + GEMM2_WEIGHT = 4 + GEMM2_OUTPUT = 5 + GEMM3_INPUT = 6 + GEMM3_WEIGHT = 7 + GEMM3_OUTPUT = 8 + + +class FP8BwdTensors(enum.IntEnum): + """FP8 backward tensor indices.""" + GRAD_OUTPUT1 = 0 + GRAD_INPUT1 = 1 + GRAD_OUTPUT2 = 2 + GRAD_INPUT2 = 3 + GRAD_OUTPUT3 = 4 + GRAD_INPUT3 = 5 + + +class CommOverlapType(enum.IntEnum): + """Communication overlap type.""" + RS = 0 + AG = 1 + + +class CommOverlapAlgo(enum.IntEnum): + """Communication overlap algorithm.""" + BULK_OVERLAP_AG = 0 + BULK_OVERLAP_RS = 1 + SPLIT_PIPELINED_AG_P2P = 2 + SPLIT_PIPELINED_RS = 3 + SPLIT_PIPELINED_RS_P2P = 4 + ATOMIC_GEMM_RS = 5 + ATOMIC_GEMM_AG_P2P = 6 + ATOMIC_GEMM_RS_P2P = 7 + EXTERNAL_BULK_OVERLAP_AG = 8 + + +class FP8TensorMeta: + """FP8 tensor metadata (pure Python replacement).""" + def __init__(self): + self.scale = None + self.scale_inv = None + self.amax_history = None + + +class CommOverlapCore: + """Stub for CommOverlapCore.""" + def is_atomic_gemm(self): + return False + + def is_p2p_overlap(self): + return False + + def is_fp8_ubuf(self): + return False diff --git a/transformer_engine/pytorch/_lite/fused_layernorm_linear.py b/transformer_engine/pytorch/_lite/fused_layernorm_linear.py new file mode 100644 index 000000000..3fe1614e9 --- /dev/null +++ b/transformer_engine/pytorch/_lite/fused_layernorm_linear.py @@ -0,0 +1,605 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Lite-native LayerNormLinear: fused normalization + linear projection.""" + +import os +from typing import Callable, Optional, Tuple, Union, List + +import torch + +import transformer_engine_torch as tex + +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule +from transformer_engine.pytorch.quantized_tensor import ( + QuantizedTensor, + QuantizedTensorStorage, + Quantizer, + prepare_for_saving, + restore_from_saved, +) +from transformer_engine.pytorch.utils import ( + cast_if_needed, + get_default_init_method, + init_method_constant, +) + +from .amax_utils import update_amax_from_bf16 +from .gemm import _gemm_bump + + +__all__ = ["LayerNormLinear"] + + +# Opt-in: skip the FP8 cast on the dgrad GEMM output when the only consumer +# is the norm backward (which dequantizes immediately). Eliminates the +# BF16 -> FP8 -> BF16 round-trip; preserves DelayedScaling amax bookkeeping +# via a standalone reduction. See amax_utils.update_amax_from_bf16. +_SKIP_FP8_DGRAD_FOR_NORM = ( + os.environ.get("NVTE_LITE_SKIP_FP8_DGRAD_FOR_NORM", "0") != "0" +) + + +def _can_skip_dgrad_cast(fp8, requires_dgrad, grad_input_quantizer): + """Whether the bwd dgrad path can emit BF16 instead of FP8. + + Decoupled from the specific ctx attribute name so both LayerNormLinear + (ctx.grad_input_quantizer) and LayerNormMLP (ctx.fc1_grad_input_quantizer) + can share the predicate. + """ + if not _SKIP_FP8_DGRAD_FOR_NORM: + return False + if not (fp8 and requires_dgrad): + return False + if grad_input_quantizer is None: + return False + # MXFP8 uses block scales computed at quantize time; amax-only shortcut + # doesn't reconstruct the per-block state. Scope to per-tensor/per-row. + return type(grad_input_quantizer).__name__ in ( + "Float8Quantizer", "Float8CurrentScalingQuantizer", + ) + + +def _get_normalization_funcs(normalization: str): + """Return (fwd_func, bwd_func) for the given normalization type.""" + if normalization == "RMSNorm": + return tex.rmsnorm_fwd, tex.rmsnorm_bwd + elif normalization == "LayerNorm": + return tex.layernorm_fwd, tex.layernorm_bwd + else: + raise ValueError(f"Unsupported normalization: {normalization}") + + +class _LayerNormLinearLite(torch.autograd.Function): + """Autograd function for fused LayerNorm + Linear (lite backend).""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + ln_weight: torch.Tensor, + ln_bias: Optional[torch.Tensor], + weight: torch.Tensor, + bias: Optional[torch.Tensor], + eps: float, + fp8: bool, + input_quantizer: Optional[Quantizer], + weight_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], + grad_input_quantizer: Optional[Quantizer], + grad_weight_quantizer: Optional[Quantizer], + activation_dtype: torch.dtype, + return_layernorm_output: bool, + normalization: str, + zero_centered_gamma: bool, + is_grad_enabled: bool, + module: "LayerNormLinear", + is_first_microbatch: Optional[bool], + ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: + + # Reshape input + in_features = weight.shape[1] + out_features = weight.shape[0] + inp_shape = inp.shape + inputmat = inp.reshape(-1, in_features) + + # Cast for native AMP + inputmat = cast_if_needed(inputmat, activation_dtype) + ln_weight = cast_if_needed(ln_weight, activation_dtype) + if ln_bias is not None: + ln_bias = cast_if_needed(ln_bias, activation_dtype) + + # Configure norm quantizer + backward_needs_input = is_grad_enabled and weight.requires_grad + if fp8 and input_quantizer is not None: + input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + + # Determine if we can use fused norm+quantize + with_quantized_norm = ( + fp8 + and input_quantizer is not None + and not return_layernorm_output + ) + + # Apply normalization + norm_fwd, _ = _get_normalization_funcs(normalization) + if normalization == "LayerNorm": + ln_out, mu, rsigma = norm_fwd( + inputmat, ln_weight, ln_bias, eps, + None, # ln_out (allocate internally) + input_quantizer if with_quantized_norm else None, + inputmat.dtype, + 0, # sm_margin (unused in lite) + zero_centered_gamma, + ) + else: # RMSNorm + ln_out, mu, rsigma = norm_fwd( + inputmat, ln_weight, eps, + None, # ln_out + input_quantizer if with_quantized_norm else None, + inputmat.dtype, + 0, # sm_margin + zero_centered_gamma, + ) + + # Save unquantized norm output if needed for return + ln_out_return = ln_out if return_layernorm_output else None + + # Quantize norm output if not already done via fused kernel. + # Set columnwise usage up front so the quantizer allocates the + # transpose buffer — required for the Triton cast-transpose path + # (otherwise _transpose_invalid=True forces the PyTorch fallback). + if fp8 and not with_quantized_norm and input_quantizer is not None: + input_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + ln_out = input_quantizer(ln_out) + + # Prepare weight + weightmat = weight + if fp8 and weight_quantizer is not None: + weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + ) + weightmat.update_usage(rowwise_usage=True) + else: + weightmat = cast_if_needed(weightmat, activation_dtype) + + # Prepare bias + gemm_bias = cast_if_needed(bias, activation_dtype) if bias is not None else bias + + # Forward GEMM: y = ln_out @ weight^T + bias + bias_dtype = TE_DType[torch.bfloat16 if gemm_bias is None else gemm_bias.dtype] + gemm_out, _, _, _ = tex.generic_gemm( + weightmat, # A + True, # transA (weight is [out, in], need transpose) + ln_out, # B + False, # transB + None, # D (allocate internally) + None, # quantizer + TE_DType[activation_dtype] if activation_dtype in TE_DType else None, + gemm_bias, # bias + bias_dtype, # bias_type (actually bias dtype) + False, # gelu + None, # gelu_in + False, # grad + torch.empty(0), # workspace (unused in lite) + 0, # workspace_size + False, # accumulate + False, # use_split_accumulator + ) + + out = gemm_out.view(-1, *inp_shape[1:-1], out_features) + + # Save tensors for backward + if is_grad_enabled: + tensors_to_save, tensor_objects = prepare_for_saving( + inputmat, weightmat, weight, bias, ln_weight, ln_out, mu, rsigma, + ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + ctx.inp_shape = inp_shape + ctx.activation_dtype = activation_dtype + ctx.fp8 = fp8 + ctx.normalization = normalization + ctx.zero_centered_gamma = zero_centered_gamma + ctx.use_bias = bias is not None + ctx.requires_dgrad = inp.requires_grad + ctx.requires_wgrad = weight.requires_grad + ctx.input_quantizer = input_quantizer + ctx.weight_quantizer = weight_quantizer + ctx.grad_output_quantizer = grad_output_quantizer + ctx.grad_input_quantizer = grad_input_quantizer + ctx.grad_weight_quantizer = grad_weight_quantizer + ctx.return_layernorm_output = return_layernorm_output + + if return_layernorm_output: + return out, ln_out_return.view(inp_shape) + return out + + @staticmethod + def backward(ctx, *grad_outputs): + grad_output = grad_outputs[0] + + saved_tensors = ctx.saved_tensors + ( + inputmat, + weightmat, + weight, + bias, + ln_weight, + ln_out, + mu, + rsigma, + ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + ctx.tensor_objects = None + + # Prepare grad_output + grad_output = grad_output.reshape(-1, weight.shape[0]) + grad_output = cast_if_needed(grad_output, ctx.activation_dtype) + + # Quantize grad_output for FP8 backward + if ctx.fp8 and ctx.grad_output_quantizer is not None: + ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) + grad_output = ctx.grad_output_quantizer(grad_output) + + # ---- DGRAD: d_ln_out = grad_output @ weight (NN layout) ---- + d_ln_out = None + if ctx.requires_dgrad: + # Configure grad_input quantizer for dgrad output + skip_cast = _can_skip_dgrad_cast( + ctx.fp8, ctx.requires_dgrad, ctx.grad_input_quantizer, + ) + dgrad_quantizer = None + if not skip_cast and ctx.fp8 and ctx.grad_input_quantizer is not None: + ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) + dgrad_quantizer = ctx.grad_input_quantizer + + bias_dtype = TE_DType[torch.bfloat16] + d_ln_out, _, _, _ = tex.generic_gemm( + weightmat, # A (weight) + False, # transA=False (NN layout) + grad_output, # B + False, # transB + None, # D + dgrad_quantizer, # quantizer — FP8 dgrad output (None when skip_cast) + TE_DType[ctx.activation_dtype] if ctx.activation_dtype in TE_DType else None, + None, # bias + bias_dtype, # bias_type + False, # gelu + None, # gelu_in + False, # grad + torch.empty(0), # workspace + 0, # workspace_size + False, # accumulate + False, # use_split_accumulator + ) + + if skip_cast: + # d_ln_out stays BF16; norm bwd consumes it directly. + # Replace the amax side-effect of the FP8 cast with a + # standalone reduction so DelayedScaling history stays current. + update_amax_from_bf16(ctx.grad_input_quantizer, d_ln_out) + _gemm_bump("dgrad_skip_fp8_cast") + + # ---- WGRAD: dW = grad_output^T @ ln_out (NT layout) ---- + dweight = None + dbias = None + if ctx.requires_wgrad: + # Re-quantize ln_out with columnwise usage for NT wgrad GEMM + if ctx.fp8 and ctx.input_quantizer is not None: + if isinstance(ln_out, QuantizedTensorStorage): + ln_out.update_usage(columnwise_usage=True) + else: + ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) + ln_out = ctx.input_quantizer(ln_out) + + # Configure grad_output with columnwise usage for wgrad + if ctx.fp8 and ctx.grad_output_quantizer is not None: + if isinstance(grad_output, QuantizedTensorStorage): + grad_output.update_usage(columnwise_usage=True) + else: + ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) + grad_output = ctx.grad_output_quantizer(grad_output) + + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + dweight, dbias_gemm, _, _ = tex.generic_gemm( + ln_out, # A (input for wgrad) + False, # transA (N) + grad_output, # B (grad output) + True, # transB (T) → NT layout + None, # D + ctx.grad_weight_quantizer, # quantizer — FP8 wgrad output + TE_DType[ctx.activation_dtype] if ctx.activation_dtype in TE_DType else None, + bias if ctx.use_bias else None, # bias (for grad computation) + bias_dtype, # bias_type + False, # gelu + None, # gelu_in + True, # grad (compute bias gradient) + torch.empty(0), # workspace + 0, # workspace_size + False, # accumulate + False, # use_split_accumulator + ) + if ctx.use_bias: + dbias = dbias_gemm + + # ---- Norm backward ---- + dgrad = None + dgamma = None + dbeta = None + if ctx.requires_dgrad: + _, norm_bwd = _get_normalization_funcs(ctx.normalization) + if ctx.normalization == "LayerNorm": + dgrad, dgamma, dbeta = norm_bwd( + d_ln_out, inputmat, mu, rsigma, ln_weight, + 0, # sm_margin + ctx.zero_centered_gamma, + ) + else: # RMSNorm + dgrad, dgamma = norm_bwd( + d_ln_out, inputmat, rsigma, ln_weight, + 0, # sm_margin + ctx.zero_centered_gamma, + ) + + dgrad = dgrad.view(ctx.inp_shape) + + # Return gradients matching forward signature + return ( + dgrad, # inp + dgamma, # ln_weight + dbeta, # ln_bias + dweight, # weight + dbias, # bias + None, # eps + None, # fp8 + None, # input_quantizer + None, # weight_quantizer + None, # grad_output_quantizer + None, # grad_input_quantizer + None, # grad_weight_quantizer + None, # activation_dtype + None, # return_layernorm_output + None, # normalization + None, # zero_centered_gamma + None, # is_grad_enabled + None, # module + None, # is_first_microbatch + ) + + +class LayerNormLinear(TransformerEngineBaseModule): + """Fused LayerNorm + Linear (lite-native, single-node). + + Applies normalization followed by a linear transformation: + y = weight @ norm(x) + bias + + Parameters + ---------- + in_features : int + Input feature dimension (also the normalization dimension). + out_features : int + Output feature dimension. + eps : float, default = 1e-5 + Epsilon for normalization stability. + bias : bool, default = True + Whether to include a bias term in the linear layer. + normalization : str, default = "LayerNorm" + Type of normalization: "LayerNorm" or "RMSNorm". + init_method : callable, optional + Weight initialization function. + params_dtype : torch.dtype, optional + Data type for parameters (default: current default dtype). + zero_centered_gamma : bool, default = False + If True, gamma is initialized to zero and used as (1 + gamma). + return_layernorm_output : bool, default = False + If True, also return the normalization output. + device : str or torch.device, default = "cuda" + Device for parameters. + """ + + def __init__( + self, + in_features: int, + out_features: int, + eps: float = 1e-5, + bias: bool = True, + normalization: str = "LayerNorm", + init_method: Optional[Callable] = None, + params_dtype: Optional[torch.dtype] = None, + zero_centered_gamma: bool = False, + return_layernorm_output: bool = False, + device: Union[torch.device, str] = "cuda", + # FSDP2 per-parameter sharding: wrap weights in FSDPAGTensor so the + # quantizer runs at all-gather time instead of at init (ROCm only). + use_fsdp2: bool = False, + keep_fp8_weight_transpose_cache: bool = True, + # Accepted for API compatibility with full-build LayerNormLinear but + # ignored in lite mode (no TP/SP/userbuffers support): + return_bias: bool = False, + parallel_mode: Optional[str] = None, + sequence_parallel: bool = False, + tp_group=None, + tp_size: int = 1, + parameters_split: Optional[Union[tuple, dict]] = None, + **kwargs, + ) -> None: + super().__init__() + + params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype + self.in_features = in_features + self.out_features = out_features + self.eps = eps + self.use_bias = bias + self.normalization = normalization + assert normalization in ("LayerNorm", "RMSNorm"), "Unsupported normalization type!" + self.zero_centered_gamma = zero_centered_gamma + self.return_layernorm_output = return_layernorm_output + + # FSDP2 flags must be set before register_parameter/reset_parameters + # so the inherited base-class wrap logic (module/base.py) sees them. + # The wrap also requires IS_HIP_EXTENSION; mirror the full build's + # gate so non-ROCm runs silently ignore the flag. + from torch.utils.cpp_extension import IS_HIP_EXTENSION as _IS_HIP + self.use_fsdp2 = use_fsdp2 if _IS_HIP else False + self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache + + # No TP/SP in lite + self.tp_size = 1 + self.sequence_parallel = False + self.set_tensor_parallel_group(None) + + if init_method is None: + init_method = get_default_init_method() + + # Norm parameters + layer_norm_weight = torch.nn.Parameter( + torch.empty(in_features, device=device, dtype=params_dtype) + ) + self.register_parameter( + "layer_norm_weight", + layer_norm_weight, + init_fn=init_method_constant(float(not zero_centered_gamma)), + ) + if normalization != "RMSNorm": + layer_norm_bias = torch.nn.Parameter( + torch.empty(in_features, device=device, dtype=params_dtype) + ) + self.register_parameter( + "layer_norm_bias", + layer_norm_bias, + init_fn=init_method_constant(0.0), + ) + else: + self.layer_norm_bias = None + + # Linear parameters + weight_tensor = torch.empty( + out_features, in_features, device=device, dtype=params_dtype, + ) + self.weight_names = ["weight"] + self.bias_names = ["bias"] + self.parameter_split_sizes = [out_features] + + self.register_parameter( + "weight", + torch.nn.Parameter(weight_tensor), + init_fn=init_method, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + + if self.use_bias: + self.register_parameter( + "bias", + torch.nn.Parameter( + torch.empty(out_features, device=device, dtype=params_dtype) + ), + init_fn=init_method_constant(0.0), + ) + else: + self.bias = torch.Tensor().to(dtype=params_dtype, device=device) + + with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() + if with_fp8_params: + self.init_fp8_metadata() + + self.reset_parameters(defer_init=(device == "meta")) + + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: + w = getattr(self, "weight") + if isinstance(w, QuantizedTensor) and self.fp8: + return [w.get_quantized_tensor()] + return [w] + + def _get_weight_quantizers(self) -> List[Quantizer]: + if not self.fp8 and not self.fp8_calibration: + return [None] + weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_quantizer.internal = True + return [weight_quantizer] + + def _get_quantizers(self, fp8_output: bool = False): + if not self.fp8: + return (None, None, None, None, None) + input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + input_quantizer.internal = True + (weight_quantizer,) = self._get_weight_quantizers() + grad_output_quantizer = None + grad_input_quantizer = None + grad_weight_quantizer = None + if torch.is_grad_enabled(): + grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + grad_output_quantizer.internal = True + grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + return ( + input_quantizer, weight_quantizer, + grad_output_quantizer, grad_input_quantizer, grad_weight_quantizer, + ) + + def set_meta_tensor(self, fwd: bool, recipe) -> None: + super().set_meta_tensor(fwd, recipe) + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + self._customize_quantizers_float8_current_scaling(fwd, recipe) + elif recipe.float8_block_scaling(): + self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + + def _customize_quantizers_float8_current_scaling(self, fwd, recipe): + if fwd: + for idx in (tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_WEIGHT): + if idx in self.quantizers["scaling_fwd"]: + q = self.quantizers["scaling_fwd"][idx] + if hasattr(recipe, 'fp8_quant_fwd_inp'): + q.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + q.amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + + def _customize_quantizers_float8_blockwise_scaling(self, fwd, recipe): + pass # Block scaling quantizers work with defaults + + def forward( + self, + inp: torch.Tensor, + is_first_microbatch: Optional[bool] = None, + fp8_output: bool = False, + **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + with self.prepare_forward(inp, num_gemms=1): + ( + input_quantizer, + weight_quantizer, + grad_output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + ) = self._get_quantizers() + + out = _LayerNormLinearLite.apply( + inp, + self.layer_norm_weight, + self.layer_norm_bias, + self.weight, + self.bias if self.use_bias else None, + self.eps, + self.fp8, + input_quantizer, + weight_quantizer, + grad_output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + self.activation_dtype, + self.return_layernorm_output, + self.normalization, + self.zero_centered_gamma, + torch.is_grad_enabled(), + self, + is_first_microbatch, + ) + + return out diff --git a/transformer_engine/pytorch/_lite/fused_layernorm_mlp.py b/transformer_engine/pytorch/_lite/fused_layernorm_mlp.py new file mode 100644 index 000000000..4bf2f3c72 --- /dev/null +++ b/transformer_engine/pytorch/_lite/fused_layernorm_mlp.py @@ -0,0 +1,704 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Lite-native LayerNormMLP: fused normalization + two-layer MLP.""" + +from typing import Callable, Dict, Optional, Tuple, Union, List + +import torch +from torch.nn import Parameter + +import transformer_engine_torch as tex + +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule +from transformer_engine.pytorch.quantized_tensor import ( + QuantizedTensor, + QuantizedTensorStorage, + Quantizer, + prepare_for_saving, + restore_from_saved, +) +from transformer_engine.pytorch.utils import ( + cast_if_needed, + get_default_init_method, + init_method_constant, +) + +from .amax_utils import update_amax_from_bf16 +from .fused_layernorm_linear import _can_skip_dgrad_cast, _get_normalization_funcs +from .gemm import _gemm_bump + + +__all__ = ["LayerNormMLP"] + + +_GATED_ACTIVATIONS = frozenset({ + "geglu", "qgeglu", "reglu", "sreglu", "swiglu", "clamped_swiglu", +}) + +# Maps activation name → (forward_fn, backward_fn, fused_dbias_dact_fn_or_None) +_ACT_FUNC_MAP = { + "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), + "geglu": (tex.geglu, tex.dgeglu, None), + "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), + "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "relu": (tex.relu, tex.drelu, tex.dbias_drelu), + "reglu": (tex.reglu, tex.dreglu, None), + "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), + "sreglu": (tex.sreglu, tex.dsreglu, None), + "silu": (tex.silu, tex.dsilu, tex.dbias_dsilu), + "swiglu": (tex.swiglu, tex.dswiglu, None), + "clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None), +} + + +def _gemm(A, transA, B, transB, bias, grad, quantizer=None, output_dtype=None, # noqa: E501 + gelu=False, gelu_in=None, accumulate=False): + """Thin wrapper around _lite generic_gemm with sane defaults.""" + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] + return tex.generic_gemm( + A, transA, B, transB, + None, # D (allocate internally) + quantizer, + output_dtype, + bias, + bias_dtype, + gelu, + gelu_in, + grad, + torch.empty(0), # workspace (unused in lite) + 0, # workspace_size + accumulate, + False, # use_split_accumulator + ) + + +class _LayerNormMLPLite(torch.autograd.Function): + """Autograd function for fused LayerNorm + MLP (lite backend).""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + ln_weight: torch.Tensor, + ln_bias: Optional[torch.Tensor], + fc1_weight: torch.Tensor, + fc1_bias: Optional[torch.Tensor], + fc2_weight: torch.Tensor, + fc2_bias: Optional[torch.Tensor], + eps: float, + fp8: bool, + fc1_input_quantizer: Optional[Quantizer], + fc1_weight_quantizer: Optional[Quantizer], + fc2_input_quantizer: Optional[Quantizer], + fc2_weight_quantizer: Optional[Quantizer], + fc2_grad_output_quantizer: Optional[Quantizer], + fc1_grad_output_quantizer: Optional[Quantizer], + fc1_grad_input_quantizer: Optional[Quantizer], + fc1_grad_weight_quantizer: Optional[Quantizer], + activation_dtype: torch.dtype, + return_layernorm_output: bool, + normalization: str, + zero_centered_gamma: bool, + activation: str, + activation_params: Optional[Dict], + is_grad_enabled: bool, + module: "LayerNormMLP", + is_first_microbatch: Optional[bool], + ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: + + act_fwd, act_bwd, dbias_dact = _ACT_FUNC_MAP[activation] + is_gated = activation in _GATED_ACTIVATIONS + + # Reshape input + hidden_size = fc1_weight.shape[1] + inp_shape = inp.shape + inputmat = inp.reshape(-1, hidden_size) + + # Cast for native AMP + inputmat = cast_if_needed(inputmat, activation_dtype) + ln_weight = cast_if_needed(ln_weight, activation_dtype) + if ln_bias is not None: + ln_bias = cast_if_needed(ln_bias, activation_dtype) + + # Configure norm quantizer + backward_needs_input = is_grad_enabled and fc1_weight.requires_grad + if fp8 and fc1_input_quantizer is not None: + fc1_input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + + with_quantized_norm = ( + fp8 + and fc1_input_quantizer is not None + and not return_layernorm_output + ) + + # ---- Normalization ---- + norm_fwd, _ = _get_normalization_funcs(normalization) + if normalization == "LayerNorm": + ln_out, mu, rsigma = norm_fwd( + inputmat, ln_weight, ln_bias, eps, None, + fc1_input_quantizer if with_quantized_norm else None, + inputmat.dtype, 0, zero_centered_gamma, + ) + else: + ln_out, mu, rsigma = norm_fwd( + inputmat, ln_weight, eps, None, + fc1_input_quantizer if with_quantized_norm else None, + inputmat.dtype, 0, zero_centered_gamma, + ) + + ln_out_return = ln_out if return_layernorm_output else None + + # Quantize norm output if not already fused + if fp8 and not with_quantized_norm and fc1_input_quantizer is not None: + ln_out = fc1_input_quantizer(ln_out) + + # ---- Prepare FC1 weight ---- + fc1_weightmat = fc1_weight + if fp8 and fc1_weight_quantizer is not None: + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + update = is_first_microbatch is None or is_first_microbatch + fc1_weightmat = module.get_weight_workspace( + tensor=fc1_weight, quantizer=fc1_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc1_weight"), + update_workspace=update, + ) + fc1_weightmat.update_usage(rowwise_usage=True) + else: + fc1_weightmat = cast_if_needed(fc1_weightmat, activation_dtype) + + # ---- FC1 GEMM ---- + fc1_bias_cast = cast_if_needed(fc1_bias, activation_dtype) if fc1_bias is not None else None + out_dtype = TE_DType[activation_dtype] if activation_dtype in TE_DType else None + + fc1_out, _, gelu_input, _ = _gemm( + fc1_weightmat, True, ln_out, False, + bias=fc1_bias_cast, grad=False, output_dtype=out_dtype, + ) + + # ---- Activation ---- + act_kwargs = activation_params or {} + act_out = act_fwd(fc1_out, fc2_input_quantizer if fp8 else None, **act_kwargs) + + # ---- Prepare FC2 weight ---- + fc2_weightmat = fc2_weight + if fp8 and fc2_weight_quantizer is not None: + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + update = is_first_microbatch is None or is_first_microbatch + fc2_weightmat = module.get_weight_workspace( + tensor=fc2_weight, quantizer=fc2_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc2_weight"), + update_workspace=update, + ) + fc2_weightmat.update_usage(rowwise_usage=True) + else: + fc2_weightmat = cast_if_needed(fc2_weightmat, activation_dtype) + + # ---- FC2 GEMM ---- + fc2_bias_cast = cast_if_needed(fc2_bias, activation_dtype) if fc2_bias is not None else None + fc2_out, _, _, _ = _gemm( + fc2_weightmat, True, act_out, False, + bias=fc2_bias_cast, grad=False, output_dtype=out_dtype, + ) + + out = fc2_out.view(-1, *inp_shape[1:-1], hidden_size) + + # ---- Save for backward ---- + if is_grad_enabled: + tensors_to_save, tensor_objects = prepare_for_saving( + inputmat, + ln_weight, + ln_out, + fc1_weightmat, fc1_weight, fc1_bias, + fc1_out, + act_out, + fc2_weightmat, fc2_weight, fc2_bias, + mu, rsigma, + ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + ctx.inp_shape = inp_shape + ctx.activation_dtype = activation_dtype + ctx.fp8 = fp8 + ctx.normalization = normalization + ctx.zero_centered_gamma = zero_centered_gamma + ctx.activation = activation + ctx.activation_params = activation_params or {} + ctx.use_fc1_bias = fc1_bias is not None + ctx.use_fc2_bias = fc2_bias is not None + ctx.requires_dgrad = inp.requires_grad + ctx.requires_wgrad = fc1_weight.requires_grad + ctx.fc1_input_quantizer = fc1_input_quantizer + ctx.fc2_input_quantizer = fc2_input_quantizer + ctx.fc2_grad_output_quantizer = fc2_grad_output_quantizer + ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer + ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer + ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer + ctx.return_layernorm_output = return_layernorm_output + + if return_layernorm_output: + return out, ln_out_return.view(inp_shape) + return out + + @staticmethod + def backward(ctx, *grad_outputs): + grad_output = grad_outputs[0] + + saved_tensors = ctx.saved_tensors + ( + inputmat, + ln_weight, + ln_out, + fc1_weightmat, fc1_weight, fc1_bias, + fc1_out, + act_out, + fc2_weightmat, fc2_weight, fc2_bias, + mu, rsigma, + ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + ctx.tensor_objects = None + + act_fwd, act_bwd, dbias_dact = _ACT_FUNC_MAP[ctx.activation] + out_dtype = TE_DType[ctx.activation_dtype] if ctx.activation_dtype in TE_DType else None + hidden_size = fc1_weight.shape[1] + + grad_output = grad_output.reshape(-1, hidden_size) + grad_output = cast_if_needed(grad_output, ctx.activation_dtype) + + # Quantize grad_output for FP8 (rowwise for FC2 dgrad, columnwise for FC2 wgrad) + if ctx.fp8 and ctx.fc2_grad_output_quantizer is not None: + ctx.fc2_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) + grad_output = ctx.fc2_grad_output_quantizer(grad_output) + + # ---- FC2 DGRAD: d_act = grad_output @ fc2_weight ---- + d_act, _, _, _ = _gemm( + fc2_weightmat, False, grad_output, False, + bias=None, grad=False, output_dtype=out_dtype, + ) + + # ---- FC2 WGRAD: dW2 = grad_output^T @ act_out (NT layout) ---- + dfc2_weight = None + dfc2_bias = None + if ctx.requires_wgrad: + # Re-quantize act_out with columnwise usage for NT wgrad GEMM + if ctx.fp8 and ctx.fc2_input_quantizer is not None: + if isinstance(act_out, QuantizedTensorStorage): + act_out.update_usage(columnwise_usage=True) + else: + ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) + act_out = ctx.fc2_input_quantizer(act_out) + + # Ensure grad_output has columnwise usage for wgrad + if ctx.fp8 and ctx.fc2_grad_output_quantizer is not None: + if isinstance(grad_output, QuantizedTensorStorage): + grad_output.update_usage(columnwise_usage=True) + + dfc2_weight, dfc2_bias_grad, _, _ = _gemm( + act_out, False, grad_output, True, + bias=fc2_bias if ctx.use_fc2_bias else None, + grad=True, output_dtype=out_dtype, + ) + if ctx.use_fc2_bias: + dfc2_bias = dfc2_bias_grad + + # ---- Activation backward + FC1 bias grad ---- + dfc1_bias = None + if dbias_dact is not None and ctx.use_fc1_bias: + # Fused bias gradient + activation backward + dfc1_out, dfc1_bias = dbias_dact(d_act, fc1_out, None, **ctx.activation_params) + else: + # Separate activation backward + dfc1_out = act_bwd(d_act, fc1_out, None, **ctx.activation_params) + if ctx.use_fc1_bias: + dfc1_bias = dfc1_out.reshape(-1, dfc1_out.shape[-1]).sum(dim=0) + + # Quantize dfc1_out (fc1_grad_output) for FC1 GEMMs + if ctx.fp8 and ctx.fc1_grad_output_quantizer is not None: + ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) + dfc1_out = ctx.fc1_grad_output_quantizer(dfc1_out) + + # ---- FC1 DGRAD: d_ln_out = dfc1_out @ fc1_weight ---- + d_ln_out = None + if ctx.requires_dgrad: + # Quantize FC1 dgrad output + skip_cast = _can_skip_dgrad_cast( + ctx.fp8, ctx.requires_dgrad, ctx.fc1_grad_input_quantizer, + ) + dgrad_quantizer = None + if not skip_cast and ctx.fp8 and ctx.fc1_grad_input_quantizer is not None: + ctx.fc1_grad_input_quantizer.set_usage(rowwise=True, columnwise=False) + dgrad_quantizer = ctx.fc1_grad_input_quantizer + + d_ln_out, _, _, _ = _gemm( + fc1_weightmat, False, dfc1_out, False, + bias=None, grad=False, quantizer=dgrad_quantizer, output_dtype=out_dtype, + ) + + if skip_cast: + # d_ln_out stays BF16; norm bwd consumes it directly. + update_amax_from_bf16(ctx.fc1_grad_input_quantizer, d_ln_out) + _gemm_bump("dgrad_skip_fp8_cast") + + # ---- FC1 WGRAD: dW1 = dfc1_out^T @ ln_out (NT layout) ---- + dfc1_weight = None + if ctx.requires_wgrad: + # Re-quantize ln_out with columnwise usage for NT wgrad GEMM + if ctx.fp8 and ctx.fc1_input_quantizer is not None: + if isinstance(ln_out, QuantizedTensorStorage): + ln_out.update_usage(columnwise_usage=True) + else: + ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) + ln_out = ctx.fc1_input_quantizer(ln_out) + + # Ensure dfc1_out has columnwise usage for wgrad + if ctx.fp8 and ctx.fc1_grad_output_quantizer is not None: + if isinstance(dfc1_out, QuantizedTensorStorage): + dfc1_out.update_usage(columnwise_usage=True) + + dfc1_weight, _, _, _ = _gemm( + ln_out, False, dfc1_out, True, + bias=None, grad=False, quantizer=ctx.fc1_grad_weight_quantizer, + output_dtype=out_dtype, + ) + + # ---- Norm backward ---- + dgrad = None + dgamma = None + dbeta = None + if ctx.requires_dgrad: + _, norm_bwd = _get_normalization_funcs(ctx.normalization) + if ctx.normalization == "LayerNorm": + dgrad, dgamma, dbeta = norm_bwd( + d_ln_out, inputmat, mu, rsigma, ln_weight, + 0, ctx.zero_centered_gamma, + ) + else: + dgrad, dgamma = norm_bwd( + d_ln_out, inputmat, rsigma, ln_weight, + 0, ctx.zero_centered_gamma, + ) + dgrad = dgrad.view(ctx.inp_shape) + + # Return gradients matching forward signature order + return ( + dgrad, # inp + dgamma, # ln_weight + dbeta, # ln_bias + dfc1_weight, # fc1_weight + dfc1_bias, # fc1_bias + dfc2_weight, # fc2_weight + dfc2_bias, # fc2_bias + None, # eps + None, # fp8 + None, # fc1_input_quantizer + None, # fc1_weight_quantizer + None, # fc2_input_quantizer + None, # fc2_weight_quantizer + None, # fc2_grad_output_quantizer + None, # fc1_grad_output_quantizer + None, # fc1_grad_input_quantizer + None, # fc1_grad_weight_quantizer + None, # activation_dtype + None, # return_layernorm_output + None, # normalization + None, # zero_centered_gamma + None, # activation + None, # activation_params + None, # is_grad_enabled + None, # module + None, # is_first_microbatch + ) + + +class LayerNormMLP(TransformerEngineBaseModule): + """Fused LayerNorm + MLP (lite-native, single-node). + + Applies normalization followed by a two-layer MLP: + y = fc2(act(fc1(norm(x)))) + + Parameters + ---------- + hidden_size : int + Input and output feature dimension. + ffn_hidden_size : int + Intermediate (FC1 output) feature dimension. + eps : float, default = 1e-5 + Epsilon for normalization stability. + bias : bool, default = True + Whether to include bias terms in linear layers. + normalization : str, default = "LayerNorm" + Type of normalization: "LayerNorm" or "RMSNorm". + activation : str, default = "gelu" + Activation function name. Supports: gelu, geglu, qgelu, qgeglu, + relu, reglu, srelu, sreglu, silu, swiglu, clamped_swiglu. + activation_params : dict, optional + Additional keyword arguments passed to the activation function. + init_method : callable, optional + Weight initialization for FC1. + output_layer_init_method : callable, optional + Weight initialization for FC2. + params_dtype : torch.dtype, optional + Data type for parameters. + zero_centered_gamma : bool, default = False + If True, gamma is initialized to zero and used as (1 + gamma). + return_layernorm_output : bool, default = False + If True, also return the normalization output. + device : str or torch.device, default = "cuda" + Device for parameters. + """ + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + eps: float = 1e-5, + bias: bool = True, + normalization: str = "LayerNorm", + activation: str = "gelu", + activation_params: Optional[Dict] = None, + init_method: Optional[Callable] = None, + output_layer_init_method: Optional[Callable] = None, + params_dtype: Optional[torch.dtype] = None, + zero_centered_gamma: bool = False, + return_layernorm_output: bool = False, + device: Union[torch.device, str] = "cuda", + return_bias: bool = False, + # FSDP2 per-parameter sharding: wrap FC1/FC2 weights in FSDPAGTensor + # so the quantizer runs at all-gather time (ROCm only). + use_fsdp2: bool = False, + keep_fp8_weight_transpose_cache: bool = True, + # Accepted for API compatibility with full-build LayerNormMLP but + # ignored in lite mode (no TP/SP/userbuffers support): + sequence_parallel: bool = False, + tp_group=None, + tp_size: int = 1, + set_parallel_mode: bool = False, + fuse_wgrad_accumulation: bool = False, + **kwargs, + ) -> None: + super().__init__() + + params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.eps = eps + self.use_bias = bias + self.normalization = normalization + assert normalization in ("LayerNorm", "RMSNorm"), "Unsupported normalization type!" + self.activation = activation + self.activation_params = activation_params + assert activation in _ACT_FUNC_MAP, f"Unsupported activation: {activation}" + self.zero_centered_gamma = zero_centered_gamma + self.return_layernorm_output = return_layernorm_output + self.return_bias = return_bias + + # FSDP2 flags must be set before register_parameter/reset_parameters + # so the inherited base-class wrap logic (module/base.py) sees them. + from torch.utils.cpp_extension import IS_HIP_EXTENSION as _IS_HIP + self.use_fsdp2 = use_fsdp2 if _IS_HIP else False + self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache + + # No TP/SP in lite + self.tp_size = 1 + self.sequence_parallel = False + self.set_tensor_parallel_group(None) + + if init_method is None: + init_method = get_default_init_method() + if output_layer_init_method is None: + output_layer_init_method = get_default_init_method() + + is_gated = activation in _GATED_ACTIVATIONS + fc1_output_features = (2 * ffn_hidden_size) if is_gated else ffn_hidden_size + + # ---- Norm parameters ---- + layer_norm_weight = Parameter( + torch.empty(hidden_size, device=device, dtype=params_dtype) + ) + self.register_parameter( + "layer_norm_weight", + layer_norm_weight, + init_fn=init_method_constant(float(not zero_centered_gamma)), + ) + if normalization != "RMSNorm": + layer_norm_bias = Parameter( + torch.empty(hidden_size, device=device, dtype=params_dtype) + ) + self.register_parameter( + "layer_norm_bias", + layer_norm_bias, + init_fn=init_method_constant(0.0), + ) + else: + self.layer_norm_bias = None + + # ---- FC1 parameters ---- + self.weight_names = ["fc1_weight", "fc2_weight"] + self.bias_names = ["fc1_bias", "fc2_bias"] + self.parameter_split_sizes = [fc1_output_features, hidden_size] + + fc1_weight = Parameter( + torch.empty(fc1_output_features, hidden_size, device=device, dtype=params_dtype) + ) + self.register_parameter( + "fc1_weight", fc1_weight, + init_fn=init_method, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + if self.use_bias: + self.register_parameter( + "fc1_bias", + Parameter(torch.empty(fc1_output_features, device=device, dtype=params_dtype)), + init_fn=init_method_constant(0.0), + ) + else: + self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device) + + # ---- FC2 parameters ---- + fc2_weight = Parameter( + torch.empty(hidden_size, ffn_hidden_size, device=device, dtype=params_dtype) + ) + self.register_parameter( + "fc2_weight", fc2_weight, + init_fn=output_layer_init_method, + fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, + ) + if self.use_bias: + self.register_parameter( + "fc2_bias", + Parameter(torch.empty(hidden_size, device=device, dtype=params_dtype)), + init_fn=init_method_constant(0.0), + ) + else: + self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device) + + with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() + if with_fp8_params: + self.init_fp8_metadata() + + self.reset_parameters(defer_init=(device == "meta")) + + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: + results = [] + for name in self.weight_names: + w = getattr(self, name) + if isinstance(w, QuantizedTensor) and self.fp8: + results.append(w.get_quantized_tensor()) + else: + results.append(w) + return results + + def _get_weight_quantizers(self) -> List[Quantizer]: + if not self.fp8 and not self.fp8_calibration: + return [None, None] + q1 = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + q1.internal = True + q2 = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] + q2.internal = True + return [q1, q2] + + def _get_quantizers(self): + if not self.fp8: + return (None,) * 8 + fc1_input_q = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + fc1_input_q.internal = True + fc1_weight_q, fc2_weight_q = self._get_weight_quantizers() + fc2_input_q = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] + fc2_input_q.internal = True + # Backward quantizers + fc2_grad_output_q = None + fc1_grad_output_q = None + fc1_grad_input_q = None + if torch.is_grad_enabled(): + fc2_grad_output_q = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2] + fc2_grad_output_q.internal = True + fc1_grad_output_q = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + fc1_grad_output_q.internal = True + fc1_grad_input_q = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + return ( + fc1_input_q, fc1_weight_q, fc2_input_q, fc2_weight_q, + fc2_grad_output_q, fc1_grad_output_q, fc1_grad_input_q, + None, # fc1_grad_weight_q (not used in full build either) + ) + + def set_meta_tensor(self, fwd: bool, recipe) -> None: + super().set_meta_tensor(fwd, recipe) + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + self._customize_quantizers_float8_current_scaling(fwd, recipe) + elif recipe.float8_block_scaling(): + self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + + def _customize_quantizers_float8_current_scaling(self, fwd, recipe): + if fwd: + for idx in (tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_WEIGHT, + tex.FP8FwdTensors.GEMM2_INPUT, tex.FP8FwdTensors.GEMM2_WEIGHT): + if idx in self.quantizers["scaling_fwd"]: + q = self.quantizers["scaling_fwd"][idx] + if hasattr(recipe, 'fp8_quant_fwd_inp'): + q.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + q.amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + + def _customize_quantizers_float8_blockwise_scaling(self, fwd, recipe): + pass + + def forward( + self, + inp: torch.Tensor, + is_first_microbatch: Optional[bool] = None, + fp8_output: bool = False, + **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + with self.prepare_forward(inp, num_gemms=2): + ( + fc1_input_q, fc1_weight_q, + fc2_input_q, fc2_weight_q, + fc2_grad_output_q, fc1_grad_output_q, + fc1_grad_input_q, fc1_grad_weight_q, + ) = self._get_quantizers() + + out = _LayerNormMLPLite.apply( + inp, + self.layer_norm_weight, + self.layer_norm_bias, + self.fc1_weight, + self.fc1_bias if self.use_bias else None, + self.fc2_weight, + self.fc2_bias if self.use_bias else None, + self.eps, + self.fp8, + fc1_input_q, + fc1_weight_q, + fc2_input_q, + fc2_weight_q, + fc2_grad_output_q, + fc1_grad_output_q, + fc1_grad_input_q, + fc1_grad_weight_q, + self.activation_dtype, + self.return_layernorm_output, + self.normalization, + self.zero_centered_gamma, + self.activation, + self.activation_params, + torch.is_grad_enabled(), + self, + is_first_microbatch, + ) + + # Match full-build LayerNormMLP's return contract. Bias is already + # folded into `out` by the fused kernel, so return a zero placeholder + # for the bias tuple slot — TransformerLayer's bias_dropout_add will + # add zero, preserving correctness. + if self.return_bias: + primary = out[0] if self.return_layernorm_output else out + bias_placeholder = torch.zeros( + primary.shape[-1], dtype=primary.dtype, device=primary.device + ) + if self.return_layernorm_output: + return primary, bias_placeholder, out[1] + return primary, bias_placeholder + return out diff --git a/transformer_engine/pytorch/_lite/gemm.py b/transformer_engine/pytorch/_lite/gemm.py new file mode 100644 index 000000000..aa14f41aa --- /dev/null +++ b/transformer_engine/pytorch/_lite/gemm.py @@ -0,0 +1,1062 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""GEMM operations -- multi-backend with AITER, Triton, and PyTorch fallback. + +Backend priority (configurable via NVTE_LITE_GEMM_BACKEND env var): +1. PyTorch torch._scaled_mm (default) -- hipBLASLt-backed on ROCm for FP8, + falls back to AITER for FP8 cases _scaled_mm can't serve +2. AITER Triton GEMM -- dedicated Triton kernels for FP8 and BF16/FP16 +3. AITER CK GEMM -- CK/ASM kernels for FP8 precisions + +Set NVTE_LITE_GEMM_BACKEND to override: + "pytorch" -- prefer torch._scaled_mm for FP8 (hipBLASLt-backed on ROCm); + fall back to AITER for FP8 cases _scaled_mm can't serve + (wgrad with per-row scale on reduction axis, block scaled, + unsupported dtype combos); dequantize + torch.matmul as a + last resort for non-FP8 or when AITER is unavailable. (default) + "triton" -- prefer AITER Triton GEMM kernels + "ck" -- prefer AITER CK kernels +""" + +import os +import torch +import torch.nn.functional as F + +from .aiter_utils import is_aiter_available, get_aiter + +# FP8 dtypes for detection +_FP8_DTYPES = ( + torch.float8_e4m3fn, torch.float8_e5m2, + torch.float8_e4m3fnuz, torch.float8_e5m2fnuz, +) + +_GEMM_BACKEND = os.environ.get("NVTE_LITE_GEMM_BACKEND", "pytorch").lower() + +_LITE_DIAG = os.environ.get("NVTE_LITE_DIAG", "0") != "0" + +from collections import Counter as _GemmCounter +_GEMM_CALLS = _GemmCounter() +_GEMM_BACKEND_PRINTED = False +_CK_FAIL_DIAG_PRINTS = 0 +_SCALED_MM_FAIL_DIAG_PRINTS = 0 +_SCALED_MM_FAIL_DIAG_MAX = 5 + + +def _log_scaled_mm_fail(reason, A, transA, B, transB, x=None, w=None, + x_scale=None, w_scale=None, M=None, N=None, + effective_transA=None, effective_transB=None, err=None): + """Log the first _SCALED_MM_FAIL_DIAG_MAX rejections from _try_scaled_mm. + + Gated by NVTE_LITE_DIAG. Captures shapes, dtypes, scale layout, and the + transpose-only state of the operands so we can classify the fallthrough + pattern (per-row on reduction axis vs shape mismatch vs library reject). + """ + if not _LITE_DIAG: + return + global _SCALED_MM_FAIL_DIAG_PRINTS + if _SCALED_MM_FAIL_DIAG_PRINTS >= _SCALED_MM_FAIL_DIAG_MAX: + return + _SCALED_MM_FAIL_DIAG_PRINTS += 1 + + def _fmt_scale(s): + if s is None: + return "None" + return f"shape={tuple(s.shape)} numel={s.numel()} dtype={s.dtype}" + + def _fmt_operand(t, name): + if t is None: + return f"{name}=None" + trans_only = _is_transpose_only(t) + return (f"{name}: shape={tuple(t.shape)} " + f"dtype={getattr(t, 'dtype', '?')} " + f"transpose_only={trans_only}") + + bits = [ + f"[LITE-SCALED-MM-FAIL #{_SCALED_MM_FAIL_DIAG_PRINTS}] reason={reason}", + _fmt_operand(A, "A") + f" transA={transA}", + _fmt_operand(B, "B") + f" transB={transB}", + ] + if effective_transA is not None or effective_transB is not None: + bits.append( + f"eff_transA={effective_transA} eff_transB={effective_transB}" + ) + if x is not None: + bits.append( + f"x: shape={tuple(x.shape)} dtype={x.dtype} " + f"stride_last={x.stride(-1)}" + ) + if w is not None: + bits.append( + f"w: shape={tuple(w.shape)} dtype={w.dtype} " + f"stride_last={w.stride(-1)}" + ) + bits.append(f"x_scale: {_fmt_scale(x_scale)}") + bits.append(f"w_scale: {_fmt_scale(w_scale)}") + if M is not None or N is not None: + bits.append(f"M={M} N={N}") + if err is not None: + msg = str(err) + if len(msg) > 200: + msg = msg[:200] + "..." + bits.append(f"err={type(err).__name__}: {msg}") + print(" | ".join(bits), flush=True) + +def _gemm_bump(tag): + if not _LITE_DIAG: + return + global _GEMM_BACKEND_PRINTED + if not _GEMM_BACKEND_PRINTED: + _GEMM_BACKEND_PRINTED = True + print(f"[LITE-GEMM-BACKEND] {_GEMM_BACKEND}", flush=True) + _GEMM_CALLS[tag] += 1 + if sum(_GEMM_CALLS.values()) % 500 == 0: + print(f"[LITE-GEMM] {dict(_GEMM_CALLS)}", flush=True) + + +def _resolve_output_dtype(output_dtype): + """Normalize output_dtype (TE_DType | torch.dtype | None) to torch.dtype. + + `cpp_extensions/gemm.py` forwards the user-provided `out_dtype` as a + `TE_DType` enum. The pure-Python path needs a `torch.dtype` to cast the + result; the full build resolves this inside cuBLAS. Returns None when the + caller did not specify an output dtype (the full build uses the D operand + dtype in that case). + """ + if output_dtype is None or isinstance(output_dtype, torch.dtype): + return output_dtype + try: + from transformer_engine.pytorch.triton_kernels.common import ( + te_dtype_to_torch_dtype, + ) + return te_dtype_to_torch_dtype(output_dtype) + except (ImportError, KeyError): + return None + + +def _dequantize_from_transpose(tensor): + """Dequantize a Float8Tensor when only its _transpose is available. + + Columnwise-only tensors (wgrad path) have _data=None; the standard + dequantize() raises NotImplementedError. We dequantize from _transpose + manually: reinterpret uint8 as FP8 dtype, transpose back to logical + shape, and multiply by the per-row or per-tensor scale. + + Transpose is done on the uint8 view (fast byte-level copy) before + reinterpreting as FP8, so the materialization doesn't go through the + slower float8_copy_kernel_cuda path. + """ + t = tensor._transpose + u8 = t if t.dtype == torch.uint8 else t.view(torch.uint8) + # _transpose is [K, d0, d1, ...] (last dim moved to front); invert to + # the logical [d0, d1, ..., K] on uint8, then view as FP8 and cast once. + if u8.ndim == 2: + u8_logical = u8.t().contiguous() + else: + inv_perm = list(range(1, u8.ndim)) + [0] + u8_logical = u8.permute(*inv_perm).contiguous() + if hasattr(tensor, '_fp8_dtype'): + from transformer_engine.pytorch._lite.quantize import _te_dtype_to_torch_fp8 + fp8_logical = u8_logical.view(_te_dtype_to_torch_fp8(tensor._fp8_dtype)) + else: + fp8_logical = u8_logical + logical = fp8_logical.to(torch.bfloat16) + scale_inv = tensor._scale_inv + if scale_inv.numel() == 1: + return logical * scale_inv + # Per-row scale shape (M_flat,); reshape to match logical's leading dims + leading_numel = 1 + for d in logical.shape[:-1]: + leading_numel *= d + if scale_inv.numel() == leading_numel: + return logical * scale_inv.reshape(*logical.shape[:-1], 1) + return logical * scale_inv.reshape(-1, 1) + + +def _dequantize_if_needed(tensor): + """Dequantize FP8/quantized tensor to BF16 for matmul.""" + if _is_mxfp8(tensor): + return tensor.dequantize(dtype=torch.bfloat16) + if _is_blockwise_fp8(tensor): + return tensor.dequantize(dtype=torch.bfloat16) + # Columnwise-only Float8Tensor: _data deleted, must dequantize from _transpose + if (hasattr(tensor, '_data') and tensor._data is None + and hasattr(tensor, '_transpose') and tensor._transpose is not None): + return _dequantize_from_transpose(tensor) + if hasattr(tensor, 'dequantize'): + return tensor.dequantize() + if isinstance(tensor, torch.Tensor) and tensor.dtype in _FP8_DTYPES: + return tensor.to(torch.bfloat16) + return tensor + + +def _is_quantized(tensor): + """Check if tensor is a quantized type with FP8 data.""" + if hasattr(tensor, '_data') and hasattr(tensor, '_scale_inv'): + return True + if _is_mxfp8(tensor): + return True + return False + + +def _get_raw_data(tensor): + """Extract raw data and scale from a quantized tensor, or return tensor as-is.""" + if _is_blockwise_fp8(tensor): + data, scale = _get_blockwise_data(tensor, need_rowwise=True) + return data, scale + if _is_mxfp8(tensor): + # MXFP8 scales are E8M0 uint8 — not directly usable as float scales + # for AITER GEMM dispatch. Return data only; GEMM will dequantize. + return tensor._rowwise_data, None + if hasattr(tensor, '_data') and hasattr(tensor, '_scale_inv'): + data = tensor._data + if data is None: + # Columnwise-only tensor: _data was deleted by update_usage. + # Use transpose if available, otherwise dequantize. + if hasattr(tensor, '_transpose') and tensor._transpose is not None: + data = tensor._transpose + else: + return tensor, None + # Float8Tensor stores FP8 bit patterns as uint8 — reinterpret as the + # actual FP8 dtype so downstream Triton kernels see the correct type. + if data.dtype == torch.uint8 and hasattr(tensor, '_fp8_dtype'): + from transformer_engine.pytorch._lite.quantize import _te_dtype_to_torch_fp8 + data = data.view(_te_dtype_to_torch_fp8(tensor._fp8_dtype)) + return data, tensor._scale_inv + return tensor, None + + +def _is_transpose_only(tensor): + """Return True if tensor has _data=None but _transpose set (columnwise-only).""" + return (hasattr(tensor, '_data') and tensor._data is None + and hasattr(tensor, '_transpose') and tensor._transpose is not None) + + +def _fp8_transposed_operand(tensor, data_2d): + """Return the transposed 2D FP8 operand for a GEMM, preferring the tensor's + _transpose cache to avoid materializing a fresh transposed copy. + + data_2d is the (already-flattened) rowwise [M, K] FP8 view of tensor._data. + If tensor._transpose is populated and valid, we reshape it to [K, M] and + return — the byte layout already matches what data_2d.t().contiguous() + would produce, at zero copy cost. + + When no transpose cache is available, we transpose via the uint8 view + instead of the fp8 view. Same number of bytes copied, but dispatches to + the plain uint8 copy kernel rather than the slow float8_copy_kernel_cuda + that .t().contiguous() on an FP8-dtype tensor hits. + """ + data_buf = getattr(tensor, '_data', None) + trans = getattr(tensor, '_transpose', None) + trans_invalid = getattr(tensor, '_transpose_invalid', True) + # Only use the cache when data_2d came from _data (i.e. the tensor has + # both buffers). If _data is None, data_2d came from _transpose already + # and we actually need to undo that layout via an explicit copy. + can_use_cache = ( + data_buf is not None and trans is not None and not trans_invalid + ) + fp8_dtype = data_2d.dtype + if can_use_cache: + t = trans + if t.ndim > 2: + t = t.reshape(t.shape[0], -1) + if t.dtype == torch.uint8: + t = t.view(fp8_dtype) + return t + # Fallback: uint8-level transpose to avoid float8_copy_kernel_cuda. + d = data_2d + if d.dtype != torch.uint8: + d = d.view(torch.uint8) + return d.t().contiguous().view(fp8_dtype) + + +# --------------------------------------------------------------------------- +# AITER CK GEMM dispatch +# --------------------------------------------------------------------------- + +def _is_per_row_scaled(scale): + """Check if scale tensor is per-row (one scale per token/row). + + Per-row scales have shape (M,) or (M, 1) — 1D with numel > 1. + Block scales are 2D with shape (ceil(M/block), ceil(N/block)). + """ + return (scale is not None + and scale.numel() > 1 + and scale.ndim == 1) + + +def _is_block_scaled(scale): + """Check if scale tensor indicates block scaling (2D multi-element scale). + + Excludes per-row scales (1D) — those use gemm_a8w8_per_token_scale. + """ + return (scale is not None + and scale.numel() > 1 + and not _is_per_row_scaled(scale)) + + +def _is_fp4(tensor): + """Check if tensor is MXFP4 quantized. + + Discriminates from MXFP8 via _fp4_dtype (MXFP4) vs _fp8_dtype (MXFP8). + """ + return (hasattr(tensor, '_rowwise_data') and + hasattr(tensor, '_fp4_dtype') and + not hasattr(tensor, '_is_2D_scaled') and # exclude Float8Blockwise + tensor._rowwise_data is not None) + + +def _is_mxfp8(tensor): + """Check if tensor is MXFP8 quantized (block-scaled FP8, group_size=32). + + MXFP8 uses _rowwise_data/_rowwise_scale_inv (shared attribute names with + MXFP4), distinguished by _fp8_dtype. No AITER GEMM kernel exists on MI300X; + future MI350 kernel hook is in _aiter_ck_gemm/_aiter_triton_gemm. + """ + return (hasattr(tensor, '_rowwise_data') and + hasattr(tensor, '_fp8_dtype') and + not hasattr(tensor, '_is_2D_scaled') and # exclude Float8Blockwise + not hasattr(tensor, '_data') and # exclude Float8Tensor + tensor._rowwise_data is not None) + + +def _get_fp4_data(tensor): + """Extract FP4 data and scale from MXFP4 tensor.""" + return tensor._rowwise_data, tensor._rowwise_scale_inv + + +def _is_blockwise_fp8(tensor): + """Check if tensor is Float8BlockwiseQTensorStorage (2D block-scaled FP8).""" + return hasattr(tensor, '_is_2D_scaled') and hasattr(tensor, '_data_format') + + +def _get_blockwise_data(tensor, need_rowwise=True): + """Extract data and scale from Float8BlockwiseQTensorStorage. + + Returns (data, scale_inv) for the requested orientation. + For GEMM: A (weight) typically needs columnwise, B (activation) needs rowwise. + """ + if need_rowwise and tensor._rowwise_data is not None: + return tensor._rowwise_data, tensor._rowwise_scale_inv + if not need_rowwise and tensor._columnwise_data is not None: + return tensor._columnwise_data, tensor._columnwise_scale_inv + # Fall back to whatever is available + if tensor._rowwise_data is not None: + return tensor._rowwise_data, tensor._rowwise_scale_inv + return tensor._columnwise_data, tensor._columnwise_scale_inv + + +def _reshape_scale_for_scaled_mm(scale, dim, is_row): + """Reshape a Float8 _scale_inv for torch._scaled_mm. + + - Per-tensor scalar (numel==1): return a 0-dim tensor. hipBLASLt's + per-tensor FP8 kernels (same family full TE uses for DelayedScaling) + are selected by scalar scale shape. Broadcasting a scalar to (dim, 1) + would force the rowwise kernel family, which isn't tuned for + mixed-dtype (E4M3×E5M2) on ROCm — that's the "could not find valid + hipblaslt solution" error for dgrad calls. + - Per-row (numel==dim): return `(dim, 1)` (is_row=True) or `(1, dim)` + (is_row=False), the rowwise convention. + - Anything else: None (caller falls through). + """ + if scale is None: + return None + scale = scale.to(torch.float32) if scale.dtype != torch.float32 else scale + if scale.numel() == 1: + return scale.reshape(()) + if scale.numel() == dim: + flat = scale.reshape(dim).contiguous() + return flat.unsqueeze(1) if is_row else flat.unsqueeze(0) + return None + + +def _try_scaled_mm(A, transA, B, transB, output_dtype): + """FP8×FP8 GEMM via torch._scaled_mm (hipBLASLt-backed on ROCm). + + Matches AITER's NT convention: x=[M,K] (rowwise), w=[N,K] (rowwise), + compute x @ w.T. Uses the same `_fp8_transposed_operand` path that + feeds AITER Triton, so operands are K-innermost by construction. + + Returns the result tensor (with original leading B dims restored), or + None when torch._scaled_mm is unavailable, block-scaled (not supported + here), or rejects the inputs (any RuntimeError falls through). + """ + if not hasattr(torch, '_scaled_mm'): + return None + + # Block-scaled uses a different scale layout — fall through. + if _is_blockwise_fp8(A) or _is_blockwise_fp8(B): + _log_scaled_mm_fail("blockwise_fp8", A, transA, B, transB) + return None + + a_data, a_scale = _get_raw_data(A) + b_data, b_scale = _get_raw_data(B) + + # Resolve NT operand form, same logic as _aiter_triton_gemm. + a_transpose_only = _is_transpose_only(A) + b_transpose_only = _is_transpose_only(B) + effective_transA = transA ^ a_transpose_only + effective_transB = transB ^ b_transpose_only + + x_leading = b_data.shape[:-1] if not b_transpose_only else b_data.shape[1:] + if b_data.ndim > 2: + if b_transpose_only: + b_data = b_data.reshape(b_data.shape[0], -1) + else: + b_data = b_data.reshape(-1, b_data.shape[-1]) + if a_data.ndim > 2: + if a_transpose_only: + a_data = a_data.reshape(a_data.shape[0], -1) + else: + a_data = a_data.reshape(-1, a_data.shape[-1]) + + x = b_data if not effective_transB else _fp8_transposed_operand(B, b_data) + w = a_data if effective_transA else _fp8_transposed_operand(A, a_data) + x_scale = b_scale + w_scale = a_scale + + # Per-row on the REDUCTION axis (wgrad corner) is not supported by + # per-row scaled GEMM kernels — fall through to dequant path. + M = x.shape[0] + N = w.shape[0] + if _is_per_row_scaled(x_scale) and x_scale.numel() != M: + _log_scaled_mm_fail("per_row_on_reduction_x", A, transA, B, transB, + x=x, w=w, x_scale=x_scale, w_scale=w_scale, + M=M, N=N, + effective_transA=effective_transA, + effective_transB=effective_transB) + return None + if _is_per_row_scaled(w_scale) and w_scale.numel() != N: + _log_scaled_mm_fail("per_row_on_reduction_w", A, transA, B, transB, + x=x, w=w, x_scale=x_scale, w_scale=w_scale, + M=M, N=N, + effective_transA=effective_transA, + effective_transB=effective_transB) + return None + + x_scale_2d = _reshape_scale_for_scaled_mm(x_scale, M, is_row=True) + w_scale_2d = _reshape_scale_for_scaled_mm(w_scale, N, is_row=False) + if x_scale_2d is None or w_scale_2d is None: + _log_scaled_mm_fail("scale_shape_mismatch", A, transA, B, transB, + x=x, w=w, x_scale=x_scale, w_scale=w_scale, + M=M, N=N, + effective_transA=effective_transA, + effective_transB=effective_transB) + return None + + # hipBLASLt FP8 kernels require mat1's dims divisible by 16. When the + # tokens count isn't a clean multiple (e.g. 8184 = 2046×4), we pad the + # M axis of x (and its per-row scale) with zeros/ones up to the next + # multiple, run the GEMM, and slice the result back. K-dim misalignment + # hits in the wgrad corner (K = tokens after transpose) — that case also + # has separate per-row-on-reduction issues, so we skip it here and let + # the caller fall through to AITER. + K = x.shape[1] + if K % 16 != 0: + _log_scaled_mm_fail("k_not_div16", A, transA, B, transB, + x=x, w=w, x_scale=x_scale_2d, w_scale=w_scale_2d, + M=M, N=N, + effective_transA=effective_transA, + effective_transB=effective_transB) + return None + + pad_rows = (-M) % 16 + if pad_rows: + x = F.pad(x, (0, 0, 0, pad_rows)) # zero-pad new rows + # Only pad x_scale_2d if it's per-row (shape (M, 1)); a 0-dim + # scalar per-tensor scale applies to every row automatically. + if x_scale_2d.ndim == 2 and x_scale_2d.shape[0] == M: + # Value irrelevant (scale × 0 = 0), just non-NaN/Inf. + x_scale_2d = F.pad(x_scale_2d, (0, 0, 0, pad_rows), value=1.0) + + out_dtype = output_dtype if output_dtype is not None else torch.bfloat16 + + try: + result = torch._scaled_mm( + x, w.t(), + scale_a=x_scale_2d, scale_b=w_scale_2d, + out_dtype=out_dtype, + ) + except (RuntimeError, TypeError) as _sm_err: + _log_scaled_mm_fail("torch._scaled_mm_raised", A, transA, B, transB, + x=x, w=w, x_scale=x_scale_2d, w_scale=w_scale_2d, + M=M, N=N, + effective_transA=effective_transA, + effective_transB=effective_transB, + err=_sm_err) + return None + + if pad_rows: + result = result[:M] + + if len(x_leading) > 1: + result = result.reshape(*x_leading, result.shape[-1]) + return result + + +def _aiter_ck_gemm(aiter, a_data, a_scale, b_data, b_scale, + a_is_fp8, b_is_fp8, transA, transB, + A, B): + """Dispatch to AITER CK/ASM kernels. Returns result tensor or None.""" + _gemm_bump("ck_enter") + try: + # MXFP8: No hardware GEMM on MI300X. Fall through to dequant path. + # TODO(MI350): Add aiter.gemm_mxfp8() dispatch when available. + if _is_mxfp8(A) or _is_mxfp8(B): + _gemm_bump("ck_skip_mxfp8") + return None + + # FP4 × FP4 + if _is_fp4(A) and _is_fp4(B): + if hasattr(aiter, 'gemm_a4w4'): + a4_data, a4_scale = _get_fp4_data(A) + b4_data, b4_scale = _get_fp4_data(B) + M, _ = a4_data.shape + N, _ = b4_data.shape + out = torch.empty(M, N, dtype=torch.bfloat16, device=a4_data.device) + return aiter.gemm_a4w4(a4_data, b4_data, a4_scale, b4_scale, out) + + # Float8Blockwise (2D block-scaled, 128×128 blocks) — always block-scaled + a_is_blockwise = _is_blockwise_fp8(A) + b_is_blockwise = _is_blockwise_fp8(B) + + # FP8 × FP8 + if a_is_fp8 and b_is_fp8: + # Determine layout: Y = X @ W^T + # TE: result = B @ A. transA=True means A is (N,K) weight layout. + # When _get_raw_data returned the transpose (columnwise-only + # tensor, e.g. wgrad inputmat), orientation is already flipped. + a_transpose_only = _is_transpose_only(A) + b_transpose_only = _is_transpose_only(B) + effective_transA = transA ^ a_transpose_only + effective_transB = transB ^ b_transpose_only + + # CK FP8 kernels require 2D; flatten N-D leading dims first so + # subsequent .t() works and scales (which are per flattened row) + # stay aligned. _data is [d0, ..., K]; _transpose is [K, d0, ...]. + x_leading_shape = b_data.shape[:-1] if not b_transpose_only else b_data.shape[1:] + if b_data.ndim > 2: + if b_transpose_only: + b_data = b_data.reshape(b_data.shape[0], -1) + else: + b_data = b_data.reshape(-1, b_data.shape[-1]) + if a_data.ndim > 2: + if a_transpose_only: + a_data = a_data.reshape(a_data.shape[0], -1) + else: + a_data = a_data.reshape(-1, a_data.shape[-1]) + + if b_is_blockwise: + x, x_scale = _get_blockwise_data(B, need_rowwise=not transB) + else: + x = b_data if not effective_transB else _fp8_transposed_operand(B, b_data) + x_scale = b_scale + + if a_is_blockwise: + w, w_scale = _get_blockwise_data(A, need_rowwise=transA) + else: + w = a_data if effective_transA else _fp8_transposed_operand(A, a_data) + w_scale = a_scale + + if (_is_block_scaled(x_scale) or _is_block_scaled(w_scale) + or a_is_blockwise or b_is_blockwise): + # Block-scale FP8 (includes Float8Blockwise) + if hasattr(aiter, 'gemm_a8w8_blockscale'): + _gemm_bump("ck_blockscale") + result = aiter.gemm_a8w8_blockscale(x, w, x_scale, w_scale) + if len(x_leading_shape) > 1: + result = result.reshape(*x_leading_shape, result.shape[-1]) + return result + else: + # Per-tensor or per-row FP8. CK's RowwiseScale kernel accepts + # x_scale (M, 1) and w_scale (1, N) — a scalar broadcasts to + # fill, a per-row vector reshapes in place. Per-row scales on + # the reduction axis (wgrad edge case — scales came from the + # non-transposed tensor) can't use CK; fall through to Triton. + M = x.shape[0] + N = w.shape[0] + x_per_row = x_scale.numel() > 1 + w_per_row = w_scale.numel() > 1 + x_ok = (not x_per_row) or (x_scale.numel() == M) + w_ok = (not w_per_row) or (w_scale.numel() == N) + if x_ok and w_ok and hasattr(aiter, 'gemm_a8w8_CK'): + x_scale_ck = ( + x_scale.expand(M).unsqueeze(1).contiguous() + if not x_per_row + else x_scale.reshape(M, 1).contiguous() + ) + w_scale_ck = ( + w_scale.expand(N).unsqueeze(0).contiguous() + if not w_per_row + else w_scale.reshape(1, N).contiguous() + ) + if x_per_row or w_per_row: + _gemm_bump("ck_per_row") + else: + _gemm_bump("ck_per_tensor") + try: + result = aiter.gemm_a8w8_CK(x, w, x_scale_ck, w_scale_ck) + except RuntimeError as _ck_err: + if _LITE_DIAG: + global _CK_FAIL_DIAG_PRINTS + if _CK_FAIL_DIAG_PRINTS < 5: + _CK_FAIL_DIAG_PRINTS += 1 + print( + f"[LITE-GEMM-CK-FAIL #{_CK_FAIL_DIAG_PRINTS}] " + f"x={tuple(x.shape)}/{x.dtype}/contig={x.is_contiguous()} " + f"w={tuple(w.shape)}/{w.dtype}/contig={w.is_contiguous()} " + f"x_scale_ck={tuple(x_scale_ck.shape)} " + f"w_scale_ck={tuple(w_scale_ck.shape)} " + f"err={type(_ck_err).__name__}: {_ck_err}", + flush=True, + ) + raise + if len(x_leading_shape) > 1: + result = result.reshape(*x_leading_shape, result.shape[-1]) + return result + else: + # Per-row scale on reduction axis — CK can't serve. + _gemm_bump("ck_reject_per_row_reduction_axis") + + elif not a_is_fp8 and b_is_fp8: + if hasattr(aiter, 'gemm_a16w8'): + _gemm_bump("ck_a16w8") + a_mat = _dequantize_if_needed(A) + if transA: + a_mat = a_mat.t() + b_mat = b_data.t() if transB else b_data + return aiter.gemm_a16w8(a_mat, b_mat, b_scale) + _gemm_bump("ck_skip_bf16_fp8_no_kernel") + + elif not a_is_fp8 and not b_is_fp8: + _gemm_bump("ck_skip_bf16_bf16") + + else: + # a_is_fp8 and not b_is_fp8 — no CK branch for this combo + _gemm_bump("ck_skip_fp8_bf16") + + except (RuntimeError, TypeError, AttributeError) as _e: + _gemm_bump(f"ck_exception_{type(_e).__name__}") + return None + + +# --------------------------------------------------------------------------- +# AITER Triton GEMM dispatch (dedicated Triton kernels per precision) +# --------------------------------------------------------------------------- + +def _aiter_triton_gemm(A, transA, B, transB, a_data, a_scale, b_data, b_scale, + a_is_fp8, b_is_fp8): + """Dispatch to AITER's dedicated Triton GEMM kernels. + + Per precision: + FP4×FP4: aiter.ops.triton.gemm_afp4wfp4 + FP8×FP8 per-row: aiter.ops.triton.gemm_a8w8_per_token_scale + FP8×FP8 block-scale: aiter.ops.triton.gemm_a8w8_blockscale + FP8×FP8 per-tensor: aiter.ops.triton.gemm_a8w8 + BF16/FP16: aiter.ops.triton.gemm_a16w16 + All kernels compute Y = X @ W^T (weight is internally transposed). + Returns result tensor or None. + """ + try: + # MXFP8: No Triton MXFP8 GEMM on MI300X. Fall through to dequant path. + # TODO(MI350): Add Triton MXFP8 GEMM kernel dispatch when available. + if _is_mxfp8(A) or _is_mxfp8(B): + return None + + # FP4 × FP4 + if _is_fp4(A) and _is_fp4(B): + from aiter.ops.triton.gemm_afp4wfp4 import ( + gemm_afp4wfp4 as triton_gemm_fp4, + ) + a4_data, a4_scale = _get_fp4_data(A) + b4_data, b4_scale = _get_fp4_data(B) + return triton_gemm_fp4(a4_data, b4_data, a4_scale, b4_scale) + + # Float8Blockwise and standard FP8 layout mapping for Y = X @ W^T + a_is_blockwise = _is_blockwise_fp8(A) + b_is_blockwise = _is_blockwise_fp8(B) + + # When _get_raw_data returned the transpose (_data was None), the + # orientation is already flipped — invert the transpose flag so the + # dispatch logic below picks the right direction. This happens for + # columnwise-only tensors in the wgrad path. + a_transpose_only = _is_transpose_only(A) + b_transpose_only = _is_transpose_only(B) + effective_transA = transA ^ a_transpose_only + effective_transB = transB ^ b_transpose_only + + # Triton FP8 kernels require 2D; flatten N-D leading dims of raw data + # before the transpose dispatch (.t() only works on 2D). + # _data has shape [d0, ..., K] → flatten to [prod(d), K]. + # _transpose has shape [K, d0, ...] → flatten to [K, prod(d)]. + x_leading = b_data.shape[:-1] if not b_transpose_only else b_data.shape[1:] + if b_data.ndim > 2: + if b_transpose_only: + b_data = b_data.reshape(b_data.shape[0], -1) + else: + b_data = b_data.reshape(-1, b_data.shape[-1]) + if a_data.ndim > 2: + if a_transpose_only: + a_data = a_data.reshape(a_data.shape[0], -1) + else: + a_data = a_data.reshape(-1, a_data.shape[-1]) + + if b_is_blockwise: + x, x_scale = _get_blockwise_data(B, need_rowwise=not transB) + else: + x = b_data if not effective_transB else _fp8_transposed_operand(B, b_data) + x_scale = b_scale + + if a_is_blockwise: + w, w_scale = _get_blockwise_data(A, need_rowwise=transA) + else: + w = a_data if effective_transA else _fp8_transposed_operand(A, a_data) + w_scale = a_scale + + if a_is_fp8 and b_is_fp8: + + # AITER Triton a8w8 kernels assume K-innermost on both operands + # (stride[-1] == 1). Non-K-innermost operands are numerically + # correct but ~10-100× slower with no diagnostic. Our + # _fp8_transposed_operand path and the raw _data views should + # both be K-innermost; assert to catch any future drift in the + # _transpose_invalid flag or the cast_transpose output layout. + assert x.stride(-1) == 1, ( + f"lite→AITER Triton a8w8: x must be K-innermost, got strides " + f"{tuple(x.stride())}; shape={tuple(x.shape)}" + ) + assert w.stride(-1) == 1, ( + f"lite→AITER Triton a8w8: w must be K-innermost, got strides " + f"{tuple(w.stride())}; shape={tuple(w.shape)}" + ) + + if _is_per_row_scaled(x_scale) or _is_per_row_scaled(w_scale): + # Per-row (per-token) FP8 — from CurrentScaling fused norm+quant. + # Per-row scales are valid only when they index the kernel's + # non-reduction axis (first dim of x and w). This holds for + # forward (X @ W^T) and dgrad (dY @ W), but NOT wgrad + # (dY^T @ X) where the transposes put per-row scales along + # the reduction axis. Verify scale-axis alignment before + # dispatching to the per-token kernel. + x_scale_valid = (x_scale is None or x_scale.numel() == 1 + or x_scale.numel() == x.shape[0]) + w_scale_valid = (w_scale is None or w_scale.numel() == 1 + or w_scale.numel() == w.shape[0]) + if not (x_scale_valid and w_scale_valid): + _gemm_bump("triton_reject_per_row_reduction_axis") + return None # Let caller fall back to dequantize + bf16 GEMM + from aiter.ops.triton.gemm_a8w8_per_token_scale import ( + gemm_a8w8_per_token_scale as triton_a8w8_pt, + ) + # Kernel expects (M, 1) and (N, 1) shaped scales + if x_scale is not None and x_scale.ndim == 1: + x_scale = x_scale.unsqueeze(1) + if w_scale is not None and w_scale.numel() == 1: + w_scale = w_scale.expand(w.shape[0]).unsqueeze(1) + elif w_scale is not None and w_scale.ndim == 1: + w_scale = w_scale.unsqueeze(1) + _gemm_bump("triton_per_row") + result = triton_a8w8_pt(x, w, x_scale, w_scale) + elif (_is_block_scaled(x_scale) or _is_block_scaled(w_scale) + or a_is_blockwise or b_is_blockwise): + from aiter.ops.triton.gemm_a8w8_blockscale import ( + gemm_a8w8_blockscale as triton_a8w8_bs, + ) + _gemm_bump("triton_blockscale") + result = triton_a8w8_bs(x, w, x_scale, w_scale) + else: + # Per-tensor FP8. gemm_a8w8 indexes the scale pointer by row + # (A) / col (B), so a scalar (1,) scale reads out of bounds + # and produces garbage. Expand to (M,) and (N,) so every + # row/col sees the same per-tensor scale. + from aiter.ops.triton.gemm_a8w8 import ( + gemm_a8w8 as triton_a8w8, + ) + x_scale_exp = x_scale.expand(x.shape[0]).contiguous() + w_scale_exp = w_scale.expand(w.shape[0]).contiguous() + _gemm_bump("triton_per_tensor") + result = triton_a8w8(x, w, x_scale_exp, w_scale_exp) + + # Restore the leading N-D shape from x (B operand) on the result + if len(x_leading) > 1: + result = result.reshape(*x_leading, result.shape[-1]) + return result + + elif not a_is_fp8 and b_is_fp8: + try: + from aiter.ops.triton.gemm_a16w8_blockscale import ( + gemm_a16w8_blockscale as triton_a16w8, + ) + x_hp = _dequantize_if_needed(B) + if transB: + x_hp = x_hp.t().contiguous() + return triton_a16w8(x_hp, w, w_scale) + except ImportError: + pass + + elif not a_is_fp8 and not b_is_fp8: + # Skip FP32 — Triton GEMM only supports BF16/FP16 + a_mat = _dequantize_if_needed(A) + if a_mat.dtype == torch.float32: + return None + + from aiter.ops.triton.gemm_a16w16 import ( + gemm_a16w16 as triton_a16w16, + ) + b_mat = _dequantize_if_needed(B) + x = b_mat if not transB else b_mat.t().contiguous() + w = a_mat if transA else a_mat.t().contiguous() + return triton_a16w16(x, w) + + except (RuntimeError, TypeError, AttributeError, ImportError): + pass + return None + + +# --------------------------------------------------------------------------- +# Unified AITER dispatch with backend selection +# --------------------------------------------------------------------------- + +def _aiter_gemm(A, transA, B, transB, D, quantizer, output_dtype, + bias, bias_type, gelu, gelu_in, grad, + accumulate, alpha): + """Dispatch GEMM to AITER backend selected by NVTE_LITE_GEMM_BACKEND. + + Falls back through: preferred backend -> other backends -> None (PyTorch). + """ + aiter = get_aiter() + if aiter is None: + return None + + a_data, a_scale = _get_raw_data(A) + b_data, b_scale = _get_raw_data(B) + + a_is_fp8 = _is_quantized(A) or (isinstance(a_data, torch.Tensor) and a_data.dtype in _FP8_DTYPES) + b_is_fp8 = _is_quantized(B) or (isinstance(b_data, torch.Tensor) and b_data.dtype in _FP8_DTYPES) + + result = None + + triton_args = (A, transA, B, transB, a_data, a_scale, b_data, b_scale, + a_is_fp8, b_is_fp8) + + if _GEMM_BACKEND == "triton": + result = _aiter_triton_gemm(*triton_args) + if result is None: + result = _aiter_ck_gemm( + aiter, a_data, a_scale, b_data, b_scale, + a_is_fp8, b_is_fp8, transA, transB, A, B, + ) + + else: + # Default "ck" path + result = _aiter_ck_gemm( + aiter, a_data, a_scale, b_data, b_scale, + a_is_fp8, b_is_fp8, transA, transB, A, B, + ) + if result is None: + result = _aiter_triton_gemm(*triton_args) + + if result is None: + _gemm_bump("pytorch_fallback") + return None # Signal caller to use PyTorch fallback + + # --- Post-GEMM epilogues --- + if alpha != 1.0: + result = result * alpha + + bias_grad = torch.Tensor() + if bias is not None and bias.numel() > 0: + if grad: + grad_out = _dequantize_if_needed(B) + bias_grad = grad_out.reshape(-1, grad_out.shape[-1]).sum(dim=0) + else: + result = result + bias + + gelu_input = torch.Tensor() + if gelu and gelu_in is not None: + gelu_in.copy_(result) + gelu_input = gelu_in + result = torch.nn.functional.gelu(result, approximate='tanh') + + if accumulate and D is not None: + D.add_(result) + elif D is not None: + D.copy_(result) + else: + D = result + + if quantizer is not None and hasattr(quantizer, 'quantize'): + D = quantizer.quantize(D) + + return D, bias_grad, gelu_input, None + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def generic_gemm(A, transA, B, transB, D, quantizer, output_dtype, + bias, bias_type, gelu, gelu_in, grad, workspace, workspace_size, + accumulate, use_split_accumulator, + comm_overlap=None, comm_type=None, extra_output=None, + bulk_overlap=False, alpha=1.0, beta=None): + """General matrix-matrix multiply with optional bias, GELU, and accumulation. + + This is the primary GEMM entry point, replacing tex.generic_gemm. + Dispatches to AITER CK/Triton kernels when available, falls back to torch.matmul. + + Backend selection via NVTE_LITE_GEMM_BACKEND env var: + "pytorch" (default), "triton", "ck" + """ + # --- AITER dispatch (all precisions) --- + if _GEMM_BACKEND != "pytorch" and is_aiter_available(): + result = _aiter_gemm( + A, transA, B, transB, D, quantizer, output_dtype, + bias, bias_type, gelu, gelu_in, grad, accumulate, alpha, + ) + if result is not None: + return result[0], result[1], result[2], extra_output + + # --- PyTorch fallback --- + # + # For FP8×FP8, prefer torch._scaled_mm (hipBLASLt-backed on ROCm) — that's + # the same path full TE takes and skips the dequant+matmul round trip. We + # fall through to dequantize + torch.matmul on any exception (unsupported + # scale/layout/dtype combo on this ROCm build). + result = None + if _is_quantized(A) and _is_quantized(B): + _gemm_bump("pytorch_scaled_mm_attempt") + result = _try_scaled_mm( + A, transA, B, transB, _resolve_output_dtype(output_dtype), + ) + if result is not None: + _gemm_bump("pytorch_scaled_mm_ok") + + if result is not None: + # torch._scaled_mm already handled compute and leading-dim restoration; + # skip the dequantize+matmul block and go straight to epilogues. + if alpha != 1.0: + result = result * alpha + + bias_grad = torch.Tensor() + if bias is not None and bias.numel() > 0: + if grad: + grad_out = _dequantize_if_needed(B) + bias_grad = grad_out.reshape(-1, grad_out.shape[-1]).sum(dim=0) + else: + result = result + bias + + gelu_input = torch.Tensor() + if gelu and gelu_in is not None: + gelu_in.copy_(result) + gelu_input = gelu_in + result = torch.nn.functional.gelu(result, approximate='tanh') + + if accumulate and D is not None: + D.add_(result) + elif D is not None: + D.copy_(result) + else: + D = result + + if quantizer is not None and hasattr(quantizer, 'quantize'): + D = quantizer.quantize(D) + + return D, bias_grad, gelu_input, extra_output + + # When backend=="pytorch" and _scaled_mm rejected the call (wgrad with + # per-row scale on the reduction axis, block-scaled, unsupported dtype + # combo, etc.), fall back to AITER before the catastrophically-slow + # dequantize+matmul path. Dequant+matmul on FP8 operands runs 100-1000x + # slower than AITER Triton and turns a few rejected calls into + # multi-minute iterations. + if (_GEMM_BACKEND == "pytorch" and is_aiter_available() + and _is_quantized(A) and _is_quantized(B)): + _gemm_bump("pytorch_aiter_fallback_attempt") + aiter_result = _aiter_gemm( + A, transA, B, transB, D, quantizer, output_dtype, + bias, bias_type, gelu, gelu_in, grad, accumulate, alpha, + ) + if aiter_result is not None: + _gemm_bump("pytorch_aiter_fallback_ok") + return aiter_result[0], aiter_result[1], aiter_result[2], extra_output + + _gemm_bump("pytorch_dequant_matmul") + a = _dequantize_if_needed(A) + b = _dequantize_if_needed(B) + + # cuBLAS column-major: C = op(A) @ op(B) + # In row-major (PyTorch): C_row = B_row @ A_row (reversed operand order) + # Typical "TN" layout: transA=True, transB=False + # A=[out,in] weight -> a.t()=[in,out], B=[batch,in] -> b as-is + # result = b @ a.t() = [batch,in] @ [in,out] = [batch,out] + + # cuBLAS GEMM treats N-D tensors as batched 2D: leading dims of B are + # preserved in the output. torch.matmul with 2D operands doesn't do + # this, so we flatten to 2D, matmul, then restore B's leading dims. + b_leading = b.shape[:-1] # leading dims of B (before transpose) + if a.dim() > 2: + a = a.reshape(-1, a.shape[-1]) + if b.dim() > 2: + b = b.reshape(-1, b.shape[-1]) + + if transA: + a = a.t() + if transB: + b = b.t() + + compute_dtype = torch.bfloat16 + if a.dtype == torch.float32 or b.dtype == torch.float32: + compute_dtype = torch.float32 + elif a.dtype == torch.float16 or b.dtype == torch.float16: + compute_dtype = torch.float16 + + a = a.to(compute_dtype) + b = b.to(compute_dtype) + + result = torch.matmul(b, a) + + # Restore B's leading dimensions in the output (cuBLAS convention) + if len(b_leading) > 1: + result = result.view(*b_leading, result.shape[-1]) + + if alpha != 1.0: + result = result * alpha + + bias_grad = torch.Tensor() + if bias is not None and bias.numel() > 0: + if grad: + grad_out = _dequantize_if_needed(B) + bias_grad = grad_out.reshape(-1, grad_out.shape[-1]).sum(dim=0) + else: + result = result + bias + + gelu_input = torch.Tensor() + if gelu and gelu_in is not None: + gelu_in.copy_(result) + gelu_input = gelu_in + result = torch.nn.functional.gelu(result, approximate='tanh') + + # Honor the caller-requested output dtype. cuBLAS casts to out_dtype in the + # full build; without this cast, an fp32 operand promotes the whole result + # to fp32 and the next module fails set_activation_dtype. + out_torch_dtype = _resolve_output_dtype(output_dtype) + if out_torch_dtype is not None and result.dtype != out_torch_dtype: + result = result.to(out_torch_dtype) + + if accumulate and D is not None: + D.add_(result) + elif D is not None: + D.copy_(result) + else: + D = result + + if quantizer is not None and hasattr(quantizer, 'quantize'): + D = quantizer.quantize(D) + + return D, bias_grad, gelu_input, extra_output + + +# Grouped GEMM (te_general_grouped_gemm) lives in `_lite/grouped_gemm.py`. diff --git a/transformer_engine/pytorch/_lite/grouped_gemm.py b/transformer_engine/pytorch/_lite/grouped_gemm.py new file mode 100644 index 000000000..8a13cf6cb --- /dev/null +++ b/transformer_engine/pytorch/_lite/grouped_gemm.py @@ -0,0 +1,141 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Grouped GEMM operations for MoE-style expert parallelism. + +This module replaces the C++ ``tex.te_general_grouped_gemm`` binding when +TE-lite is active, so Megatron's ``GroupedLinear`` / ``GroupedMLP`` can call +into AITER's Triton GMM kernels without the full TE C++ extension. + +Supported today: + +* BF16 / FP16 grouped GEMM for forward, dgrad, and wgrad — routed through + ``transformer_engine.pytorch.triton_kernels.grouped_gemm.general_grouped_gemm_triton``, + which wraps AITER's ``gmm`` / ``ptgmm`` Triton kernels. + +Not yet supported: + +* FP8 grouped GEMM. AITER's generic GMM family (``gmm``, ``ptgmm``, + ``nptgmm``) is BF16/FP16 only — the ``p``/``np`` prefix is persistent vs + non-persistent kernel, not per-tensor scaling. FP8 grouped expert + compute lives in AITER as a fused MoE op (``aiter.fused_moe``, + ``moe_op_gemm_a8w8_blockscale``) with a different API shape, so a + separate dispatcher will land in Phase 2. +""" + +import torch + +from .gemm import _FP8_DTYPES, _is_quantized + + +def _is_fp8_operand(tensor): + """True if `tensor` is FP8 raw bytes or a TE Float8Tensor wrapper.""" + if tensor is None: + return False + if _is_quantized(tensor): + return True + return getattr(tensor, "dtype", None) in _FP8_DTYPES + + +def _any_fp8(tensor_list): + return tensor_list is not None and any(_is_fp8_operand(t) for t in tensor_list) + + +def te_general_grouped_gemm( + A, transa, B, transb, out, out_dtype, m_splits, bias, bias_dtype, + single_output, pre_gelu_out, grad, workspaces, workspace_size, + accumulate, use_split_accumulator, sm_count, **kwargs, +): + """Grouped GEMM for MoE expert parallelism (lite-mode replacement for + ``tex.te_general_grouped_gemm``). + + Signature matches the C++ binding that ``general_grouped_gemm`` in + ``cpp_extensions/gemm.py`` calls. Adapts to ``general_grouped_gemm_triton``'s + keyword interface by: + + * deriving the ``"TN"``/``"NN"``/``"NT"`` layout string from + ``transa``/``transb`` flags; + * converting ``bias_dtype`` (``TE_DType``) into a ``use_bias`` flag; + * treating a non-empty ``pre_gelu_out`` as ``gelu=True``. + + Mutation contract matches the C++ binding: ``out`` and ``pre_gelu_out`` + are filled in place; only the bias / grad-bias list is returned. + """ + if _any_fp8(A) or _any_fp8(B): + raise NotImplementedError( + "FP8 grouped GEMM is not yet supported in TE-lite. " + "AITER's generic GMM kernels (gmm/ptgmm/nptgmm) are BF16/FP16 only; " + "FP8 expert compute requires the fused-MoE path " + "(aiter.fused_moe / moe_op_gemm_a8w8_blockscale). " + "Run with TE_FP8=0 for now, or wait for Phase 2 of the lite " + "grouped-GEMM dispatcher." + ) + + # Empty-token short-circuit: when MoE token routing sends zero tokens to + # this rank's local expert(s) (common in early training before the + # auxiliary load-balancing loss kicks in), AITER's gmm asserts M > 0. + # That's a legal MoE state from Megatron's side, so handle it here. + # Forward / dgrad outputs are (M, ...) so already empty; wgrad output + # is (G, K, N) and represents zero contribution from this microbatch: + # accumulate=True -> leave existing_out alone (no-op contribution), + # accumulate=False -> zero existing_out so caller sees sane state. + if m_splits is not None and sum(m_splits) == 0: + is_wgrad = transa and not transb and grad + if is_wgrad and not accumulate: + for o in out: + o.zero_() + # bias / grad-bias: forward path returns the input bias list as-is; + # wgrad path would normally return per-group grad-bias tensors, which + # are also zero contribution under M=0. Match the empty-bias case. + return [None] * len(m_splits) if (bias is None or len(bias) == 0) else bias + + try: + from transformer_engine.pytorch.triton_kernels.grouped_gemm import ( + general_grouped_gemm_triton, + ) + except (ImportError, ModuleNotFoundError): + raise NotImplementedError( + "Grouped GEMM in lite mode requires AITER. " + "Install AITER (pip install amd-aiter) or use the standard " + "GEMM path." + ) + + layout = ("T" if transa else "N") + ("T" if transb else "N") + + use_bias = bias is not None and len(bias) > 0 and bias[0].numel() > 0 + + gelu = ( + pre_gelu_out is not None + and len(pre_gelu_out) > 0 + and pre_gelu_out[0].numel() > 0 + ) + + # out_dtype arrives as TE_DType (general_grouped_gemm reassigns via + # TE_DType[out[0].dtype]); convert back to torch.dtype for the Triton + # wrapper, which compares directly against tensor.dtype. + if not isinstance(out_dtype, torch.dtype): + try: + from transformer_engine.pytorch.triton_kernels.common import ( + te_dtype_to_torch_dtype, + ) + out_dtype = te_dtype_to_torch_dtype(out_dtype) + except (ImportError, KeyError): + out_dtype = out[0].dtype + + _, bias_or_grad_bias, _ = general_grouped_gemm_triton( + A, B, out, + quantization_params=None, + out_dtype=out_dtype, + layout=layout, + m_splits=m_splits, + gelu=gelu, + grad=grad, + accumulate=accumulate, + bias=bias if use_bias else None, + use_bias=use_bias, + use_split_accumulator=use_split_accumulator, + single_output=single_output, + ) + return bias_or_grad_bias diff --git a/transformer_engine/pytorch/_lite/misc.py b/transformer_engine/pytorch/_lite/misc.py new file mode 100644 index 000000000..fa1c59364 --- /dev/null +++ b/transformer_engine/pytorch/_lite/misc.py @@ -0,0 +1,11 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Miscellaneous utility functions.""" + + +def get_num_cublas_streams(): + """Get number of compute streams. Returns default 1 in lite mode.""" + return 1 diff --git a/transformer_engine/pytorch/_lite/mori_ep.py b/transformer_engine/pytorch/_lite/mori_ep.py new file mode 100644 index 000000000..21859b87f --- /dev/null +++ b/transformer_engine/pytorch/_lite/mori_ep.py @@ -0,0 +1,1038 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. + +"""MORI Expert Parallelism integration for tealite. + +Wraps MORI's EpDispatchCombineOp to provide distributed expert parallelism +for the MoE pipeline. MORI handles high-performance token dispatch/combine +across GPUs using XGMI (intra-node) and RDMA (inter-node). + +Requires: ``pip install mori`` (or MORI built from source with ROCm 6.4+). + +Inference usage:: + + from transformer_engine.pytorch._lite.mori_ep import ( + mori_ep_available, + init_mori_ep, + MoriExpertParallel, + ) + + init_mori_ep() + ep = MoriExpertParallel(rank=rank, world_size=8, ...) + + recv, recv_w, recv_idx, n = ep.dispatch(tokens, weights, indices) + expert_out = run_experts(recv[:n]) + output, _ = ep.combine(expert_out, recv_w, recv_idx) + ep.reset() + +Training usage (with autograd):: + + state = ep.new_cycle() + recv, recv_w, recv_idx = MoriEPDispatch.apply(tokens, weights, indices, state) + expert_out = run_experts(recv) # normal autograd + weighted = expert_out * recv_w[..., None] # apply routing weights + output, _ = MoriEPCombine.apply(weighted, recv_w, recv_idx, state) + loss = loss_fn(output[:num_tokens]) + loss.backward() # gradients flow through combine → expert → dispatch +""" + +import warnings +from typing import Optional, Tuple + +import torch + +# --------------------------------------------------------------------------- +# Lazy MORI import +# --------------------------------------------------------------------------- +_mori = None +_mori_available: Optional[bool] = None +_mori_shmem_initialized = False + + +def _try_import_mori(): + global _mori, _mori_available + if _mori_available is not None: + return _mori_available + try: + import mori + _mori = mori + _mori_available = True + except ImportError: + _mori_available = False + return _mori_available + + +def mori_ep_available() -> bool: + """Check whether MORI is installed and available for expert parallelism.""" + return _try_import_mori() + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + +def init_mori_ep(process_group_name: str = "default") -> None: + """Initialize MORI shmem from a PyTorch distributed process group. + + Must be called once per process after ``torch.distributed.init_process_group()``. + Safe to call multiple times -- subsequent calls are no-ops. + + Args: + process_group_name: Name of the PyTorch process group to use for + bootstrapping MORI's symmetric memory. Defaults to ``"default"`` + (the WORLD group). + """ + global _mori_shmem_initialized + if _mori_shmem_initialized: + return + + if not _try_import_mori(): + raise RuntimeError( + "MORI is not installed. Install with: pip install mori " + "(or build from source at https://github.com/ROCm/mori)" + ) + + import torch.distributed as dist + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed must be initialized before init_mori_ep(). " + "Call torch.distributed.init_process_group() first." + ) + + _mori.shmem.shmem_torch_process_group_init(process_group_name) + _mori_shmem_initialized = True + + +def finalize_mori_ep() -> None: + """Finalize MORI shmem. Call during process cleanup.""" + global _mori_shmem_initialized + if not _mori_shmem_initialized: + return + _mori.shmem.shmem_finalize() + _mori_shmem_initialized = False + + +def is_mori_ep_initialized() -> bool: + """Return whether MORI shmem has been initialized.""" + return _mori_shmem_initialized + + +# --------------------------------------------------------------------------- +# Routing map conversion +# --------------------------------------------------------------------------- + +def mask_to_index( + routing_map: torch.Tensor, + probs: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert a mask-map routing tensor to the index-map format used by MORI. + + TE's MoE layer supports two routing map formats: + + - **mask**: ``[num_tokens, num_experts]`` binary int32 tensor where 1 means + the token is routed to that expert. + - **index**: ``[num_tokens, topk]`` int32 tensor of selected expert IDs. + + MORI only accepts the index format. This function converts mask → index + and gathers the corresponding routing probabilities. + + Args: + routing_map: Binary mask tensor, shape ``[num_tokens, num_experts]``, + dtype int32. Each row has exactly ``topk`` ones. + probs: Optional probability tensor, shape ``[num_tokens, num_experts]``, + dtype float32. Contains the routing probability for each + token-expert pair. Only the entries where ``routing_map == 1`` + are meaningful. + + Returns: + Tuple of ``(indices, weights)``: + + - ``indices``: Expert indices, shape ``[num_tokens, topk]``, int32. + - ``weights``: Routing weights gathered from *probs* at the selected + positions, shape ``[num_tokens, topk]``, float32. If *probs* is + ``None``, returns uniform weights of 1.0. + + Example:: + + # mask: token 0 → experts 1,3; token 1 → experts 0,2 + mask = torch.tensor([[0,1,0,1],[1,0,1,0]], dtype=torch.int32, device="cuda") + probs = torch.tensor([[0,.3,0,.7],[.5,0,.5,0]], dtype=torch.float32, device="cuda") + indices, weights = mask_to_index(mask, probs) + # indices: [[1,3],[0,2]] weights: [[.3,.7],[.5,.5]] + """ + # nonzero gives sorted (row, col) pairs — rows are in-order, columns + # within each row are ascending, which matches TE's mask-map convention. + nz = routing_map.nonzero(as_tuple=False) # [nnz, 2] + expert_ids = nz[:, 1].to(torch.int32) + + # Determine topk from the mask (number of ones per row, assumed uniform) + num_tokens = routing_map.shape[0] + topk = nz.shape[0] // num_tokens if num_tokens > 0 else 0 + + indices = expert_ids.reshape(num_tokens, topk) + + if probs is not None: + # Gather probabilities at selected positions + weights = probs[nz[:, 0], nz[:, 1]].to(torch.float32).reshape(num_tokens, topk) + else: + weights = torch.ones( + num_tokens, topk, dtype=torch.float32, device=routing_map.device, + ) + + return indices, weights + + +def index_to_mask( + indices: torch.Tensor, + num_experts: int, + weights: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Convert an index-map routing tensor to the mask-map format used by TE. + + Inverse of :func:`mask_to_index`. + + Args: + indices: Expert indices, shape ``[num_tokens, topk]``, int32. + num_experts: Total number of experts. + weights: Optional routing weights, shape ``[num_tokens, topk]``, float32. + + Returns: + Tuple of ``(routing_map, probs)``: + + - ``routing_map``: Binary mask, shape ``[num_tokens, num_experts]``, int32. + - ``probs``: Probability tensor with weights scattered to expert + positions, shape ``[num_tokens, num_experts]``, float32. + ``None`` if *weights* is ``None``. + """ + num_tokens, topk = indices.shape + routing_map = torch.zeros( + num_tokens, num_experts, dtype=torch.int32, device=indices.device, + ) + row_idx = torch.arange(num_tokens, device=indices.device).unsqueeze(1).expand_as(indices) + routing_map[row_idx, indices.long()] = 1 + + probs = None + if weights is not None: + probs = torch.zeros( + num_tokens, num_experts, dtype=torch.float32, device=indices.device, + ) + probs[row_idx, indices.long()] = weights + + return routing_map, probs + + +# --------------------------------------------------------------------------- +# Expert Parallel operator +# --------------------------------------------------------------------------- + +class MoriExpertParallel: + """High-level wrapper around MORI's EP dispatch/combine for tealite MoE. + + This replaces the local permute/unpermute steps with distributed + dispatch/combine when running with expert parallelism across multiple GPUs. + + Args: + rank: Rank of this process in the expert-parallel group. + world_size: Total number of ranks in the expert-parallel group. + hidden_dim: Hidden dimension of token embeddings. + num_experts_per_rank: Number of local experts hosted on each rank. + num_experts_per_token: Number of experts selected per token (top-k). + max_num_inp_token_per_rank: Maximum number of input tokens per rank. + dtype: Data type for dispatch/combine buffers. + kernel_type: MORI kernel type. One of ``"intra_node"`` (default), + ``"inter_node"``, ``"inter_node_v1"``, ``"inter_node_v1_ll"``, + ``"async_ll"``. + block_num: Number of GPU blocks for kernel launch. + warp_num_per_block: Number of warps per GPU block. + gpu_per_node: Number of GPUs per node (for topology). + rdma_block_num: Number of RDMA blocks (inter-node kernels). + quant_type: Quantization mode. ``"none"`` or ``"fp8_direct_cast"``. + """ + + # Map user-friendly names to MORI kernel type enum values + _KERNEL_TYPE_MAP = { + "intra_node": "IntraNode", + "inter_node": "InterNode", + "inter_node_v1": "InterNodeV1", + "inter_node_v1_ll": "InterNodeV1LL", + "async_ll": "AsyncLL", + } + + def __init__( + self, + rank: int, + world_size: int, + hidden_dim: int, + num_experts_per_rank: int, + num_experts_per_token: int, + max_num_inp_token_per_rank: int, + dtype: torch.dtype = torch.bfloat16, + kernel_type: str = "intra_node", + block_num: int = 80, + warp_num_per_block: int = 8, + gpu_per_node: int = 8, + rdma_block_num: int = 0, + quant_type: str = "none", + ): + if not mori_ep_available(): + raise RuntimeError( + "MORI is not installed. Install with: pip install mori" + ) + if not is_mori_ep_initialized(): + raise RuntimeError( + "MORI shmem not initialized. Call init_mori_ep() first." + ) + + self.rank = rank + self.world_size = world_size + self.hidden_dim = hidden_dim + self.num_experts_per_rank = num_experts_per_rank + self.num_experts_per_token = num_experts_per_token + self.max_num_inp_token_per_rank = max_num_inp_token_per_rank + self.dtype = dtype + + # Resolve kernel type + kt_name = self._KERNEL_TYPE_MAP.get(kernel_type) + if kt_name is None: + raise ValueError( + f"Unknown kernel_type '{kernel_type}'. " + f"Expected one of: {list(self._KERNEL_TYPE_MAP.keys())}" + ) + kt_enum = getattr(_mori.ops.EpDispatchCombineKernelType, kt_name) + + use_external_inp_buf = quant_type == "fp8_direct_cast" + + self._config = _mori.ops.EpDispatchCombineConfig( + data_type=dtype, + rank=rank, + world_size=world_size, + hidden_dim=hidden_dim, + scale_dim=0, + scale_type_size=1, + max_token_type_size=torch.tensor([], dtype=torch.float32).element_size(), + max_num_inp_token_per_rank=max_num_inp_token_per_rank, + num_experts_per_rank=num_experts_per_rank, + num_experts_per_token=num_experts_per_token, + warp_num_per_block=warp_num_per_block, + block_num=block_num, + use_external_inp_buf=use_external_inp_buf, + kernel_type=kt_enum, + gpu_per_node=gpu_per_node, + rdma_block_num=rdma_block_num, + quant_type=quant_type, + ) + + self._op = _mori.ops.EpDispatchCombineOp(self._config) + + @property + def num_experts(self) -> int: + """Total number of experts across all ranks.""" + return self.num_experts_per_rank * self.world_size + + def dispatch( + self, + input: torch.Tensor, + weights: torch.Tensor, + indices: torch.Tensor, + scales: Optional[torch.Tensor] = None, + block_num: int = -1, + warp_per_block: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + """Dispatch tokens to expert-owning ranks. + + Args: + input: Token embeddings, shape ``[num_tokens, hidden_dim]``. + weights: Routing weights from the router, shape + ``[num_tokens, num_experts_per_token]``. + indices: Expert indices from the router, shape + ``[num_tokens, num_experts_per_token]``, dtype int32. + scales: Optional per-token scales for quantized paths. + block_num: Override GPU block count for this launch. + warp_per_block: Override warps-per-block for this launch. + + Returns: + Tuple of ``(recv_tokens, recv_weights, recv_indices, num_recv_tokens)``: + + - ``recv_tokens``: Received token embeddings, shape + ``[max_recv, hidden_dim]``. Only the first ``num_recv_tokens`` + rows are valid. + - ``recv_weights``: Routing weights for received tokens. + - ``recv_indices``: Expert indices for received tokens. + - ``num_recv_tokens``: Number of valid received tokens. + """ + if indices.dtype != torch.int32: + indices = indices.to(torch.int32) + + if scales is None: + scales = torch.empty( + input.size(0), 0, dtype=torch.float32, device=input.device, + ) + + out, out_weights, _out_scales, out_indices, total_recv = self._op.dispatch( + input, weights, scales, indices, + block_num=block_num, + warp_per_block=warp_per_block, + ) + + torch.cuda.synchronize() + num_recv = total_recv[0].item() + + return out, out_weights, out_indices, num_recv + + def combine( + self, + input: torch.Tensor, + weights: torch.Tensor, + indices: torch.Tensor, + block_num: int = -1, + warp_per_block: int = -1, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Combine expert outputs back to original ranks. + + Args: + input: Expert output embeddings for tokens this rank processed, + shape ``[num_recv_tokens, hidden_dim]``. + weights: Routing weights returned from :meth:`dispatch`. + indices: Expert indices returned from :meth:`dispatch`. + block_num: Override GPU block count for this launch. + warp_per_block: Override warps-per-block for this launch. + + Returns: + Tuple of ``(output, output_weights)``: + + - ``output``: Combined token embeddings, shape + ``[max_num_inp_token_per_rank, hidden_dim]``. Only the first + ``num_input_tokens`` rows (matching the original input count) + are valid. + - ``output_weights``: Combined routing weights, or ``None``. + """ + if indices.dtype != torch.int32: + indices = indices.to(torch.int32) + + output, output_weights = self._op.combine( + input, weights, indices, + block_num=block_num, + warp_per_block=warp_per_block, + ) + + torch.cuda.synchronize() + return output, output_weights + + # ------------------------------------------------------------------ + # Standard MoE layout (per-expert grouped output) + # ------------------------------------------------------------------ + + _STDMOE_KERNEL_TYPES = {"intra_node", "inter_node_v1_ll"} + + def dispatch_standard_moe( + self, + input: torch.Tensor, + weights: torch.Tensor, + indices: torch.Tensor, + scales: Optional[torch.Tensor] = None, + block_num: int = -1, + rdma_block_num: int = -1, + warp_per_block: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Dispatch tokens and output in per-expert layout for grouped GEMM. + + Unlike :meth:`dispatch` which returns tokens in a flat buffer, + this method arranges received tokens by their destination expert, + producing a layout directly consumable by grouped GEMM:: + + [num_local_experts, max_tokens_per_expert, hidden_dim] + + Each expert's slice contains only the tokens routed to it, with + ``recv_count[e]`` valid rows for expert ``e``. + + Requires MORI built with ``ENABLE_STANDARD_MOE_ADAPT=ON``. + Only supported for ``intra_node`` and ``inter_node_v1_ll`` kernels. + + Args: + input: Token embeddings, shape ``[num_tokens, hidden_dim]``. + weights: Routing weights, shape ``[num_tokens, topk]``. + indices: Expert indices, shape ``[num_tokens, topk]``, int32. + scales: Optional per-token scales for quantized paths. + block_num: Override GPU block count. + rdma_block_num: Override RDMA block count. + warp_per_block: Override warps-per-block. + + Returns: + Tuple of ``(packed_tokens, recv_count, src_info)``: + + - ``packed_tokens``: Per-expert token tensor, shape + ``[num_local_experts, max_tokens_per_expert, hidden_dim]``. + Expert ``e`` has ``recv_count[e]`` valid rows. + - ``recv_count``: Number of valid tokens per expert, shape + ``[num_local_experts]``, int32. + - ``src_info``: Source token provenance metadata, shape + ``[num_local_experts, max_tokens_per_expert]``, int32. + Used internally by :meth:`combine_standard_moe`. + """ + if indices.dtype != torch.int32: + indices = indices.to(torch.int32) + + if scales is None: + scales = torch.empty( + input.size(0), 0, dtype=torch.float32, device=input.device, + ) + + packed_tokens, recv_count, src_info, _ = self._op.dispatch_standard_moe( + input, weights, scales, indices, + block_num=block_num, + rdma_block_num=rdma_block_num, + warp_per_block=warp_per_block, + ) + + torch.cuda.synchronize() + return packed_tokens, recv_count, src_info + + def combine_standard_moe( + self, + expert_output: torch.Tensor, + weights: torch.Tensor, + indices: torch.Tensor, + block_num: int = -1, + rdma_block_num: int = -1, + warp_per_block: int = -1, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Combine expert outputs from per-expert layout back to original ranks. + + Accepts expert output in the same per-expert layout produced by + :meth:`dispatch_standard_moe`:: + + [num_local_experts, max_tokens_per_expert, hidden_dim] + + Requires MORI built with ``ENABLE_STANDARD_MOE_ADAPT=ON``. + Only supported for ``intra_node`` and ``inter_node_v1_ll`` kernels. + + Args: + expert_output: Per-expert output tensor, shape + ``[num_local_experts, max_tokens_per_expert, hidden_dim]``. + weights: Routing weights from the original dispatch, shape + ``[num_tokens, topk]``. + indices: Expert indices from the original dispatch, shape + ``[num_tokens, topk]``, int32. + block_num: Override GPU block count. + rdma_block_num: Override RDMA block count. + warp_per_block: Override warps-per-block. + + Returns: + Tuple of ``(output, output_weights)``: + + - ``output``: Combined token embeddings, shape + ``[max_num_inp_token_per_rank, hidden_dim]``. + - ``output_weights``: ``None`` (standard MoE combine does not + return accumulated weights). + """ + if indices.dtype != torch.int32: + indices = indices.to(torch.int32) + + output, output_weights = self._op.combine_standard_moe( + expert_output, weights, indices, + block_num=block_num, + rdma_block_num=rdma_block_num, + warp_per_block=warp_per_block, + ) + + torch.cuda.synchronize() + return output, output_weights + + def convert_dispatch_to_standard( + self, + dispatch_tokens: torch.Tensor, + dispatch_indices: torch.Tensor, + block_num: int = -1, + warp_per_block: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert flat dispatch output to per-expert layout. + + Takes the flat output from :meth:`dispatch` and rearranges it into + the per-expert layout expected by grouped GEMM. Useful when you want + to use the regular :meth:`dispatch` (which supports all kernel types) + but still need per-expert layout for the expert computation. + + Requires MORI built with ``ENABLE_STANDARD_MOE_ADAPT=ON``. + + Args: + dispatch_tokens: Flat dispatch output, shape ``[max_recv, hidden_dim]``. + dispatch_indices: Expert indices from dispatch output, shape + ``[max_recv, topk]``, int32. + block_num: Override GPU block count. + warp_per_block: Override warps-per-block. + + Returns: + Tuple of ``(packed_tokens, recv_count, src_info)`` — same as + :meth:`dispatch_standard_moe`. + """ + if dispatch_indices.dtype != torch.int32: + dispatch_indices = dispatch_indices.to(torch.int32) + + packed_tokens, recv_count, src_info, _ = self._op.convert_dispatch_output( + dispatch_tokens, dispatch_indices, + block_num=block_num, + warp_per_block=warp_per_block, + ) + + torch.cuda.synchronize() + return packed_tokens, recv_count, src_info + + def convert_standard_to_combine_input( + self, + packed_tokens: torch.Tensor, + src_info: torch.Tensor, + block_num: int = -1, + warp_per_block: int = -1, + ) -> torch.Tensor: + """Convert per-expert layout back to flat layout for :meth:`combine`. + + Takes expert output in per-expert layout and converts it to the flat + layout expected by :meth:`combine`. Useful when you used + :meth:`convert_dispatch_to_standard` for expert computation but + want to use the regular :meth:`combine` (which supports all kernel + types). + + Requires MORI built with ``ENABLE_STANDARD_MOE_ADAPT=ON``. + + Args: + packed_tokens: Per-expert output, shape + ``[num_local_experts, max_tokens_per_expert, hidden_dim]``. + src_info: Source info from :meth:`dispatch_standard_moe` or + :meth:`convert_dispatch_to_standard`. + block_num: Override GPU block count. + warp_per_block: Override warps-per-block. + + Returns: + Flat combine input, shape ``[max_recv, hidden_dim]``. + """ + layout_range = torch.empty(0, dtype=torch.int64, device=packed_tokens.device) + + flat_input = self._op.convert_combine_input( + packed_tokens, src_info, layout_range, + block_num=block_num, + warp_per_block=warp_per_block, + ) + + torch.cuda.synchronize() + return flat_input + + def reset(self) -> None: + """Reset internal state for the next dispatch/combine cycle. + + Must be called after each complete dispatch + combine round. + """ + self._op.reset() + + def new_cycle(self) -> "_MoriEPCycleState": + """Create a new cycle state for use with the flat autograd functions. + + Each forward + backward pass requires one cycle state. The state + coordinates the paired dispatch/combine calls on the underlying MORI + operator across the forward and backward passes. + + Returns: + A :class:`_MoriEPCycleState` to pass to :class:`MoriEPDispatch` + and :class:`MoriEPCombine`. + """ + return _MoriEPCycleState(self) + + def new_std_moe_cycle(self) -> "_MoriStdMoECycleState": + """Create a cycle state for the standard MoE layout autograd functions. + + Like :meth:`new_cycle` but for use with + :class:`MoriEPDispatchStdMoE` and :class:`MoriEPCombineStdMoE`. + + Returns: + A :class:`_MoriStdMoECycleState`. + """ + return _MoriStdMoECycleState(self) + + def dispatch_and_combine( + self, + input: torch.Tensor, + weights: torch.Tensor, + indices: torch.Tensor, + expert_fn, + scales: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Run a full dispatch -> expert compute -> combine cycle. + + Convenience method that chains :meth:`dispatch`, expert computation, + and :meth:`combine` together. Not differentiable -- use + :class:`MoriEPDispatch` and :class:`MoriEPCombine` for training. + + Args: + input: Token embeddings, shape ``[num_tokens, hidden_dim]``. + weights: Routing weights, shape ``[num_tokens, num_experts_per_token]``. + indices: Expert indices, shape ``[num_tokens, num_experts_per_token]``. + expert_fn: Callable that takes ``(tokens, indices, num_tokens)`` and + returns expert output of shape ``[num_tokens, hidden_dim]``. + scales: Optional per-token scales. + + Returns: + Tuple of ``(output, output_weights)`` from :meth:`combine`. + """ + recv_tokens, recv_weights, recv_indices, num_recv = self.dispatch( + input, weights, indices, scales=scales, + ) + + expert_output = expert_fn(recv_tokens, recv_indices, num_recv) + + output, output_weights = self.combine( + expert_output, recv_weights, recv_indices, + ) + + self.reset() + return output, output_weights + + +# --------------------------------------------------------------------------- +# Autograd support for training +# --------------------------------------------------------------------------- + +class _MoriEPCycleState: + """Shared state between dispatch and combine within one forward+backward pass. + + MORI's dispatch and combine are stateful and paired -- you must call + dispatch before combine on the same operator, then reset. This state + object coordinates that pairing across the forward pass and again across + the backward pass (where the roles are reversed). + + Lifecycle:: + + Forward: dispatch(fwd) → expert_fn → combine(fwd) → reset + Backward: dispatch(bwd) → expert.bwd → combine(bwd) → reset + ↑ in combine.backward ↑ in dispatch.backward + """ + + def __init__(self, ep: MoriExpertParallel): + self.ep = ep + # Saved from forward dispatch for backward combine + self.fwd_weights: Optional[torch.Tensor] = None + self.fwd_indices: Optional[torch.Tensor] = None + self.fwd_num_input: int = 0 + self.fwd_num_recv: int = 0 + # Saved from backward dispatch (in combine.backward) for backward combine + # (in dispatch.backward) + self.bwd_recv_weights: Optional[torch.Tensor] = None + self.bwd_recv_indices: Optional[torch.Tensor] = None + self.bwd_num_recv: int = 0 + + +class MoriEPDispatch(torch.autograd.Function): + """Autograd-aware MORI EP dispatch. + + Forward: dispatches tokens to expert-owning ranks. + Backward: combines gradients back from expert ranks (completing the + backward MORI cycle started by :class:`MoriEPCombine`'s backward). + + Usage:: + + state = ep.new_cycle() + recv_tokens, recv_weights, recv_indices = MoriEPDispatch.apply( + input, weights, indices, state, + ) + """ + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + weights: torch.Tensor, + indices: torch.Tensor, + state: _MoriEPCycleState, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if indices.dtype != torch.int32: + indices = indices.to(torch.int32) + + scales = torch.empty( + input.size(0), 0, dtype=torch.float32, device=input.device, + ) + + out, out_w, _out_s, out_idx, total_recv = state.ep._op.dispatch( + input, weights, scales, indices, + ) + torch.cuda.synchronize() + num_recv = total_recv[0].item() + + # Save routing info for backward + state.fwd_weights = weights.detach() + state.fwd_indices = indices.detach() + state.fwd_num_input = input.shape[0] + state.fwd_num_recv = num_recv + + ctx.state = state + + # Return only valid rows -- clone to decouple from MORI's internal buffers + return ( + out[:num_recv].clone(), + out_w[:num_recv].clone(), + out_idx[:num_recv].clone(), + ) + + @staticmethod + def backward( + ctx, + grad_recv_tokens: torch.Tensor, + grad_recv_weights: torch.Tensor, + grad_recv_indices: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], None, None, None]: + """Complete the backward MORI cycle: combine gradients back to source ranks. + + The backward dispatch (which sent grad_output to expert ranks) was + already initiated by :meth:`MoriEPCombine.backward`. Now we combine + the gradients that flowed through ``expert.backward`` back to the + original token-owning ranks. + """ + state = ctx.state + + output, _ = state.ep._op.combine( + grad_recv_tokens, + state.bwd_recv_weights, + state.bwd_recv_indices, + ) + torch.cuda.synchronize() + state.ep._op.reset() + + grad_input = output[:state.fwd_num_input] + return grad_input, None, None, None + + +class MoriEPCombine(torch.autograd.Function): + """Autograd-aware MORI EP combine. + + Forward: combines expert outputs back to original ranks and resets + the forward MORI cycle. + Backward: dispatches gradients to expert-owning ranks (starting the + backward MORI cycle that :class:`MoriEPDispatch`'s backward completes). + + Usage:: + + output, output_weights = MoriEPCombine.apply( + expert_output, recv_weights, recv_indices, state, + ) + """ + + @staticmethod + def forward( + ctx, + expert_output: torch.Tensor, + recv_weights: torch.Tensor, + recv_indices: torch.Tensor, + state: _MoriEPCycleState, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if recv_indices.dtype != torch.int32: + recv_indices = recv_indices.to(torch.int32) + + output, output_w = state.ep._op.combine( + expert_output, recv_weights, recv_indices, + ) + torch.cuda.synchronize() + state.ep._op.reset() # forward cycle complete + + ctx.state = state + num_input = state.fwd_num_input + + # Clone the valid portion to decouple from MORI buffers + return ( + output[:num_input].clone(), + output_w[:num_input].clone() if output_w is not None else None, + ) + + @staticmethod + def backward( + ctx, + grad_output: torch.Tensor, + grad_output_weights: Optional[torch.Tensor], + ) -> Tuple[Optional[torch.Tensor], None, None, None]: + """Start the backward MORI cycle: dispatch gradients to expert ranks. + + This sends ``grad_output`` (on the token-originating ranks) to the + expert-owning ranks, using the same routing as the forward dispatch. + The dispatched gradients then flow through ``expert.backward`` via + normal autograd, and finally :meth:`MoriEPDispatch.backward` combines + them back. + """ + state = ctx.state + + # Dispatch gradients using the same routing as forward + scales = torch.empty( + grad_output.size(0), 0, dtype=torch.float32, device=grad_output.device, + ) + out, out_w, _out_s, out_idx, total_recv = state.ep._op.dispatch( + grad_output, + state.fwd_weights, + scales, + state.fwd_indices, + ) + torch.cuda.synchronize() + bwd_num_recv = total_recv[0].item() + + # Save for dispatch.backward to complete the backward cycle + state.bwd_recv_weights = out_w[:bwd_num_recv] + state.bwd_recv_indices = out_idx[:bwd_num_recv] + state.bwd_num_recv = bwd_num_recv + + grad_expert_output = out[:bwd_num_recv].clone() + return grad_expert_output, None, None, None + + +# --------------------------------------------------------------------------- +# Standard MoE autograd (per-expert layout) +# --------------------------------------------------------------------------- + +class _MoriStdMoECycleState: + """Shared state for standard MoE dispatch/combine autograd cycle. + + Like :class:`_MoriEPCycleState` but for the standard MoE layout path + where tokens are arranged per-expert. + + Lifecycle:: + + Forward: dispatch_standard_moe → expert_fn → combine_standard_moe → reset + Backward: dispatch_standard_moe → expert.bwd → combine_standard_moe → reset + """ + + def __init__(self, ep: MoriExpertParallel): + self.ep = ep + self.fwd_weights: Optional[torch.Tensor] = None + self.fwd_indices: Optional[torch.Tensor] = None + self.fwd_num_input: int = 0 + self.fwd_recv_count: Optional[torch.Tensor] = None + self.fwd_src_info: Optional[torch.Tensor] = None + # Backward state + self.bwd_recv_count: Optional[torch.Tensor] = None + self.bwd_src_info: Optional[torch.Tensor] = None + + +class MoriEPDispatchStdMoE(torch.autograd.Function): + """Autograd-aware MORI EP dispatch with standard MoE per-expert layout. + + Forward: dispatches tokens and arranges output as + ``[num_local_experts, max_tokens_per_expert, hidden_dim]``. + Backward: combines gradients using :meth:`combine_standard_moe`. + + Usage:: + + state = ep.new_std_moe_cycle() + packed, recv_count, src_info = MoriEPDispatchStdMoE.apply( + input, weights, indices, state, + ) + # packed: [num_local_experts, max_tokens_per_expert, hidden_dim] + # recv_count: [num_local_experts] -- valid tokens per expert + """ + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + weights: torch.Tensor, + indices: torch.Tensor, + state: _MoriStdMoECycleState, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if indices.dtype != torch.int32: + indices = indices.to(torch.int32) + + scales = torch.empty( + input.size(0), 0, dtype=torch.float32, device=input.device, + ) + + packed, recv_count, src_info, _ = state.ep._op.dispatch_standard_moe( + input, weights, scales, indices, + ) + torch.cuda.synchronize() + + state.fwd_weights = weights.detach() + state.fwd_indices = indices.detach() + state.fwd_num_input = input.shape[0] + state.fwd_recv_count = recv_count.clone() + state.fwd_src_info = src_info.clone() + + ctx.state = state + + return packed.clone(), recv_count.clone(), src_info.clone() + + @staticmethod + def backward( + ctx, + grad_packed: torch.Tensor, + grad_recv_count: torch.Tensor, + grad_src_info: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], None, None, None]: + state = ctx.state + + output, _ = state.ep._op.combine_standard_moe( + grad_packed, + state.bwd_recv_count, # dummy -- combine uses internal state + state.bwd_src_info, # not directly used but kept for consistency + ) + torch.cuda.synchronize() + state.ep._op.reset() + + grad_input = output[:state.fwd_num_input] + return grad_input, None, None, None + + +class MoriEPCombineStdMoE(torch.autograd.Function): + """Autograd-aware MORI EP combine with standard MoE per-expert layout. + + Forward: combines expert outputs from per-expert layout back to + original ranks and resets the cycle. + Backward: dispatches gradients using :meth:`dispatch_standard_moe`. + + Usage:: + + output, _ = MoriEPCombineStdMoE.apply( + expert_output, weights, indices, state, + ) + """ + + @staticmethod + def forward( + ctx, + expert_output: torch.Tensor, + weights: torch.Tensor, + indices: torch.Tensor, + state: _MoriStdMoECycleState, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if indices.dtype != torch.int32: + indices = indices.to(torch.int32) + + output, output_w = state.ep._op.combine_standard_moe( + expert_output, weights, indices, + ) + torch.cuda.synchronize() + state.ep._op.reset() + + ctx.state = state + num_input = state.fwd_num_input + + return ( + output[:num_input].clone(), + output_w[:num_input].clone() if output_w is not None else None, + ) + + @staticmethod + def backward( + ctx, + grad_output: torch.Tensor, + grad_output_weights: Optional[torch.Tensor], + ) -> Tuple[Optional[torch.Tensor], None, None, None]: + state = ctx.state + + scales = torch.empty( + grad_output.size(0), 0, dtype=torch.float32, device=grad_output.device, + ) + packed, recv_count, src_info, _ = state.ep._op.dispatch_standard_moe( + grad_output, + state.fwd_weights, + scales, + state.fwd_indices, + ) + torch.cuda.synchronize() + + state.bwd_recv_count = recv_count + state.bwd_src_info = src_info + + return packed.clone(), None, None, None diff --git a/transformer_engine/pytorch/_lite/multi_tensor.py b/transformer_engine/pytorch/_lite/multi_tensor.py new file mode 100644 index 000000000..4ee4cca1d --- /dev/null +++ b/transformer_engine/pytorch/_lite/multi_tensor.py @@ -0,0 +1,231 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Multi-tensor operations -- PyTorch-native implementations. + +These replace the fused C++ multi-tensor kernels. Performance is lower due to +per-tensor kernel launches instead of batched execution, but functionality is preserved. +""" + +import torch +import math + + +def multi_tensor_scale(chunk_size, noop_flag, tensor_lists, scale): + """Scale tensors by a scalar, matching common/multi_tensor/scale.cu. + + tensor_lists is [in_list, out_list]. For each pair, writes + out = cast(in * scale, out.dtype). If any element of `in` is non-finite, + sets noop_flag[0] to 1 (the overflow flag). + """ + in_list, out_list = tensor_lists[0], tensor_lists[1] + any_non_finite = False + for src, dst in zip(in_list, out_list): + scaled = src.float() * scale + if not any_non_finite and not torch.isfinite(src).all().item(): + any_non_finite = True + dst.copy_(scaled.to(dst.dtype)) + if any_non_finite and noop_flag is not None and noop_flag.numel() > 0: + noop_flag[0] = 1 + + +def multi_tensor_l2norm(chunk_size, noop_flag, tensor_lists, per_tensor=False): + """Compute L2 norm for a list of tensors. + + Always returns a 2-tuple (total_norm, per_tensor_norms) to match the C++ + contract in csrc/extensions/multi_tensor/l2norm.cpp. When per_tensor is + False, the second tensor is empty. + """ + device = tensor_lists[0][0].device + if per_tensor: + norms = [t.float().norm().item() for t in tensor_lists[0]] + total = math.sqrt(sum(n * n for n in norms)) + return (torch.tensor([total], device=device, dtype=torch.float32), + torch.tensor(norms, device=device, dtype=torch.float32)) + total_sq = 0.0 + for t in tensor_lists[0]: + total_sq += t.float().norm().item() ** 2 + return (torch.tensor([math.sqrt(total_sq)], device=device, dtype=torch.float32), + torch.empty(0, device=device, dtype=torch.float32)) + + +def multi_tensor_unscale_l2norm(chunk_size, noop_flag, tensor_lists, inv_scale, per_tensor=False): + """Compute L2 norm after unscaling (tensors are NOT modified). + + Always returns a 2-tuple (total_norm, per_tensor_norms) to match the C++ + contract. When per_tensor is False, the second tensor is empty. + """ + scale = 1.0 / inv_scale.item() if inv_scale.numel() == 1 else 1.0 / inv_scale + device = tensor_lists[0][0].device + if per_tensor: + norms = [(t.float() * scale).norm().item() for t in tensor_lists[0]] + total = math.sqrt(sum(n * n for n in norms)) + return (torch.tensor([total], device=device, dtype=torch.float32), + torch.tensor(norms, device=device, dtype=torch.float32)) + total_sq = 0.0 + for t in tensor_lists[0]: + total_sq += (t.float() * scale).norm().item() ** 2 + return (torch.tensor([math.sqrt(total_sq)], device=device, dtype=torch.float32), + torch.empty(0, device=device, dtype=torch.float32)) + + +def multi_tensor_adam(chunk_size, noop_flag, tensor_lists, lr, beta1, beta2, eps, + step, adam_w_mode, bias_correction, weight_decay): + """Fused Adam step mirroring common/multi_tensor/adam.cu. + + tensor_lists layout: + 4 lists: [grads, params, exp_avg, exp_avg_sq] + 5 lists: [grads, params, exp_avg, exp_avg_sq, master_params] + + With master_params, Adam math runs in fp32 on master_params and the result + is downcast into params. Without them, math runs on params directly. + adam_w_mode=True is ADAM_MODE_1 (decoupled weight decay), False is + ADAM_MODE_0 (L2 regularization folded into the gradient). + """ + assert len(tensor_lists) in (4, 5), ( + f"multi_tensor_adam expects 4 or 5 tensor lists, got {len(tensor_lists)}" + ) + grads = tensor_lists[0] + params = tensor_lists[1] + exp_avgs = tensor_lists[2] + exp_avg_sqs = tensor_lists[3] + master_params = tensor_lists[4] if len(tensor_lists) == 5 else [None] * len(params) + + bc1 = (1.0 - beta1 ** step) if bias_correction else 1.0 + bc2 = (1.0 - beta2 ** step) if bias_correction else 1.0 + + for g, p, m, v, pm in zip(grads, params, exp_avgs, exp_avg_sqs, master_params): + # Match C++ multi_tensor_apply semantics: the number of elements to + # process is taken from tensor_lists[0] (grads), and the other tensors + # are accessed by raw pointer. So m/v/p/pm may be larger than g — e.g. + # Megatron's distributed optimizer stores m/v for the full vocab while + # passing only this rank's TP shard as the gradient. Operate on flat + # views truncated to g.numel(). + n = g.numel() + g_f = g.reshape(-1).float() + m_view = m.view(-1)[:n] + v_view = v.view(-1)[:n] + p_src_view = (pm if pm is not None else p).view(-1)[:n] + p_f = p_src_view.float() + + if not adam_w_mode and weight_decay != 0.0: + # ADAM_MODE_0 (L2): fold weight decay into the gradient + g_f = g_f + weight_decay * p_f + + m_view.mul_(beta1).add_(g_f.to(m_view.dtype), alpha=1 - beta1) + v_view.mul_(beta2).addcmul_(g_f.to(v_view.dtype), + g_f.to(v_view.dtype), value=1 - beta2) + + denom = (v_view.float() / bc2).sqrt_().add_(eps) + update = (m_view.float() / bc1) / denom + if adam_w_mode and weight_decay != 0.0: + # ADAM_MODE_1 (decoupled weight decay / AdamW) + update = update + weight_decay * p_f + + p_new = p_f - lr * update + + if pm is not None: + pm.view(-1)[:n].copy_(p_new) + p.view(-1)[:n].copy_(p_new.to(p.dtype)) + else: + p_src_view.copy_(p_new.to(p.dtype)) + + +def multi_tensor_adam_param_remainder(*args, **kwargs): + """Adam with parameter remainder (for mixed-precision master weights). + + TODO: Implement when needed. + """ + raise NotImplementedError("multi_tensor_adam_param_remainder not yet implemented in lite mode.") + + +def multi_tensor_adam_fp8(*args, **kwargs): + """Adam with FP8 momentum. + + TODO: Implement when needed. + """ + raise NotImplementedError("multi_tensor_adam_fp8 not yet implemented in lite mode.") + + +def multi_tensor_adam_capturable(*args, **kwargs): + """Adam with CUDA graph support. + + Not applicable in lite mode (no CUDA graph capture for Python ops). + Falls back to standard Adam behavior. + """ + raise NotImplementedError("multi_tensor_adam_capturable not yet implemented in lite mode.") + + +def multi_tensor_adam_capturable_master(*args, **kwargs): + """Adam capturable with FP32 master weights. + + TODO: Implement when needed. + """ + raise NotImplementedError( + "multi_tensor_adam_capturable_master not yet implemented in lite mode." + ) + + +def multi_tensor_sgd(chunk_size, noop_flag, tensor_lists, lr, momentum, dampening, + weight_decay, nesterov, first_run, wd_after_momentum, scale=1.0): + """Fused SGD step mirroring common/multi_tensor/sgd.cu. + + tensor_lists layout: + 3 lists: [grads, weights, momentum_bufs] + 4 lists: [grads, weights, momentum_bufs, weights_fp16_copy] + Math runs in fp32 on upcast copies; updates are written back in the source dtype. + """ + if noop_flag is not None and noop_flag.numel() > 0 and noop_flag.item() != 0: + return + assert len(tensor_lists) in (3, 4), ( + f"multi_tensor_sgd expects 3 or 4 tensor lists, got {len(tensor_lists)}" + ) + grads = tensor_lists[0] + weights = tensor_lists[1] + mom_bufs = tensor_lists[2] + fp16_copies = tensor_lists[3] if len(tensor_lists) == 4 else [None] * len(weights) + + for g, w, mom, w_fp16 in zip(grads, weights, mom_bufs, fp16_copies): + g_f = g.float() * scale + w_f = w.float() + + if weight_decay != 0.0 and not wd_after_momentum: + g_f = g_f + weight_decay * w_f + + if momentum != 0.0: + if first_run: + mom.copy_(g_f.to(mom.dtype)) + else: + mom.mul_(momentum).add_(g_f.to(mom.dtype), alpha=1 - dampening) + mom_f = mom.float() + g_f = g_f + momentum * mom_f if nesterov else mom_f + + if weight_decay != 0.0 and wd_after_momentum: + g_f = g_f + weight_decay * w_f + + w_f = w_f - lr * g_f + w.copy_(w_f.to(w.dtype)) + if w_fp16 is not None: + w_fp16.copy_(w_f.to(w_fp16.dtype)) + + +def multi_tensor_compute_scale_and_scale_inv(amax_list, scale_list, scale_inv_list, + fp8_max, margin=0): + """Compute scale and scale_inv from amax.""" + for amax, scale, scale_inv in zip(amax_list, scale_list, scale_inv_list): + safe_amax = torch.clamp(amax, min=1e-12) + sf = (fp8_max / safe_amax) / (2 ** margin) + scale.copy_(sf) + scale_inv.copy_(1.0 / sf) + + +def multi_tensor_compute_scale_inv_e8m0(*args, **kwargs): + """Compute e8m0 scale_inv from amax (MXFP8/MXFP4 master-weight cast path). + + TODO: Implement when needed. + """ + raise NotImplementedError( + "multi_tensor_compute_scale_inv_e8m0 not yet implemented in lite mode." + ) diff --git a/transformer_engine/pytorch/_lite/norms.py b/transformer_engine/pytorch/_lite/norms.py new file mode 100644 index 000000000..c1873aba5 --- /dev/null +++ b/transformer_engine/pytorch/_lite/norms.py @@ -0,0 +1,709 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Normalization -- AITER Triton, TE Triton, or PyTorch-native fallback. + +Backend priority: + 1. AITER fused norm+quantize (single kernel: RMSNorm/LayerNorm -> FP8 cast) + - Current scaling: rmsnorm2d_fwd_with_dynamicquant (Float8CurrentScalingQuantizer) + Per-row dynamic: computes per-row scale in-kernel, no global amax pass. + Output: FP8 data + yscale(M,) per-row dequant scales. + - Per-tensor static: fused_rms_fp8_per_tensor_static_quant (Float8Quantizer) + - Block scaling: fused_rms_fp8_group_quant (MXFP8Quantizer) + 2. AITER Triton norm kernels (no quantize fusion) + 3. TE Triton kernels (triton_kernels/norms_common.py) + 4. Pure PyTorch fallback + +Fused norm+quantize is used when a compatible quantizer is provided in the +forward pass. Otherwise falls back to norm -> quantizer.quantize() separately. +""" + +import os + +import torch + +from .aiter_utils import is_aiter_available + +_LITE_DIAG = os.environ.get("NVTE_LITE_DIAG", "0") != "0" + +from collections import Counter as _NormCounter +_NORM_CALLS = _NormCounter() + +def _norm_bump(tag): + if not _LITE_DIAG: + return + _NORM_CALLS[tag] += 1 + if sum(_NORM_CALLS.values()) % 500 == 0: + print(f"[LITE-NORM] {dict(_NORM_CALLS)}", flush=True) + +# --------------------------------------------------------------------------- +# Lazy-loaded backends. None = not yet attempted. +# --------------------------------------------------------------------------- + +# AITER Triton norm functions +_aiter_rms_fwd = None +_aiter_rms_bwd = None +_aiter_ln_fwd = None +_aiter_ln_bwd = None +# AITER fused norm+quantize kernels +_aiter_fused_rms_fp8_static = None +_aiter_fused_rms_fp8_group = None +_aiter_fused_rms_dynamic_quant = None # Per-row dynamic: rmsnorm2d_fwd_with_dynamicquant +_aiter_fused_ln_fp8_static = None # LayerNorm variant (if available) +_aiter_import_attempted = False + +# TE Triton norm functions (fallback) +_triton_ln_fwd = None +_triton_ln_bwd = None +_triton_rms_fwd = None +_triton_rms_bwd = None +_triton_import_attempted = False + + +def _try_load_aiter_norms(): + """Lazy-import AITER Triton norm kernels. Called once, result cached.""" + global _aiter_rms_fwd, _aiter_rms_bwd, _aiter_ln_fwd, _aiter_ln_bwd + global _aiter_fused_rms_fp8_static, _aiter_fused_rms_fp8_group + global _aiter_fused_rms_dynamic_quant + global _aiter_fused_ln_fp8_static + global _aiter_import_attempted + + if _aiter_import_attempted: + return + _aiter_import_attempted = True + + if not is_aiter_available(): + return + try: + from aiter.ops.triton.rmsnorm import ( + _rmsnorm_forward, + _rmsnorm_backward, + ) + from aiter.ops.triton.norm import ( + _layernorm_forward, + _layernorm_backward, + ) + _aiter_rms_fwd = _rmsnorm_forward + _aiter_rms_bwd = _rmsnorm_backward + _aiter_ln_fwd = _layernorm_forward + _aiter_ln_bwd = _layernorm_backward + except (ImportError, AttributeError): + pass + + # Fused norm+quantize kernels. AITER reorganized these into a `quant/` + # subpackage in newer versions; try the new path first, then the legacy + # top-level path for older installs. + _fused_static = None + _fused_group = None + for _mod_path in ( + "aiter.ops.triton.quant.fused_fp8_quant", + "aiter.ops.triton.fused_fp8_quant", + ): + try: + _mod = __import__(_mod_path, fromlist=[ + "fused_rms_fp8_per_tensor_static_quant", + "fused_rms_fp8_group_quant", + ]) + _fused_static = getattr(_mod, "fused_rms_fp8_per_tensor_static_quant", None) + _fused_group = getattr(_mod, "fused_rms_fp8_group_quant", None) + if _fused_static is not None or _fused_group is not None: + break + except BaseException as _e: + if _LITE_DIAG: + print( + f"[LITE-NORM-DIAG] {_mod_path} import failed: " + f"{type(_e).__name__}: {_e}", + flush=True, + ) + if _fused_static is not None: + _aiter_fused_rms_fp8_static = _fused_static + if _fused_group is not None: + _aiter_fused_rms_fp8_group = _fused_group + + # Fused RMSNorm + per-row dynamic FP8 quantize (current scaling) + try: + from aiter.ops.triton.rmsnorm import ( + rmsnorm2d_fwd_with_dynamicquant, + ) + _aiter_fused_rms_dynamic_quant = rmsnorm2d_fwd_with_dynamicquant + except (ImportError, AttributeError): + pass + + +def _try_load_triton_norms(): + """Lazy-import TE Triton norm kernels. Called once, result cached.""" + global _triton_ln_fwd, _triton_ln_bwd + global _triton_rms_fwd, _triton_rms_bwd + global _triton_import_attempted + + if _triton_import_attempted: + return + + _triton_import_attempted = True + try: + from transformer_engine.pytorch.triton_kernels.norms_common import ( + te_layernorm_fwd_triton, + te_layernorm_bwd_triton, + te_rmsnorm_fwd_triton, + te_rmsnorm_bwd_triton, + ) + _triton_ln_fwd = te_layernorm_fwd_triton + _triton_ln_bwd = te_layernorm_bwd_triton + _triton_rms_fwd = te_rmsnorm_fwd_triton + _triton_rms_bwd = te_rmsnorm_bwd_triton + except (ImportError, ModuleNotFoundError): + pass + + +# --------------------------------------------------------------------------- +# PyTorch fallback implementations +# --------------------------------------------------------------------------- + +def _layernorm_fwd_pytorch(input, weight, bias, eps, zero_centered_gamma): + """LayerNorm forward -- pure PyTorch.""" + if zero_centered_gamma: + weight = weight + 1.0 + mean = input.mean(dim=-1, keepdim=True) + var = input.var(dim=-1, keepdim=True, unbiased=False) + rstdev = torch.rsqrt(var + eps) + output = (input - mean) * rstdev * weight + if bias is not None: + output = output + bias + return output, mean.squeeze(-1), rstdev.squeeze(-1) + + +def _layernorm_bwd_pytorch(grad_output, input, mean, rstdev, weight, + zero_centered_gamma): + """LayerNorm backward -- pure PyTorch.""" + if zero_centered_gamma: + weight = weight + 1.0 + hidden_size = input.shape[-1] + x_hat = (input - mean.unsqueeze(-1)) * rstdev.unsqueeze(-1) + grad_weight = (grad_output * x_hat).sum(dim=tuple(range(grad_output.ndim - 1))) + grad_bias = grad_output.sum(dim=tuple(range(grad_output.ndim - 1))) + dx_hat = grad_output * weight + dvar = (dx_hat * (input - mean.unsqueeze(-1)) * (-0.5) * + (rstdev.unsqueeze(-1) ** 3)).sum(dim=-1, keepdim=True) + dmean = (-dx_hat * rstdev.unsqueeze(-1)).sum(dim=-1, keepdim=True) + \ + dvar * (-2.0 / hidden_size) * (input - mean.unsqueeze(-1)).sum(dim=-1, keepdim=True) + grad_input = dx_hat * rstdev.unsqueeze(-1) + \ + dvar * 2.0 / hidden_size * (input - mean.unsqueeze(-1)) + \ + dmean / hidden_size + return grad_input, grad_weight, grad_bias + + +def _rmsnorm_fwd_pytorch(input, weight, eps, zero_centered_gamma): + """RMSNorm forward -- pure PyTorch.""" + if zero_centered_gamma: + weight = weight + 1.0 + rms = input.float().square().mean(dim=-1, keepdim=True).add_(eps).rsqrt() + output = (input * rms).to(input.dtype) * weight + return output, rms.squeeze(-1) + + +def _rmsnorm_bwd_pytorch(grad_output, input, rstdev, weight, zero_centered_gamma): + """RMSNorm backward -- pure PyTorch.""" + if zero_centered_gamma: + weight = weight + 1.0 + hidden_size = input.shape[-1] + x_hat = input * rstdev.unsqueeze(-1) + grad_weight = (grad_output * x_hat).sum(dim=tuple(range(grad_output.ndim - 1))) + dx_hat = grad_output * weight + grad_input = dx_hat * rstdev.unsqueeze(-1) - \ + (dx_hat * input).sum(dim=-1, keepdim=True) * input * \ + (rstdev.unsqueeze(-1) ** 3) / hidden_size + return grad_input, grad_weight + + +# --------------------------------------------------------------------------- +# AITER adapter functions +# --------------------------------------------------------------------------- + +def _aiter_layernorm_fwd(input_2d, weight, bias, eps, zero_centered_gamma): + """LayerNorm forward via AITER Triton kernel. + + AITER's _layernorm_forward writes into pre-allocated tensors. + """ + if zero_centered_gamma: + weight = weight + 1.0 + M, N = input_2d.shape + y = torch.empty_like(input_2d) + mean = torch.empty(M, dtype=torch.float32, device=input_2d.device) + rstd = torch.empty(M, dtype=torch.float32, device=input_2d.device) + _aiter_ln_fwd(y, input_2d, weight, bias, mean, rstd, eps) + return y, mean, rstd + + +def _aiter_layernorm_bwd(grad_output_2d, input_2d, mean, rstdev, weight, + zero_centered_gamma): + """LayerNorm backward via AITER Triton kernel. + + AITER's _layernorm_backward writes into pre-allocated tensors. + """ + if zero_centered_gamma: + weight = weight + 1.0 + dx = torch.empty_like(input_2d) + dw = torch.empty_like(weight) + db = torch.empty_like(weight) + _aiter_ln_bwd(grad_output_2d, dx, dw, db, input_2d, weight, mean, rstdev) + return dx, dw, db + + +def _aiter_rmsnorm_fwd(input_2d, weight, eps, zero_centered_gamma): + """RMSNorm forward via AITER Triton kernel. + + AITER's _rmsnorm_forward allocates output internally. + """ + if zero_centered_gamma: + weight = weight + 1.0 + y, rsigma = _aiter_rms_fwd(input_2d, weight, eps) + return y, rsigma + + +def _aiter_rmsnorm_bwd(grad_output_2d, input_2d, rstdev, weight, + zero_centered_gamma): + """RMSNorm backward via AITER Triton kernel.""" + if zero_centered_gamma: + weight = weight + 1.0 + dx, dgamma = _aiter_rms_bwd(grad_output_2d, input_2d, weight, rstdev) + return dx, dgamma + + +# --------------------------------------------------------------------------- +# Fused norm+quantize adapters +# --------------------------------------------------------------------------- + +def _is_delayed_scaling_quantizer(quantizer): + """Check if quantizer is Float8Quantizer (delayed per-tensor scaling).""" + # Avoid importing Float8Quantizer at module level (circular import risk). + # Use duck typing: has scale (pre-computed) and amax (to be updated). + return ( + quantizer is not None + and type(quantizer).__name__ == "Float8Quantizer" + and hasattr(quantizer, "scale") + and hasattr(quantizer, "amax") + ) + + +def _is_current_scaling_quantizer(quantizer): + """Check if quantizer is Float8CurrentScalingQuantizer (per-tensor current scaling). + + This quantizer computes amax from the current tensor (no history window). + With per-row fusion, we bypass the per-tensor amax entirely — each row + gets its own dynamic scale computed inside the fused kernel. + """ + return ( + quantizer is not None + and type(quantizer).__name__ == "Float8CurrentScalingQuantizer" + ) + + +def _is_mxfp8_quantizer(quantizer): + """Check if quantizer is MXFP8Quantizer (block scaling).""" + return ( + quantizer is not None + and type(quantizer).__name__ == "MXFP8Quantizer" + ) + + +def _get_fp8_torch_dtype(quantizer): + """Get the torch FP8 dtype from a quantizer's TE dtype.""" + try: + from transformer_engine.pytorch._lite.quantize import _te_dtype_to_torch_fp8 + return _te_dtype_to_torch_fp8(quantizer.dtype) + except (ImportError, AttributeError): + return torch.float8_e4m3fnuz + + +_FUSED_RMS_DIAG_PRINTED = False + +def _try_fused_rmsnorm_quant(input_2d, weight, eps, quantizer, zero_centered_gamma, + orig_shape=None): + """Attempt fused RMSNorm+FP8 quantize via AITER. + + Returns (output, rsigma) on success, or None if fusion not possible. + The output is a QuantizedTensor (Float8Tensor or MXFP8Tensor). + """ + if _LITE_DIAG: + global _FUSED_RMS_DIAG_PRINTED + if not _FUSED_RMS_DIAG_PRINTED: + _FUSED_RMS_DIAG_PRINTED = True + qtype = type(quantizer).__name__ if quantizer is not None else "None" + print( + f"[LITE-NORM-DIAG] first fused-rms attempt: " + f"quantizer_type={qtype}, " + f"fused_dynamic={_aiter_fused_rms_dynamic_quant is not None}, " + f"fused_static={_aiter_fused_rms_fp8_static is not None}, " + f"fused_group={_aiter_fused_rms_fp8_group is not None}", + flush=True, + ) + + if orig_shape is None: + orig_shape = input_2d.shape + + if zero_centered_gamma: + weight = weight + 1.0 + + # Float8CurrentScalingQuantizer: rmsnorm2d_fwd_with_dynamicquant + # Fused RMSNorm + per-row dynamic FP8 quantize in a single kernel. + # Each row computes its own scale in registers — no global amax pass, + # no BF16 intermediate written to HBM. + if _is_current_scaling_quantizer(quantizer) and _aiter_fused_rms_dynamic_quant is not None: + M, N = input_2d.shape + fp8_dtype = _get_fp8_torch_dtype(quantizer) + + # Pre-allocate output tensors for the AITER kernel + out_fp8 = torch.empty(M, N, dtype=fp8_dtype, device=input_2d.device) + yscale = torch.empty(M, dtype=torch.float32, device=input_2d.device) + + _aiter_fused_rms_dynamic_quant(out_fp8, input_2d, yscale, weight, eps) + + # yscale is the per-row dequant scale (multiply FP8 data by yscale to + # recover high-precision values). Wrap in Float8Tensor with vector + # _scale_inv of shape (M,) instead of the usual scalar. + out = quantizer.make_empty( + orig_shape, dtype=input_2d.dtype, device=input_2d.device, + ) + fp8_bytes = out_fp8.view(torch.uint8) + if hasattr(out, '_data'): + out._data.copy_(fp8_bytes.reshape(out._data.shape)) + # Store per-row dequant scales — downstream GEMM dispatch will detect + # scale_inv.numel() > 1 and route to gemm_a8w8_per_token_scale. + if hasattr(out, '_scale_inv'): + out._scale_inv = yscale + # make_empty allocated the transpose buffer (columnwise_usage was set + # on the quantizer) but the fused kernel only writes _data. Mark the + # buffer stale so update_usage/_create_transpose regenerates it from + # _data on demand — otherwise downstream wgrad reads uninitialized + # memory. + if hasattr(out, '_transpose') and out._transpose is not None: + out._transpose_invalid = True + + # Compute rsigma for backward pass. The fused kernel doesn't return it, + # so cheaply recompute from input (one reduction, no FP8 cast). + rsigma = input_2d.float().square().mean(dim=-1).add_(eps).rsqrt() + return out, rsigma + + # Float8Quantizer: fused_rms_fp8_per_tensor_static_quant + if _is_delayed_scaling_quantizer(quantizer) and _aiter_fused_rms_fp8_static is not None: + # AITER kernel expects dequant scale = 1/quant_scale + dequant_scale = (1.0 / quantizer.scale).to(torch.float32) + + # Request the unquantized post-RMSNorm output so we can track its amax. + # Delayed scaling's history must reflect the distribution of what gets + # cast to FP8 (post-norm), not the pre-norm input — otherwise every + # computed scale is off by the RMS factor of the input, and the error + # compounds as activation magnitudes drift during training. + out_fp8, out_norm, _, _ = _aiter_fused_rms_fp8_static( + input_2d, weight, eps, dequant_scale, + output_unquantized_inp1=True, + ) + + # Update amax from the post-norm (pre-cast) tensor for next step's + # delayed scaling. copy_() keeps the reduction on-device; .item() + # would force a CPU<->GPU sync on every RMSNorm forward. + quantizer.amax.copy_(out_norm.abs().amax()) + + # Wrap raw FP8 data in Float8Tensor via the quantizer. + # Create empty container with the ORIGINAL (possibly N-D) shape, + # then copy in the 2D FP8 data from the fused kernel. + out = quantizer.make_empty( + orig_shape, dtype=input_2d.dtype, device=input_2d.device, + ) + fp8_bytes = out_fp8.view(torch.uint8) + if hasattr(out, '_data'): + out._data.copy_(fp8_bytes.reshape(out._data.shape)) + if hasattr(out, '_scale_inv'): + out._scale_inv.copy_(dequant_scale) + # make_empty allocated the transpose buffer (columnwise_usage was set + # on the quantizer) but the fused kernel only writes _data. Mark the + # buffer stale so update_usage/_create_transpose regenerates it from + # _data on demand — otherwise downstream wgrad reads uninitialized + # memory. + if hasattr(out, '_transpose') and out._transpose is not None: + out._transpose_invalid = True + + # Compute rsigma for backward pass (we need it, but the fused kernel + # doesn't return it). Cheaply recompute from input. + rsigma = input_2d.float().square().mean(dim=-1).add_(eps).rsqrt() + return out, rsigma + + # MXFP8Quantizer: fused_rms_fp8_group_quant + # Single kernel: RMSNorm → per-block FP8 quantize (group_size=32). + if _is_mxfp8_quantizer(quantizer) and _aiter_fused_rms_fp8_group is not None: + try: + from transformer_engine.pytorch._lite.quantize import _linear_scale_to_e8m0 + + (out_fp8, out_scales), _, _, _ = _aiter_fused_rms_fp8_group( + input_2d, weight, eps, group_size=32, + ) + + # Create empty MXFP8 container via quantizer + out = quantizer.make_empty( + orig_shape, dtype=input_2d.dtype, device=input_2d.device, + ) + + # Copy FP8 data (uint8 bit pattern) + fp8_bytes = out_fp8.view(torch.uint8) + if hasattr(out, '_rowwise_data') and out._rowwise_data is not None: + out._rowwise_data.copy_(fp8_bytes.reshape(out._rowwise_data.shape)) + + # Convert AITER linear float32 scales → E8M0 uint8 and store + e8m0_scales = _linear_scale_to_e8m0(out_scales) + if hasattr(out, '_rowwise_scale_inv') and out._rowwise_scale_inv is not None: + out._rowwise_scale_inv.copy_( + e8m0_scales.reshape(out._rowwise_scale_inv.shape) + ) + + # Compute rsigma for backward pass + rsigma = input_2d.float().square().mean(dim=-1).add_(eps).rsqrt() + return out, rsigma + except (RuntimeError, ValueError): + # Scale shape mismatch or other issue — fall back to separate path + pass + + return None + + +# --------------------------------------------------------------------------- +# Reshape helpers for N-D input +# --------------------------------------------------------------------------- + +def _ensure_2d(t): + """Reshape to 2D (M, N) if needed. Returns (tensor_2d, original_shape).""" + if t.ndim <= 2: + return t, t.shape + orig = t.shape + return t.reshape(-1, orig[-1]), orig + + +def _restore_nd(t, orig_shape): + """Restore from 2D back to original N-D shape.""" + if len(orig_shape) <= 2: + return t + return t.reshape(orig_shape) + + +def _restore_nd_quantized(out, orig_shape): + """Restore N-D shape for possibly-quantized output.""" + if len(orig_shape) <= 2: + return out + batch_shape = orig_shape[:-1] + if hasattr(out, '_data'): + out._data = out._data.reshape(*batch_shape, -1) + elif isinstance(out, torch.Tensor): + out = out.reshape(*batch_shape, -1) + return out + + +def _restore_stats(stats, orig_shape): + """Restore stats (mean, rstdev) to batch shape.""" + if len(orig_shape) <= 2 or stats is None: + return stats + if stats.numel() == 0: + return stats + return stats.reshape(orig_shape[:-1]) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def layernorm_fwd(input, weight, bias, eps, ln_out, quantizer, otype, sm_margin, + zero_centered_gamma): + """LayerNorm forward. + + Backend priority: AITER Triton > TE Triton > PyTorch. + """ + _try_load_aiter_norms() + _try_load_triton_norms() + + input_2d, orig_shape = _ensure_2d(input) + + # Try AITER Triton + if _aiter_ln_fwd is not None: + _norm_bump("ln_fwd_aiter_triton") + out, mu, rsigma = _aiter_layernorm_fwd(input_2d, weight, bias, eps, + zero_centered_gamma) + # Try TE Triton + elif _triton_ln_fwd is not None: + _norm_bump("ln_fwd_te_triton") + if otype is None: + otype = input.dtype + out, mu, rsigma = _triton_ln_fwd( + input_2d, weight, bias, eps, ln_out, quantizer, otype, + sm_margin, zero_centered_gamma, + ) + # TE Triton handles quantizer internally + out = _restore_nd_quantized(out, orig_shape) + mu = _restore_stats(mu, orig_shape) + rsigma = _restore_stats(rsigma, orig_shape) + return out, mu, rsigma + # PyTorch fallback + else: + _norm_bump("ln_fwd_pytorch") + out, mu, rsigma = _layernorm_fwd_pytorch(input_2d, weight, bias, eps, + zero_centered_gamma) + + # Apply quantizer (separate step -- AITER and PyTorch paths) + if quantizer is not None and hasattr(quantizer, 'quantize'): + out = quantizer.quantize(out) + + if ln_out is not None and ln_out is not out: + ln_out.copy_(out) + else: + ln_out = out + + ln_out = _restore_nd_quantized(ln_out, orig_shape) + mu = _restore_stats(mu, orig_shape) + rsigma = _restore_stats(rsigma, orig_shape) + return ln_out, mu, rsigma + + +def layernorm_bwd(grad_output, input, mean, rstdev, weight, sm_margin, + zero_centered_gamma): + """LayerNorm backward. + + Backend priority: AITER Triton > TE Triton > PyTorch. + """ + _try_load_aiter_norms() + _try_load_triton_norms() + + # Dequantize grad_output if it arrived as a QuantizedTensor (e.g., from + # the dgrad GEMM of a LayerNormLinear under FP8 CurrentScaling). + if hasattr(grad_output, 'dequantize') and hasattr(grad_output, '_fp8_dtype'): + grad_output = grad_output.dequantize(dtype=input.dtype) + + orig_shape = input.shape + input_2d, _ = _ensure_2d(input) + grad_2d, _ = _ensure_2d(grad_output) + if mean is not None and mean.ndim > 1: + mean = mean.reshape(-1) + if rstdev.ndim > 1: + rstdev = rstdev.reshape(-1) + + if _aiter_ln_bwd is not None: + _norm_bump("ln_bwd_aiter_triton") + dx, dgamma, dbeta = _aiter_layernorm_bwd(grad_2d, input_2d, mean, rstdev, + weight, zero_centered_gamma) + elif _triton_ln_bwd is not None: + _norm_bump("ln_bwd_te_triton") + dx, dgamma, dbeta = _triton_ln_bwd( + grad_2d, input_2d, mean, rstdev, weight, sm_margin, + zero_centered_gamma, + ) + else: + _norm_bump("ln_bwd_pytorch") + dx, dgamma, dbeta = _layernorm_bwd_pytorch(grad_2d, input_2d, mean, rstdev, + weight, zero_centered_gamma) + + dx = _restore_nd(dx, orig_shape) + return dx, dgamma, dbeta + + +def rmsnorm_fwd(input, weight, eps, ln_out, quantizer, otype, sm_margin, + zero_centered_gamma): + """RMSNorm forward. + + Backend priority: + 1. AITER fused norm+quantize (single kernel, Float8Quantizer only) + 2. AITER Triton norm + separate quantize + 3. TE Triton norm (handles quantizer internally) + 4. PyTorch fallback norm + separate quantize + """ + _try_load_aiter_norms() + _try_load_triton_norms() + + input_2d, orig_shape = _ensure_2d(input) + + # Try AITER fused norm+quantize (single kernel launch) + fused_result = _try_fused_rmsnorm_quant( + input_2d, weight, eps, quantizer, zero_centered_gamma, + orig_shape=orig_shape, + ) + if fused_result is not None: + _norm_bump("rms_fwd_aiter_fused_norm_quant") + out, rsigma = fused_result + rsigma = _restore_stats(rsigma, orig_shape) + return out, torch.Tensor(), rsigma + + # Try AITER Triton (norm only, quantize separate) + if _aiter_rms_fwd is not None: + _norm_bump("rms_fwd_aiter_triton_unfused") + out, rsigma = _aiter_rmsnorm_fwd(input_2d, weight, eps, zero_centered_gamma) + mu = torch.Tensor() + # Try TE Triton (handles quantizer internally) + elif _triton_rms_fwd is not None: + _norm_bump("rms_fwd_te_triton") + if otype is None: + otype = input.dtype + out, mu, rsigma = _triton_rms_fwd( + input_2d, weight, eps, ln_out, quantizer, otype, + sm_margin, zero_centered_gamma, + ) + out = _restore_nd_quantized(out, orig_shape) + rsigma = _restore_stats(rsigma, orig_shape) + return out, mu, rsigma + # PyTorch fallback + else: + _norm_bump("rms_fwd_pytorch") + out, rsigma = _rmsnorm_fwd_pytorch(input_2d, weight, eps, zero_centered_gamma) + mu = torch.Tensor() + + # Apply quantizer (separate step -- AITER norm and PyTorch paths) + if quantizer is not None and hasattr(quantizer, 'quantize'): + out = quantizer.quantize(out) + + if ln_out is not None and ln_out is not out: + ln_out.copy_(out) + else: + ln_out = out + + ln_out = _restore_nd_quantized(ln_out, orig_shape) + rsigma = _restore_stats(rsigma, orig_shape) + return ln_out, mu, rsigma + + +def rmsnorm_bwd(grad_output, input, rstdev, weight, sm_margin, zero_centered_gamma): + """RMSNorm backward. + + Backend priority: AITER Triton > TE Triton > PyTorch. + """ + _try_load_aiter_norms() + _try_load_triton_norms() + + # Dequantize grad_output if it arrived as a QuantizedTensor (e.g., from + # the dgrad GEMM of a LayerNormLinear under FP8 CurrentScaling). + if hasattr(grad_output, 'dequantize') and hasattr(grad_output, '_fp8_dtype'): + grad_output = grad_output.dequantize(dtype=input.dtype) + + orig_shape = input.shape + input_2d, _ = _ensure_2d(input) + grad_2d, _ = _ensure_2d(grad_output) + if rstdev.ndim > 1: + rstdev = rstdev.reshape(-1) + + if _aiter_rms_bwd is not None: + _norm_bump("rms_bwd_aiter_triton") + dx, dgamma = _aiter_rmsnorm_bwd(grad_2d, input_2d, rstdev, weight, + zero_centered_gamma) + elif _triton_rms_bwd is not None: + _norm_bump("rms_bwd_te_triton") + dx, dgamma = _triton_rms_bwd( + grad_2d, input_2d, rstdev, weight, sm_margin, + zero_centered_gamma, + ) + else: + _norm_bump("rms_bwd_pytorch") + dx, dgamma = _rmsnorm_bwd_pytorch(grad_2d, input_2d, rstdev, weight, + zero_centered_gamma) + + dx = _restore_nd(dx, orig_shape) + return dx, dgamma + + +def rmsnorm_bwd_add(grad_output, input, rstdev, weight, zero_centered_gamma): + """Fused RMSNorm backward + add. Returns (grad_input, grad_weight).""" + return rmsnorm_bwd(grad_output, input, rstdev, weight, 0, zero_centered_gamma) diff --git a/transformer_engine/pytorch/_lite/padding.py b/transformer_engine/pytorch/_lite/padding.py new file mode 100644 index 000000000..28ee18bd4 --- /dev/null +++ b/transformer_engine/pytorch/_lite/padding.py @@ -0,0 +1,72 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Multi-row padding / unpadding -- tex-compatible interface. + +Uses PyTorch-native operations. The existing Triton ``zero_pad_kernel`` +in ``common/triton/pad.py`` is purpose-built for 2-D columnwise-scale +alignment padding and does not apply to the multi-row copy-with-padding +pattern needed here. +""" + +import torch + + +def fused_multi_row_padding(input, output, input_row_list, padded_input_row_list): + """Copy rows from *input* into *output*, zero-padding the extra rows. + + Matches ``tex.fused_multi_row_padding(input, output, src_splits, dst_splits)``. + + Parameters + ---------- + input : torch.Tensor + Source tensor of shape ``[sum(input_row_list), features]``. + output : torch.Tensor + Pre-allocated destination of shape ``[sum(padded_input_row_list), features]``. + input_row_list : list[int] + Number of rows per group in the source tensor. + padded_input_row_list : list[int] + Number of rows per group in the destination tensor (≥ corresponding + entry in *input_row_list*). + """ + in_offset = 0 + out_offset = 0 + for src_rows, dst_rows in zip(input_row_list, padded_input_row_list): + if src_rows > 0: + output[out_offset:out_offset + src_rows].copy_( + input[in_offset:in_offset + src_rows], + ) + if dst_rows > src_rows: + output[out_offset + src_rows:out_offset + dst_rows].zero_() + in_offset += src_rows + out_offset += dst_rows + + +def fused_multi_row_unpadding(input, output, input_row_list, unpadded_input_row_list): + """Extract unpadded rows from a padded tensor. + + Matches ``tex.fused_multi_row_unpadding(input, output, src_splits, dst_splits)``. + + Parameters + ---------- + input : torch.Tensor + Padded source tensor of shape ``[sum(input_row_list), features]``. + output : torch.Tensor + Pre-allocated destination of shape ``[sum(unpadded_input_row_list), features]``. + input_row_list : list[int] + Number of rows per group in the padded source tensor. + unpadded_input_row_list : list[int] + Number of rows per group to extract (≤ corresponding entry in + *input_row_list*). + """ + in_offset = 0 + out_offset = 0 + for src_rows, dst_rows in zip(input_row_list, unpadded_input_row_list): + if dst_rows > 0: + output[out_offset:out_offset + dst_rows].copy_( + input[in_offset:in_offset + dst_rows], + ) + in_offset += src_rows + out_offset += dst_rows diff --git a/transformer_engine/pytorch/_lite/permutation.py b/transformer_engine/pytorch/_lite/permutation.py new file mode 100644 index 000000000..24ec756d0 --- /dev/null +++ b/transformer_engine/pytorch/_lite/permutation.py @@ -0,0 +1,147 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""MOE permutation operations -- tex-compatible interface. + +Index-map path: Uses Triton sort_chunks_by_map kernel for gather operations +when available, falling back to PyTorch-native. + +Mask-map path: The higher-level permutation.py imports +transformer_engine.pytorch.triton.permutation directly for mask-map +operations -- those Triton kernels work in lite mode without changes here. +""" + +import torch + +# --------------------------------------------------------------------------- +# Lazy Triton import for sort_chunks_by_map (gather/scatter kernel) +# --------------------------------------------------------------------------- +_triton_sort = None +_triton_attempted = False + + +def _try_load_triton_sort(): + global _triton_sort, _triton_attempted + if _triton_attempted: + return _triton_sort + _triton_attempted = True + try: + from transformer_engine.pytorch.triton.permutation import sort_chunks_by_map + _triton_sort = sort_chunks_by_map + except (ImportError, RuntimeError): + pass + return _triton_sort + + +# --------------------------------------------------------------------------- +# tex-compatible API +# --------------------------------------------------------------------------- + +def moe_permute_fwd(input, dtype, indices, num_out_tokens, workspace, max_expanded_token_num): + """MOE permute forward: sort tokens by expert assignment. + + Matches the ``tex.moe_permute_fwd`` C++ interface used by + ``_moe_permute_index_map`` in ``permutation.py``. + + Returns + ------- + (permuted_output, row_id_map, workspace) + """ + num_tokens = input.size(0) + num_cols = input.size(1) + topK = indices.size(1) + + # Flatten expert indices and sort to group tokens by expert + flat_indices = indices.reshape(-1).to(torch.int32) + _, sorted_row_id = torch.sort(flat_indices, stable=True) + + num_out = num_out_tokens if num_out_tokens > 0 else num_tokens * topK + + # Map each permuted position to its source token row + source_token_ids = (sorted_row_id[:num_out] // topK).to(torch.int32) + + # Gather rows -- prefer Triton kernel when available + sort_fn = _try_load_triton_sort() + if sort_fn is not None and input.is_cuda: + permuted_output, _ = sort_fn( + input, source_token_ids, None, num_out, num_cols, is_forward=False, + ) + else: + permuted_output = input[source_token_ids.long()] + + # Build inverse map: flat position j → permuted position + row_id_map = torch.zeros( + num_tokens * topK, dtype=torch.int32, device=input.device, + ) + row_id_map[sorted_row_id[:num_out]] = torch.arange( + num_out, dtype=torch.int32, device=input.device, + ) + + return permuted_output, row_id_map, workspace + + +def moe_permute_bwd(input, dtype, row_id_map, prob, num_tokens, topK): + """MOE permute backward -- identical to ``moe_unpermute_fwd``. + + Matches ``tex.moe_permute_bwd`` (the C++ implementation delegates + to ``moe_unpermute_fwd`` as well). + """ + return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK) + + +def moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK): + """MOE unpermute forward: scatter-add from permuted to original order. + + Matches ``tex.moe_unpermute_fwd``. + """ + num_cols = input.size(1) + + # Gather permuted data back to flat (num_tokens*topK) order + gathered = input[row_id_map.long()] # [num_tokens * topK, num_cols] + + if prob.numel() > 0: + gathered = gathered * prob.reshape(-1, 1) + + # Sum over the topK dimension to merge expert contributions + output = gathered.reshape(num_tokens, topK, num_cols).sum(dim=1) + return output + + +def moe_unpermute_bwd(input_bwd, input_fwd, dtype, row_id_map, prob): + """MOE unpermute backward. + + Matches ``tex.moe_unpermute_bwd``. + + Returns + ------- + (act_grad, prob_grad) + """ + topK = prob.size(1) if prob.numel() > 0 else 1 + num_tokens = prob.size(0) if prob.numel() > 0 else row_id_map.size(0) + num_cols = input_bwd.size(1) + + # Expand grad from [num_tokens, num_cols] to [num_tokens * topK, num_cols] + token_ids = torch.arange(num_tokens, device=input_bwd.device).repeat_interleave(topK) + + act_grad = torch.zeros( + input_fwd.size(0), num_cols, + device=input_bwd.device, dtype=input_bwd.dtype, + ) + + if prob.numel() > 0: + weights = prob.reshape(-1, 1).to(input_bwd.dtype) + act_grad[row_id_map.long()] = input_bwd[token_ids] * weights + + # prob_grad = dot(d_output[token], fwd_input[permuted_pos]) + fwd_gathered = input_fwd[row_id_map.long()].float() + prob_grad = ( + (fwd_gathered * input_bwd[token_ids].float()) + .sum(dim=-1).reshape(num_tokens, topK) + ) + else: + act_grad[row_id_map.long()] = input_bwd[token_ids] + prob_grad = torch.empty(0, device=input_bwd.device, dtype=torch.float32) + + return act_grad, prob_grad diff --git a/transformer_engine/pytorch/_lite/quantize.py b/transformer_engine/pytorch/_lite/quantize.py new file mode 100644 index 000000000..4fa48eb93 --- /dev/null +++ b/transformer_engine/pytorch/_lite/quantize.py @@ -0,0 +1,984 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Quantization operations -- Triton cast kernels with PyTorch-native fallback. + +Uses Triton cast/transpose kernels from triton_kernels/cast_transpose.py when +available, falls back to pure PyTorch implementations otherwise. + +IMPORTANT: This module must NOT call tex.quantize/tex.dequantize in any +fallback path, because in lite mode tex IS this module — that would recurse. +""" + +import os +import torch + +_LITE_DIAG = os.environ.get("NVTE_LITE_DIAG", "0") != "0" + +from collections import Counter as _QuantCounter +_QUANT_CALLS = _QuantCounter() + +def _quant_bump(tag): + if not _LITE_DIAG: + return + _QUANT_CALLS[tag] += 1 + if sum(_QUANT_CALLS.values()) % 500 == 0: + print(f"[LITE-QUANT] {dict(_QUANT_CALLS)}", flush=True) + +_FP8_FALLBACK_DIAG_PRINTS = 0 +_FP8_FALLBACK_DIAG_MAX = 5 + +# Lazy-loaded Triton cast functions and type checks +_triton_cast_import_attempted = False +_triton_cast_transpose_noop = None +_triton_cast_transpose_mxfp8 = None +_triton_cast_transpose_mxfp4 = None +_triton_dequantize_mxfp8 = None +_setup_transpose_storage = None +_Float8TensorStorage = None +_MXFP8TensorStorage = None +_MXFP4TensorStorage = None +_Float8CurrentScalingQuantizer = None + +# AITER per-row dynamic quantize (lazy-loaded) +_aiter_dynamic_per_token_quant = None +_aiter_quant_import_attempted = False + + +def _try_load_triton_cast(): + """Lazy-import Triton cast kernels and tensor storage types.""" + global _triton_cast_import_attempted + global _triton_cast_transpose_noop, _triton_cast_transpose_mxfp8 + global _triton_cast_transpose_mxfp4, _triton_dequantize_mxfp8 + global _setup_transpose_storage + global _Float8TensorStorage, _MXFP8TensorStorage, _MXFP4TensorStorage + global _Float8CurrentScalingQuantizer + + if _triton_cast_import_attempted: + return + + _triton_cast_import_attempted = True + try: + from transformer_engine.pytorch.triton_kernels.cast_transpose import ( + te_cast_transpose_noop_triton, + te_cast_transpose_mxfp8_triton, + te_cast_transpose_mxfp4_triton, + te_dequantize_mxfp8_triton, + ) + from transformer_engine.pytorch.triton_kernels.cast import ( + _setup_conditional_transpose_storage, + ) + _triton_cast_transpose_noop = te_cast_transpose_noop_triton + _triton_cast_transpose_mxfp8 = te_cast_transpose_mxfp8_triton + _triton_cast_transpose_mxfp4 = te_cast_transpose_mxfp4_triton + _triton_dequantize_mxfp8 = te_dequantize_mxfp8_triton + _setup_transpose_storage = _setup_conditional_transpose_storage + except (ImportError, ModuleNotFoundError): + pass + + # Always try to load tensor storage types (no Triton dependency) + try: + from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import ( + Float8TensorStorage, + ) + _Float8TensorStorage = Float8TensorStorage + except (ImportError, ModuleNotFoundError): + pass + try: + from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import ( + MXFP8TensorStorage, + ) + _MXFP8TensorStorage = MXFP8TensorStorage + except (ImportError, ModuleNotFoundError): + pass + try: + from transformer_engine.pytorch.tensor.storage.mxfp4_tensor_storage import ( + MXFP4TensorStorage, + ) + _MXFP4TensorStorage = MXFP4TensorStorage + except (ImportError, ModuleNotFoundError): + pass + try: + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, + ) + _Float8CurrentScalingQuantizer = Float8CurrentScalingQuantizer + except (ImportError, ModuleNotFoundError): + pass + + +def _try_load_aiter_quant(): + """Lazy-import AITER per-row dynamic quantize kernel.""" + global _aiter_dynamic_per_token_quant, _aiter_quant_import_attempted + + if _aiter_quant_import_attempted: + return + _aiter_quant_import_attempted = True + + try: + from .aiter_utils import is_aiter_available + if not is_aiter_available(): + return + from aiter.ops.triton.quant import dynamic_per_token_quant_fp8_i8 + _aiter_dynamic_per_token_quant = dynamic_per_token_quant_fp8_i8 + except (ImportError, AttributeError): + pass + + +def _empty_tensor(): + """Get tensor with no entries and no data.""" + return torch.Tensor().cuda() + + +# --------------------------------------------------------------------------- +# PyTorch fallback for quantize -- no recursion through tex.quantize +# --------------------------------------------------------------------------- + +def _te_dtype_to_torch_fp8(te_dtype): + """Map TE DType enum to torch FP8 dtype.""" + try: + from transformer_engine.pytorch.triton_kernels.common import te_dtype_to_torch_dtype + return te_dtype_to_torch_dtype(te_dtype) + except (KeyError, ImportError): + return torch.float8_e4m3fnuz + + +def _linear_scale_to_e8m0(scale_float32): + """Convert linear float32 scales to E8M0 biased exponent (uint8). + + E8M0 format: value = 2^(exponent - 127), stored as uint8. + Conversion: e8m0 = floor(log2(scale)) + 127, clamped to [0, 254]. + + Args: + scale_float32: float32 tensor of per-group linear dequant scales + Returns: + uint8 tensor of E8M0 biased exponents + """ + scale_clamped = scale_float32.float().clamp(min=2**-127) + exponent = torch.floor(torch.log2(scale_clamped)) + 127 + return exponent.clamp(0, 254).to(torch.uint8) + + +def _quantize_float8_pytorch(input_tensor, quantizer, out): + """Quantize to Float8 using PyTorch ops. No C++ or tex.quantize dependency.""" + if input_tensor.nelement() == 0: + return out + + # Compute amax and scale. Keep both on-device: .item() would force a + # CPU<->GPU sync on every quantize call. + amax_val = input_tensor.abs().amax() + if hasattr(quantizer, 'amax') and quantizer.amax is not None: + quantizer.amax.copy_(amax_val) + + scale = quantizer.scale + scale_inv = out._scale_inv + torch_fp8_dtype = _te_dtype_to_torch_fp8(quantizer.dtype) + + # Scale, cast to FP8, then store as uint8 (FP8 bit pattern) + scaled = input_tensor.float() * scale.float() + fp8_data = scaled.to(torch_fp8_dtype) + out._data.copy_(fp8_data.view(torch.uint8)) + scale_inv.copy_(scale.float().reciprocal()) + + return out + + +def _quantize_per_row_dynamic(input_tensor, quantizer, out): + """Per-row dynamic FP8 quantize via AITER dynamic_per_token_quant_fp8_i8. + + Each row gets its own scale computed in-kernel (no global amax pass). + Output Float8Tensor has _scale_inv shape (M,) instead of scalar. + Used for CurrentScaling in backward (dY quantization) and standalone + quantize calls. + """ + if input_tensor.nelement() == 0: + return out + + input_2d = input_tensor.reshape(-1, input_tensor.shape[-1]) + M, N = input_2d.shape + torch_fp8_dtype = _te_dtype_to_torch_fp8(quantizer.dtype) + + # Pre-allocate output tensors for the AITER kernel + qx = torch.empty(M, N, dtype=torch_fp8_dtype, device=input_2d.device) + scale_out = torch.empty(M, dtype=torch.float32, device=input_2d.device) + + _aiter_dynamic_per_token_quant(qx, input_2d, scale_out) + + # Write FP8 data into the output container + fp8_bytes = qx.view(torch.uint8) + out._data.copy_(fp8_bytes.reshape(out._data.shape)) + # Store per-row dequant scales — downstream GEMM detects numel() > 1 + out._scale_inv = scale_out + # Mark transpose cache stale so update_usage(columnwise=True) will + # regenerate it from the freshly-written _data instead of using the + # uninitialized buffer allocated by make_empty(). + if hasattr(out, '_transpose_invalid'): + out._transpose_invalid = True + + return out + + +def _quantize_mxfp8_pytorch(input_tensor, quantizer, out): + """Quantize to MXFP8 using pure PyTorch ops — no Triton dependency. + + Implements group_size=32 block scaling with E8M0 scale format: + 1. Reshape input into groups of 32 + 2. Compute per-group amax → E8M0 biased exponent + 3. Scale groups, cast to FP8, store as uint8 + """ + if input_tensor.nelement() == 0: + return out + + try: + from transformer_engine.pytorch.constants import MXFP8_BLOCK_SCALING_SIZE + group_size = MXFP8_BLOCK_SCALING_SIZE # 32 + except ImportError: + group_size = 32 + + input_2d = input_tensor.reshape(-1, input_tensor.shape[-1]) + M, K = input_2d.shape + torch_fp8_dtype = _te_dtype_to_torch_fp8(quantizer.dtype) + + # Pad K to multiple of group_size if needed + K_padded = ((K + group_size - 1) // group_size) * group_size + if K_padded != K: + input_padded = torch.nn.functional.pad(input_2d, (0, K_padded - K)) + else: + input_padded = input_2d + + # Reshape into groups: (M, K/32, 32) + num_groups = K_padded // group_size + grouped = input_padded.float().reshape(M, num_groups, group_size) + + # Per-group amax + group_amax = grouped.abs().amax(dim=-1) # (M, num_groups) + group_amax = group_amax.clamp(min=2**-127) + + # E8M0 biased exponent: floor(log2(amax)) + 127 + biased_exp = torch.floor(torch.log2(group_amax)) + 127 + biased_exp = biased_exp.clamp(0, 254) + + # Dequant: output = fp8_data * 2^(biased_exp - 127) + # Quantize: fp8_data = input / 2^(biased_exp - 127) + dequant_scale = torch.exp2(biased_exp - 127) # (M, num_groups) + inv_scale = 1.0 / dequant_scale # (M, num_groups) + + # Scale each group and cast to FP8 + scaled = grouped * inv_scale.unsqueeze(-1) # (M, num_groups, 32) + fp8_data = scaled.reshape(M, K_padded)[:, :K].contiguous().to(torch_fp8_dtype) + fp8_bytes = fp8_data.view(torch.uint8) + + # Write into output container + if hasattr(out, '_rowwise_data') and out._rowwise_data is not None: + out._rowwise_data.copy_(fp8_bytes.reshape(out._rowwise_data.shape)) + if hasattr(out, '_rowwise_scale_inv') and out._rowwise_scale_inv is not None: + e8m0 = biased_exp[:, :((K + group_size - 1) // group_size)].to(torch.uint8) + out._rowwise_scale_inv.copy_(e8m0.reshape(out._rowwise_scale_inv.shape)) + + return out + + +def _quantize_pytorch_fallback(tensor, quantizer, output=None, noop=None): + """Pure PyTorch quantize -- never calls tex.quantize (avoids recursion).""" + _try_load_triton_cast() + + if quantizer is None: + if output is not None: + output.copy_(tensor) + return output + return tensor + + # Create output tensor if not provided + out = output + if out is None and hasattr(quantizer, 'make_empty'): + fake_dtype = tensor.dtype if tensor.dtype.is_floating_point else torch.float32 + out = quantizer.make_empty(tensor.shape, dtype=fake_dtype) + if _Float8TensorStorage is not None and isinstance(out, _Float8TensorStorage): + if _setup_transpose_storage is not None: + _setup_transpose_storage(out) + + if out is None: + # No quantizer.make_empty — just return tensor as-is + return tensor + + # Dispatch to appropriate PyTorch fallback based on output type + if _MXFP8TensorStorage is not None and isinstance(out, _MXFP8TensorStorage): + return _quantize_mxfp8_pytorch(tensor.contiguous(), quantizer, out) + if _Float8TensorStorage is not None and isinstance(out, _Float8TensorStorage): + return _quantize_float8_pytorch(tensor.contiguous(), quantizer, out) + + # For other quantized types without Triton, try quantizer.quantize + # but guard against recursion by checking if we'd go through tex.quantize + if hasattr(quantizer, 'quantize'): + # This is safe for non-Float8 quantizers that don't recurse through tex + return quantizer.quantize(tensor) + + if output is not None: + output.copy_(tensor) + return output + return tensor + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def quantize(tensor, quantizer, output=None, noop=None): + """Quantize tensor. Uses Triton cast kernels when available.""" + _try_load_triton_cast() + + input_tensor = tensor.contiguous() if tensor is not None else tensor + + # Fast path: no quantizer + if quantizer is None: + if output is not None: + output.copy_(input_tensor) + return output + return input_tensor + + # Create output tensor if not provided + out = output + if out is None and hasattr(quantizer, 'make_empty'): + fake_dtype = input_tensor.dtype if input_tensor.dtype.is_floating_point else torch.float32 + if input_tensor.ndim == 0: + out = quantizer.make_empty((1,), dtype=fake_dtype) + if _Float8TensorStorage and isinstance(out, _Float8TensorStorage): + out._data = out._data.squeeze(0) + if out._transpose is not None: + out._transpose = out._transpose.squeeze(0) + else: + out = quantizer.make_empty(input_tensor.shape, dtype=fake_dtype) + + if _Float8TensorStorage and isinstance(out, _Float8TensorStorage): + if _setup_transpose_storage is not None: + _setup_transpose_storage(out) + + if out is None: + return input_tensor + + # Construct no-op flag + noop_flag = noop if noop is not None else _empty_tensor() + + # Check for empty output + if (_MXFP8TensorStorage and isinstance(out, _MXFP8TensorStorage) + and out._rowwise_data is None and out._columnwise_data is None): + return out + if not (_MXFP8TensorStorage and isinstance(out, _MXFP8TensorStorage)): + if hasattr(out, 'size') and callable(out.size) and out.size().numel() == 0: + return out + + # --- Per-row dynamic quantize (CurrentScaling + AITER) --- + # Must come before per-tensor paths: per-row is strictly better for + # CurrentScaling (fused single kernel, no global amax pass). + _try_load_aiter_quant() + if (_Float8TensorStorage and isinstance(out, _Float8TensorStorage) + and _Float8CurrentScalingQuantizer is not None + and isinstance(quantizer, _Float8CurrentScalingQuantizer) + and _aiter_dynamic_per_token_quant is not None + and input_tensor.nelement() > 0): + _quant_bump("per_row_dynamic_aiter") + return _quantize_per_row_dynamic(input_tensor, quantizer, out) + + # --- Triton dispatch --- + if _Float8TensorStorage and isinstance(out, _Float8TensorStorage): + if input_tensor.nelement() > 0: + if _LITE_DIAG: + global _FP8_FALLBACK_DIAG_PRINTS + if _FP8_FALLBACK_DIAG_PRINTS < _FP8_FALLBACK_DIAG_MAX: + _FP8_FALLBACK_DIAG_PRINTS += 1 + # Also read usage from the quantizer copy the tensor holds — + # that's what _setup_conditional_transpose_storage looked at. + stored_q = getattr(out, "_quantizer", None) + stored_rw = getattr(stored_q, "rowwise_usage", "MISSING") + stored_cw = getattr(stored_q, "columnwise_usage", "MISSING") + print( + f"[LITE-QUANT-DIAG #{_FP8_FALLBACK_DIAG_PRINTS}] Float8 path: " + f"qt={type(quantizer).__name__}, " + f"live_q.rw={getattr(quantizer, 'rowwise_usage', '?')}, " + f"live_q.cw={getattr(quantizer, 'columnwise_usage', '?')}, " + f"stored_q.rw={stored_rw}, stored_q.cw={stored_cw}, " + f"transpose_none={out._transpose is None}, " + f"transpose_invalid={out._transpose_invalid}, " + f"shape={tuple(input_tensor.shape)}", + flush=True, + ) + if _triton_cast_transpose_noop is not None: + # Triton cast+transpose. The kernel always writes a transpose, + # so when the caller didn't ask for columnwise data we pass a + # throwaway buffer and drop it. Still much cheaper than the + # pure-PyTorch fallback. + q = out._get_quantizer() + is_current_scaling = ( + _Float8CurrentScalingQuantizer is not None + and isinstance(q, _Float8CurrentScalingQuantizer) + ) + if out._transpose is not None and not out._transpose_invalid: + trans_out = out._transpose + _quant_bump("float8_triton_cast_transpose") + else: + row_length = ( + input_tensor.shape[-1] if input_tensor.ndim > 0 else 1 + ) + num_rows = input_tensor.numel() // row_length + trans_out = torch.empty( + (row_length, num_rows), + dtype=torch.uint8, + device=input_tensor.device, + ) + _quant_bump("float8_triton_rowwise_only") + _triton_cast_transpose_noop( + input_tensor, + noop_flag, + input_scale=q.scale, + cast_out=out._data, + trans_out=trans_out, + amax_out=q.amax, + scale_inv_out=out._scale_inv, + otype=q.dtype, + current_scaling=is_current_scaling, + eps=getattr(q, "amax_epsilon", 0.0), + force_pow_2_scales=getattr(q, "force_pow_2_scales", False), + ) + return out + else: + _quant_bump("float8_pytorch_fallback") + # No Triton cast kernel available — PyTorch fallback. + if hasattr(out, 'remove_caches'): + out.remove_caches() + return _quantize_float8_pytorch(input_tensor, quantizer, out) + + elif _MXFP8TensorStorage and isinstance(out, _MXFP8TensorStorage): + if _triton_cast_transpose_mxfp8 is not None: + _quant_bump("mxfp8_triton") + _triton_cast_transpose_mxfp8(input_tensor, out) + return out + else: + _quant_bump("mxfp8_pytorch_fallback") + return _quantize_mxfp8_pytorch(input_tensor, quantizer, out) + + elif _MXFP4TensorStorage and isinstance(out, _MXFP4TensorStorage): + if _triton_cast_transpose_mxfp4 is not None: + _quant_bump("mxfp4_triton") + _triton_cast_transpose_mxfp4(input_tensor, out) + return out + + # Fallback for unrecognized types + _quant_bump("unrecognized_pytorch_fallback") + return _quantize_pytorch_fallback(tensor, quantizer, output, noop) + + +def dequantize(input, otype): + """Dequantize tensor to the specified output type.""" + _try_load_triton_cast() + + # Determine target torch dtype + if isinstance(otype, torch.dtype): + target_dtype = otype + else: + dtype_map = {0: torch.uint8, 2: torch.float32, 3: torch.float16, 4: torch.bfloat16} + target_dtype = dtype_map.get(int(otype), torch.float32) + + # Triton MXFP8 dequantize + if (_MXFP8TensorStorage and isinstance(input, _MXFP8TensorStorage) + and _triton_dequantize_mxfp8 is not None): + return _triton_dequantize_mxfp8(input, otype) + + # Float8 dequantize -- PyTorch (no Triton kernel exists for this) + if _Float8TensorStorage and isinstance(input, _Float8TensorStorage): + if input._data is not None: + if input._data.nelement() == 0: + return torch.empty_like(input._data, dtype=target_dtype) + # Reinterpret uint8 bits as FP8 dtype, then cast to target + torch_fp8_dtype = _te_dtype_to_torch_fp8(input._fp8_dtype) + fp8_view = input._data.view(torch_fp8_dtype) + hp = fp8_view.to(target_dtype) + scale_inv = input._scale_inv + if scale_inv.numel() == 1: + return hp * scale_inv + # Per-row scale: quantize produced (M_flat,) scale from a 2D view, + # but _data may be stored in N-D shape. Reshape scale to match + # hp's leading dims so broadcast against the last dim works. + leading_numel = 1 + for d in hp.shape[:-1]: + leading_numel *= d + if scale_inv.numel() == leading_numel: + scale_inv = scale_inv.reshape(*hp.shape[:-1], 1) + else: + scale_inv = scale_inv.reshape( + *scale_inv.shape, *([1] * (hp.ndim - scale_inv.ndim)) + ) + return hp * scale_inv + raise NotImplementedError("Dequantize from transpose not implemented in lite mode") + + # Plain tensor — just cast dtype + if isinstance(input, torch.Tensor): + return input.to(target_dtype) + + # Object with dequantize method (custom quantized types) + if hasattr(input, 'dequantize'): + return input.dequantize() + + return input.to(target_dtype) + + +def bgrad_quantize(input, quantizer): + """Compute bias gradient and quantize. + + Uses separate sum + quantize. Both ops dispatch to optimized CUDA/Triton + kernels individually. A true single-pass fusion would require merging + bgrad accumulation into the cast kernel (te_cast_transpose_noop_triton). + """ + bgrad = input.sum(dim=tuple(range(input.ndim - 1))) + quantized = quantize(input, quantizer) + return bgrad, quantized + + +def multi_tensor_quantize(tensor_list, quantizer_list): + """Quantize multiple tensors with corresponding quantizers.""" + results = [] + for tensor, quant in zip(tensor_list, quantizer_list): + results.append(quantize(tensor, quant)) + return results + + +def split_quantize(tensor, split_sections, quantizer_list): + """Split tensor and quantize each section.""" + splits = torch.split(tensor, split_sections, dim=0) + results = [] + for split, quant in zip(splits, quantizer_list): + results.append(quantize(split, quant)) + return results + + +def compute_amax(input, amax): + """Compute absolute max value in tensor.""" + amax.copy_(input.abs().amax()) + + +def _fp8_max_for_dtype(fp8_dtype): + """Resolve TE FP8 dtype → max representable value, honoring ROCm fnuz clamp.""" + from transformer_engine.common.recipe import _FormatMaxVals + + try: + is_fnuz = torch.float8_e4m3fnuz is not None and torch.cuda.is_available() + except AttributeError: + is_fnuz = False + dtype_name = str(fp8_dtype).rsplit('.', 1)[-1] # "kFloat8E4M3" or "kFloat8E5M2" + if "E4M3" in dtype_name: + return _FormatMaxVals.E4M3.value[1 if is_fnuz else 0] + return _FormatMaxVals.E5M2.value[1 if is_fnuz else 0] + + +# --- Fused amax / scale update (Triton + Python fallback) ------------------- +# +# Replaces the delayed-scaling C++ kernel (common/recipe/delayed_scaling.cu +# kernel_bulk) with a single Triton launch that processes every scale channel +# across every amax-history tensor in the group. See also the Python fallback +# below for the exact slot-convention contract. + +_triton_amax_update_loaded = False +_triton_amax_update_kernel = None +_amax_pack_cache = {} # keyed by (id(amax_histories), id(scales)) → packed tensors + + +def _try_load_triton_amax_update(): + global _triton_amax_update_loaded, _triton_amax_update_kernel + if _triton_amax_update_loaded: + return _triton_amax_update_kernel is not None + _triton_amax_update_loaded = True + try: + import triton + import triton.language as tl + except ImportError: + return False + + @triton.jit + def _kernel( + reduction_ptr, # *fp32 [N_total] + hist_ptr_arr, # *int64 [N_total] device pointers + scale_ptr_arr, # *int64 [N_total] + stride_arr, # *int32 [N_total] owner tensor's num_scale (= row stride) + local_ch_arr, # *int32 [N_total] channel index within owner + H, # int amax_history_length + scaled_max, # fp32 fp8_max * 2**-margin + ALGO: tl.constexpr, # 0=max, 1=most_recent + BLOCK_H: tl.constexpr, + ): + pid = tl.program_id(0) + + hist_base = tl.load(hist_ptr_arr + pid).to(tl.pointer_type(tl.float32)) + scale_base = tl.load(scale_ptr_arr + pid).to(tl.pointer_type(tl.float32)) + stride = tl.load(stride_arr + pid) + local_ch = tl.load(local_ch_arr + pid) + new_amax = tl.load(reduction_ptr + pid) + + offs = tl.arange(0, BLOCK_H) + mask = offs < H + + # Current column [H], with slot 0 replaced by the newly-reduced amax + # (matches the lite Python path's `amax_history[0].copy_(amax_chunk)`). + hist = tl.load(hist_base + offs * stride + local_ch, mask=mask, other=0.0) + hist_with_new = tl.where(offs == 0, new_amax, hist) + + if ALGO == 0: + amax_for_max = tl.where(mask, hist_with_new, -float('inf')) + amax_val = tl.max(amax_for_max, axis=0) + else: + amax_val = new_amax + + # Rolled history: out[0]=0, out[i]=hist[i+1] for 0 0; +/-inf → FP32_MAX. + prev_scale = tl.load(scale_base + local_ch) + finite = (amax_val == amax_val) & (amax_val != float('inf')) & (amax_val != -float('inf')) + sf = tl.where(finite & (amax_val > 0.0), scaled_max / amax_val, prev_scale) + FP32_MAX = 3.4028234663852886e38 + sf = tl.where((sf == float('inf')) | (sf == -float('inf')), FP32_MAX, sf) + tl.store(scale_base + local_ch, sf) + + _triton_amax_update_kernel = _kernel + return True + + +def _pack_amax_update_args(amax_histories, scales): + """Build (and cache) the per-channel pointer/stride/local-channel arrays. + + Returns (hist_ptrs, scale_ptrs, strides, local_chs) — all on the same + device as the histories. Caches on list-identity because upstream's + global_amax_history_buffer entries are stable across steps. + """ + cache_key = (id(amax_histories), id(scales), + len(amax_histories), + sum(h.shape[-1] for h in amax_histories)) + cached = _amax_pack_cache.get(cache_key) + if cached is not None: + return cached + + device = amax_histories[0].device + hist_ptrs_cpu, scale_ptrs_cpu, strides_cpu, local_chs_cpu = [], [], [], [] + for h, s in zip(amax_histories, scales): + n = h.shape[-1] + hp = h.data_ptr() + sp = s.data_ptr() + for ch in range(n): + hist_ptrs_cpu.append(hp) + scale_ptrs_cpu.append(sp) + strides_cpu.append(n) + local_chs_cpu.append(ch) + + hist_ptrs = torch.tensor(hist_ptrs_cpu, dtype=torch.int64, device=device) + scale_ptrs = torch.tensor(scale_ptrs_cpu, dtype=torch.int64, device=device) + strides = torch.tensor(strides_cpu, dtype=torch.int32, device=device) + local_chs = torch.tensor(local_chs_cpu, dtype=torch.int32, device=device) + packed = (hist_ptrs, scale_ptrs, strides, local_chs) + _amax_pack_cache[cache_key] = packed + return packed + + +def _fused_amax_and_scale_update_triton( + contiguous_amax, amax_histories, scales, amax_compute_algo, scaled_max, +): + import triton + + H = amax_histories[0].shape[0] + # Guard the assumption that H is shared (C++ kernel_bulk also assumes this). + # If a caller ever violates it, fall back to the Python loop. + if any(h.shape[0] != H for h in amax_histories): + return False + + hist_ptrs, scale_ptrs, strides, local_chs = _pack_amax_update_args( + amax_histories, scales, + ) + n_total = hist_ptrs.numel() + algo = 1 if amax_compute_algo == "most_recent" else 0 + block_h = max(triton.next_power_of_2(H), 16) + + _triton_amax_update_kernel[(n_total,)]( + contiguous_amax, hist_ptrs, scale_ptrs, strides, local_chs, + H, float(scaled_max), + ALGO=algo, BLOCK_H=block_h, num_warps=4, + ) + return True + + +def _fused_amax_and_scale_update_python( + contiguous_amax, amax_histories, scales, amax_compute_algo, scaled_max, +): + """Per-group Python loop (fallback). Kept for A/B against the Triton path.""" + chunk_sizes = [h.shape[-1] for h in amax_histories] + splits = contiguous_amax.split(chunk_sizes) + fp32_max = torch.finfo(torch.float32).max + for amax_history, scale, amax_chunk in zip(amax_histories, scales, splits): + amax_history[0].copy_(amax_chunk) + + if amax_compute_algo == "most_recent": + amax = amax_history[0].clone() + else: + amax = amax_history.max(dim=0).values + + if amax_history.shape[0] > 1: + amax_history.copy_(torch.roll(amax_history, -1, 0)) + amax_history[0].fill_(0.0) + + sf = scaled_max / amax + sf = torch.where(amax > 0.0, sf, scale) + sf = torch.where(torch.isfinite(amax), sf, scale) + sf = torch.where(torch.isinf(sf), torch.full_like(sf, fp32_max), sf) + scale.copy_(sf) + + +def fused_amax_and_scale_update_after_reduction( + contiguous_amax, amax_histories, scales, + amax_compute_algo, fp8_dtype, margin, +): + """Update amax history and FP8 scale after amax reduction (delayed scaling). + + Called by FP8GlobalStateManager.reduce_and_update_fp8_tensors during every + training step. Dispatches to a single Triton multi-tensor-apply kernel that + mirrors common/recipe/delayed_scaling.cu's kernel_bulk; falls back to the + per-group Python loop if Triton is unavailable or NVTE_LITE_AMAX_FUSED=0. + + Args: + contiguous_amax: flat [N_total] fp32 tensor of reduced amaxes + amax_histories: list of [H, N_i] fp32 tensors (H shared across list) + scales: list of [N_i] fp32 scale buffers + amax_compute_algo: "max" or "most_recent" + fp8_dtype: TE DType (kFloat8E4M3 or kFloat8E5M2) + margin: int, scaled_max = fp8_max / 2**margin + """ + if len(amax_histories) == 0: + return + + scaled_max = _fp8_max_for_dtype(fp8_dtype) / (2 ** margin) + + use_triton = os.environ.get("NVTE_LITE_AMAX_FUSED", "1") != "0" + if use_triton and _try_load_triton_amax_update(): + if _fused_amax_and_scale_update_triton( + contiguous_amax, amax_histories, scales, amax_compute_algo, scaled_max, + ): + return + _fused_amax_and_scale_update_python( + contiguous_amax, amax_histories, scales, amax_compute_algo, scaled_max, + ) + + +# --------------------------------------------------------------------------- +# Triton kernels for FP8 block scaling +# --------------------------------------------------------------------------- + +_triton_block_scaling_loaded = False +_triton_block_amax_kernel = None +_triton_block_cast_kernel = None + + +def _try_load_triton_block_scaling(): + """Define Triton kernels for block scaling on first call.""" + global _triton_block_scaling_loaded, _triton_block_amax_kernel, _triton_block_cast_kernel + + if _triton_block_scaling_loaded: + return + _triton_block_scaling_loaded = True + + try: + import triton + import triton.language as tl + + @triton.autotune( + configs=[ + triton.Config({"TILE_ROWS": 4}, num_warps=4), + triton.Config({"TILE_ROWS": 8}, num_warps=4), + triton.Config({"TILE_ROWS": 16}, num_warps=8), + triton.Config({"TILE_ROWS": 32}, num_warps=8), + ], + key=["BLOCK_LEN"], + ) + @triton.jit + def _block_amax_kernel( + input_ptr, amax_ptr, + h, w, + input_row_stride, + num_blocks_w, + BLOCK_LEN: tl.constexpr, + TILE_ROWS: tl.constexpr, + ): + """2D-tiled per-block amax reduction. + + Each program handles one (BLOCK_LEN x BLOCK_LEN) block. + Loads TILE_ROWS rows x BLOCK_LEN cols per iteration, + processing all rows in ceil(BLOCK_LEN / TILE_ROWS) steps. + """ + block_idx = tl.program_id(0) + block_i = block_idx // num_blocks_w + block_j = block_idx % num_blocks_w + + row_start = block_i * BLOCK_LEN + col_start = block_j * BLOCK_LEN + + # 2D offsets for one tile: (TILE_ROWS, BLOCK_LEN) + row_offsets = tl.arange(0, TILE_ROWS) # [TILE_ROWS] + col_offsets = tl.arange(0, BLOCK_LEN) # [BLOCK_LEN] + + max_val = 0.0 + for tile_start in tl.static_range(0, BLOCK_LEN, TILE_ROWS): + rows = row_start + tile_start + row_offsets # [TILE_ROWS] + cols = col_start + col_offsets # [BLOCK_LEN] + + # 2D mask: valid rows AND valid cols + row_mask = rows < h # [TILE_ROWS] + col_mask = cols < w # [BLOCK_LEN] + mask = row_mask[:, None] & col_mask[None, :] # [TILE_ROWS, BLOCK_LEN] + + # 2D load + ptrs = input_ptr + rows[:, None] * input_row_stride + cols[None, :] + vals = tl.load(ptrs, mask=mask, other=0.0) # [TILE_ROWS, BLOCK_LEN] + + max_val = tl.maximum(max_val, tl.max(tl.abs(vals))) + + tl.store(amax_ptr + block_idx, max_val) + + @triton.autotune( + configs=[ + triton.Config({"TILE_ROWS": 4}, num_warps=4), + triton.Config({"TILE_ROWS": 8}, num_warps=4), + triton.Config({"TILE_ROWS": 16}, num_warps=8), + triton.Config({"TILE_ROWS": 32}, num_warps=8), + ], + key=["BLOCK_LEN"], + ) + @triton.jit + def _block_cast_kernel( + input_ptr, output_ptr, scale_ptr, + h, w, + input_row_stride, output_row_stride, + num_blocks_w, + BLOCK_LEN: tl.constexpr, + TILE_ROWS: tl.constexpr, + ): + """2D-tiled per-block scale and copy. + + Each program handles one (BLOCK_LEN x BLOCK_LEN) block. + Loads TILE_ROWS rows x BLOCK_LEN cols per iteration. + """ + block_idx = tl.program_id(0) + block_i = block_idx // num_blocks_w + block_j = block_idx % num_blocks_w + + row_start = block_i * BLOCK_LEN + col_start = block_j * BLOCK_LEN + + s = tl.load(scale_ptr + block_idx) + + row_offsets = tl.arange(0, TILE_ROWS) + col_offsets = tl.arange(0, BLOCK_LEN) + + for tile_start in tl.static_range(0, BLOCK_LEN, TILE_ROWS): + rows = row_start + tile_start + row_offsets + cols = col_start + col_offsets + + row_mask = rows < h + col_mask = cols < w + mask = row_mask[:, None] & col_mask[None, :] + + in_ptrs = input_ptr + rows[:, None] * input_row_stride + cols[None, :] + vals = tl.load(in_ptrs, mask=mask, other=0.0) + + out_ptrs = output_ptr + rows[:, None] * output_row_stride + cols[None, :] + tl.store(out_ptrs, vals * s, mask=mask) + + _triton_block_amax_kernel = _block_amax_kernel + _triton_block_cast_kernel = _block_cast_kernel + + except (ImportError, ModuleNotFoundError): + pass + + +# --------------------------------------------------------------------------- +# PyTorch fallbacks for block scaling (used when Triton unavailable) +# --------------------------------------------------------------------------- + +def _fp8_block_scaling_compute_partial_amax_pytorch(partial, amax, h, w, block_len): + """Vectorized PyTorch fallback for block amax.""" + num_blocks_h = (h + block_len - 1) // block_len + num_blocks_w = (w + block_len - 1) // block_len + + pad_h = num_blocks_h * block_len - h + pad_w = num_blocks_w * block_len - w + if pad_h > 0 or pad_w > 0: + partial = torch.nn.functional.pad(partial, (0, pad_w, 0, pad_h), value=0.0) + + blocked = partial.reshape(num_blocks_h, block_len, num_blocks_w, block_len) + block_amaxes = blocked.abs().amax(dim=(1, 3)) + amax.copy_(block_amaxes.reshape(-1)) + + +def _fp8_block_scaling_partial_cast_pytorch(partial, out, scale, h, w, block_len): + """Vectorized PyTorch fallback for block cast.""" + num_blocks_h = (h + block_len - 1) // block_len + num_blocks_w = (w + block_len - 1) // block_len + + pad_h = num_blocks_h * block_len - h + pad_w = num_blocks_w * block_len - w + if pad_h > 0 or pad_w > 0: + partial = torch.nn.functional.pad(partial, (0, pad_w, 0, pad_h), value=0.0) + + blocked = partial.reshape(num_blocks_h, block_len, num_blocks_w, block_len) + scale_2d = scale.reshape(num_blocks_h, num_blocks_w)[:, None, :, None] + scaled = blocked * scale_2d + result = scaled.reshape(num_blocks_h * block_len, num_blocks_w * block_len) + out.copy_(result[:h, :w]) + + +# --------------------------------------------------------------------------- +# Public API for block scaling +# --------------------------------------------------------------------------- + +def fp8_block_scaling_compute_partial_amax(tensor, amax, h, w, start_offset, block_len): + """Compute per-block amax. Uses Triton kernel when available.""" + partial = tensor.view(-1)[start_offset:start_offset + h * w].view(h, w) + num_blocks_h = (h + block_len - 1) // block_len + num_blocks_w = (w + block_len - 1) // block_len + + _try_load_triton_block_scaling() + if _triton_block_amax_kernel is not None: + grid = (num_blocks_h * num_blocks_w,) + _triton_block_amax_kernel[grid]( + partial, amax, + h, w, + partial.stride(0), + num_blocks_w, + BLOCK_LEN=block_len, + ) + return + + _fp8_block_scaling_compute_partial_amax_pytorch(partial, amax, h, w, block_len) + + +def fp8_block_scaling_partial_cast(inp, out, scale, h, w, start_offset, block_len, out_dtype): + """Partial cast with per-block scaling. Uses Triton kernel when available.""" + partial = inp.view(-1)[start_offset:start_offset + h * w].view(h, w) + num_blocks_h = (h + block_len - 1) // block_len + num_blocks_w = (w + block_len - 1) // block_len + + _try_load_triton_block_scaling() + if _triton_block_cast_kernel is not None: + grid = (num_blocks_h * num_blocks_w,) + _triton_block_cast_kernel[grid]( + partial, out, scale, + h, w, + partial.stride(0), out.stride(0), + num_blocks_w, + BLOCK_LEN=block_len, + ) + return + + _fp8_block_scaling_partial_cast_pytorch(partial, out, scale, h, w, block_len) diff --git a/transformer_engine/pytorch/_lite/rope.py b/transformer_engine/pytorch/_lite/rope.py new file mode 100644 index 000000000..7a0ade74f --- /dev/null +++ b/transformer_engine/pytorch/_lite/rope.py @@ -0,0 +1,407 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Rotary Position Embedding (RoPE) -- AITER CK-JIT or PyTorch-native fallback. + +When AITER is available, uses its fused CK-JIT RoPE kernel (single kernel launch). +Otherwise, falls back to PyTorch-native implementation (~8 kernel launches). + +Supports context parallelism (cp_size, cp_rank) for DualChunkSwap +sequence partitioning used by TE's CP implementation. +""" + +import torch +from typing import Optional, Union + +from .aiter_utils import get_aiter_rope + + +# --------------------------------------------------------------------------- +# Context parallelism helpers +# --------------------------------------------------------------------------- + +def _get_freqs_on_this_cp_rank( + freqs: torch.Tensor, seqlen: int, cp_size: int, cp_rank: int +) -> torch.Tensor: + """Slice positional embedding frequencies for this CP rank. + + Implements the DualChunkSwap position mapping: each rank gets two + non-contiguous segments of the full frequency table. + + Args: + freqs: Full frequency tensor, shape ``[s_full, ...]``. + seqlen: Local sequence length on this rank (= s_full / cp_size). + cp_size: Context parallel world size. + cp_rank: Context parallel rank. + + Returns: + Frequency tensor of shape ``[seqlen, ...]`` with the two + DualChunkSwap chunks concatenated. + """ + if cp_size > 1: + cp_seg = seqlen // 2 + full_seqlen = cp_size * seqlen + return torch.cat( + [ + freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg], + freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg], + ] + ) + return freqs[:seqlen] + + +# --------------------------------------------------------------------------- +# QKV format enum values (mirrors NVTE_QKV_Format) +# --------------------------------------------------------------------------- + +_BSHD = 0 +_SBHD = 1 +_THD = 2 + +# AITER rotate_style enum: 0 = NEOX (TE non-interleaved), 1 = GPTJ (TE interleaved) +_NEOX = 0 +_GPTJ = 1 + + +def _seqlen_from_tensor(t, qkv_format_int): + """Return the sequence length from a tensor given its QKV format.""" + if qkv_format_int == _BSHD: + return t.shape[1] + return t.shape[0] + + +def _te_interleaved_to_aiter_style(interleaved): + """Map TE interleaved flag to AITER rotate_style int.""" + return _GPTJ if interleaved else _NEOX + + +# --------------------------------------------------------------------------- +# AITER adapter helpers +# --------------------------------------------------------------------------- + +def _aiter_fwd(aiter_rope, t, freqs, interleaved, qkv_format): + """Call AITER rope_fwd with TE parameter conventions. + + AITER expects SBHD [s, b, h, d]. For BSHD input, transpose around the call. + TE freqs are [s, 1, 1, rot_dim] (already doubled); AITER with + reuse_freqs_front_part=False expects the same shape. + """ + style = _te_interleaved_to_aiter_style(interleaved) + if qkv_format == _BSHD: + t_sbhd = t.transpose(0, 1).contiguous() + out = aiter_rope.rope_fwd(t_sbhd, freqs, style, False, False) + return out.transpose(0, 1).contiguous() + return aiter_rope.rope_fwd(t, freqs, style, False, False) + + +def _aiter_bwd(aiter_rope, grad, freqs, interleaved, qkv_format): + """Call AITER rope_bwd with TE parameter conventions.""" + style = _te_interleaved_to_aiter_style(interleaved) + if qkv_format == _BSHD: + g_sbhd = grad.transpose(0, 1).contiguous() + out = aiter_rope.rope_bwd(g_sbhd, freqs, style, False, False) + return out.transpose(0, 1).contiguous() + return aiter_rope.rope_bwd(grad, freqs, style, False, False) + + +def _aiter_thd_fwd(aiter_rope, t, cu_seqlens, freqs, interleaved): + """Call AITER rope_thd_fwd with TE parameter conventions.""" + style = _te_interleaved_to_aiter_style(interleaved) + return aiter_rope.rope_thd_fwd(t, cu_seqlens, freqs, style, False, False) + + +def _aiter_thd_bwd(aiter_rope, grad, cu_seqlens, freqs, interleaved): + """Call AITER rope_thd_bwd with TE parameter conventions.""" + style = _te_interleaved_to_aiter_style(interleaved) + return aiter_rope.rope_thd_bwd(grad, cu_seqlens, freqs, style, False, False) + + +# --------------------------------------------------------------------------- +# Core PyTorch RoPE (fallback when AITER is unavailable) +# --------------------------------------------------------------------------- + +def _rotate_half(x): + """Rotate the last dimension: [-x2, x1] from [x1, x2].""" + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_half_interleaved(x): + """Rotate with interleaved layout: pairs are (even, odd) indices.""" + x1 = x[..., ::2] + x2 = x[..., 1::2] + x_new = torch.stack((-x2, x1), dim=-1) + return x_new.view(*x.shape) + + +def _apply_rope_pytorch(t, freqs, interleaved=False): + """Apply RoPE using PyTorch operations. + + Freqs are raw angle values (not pre-computed cos/sin). + The rotation is: t * cos(freqs) + rotate_half(t) * sin(freqs). + Computation is done in float32 for precision (matching the C++ fused kernel). + + Args: + t: Input tensor, last dim is head_dim. + freqs: Angle tensor, shape broadcastable to t, last dim is rot_dim. + ``rot_dim <= head_dim``; unrotated dims are passed through. + interleaved: If True, use interleaved rotation pattern. + """ + orig_dtype = t.dtype + cos_ = torch.cos(freqs) + sin_ = torch.sin(freqs) + + rot_dim = freqs.shape[-1] + t_rot, t_pass = t[..., :rot_dim].float(), t[..., rot_dim:] + + rotate_fn = _rotate_half_interleaved if interleaved else _rotate_half + t_rot = t_rot * cos_ + rotate_fn(t_rot) * sin_ + return torch.cat((t_rot.to(orig_dtype), t_pass), dim=-1) + + +def _inverse_rope_pytorch(grad_output, freqs, interleaved=False): + """Inverse RoPE rotation for backward pass. + + The inverse of ``t * cos + rotate_half(t) * sin`` is + ``g * cos + rotate_half(g) * (-sin)``, i.e. negate sin. + Computation is done in float32 for precision (matching the C++ fused kernel). + """ + orig_dtype = grad_output.dtype + cos_ = torch.cos(freqs) + sin_ = torch.sin(freqs) + + rot_dim = freqs.shape[-1] + g_rot, g_pass = grad_output[..., :rot_dim].float(), grad_output[..., rot_dim:] + + rotate_fn = _rotate_half_interleaved if interleaved else _rotate_half + g_rot = g_rot * cos_ + rotate_fn(g_rot) * (-sin_) + return torch.cat((g_rot.to(orig_dtype), g_pass), dim=-1) + + +# --------------------------------------------------------------------------- +# Public API: fused_rope_forward / backward +# --------------------------------------------------------------------------- + +def fused_rope_forward( + t: torch.Tensor, + freqs: torch.Tensor, + start_positions: Optional[torch.Tensor] = None, + qkv_format: int = _SBHD, + interleaved: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + cp_size: int = 1, + cp_rank: int = 0, +) -> torch.Tensor: + """Fused RoPE forward -- lite replacement for ``tex.fused_rope_forward``. + + Signature matches the C++ binding so that ``FusedRoPEFunc`` in + ``transformer_engine.pytorch.attention.rope`` can call through + ``tex.fused_rope_forward`` transparently. + """ + _aiter_rope = get_aiter_rope() + + # Determine local sequence length from tensor + format + seqlen = _seqlen_from_tensor(t, qkv_format) + + # Handle start_positions: stack per-batch offset freqs before CP slicing + # start_positions offsets each batch element's freqs independently. + if start_positions is not None: + freqs = torch.cat( + [freqs[int(p) : int(p) + seqlen * cp_size] for p in start_positions], dim=1 + ) + # freqs now has shape [seqlen*cp_size, batch, 1, d] + + # Slice frequencies for this CP rank + if cp_size > 1: + freqs = _get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank) + else: + freqs = freqs[:seqlen] + + # THD format with cu_seqlens + if qkv_format == _THD and cu_seqlens is not None: + if cp_size > 1: + cu_seqlens = cu_seqlens // cp_size + if _aiter_rope is not None and start_positions is None: + return _aiter_thd_fwd(_aiter_rope, t, cu_seqlens, freqs, interleaved) + # PyTorch fallback: split by sequence, apply per-sequence + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + results = [] + for idx, x in enumerate(torch.split(t, seqlens)): + seq_freqs = freqs[:x.size(0)] + results.append(_apply_rope_pytorch(x.unsqueeze(1), seq_freqs, interleaved).squeeze(1)) + return torch.cat(results) + + # BSHD/SBHD path -- use AITER fused kernel when available + if _aiter_rope is not None and start_positions is None: + return _aiter_fwd(_aiter_rope, t, freqs, interleaved, qkv_format) + + # PyTorch fallback + if qkv_format == _BSHD: + freqs = freqs.transpose(0, 1) if freqs.dim() == 4 else freqs + return _apply_rope_pytorch(t, freqs, interleaved) + + +def fused_rope_backward( + grad_output: torch.Tensor, + freqs: torch.Tensor, + start_positions: Optional[torch.Tensor] = None, + qkv_format: int = _SBHD, + interleaved: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + cp_size: int = 1, + cp_rank: int = 0, +) -> torch.Tensor: + """Fused RoPE backward -- lite replacement for ``tex.fused_rope_backward``. + + RoPE backward is the inverse rotation (negate sin component). + """ + _aiter_rope = get_aiter_rope() + + seqlen = _seqlen_from_tensor(grad_output, qkv_format) + + # Handle start_positions: stack per-batch offset freqs + if start_positions is not None: + freqs = torch.cat( + [freqs[int(p) : int(p) + seqlen * cp_size] for p in start_positions], dim=1 + ) + + if cp_size > 1: + freqs = _get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank) + else: + freqs = freqs[:seqlen] + + # THD + if qkv_format == _THD and cu_seqlens is not None: + if cp_size > 1: + cu_seqlens = cu_seqlens // cp_size + if _aiter_rope is not None and start_positions is None: + return _aiter_thd_bwd(_aiter_rope, grad_output, cu_seqlens, freqs, interleaved) + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + results = [] + for idx, g in enumerate(torch.split(grad_output, seqlens)): + seq_freqs = freqs[:g.size(0)] + results.append(_inverse_rope_pytorch(g.unsqueeze(1), seq_freqs, interleaved).squeeze(1)) + return torch.cat(results) + + # BSHD/SBHD -- use AITER fused kernel when available + if _aiter_rope is not None and start_positions is None: + return _aiter_bwd(_aiter_rope, grad_output, freqs, interleaved, qkv_format) + + # PyTorch fallback + if qkv_format == _BSHD: + freqs = freqs.transpose(0, 1) if freqs.dim() == 4 else freqs + return _inverse_rope_pytorch(grad_output, freqs, interleaved) + + +# --------------------------------------------------------------------------- +# Public API: fused_qkv_rope_forward / backward +# --------------------------------------------------------------------------- + +def fused_qkv_rope_forward( + qkv: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + start_positions: Optional[torch.Tensor] = None, + qkv_split_arg_list=None, + qkv_format: int = _SBHD, + interleaved: bool = False, + cp_size: int = 1, + cp_rank: int = 0, +) -> torch.Tensor: + """Fused QKV RoPE forward -- lite replacement for ``tex.fused_qkv_rope_forward``. + + Apply RoPE to Q and K within a packed QKV tensor. + Returns tuple (Q_rotated, K_rotated, V_unchanged). + """ + _aiter_rope = get_aiter_rope() + + seqlen = _seqlen_from_tensor(qkv, qkv_format) + + # Slice frequencies for CP + if cp_size > 1: + q_freqs = _get_freqs_on_this_cp_rank(q_freqs, seqlen, cp_size, cp_rank) + k_freqs = _get_freqs_on_this_cp_rank(k_freqs, seqlen, cp_size, cp_rank) + else: + q_freqs = q_freqs[:seqlen] + k_freqs = k_freqs[:seqlen] + + # Split QKV along the last (head_dim) dimension + if qkv_split_arg_list is not None: + q, k, v = torch.split(qkv, qkv_split_arg_list, dim=-1) + else: + q, k, v = qkv.chunk(3, dim=-1) + + # The C++ kernel reshapes Q/K so each split is expressed as (num_heads, head_dim) + # where head_dim is derived from the K split (which is always 1 head_dim per head). + # e.g. Q [s, b, 64, 512] with K head_dim=128 -> [s, b, 256, 128] + # This allows partial rotation (rot_dim < head_dim) to work correctly. + head_dim = k.shape[-1] + + if q.shape[-1] != head_dim: + new_q_heads = q.shape[-2] * q.shape[-1] // head_dim + q = q.reshape(*q.shape[:-2], new_q_heads, head_dim) + + # Use AITER fused kernel when available + if _aiter_rope is not None and start_positions is None: + q_rot = _aiter_fwd(_aiter_rope, q, q_freqs, interleaved, qkv_format) + k_rot = _aiter_fwd(_aiter_rope, k, k_freqs, interleaved, qkv_format) + return q_rot, k_rot, v + + # PyTorch fallback + if qkv_format == _BSHD: + q_freqs = q_freqs.transpose(0, 1) if q_freqs.dim() == 4 else q_freqs + k_freqs = k_freqs.transpose(0, 1) if k_freqs.dim() == 4 else k_freqs + + q_rot = _apply_rope_pytorch(q, q_freqs, interleaved) + k_rot = _apply_rope_pytorch(k, k_freqs, interleaved) + + return q_rot, k_rot, v + + +def fused_qkv_rope_backward( + grad_output_q: torch.Tensor, + grad_output_k: torch.Tensor, + grad_output_v: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list=None, + qkv_format: int = _SBHD, + interleaved: bool = False, + cp_size: int = 1, + cp_rank: int = 0, +) -> torch.Tensor: + """Fused QKV RoPE backward -- lite replacement for ``tex.fused_qkv_rope_backward``.""" + _aiter_rope = get_aiter_rope() + + seqlen = _seqlen_from_tensor(grad_output_q, qkv_format) + + if cp_size > 1: + q_freqs = _get_freqs_on_this_cp_rank(q_freqs, seqlen, cp_size, cp_rank) + k_freqs = _get_freqs_on_this_cp_rank(k_freqs, seqlen, cp_size, cp_rank) + else: + q_freqs = q_freqs[:seqlen] + k_freqs = k_freqs[:seqlen] + + # Use AITER fused kernel when available + if _aiter_rope is not None: + gq_rot = _aiter_bwd(_aiter_rope, grad_output_q, q_freqs, interleaved, qkv_format) + gk_rot = _aiter_bwd(_aiter_rope, grad_output_k, k_freqs, interleaved, qkv_format) + else: + if qkv_format == _BSHD: + q_freqs = q_freqs.transpose(0, 1) if q_freqs.dim() == 4 else q_freqs + k_freqs = k_freqs.transpose(0, 1) if k_freqs.dim() == 4 else k_freqs + gq_rot = _inverse_rope_pytorch(grad_output_q, q_freqs, interleaved) + gk_rot = _inverse_rope_pytorch(grad_output_k, k_freqs, interleaved) + + # Reshape Q/K grads back to original split dims before concatenation. + # The forward reshaped e.g. [s, b, 64, 512] -> [s, b, 256, 128]; + # backward receives [s, b, 256, 128] and must produce [s, b, 64, 512]. + if qkv_split_arg_list is not None: + q_split_dim = qkv_split_arg_list[0] + v_head_dim = grad_output_v.shape[-2] # original num_heads + if gq_rot.shape[-1] != q_split_dim: + gq_rot = gq_rot.reshape(*gq_rot.shape[:-2], v_head_dim, q_split_dim) + + return torch.cat([gq_rot, gk_rot, grad_output_v], dim=-1) diff --git a/transformer_engine/pytorch/_lite/router.py b/transformer_engine/pytorch/_lite/router.py new file mode 100644 index 000000000..84859a4ca --- /dev/null +++ b/transformer_engine/pytorch/_lite/router.py @@ -0,0 +1,279 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""MOE router operations -- Triton-fused with PyTorch-native fallback.""" + +import torch +import torch.nn.functional as F + + +_EPSILON = 1e-9 + +# --------------------------------------------------------------------------- +# Lazy Triton import +# --------------------------------------------------------------------------- +_triton_router = None +_triton_attempted = False + + +def _try_load_triton_router(): + global _triton_router, _triton_attempted + if _triton_attempted: + return _triton_router + _triton_attempted = True + try: + from transformer_engine.pytorch.triton import fused_router + _triton_router = fused_router + except (ImportError, RuntimeError): + pass + return _triton_router + + +# --------------------------------------------------------------------------- +# Forward +# --------------------------------------------------------------------------- + +def fused_topk_with_score_function_fwd(logits, topk, use_pre_softmax, num_groups, + group_topk, scaling_factor, score_function, + expert_bias): + """Fused topk with score function forward. + + Uses a single Triton kernel when available (no group_topk). + Falls back to PyTorch-native for group_topk or when Triton is unavailable. + + Returns + ------- + (probs, routing_map, intermediate_output) + """ + triton_mod = _try_load_triton_router() + if triton_mod is not None and logits.is_cuda: + return triton_mod.fused_topk_with_score_function_fwd( + logits, topk, use_pre_softmax, scaling_factor, + score_function, expert_bias, num_groups, group_topk, + ) + + return _fused_topk_fwd_pytorch( + logits, topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, score_function, expert_bias, + ) + + +def fused_topk_with_score_function_bwd(num_tokens, num_experts, routing_map, + intermediate_output, grad_probs, topk, + use_pre_softmax, scaling_factor, score_function): + """Fused topk with score function backward.""" + triton_mod = _try_load_triton_router() + if triton_mod is not None and grad_probs.is_cuda: + return triton_mod.fused_topk_with_score_function_bwd( + num_tokens, num_experts, routing_map, intermediate_output, + grad_probs, topk, use_pre_softmax, scaling_factor, score_function, + ) + + return _fused_topk_bwd_pytorch( + num_tokens, num_experts, routing_map, intermediate_output, + grad_probs, topk, use_pre_softmax, scaling_factor, score_function, + ) + + +# --------------------------------------------------------------------------- +# Aux-loss score functions +# --------------------------------------------------------------------------- + +def fused_score_for_moe_aux_loss_fwd(logits, topk, score_function): + """Compute scores for MOE auxiliary loss.""" + triton_mod = _try_load_triton_router() + if triton_mod is not None and logits.is_cuda: + return triton_mod.fused_score_for_moe_aux_loss_fwd( + logits, topk, score_function, + ) + + return _score_aux_loss_fwd_pytorch(logits, topk, score_function) + + +def fused_score_for_moe_aux_loss_bwd(num_tokens, num_experts, intermediate_output, + grad_scores, topk, score_function): + """Backward of scores for MOE auxiliary loss.""" + triton_mod = _try_load_triton_router() + if triton_mod is not None and grad_scores.is_cuda: + return triton_mod.fused_score_for_moe_aux_loss_bwd( + num_tokens, num_experts, intermediate_output, + grad_scores, topk, score_function, + ) + + return _score_aux_loss_bwd_pytorch( + intermediate_output, grad_scores, score_function, + ) + + +# --------------------------------------------------------------------------- +# Aux-loss (unchanged -- already minimal, no fusion opportunity) +# --------------------------------------------------------------------------- + +def fused_moe_aux_loss_fwd(probs, tokens_per_expert, total_num_tokens, num_experts, + num_rows, num_cols, topk, coeff): + """MOE auxiliary (load balancing) loss forward. + + Returns + ------- + (aux_loss, Const_buf) + Matches the C++ interface. Const_buf is a scalar tensor holding the + pre-computed gradient coefficient used by the backward pass. + """ + f = tokens_per_expert.float() / total_num_tokens + p = probs.mean(dim=0) + loss = coeff * num_experts * (f * p).sum() + + # Const_buf = (num_experts * coeff) / topk / total_num_tokens^2 + c_coeff = (num_experts * coeff) / topk / (total_num_tokens * total_num_tokens) + const_buf = torch.tensor(c_coeff, dtype=torch.float32, device=probs.device) + + return loss, const_buf + + +def fused_moe_aux_loss_bwd(Const_buf=None, tokens_per_expert=None, + num_rows=None, num_cols=None, grad_aux_loss=None, + **kwargs): + """MOE auxiliary loss backward. + + grad_probs[j, i] = Const_buf * tokens_per_expert[i] * grad_aux_loss + """ + # Const_buf is a scalar, tokens_per_expert is [num_cols], grad_aux_loss is scalar + # Output: [num_rows, num_cols] + grad_row = Const_buf * tokens_per_expert.float() * grad_aux_loss # [num_cols] + grad_probs = grad_row.unsqueeze(0).expand(num_rows, num_cols).contiguous() + return grad_probs + + +# =========================================================================== # +# PyTorch-native fallbacks +# =========================================================================== # + +def _fused_topk_fwd_pytorch(logits, topk, use_pre_softmax, num_groups, + group_topk, scaling_factor, score_function, + expert_bias): + """PyTorch-native forward (supports group_topk).""" + num_tokens, num_experts = logits.shape + + if score_function == "sigmoid": + use_pre_softmax = False + scores = torch.sigmoid(logits) + intermediate_output = scores.clone() + if expert_bias is not None: + scores = scores + expert_bias + elif score_function == "softmax": + if use_pre_softmax: + scores = F.softmax(logits, dim=-1) + intermediate_output = scores.clone() + else: + scores = logits.clone() + intermediate_output = torch.zeros_like(logits) + else: + raise ValueError(f"score_function must be 'softmax' or 'sigmoid', got '{score_function}'") + + use_group_topk = ( + group_topk is not None and group_topk > 0 + and num_groups is not None and num_groups > 0 + ) + if use_group_topk: + group_size = num_experts // num_groups + group_scores = torch.zeros(num_tokens, num_groups, device=logits.device, dtype=scores.dtype) + for g in range(num_groups): + g_start = g * group_size + g_end = g_start + group_size + g_vals, _ = torch.topk(scores[:, g_start:g_end], k=topk // group_topk, dim=-1) + group_scores[:, g] = g_vals.sum(dim=-1) + _, top_group_indices = torch.topk(group_scores, k=group_topk, dim=-1) + mask = torch.zeros_like(scores, dtype=torch.bool) + for g_idx in range(group_topk): + g = top_group_indices[:, g_idx] + for offset in range(group_size): + mask[torch.arange(num_tokens, device=logits.device), g * group_size + offset] = True + masked_scores = torch.where(mask, scores, torch.tensor(float('-inf'), device=logits.device)) + topk_values, topk_indices = torch.topk(masked_scores, k=topk, dim=-1) + else: + topk_values, topk_indices = torch.topk(scores, k=topk, dim=-1) + + if score_function == "sigmoid" and expert_bias is not None: + topk_values = topk_values - expert_bias[topk_indices] + + if score_function == "softmax" and not use_pre_softmax: + topk_values = F.softmax(topk_values, dim=-1) + intermediate_output.scatter_(1, topk_indices, topk_values) + + if score_function == "sigmoid" and topk > 1: + score_sum = topk_values.sum(dim=-1, keepdim=True) + _EPSILON + topk_values = topk_values / score_sum + + if scaling_factor is not None and scaling_factor > 0: + topk_values = topk_values * scaling_factor + + probs = torch.zeros(num_tokens, num_experts, device=logits.device, dtype=logits.dtype) + probs.scatter_(1, topk_indices, topk_values.to(logits.dtype)) + + routing_map = torch.zeros(num_tokens, num_experts, device=logits.device, dtype=torch.bool) + routing_map.scatter_(1, topk_indices, True) + + return probs, routing_map, intermediate_output + + +def _fused_topk_bwd_pytorch(num_tokens, num_experts, routing_map, + intermediate_output, grad_probs, topk, + use_pre_softmax, scaling_factor, score_function): + """PyTorch-native backward.""" + scaling_factor_val = scaling_factor if scaling_factor is not None and scaling_factor > 0 else 1.0 + grad = grad_probs * routing_map.float() * scaling_factor_val + + if score_function == "sigmoid": + if topk > 1: + fwd_out = intermediate_output * routing_map.float() + sum_fwd = fwd_out.sum(dim=-1, keepdim=True) + _EPSILON + out_x_grad = (fwd_out * grad).sum(dim=-1, keepdim=True) + grad = torch.where( + routing_map, + grad / sum_fwd - out_x_grad / (sum_fwd * sum_fwd), + torch.zeros_like(grad), + ) + grad = grad * routing_map.float() + grad = grad * intermediate_output * (1.0 - intermediate_output) + + elif score_function == "softmax": + if not use_pre_softmax: + out_x_grad = (intermediate_output * grad * routing_map.float()).sum(dim=-1, keepdim=True) + grad = torch.where( + routing_map, + intermediate_output * (grad - out_x_grad), + torch.zeros_like(grad), + ) + else: + grad = grad * routing_map.float() + dot = (intermediate_output * grad).sum(dim=-1, keepdim=True) + grad = intermediate_output * (grad - dot) + + return grad + + +def _score_aux_loss_fwd_pytorch(logits, topk, score_function): + """PyTorch-native aux-loss score forward.""" + if score_function == "sigmoid": + scores = torch.sigmoid(logits) + else: + scores = F.softmax(logits, dim=-1) + intermediate_output = scores.clone() + + _, topk_indices = torch.topk(scores, k=topk, dim=-1) + routing_map = torch.zeros_like(logits, dtype=torch.bool) + routing_map.scatter_(1, topk_indices, True) + + return scores, routing_map, intermediate_output + + +def _score_aux_loss_bwd_pytorch(intermediate_output, grad_scores, score_function): + """PyTorch-native aux-loss score backward.""" + if score_function == "sigmoid": + grad_logits = grad_scores * intermediate_output * (1.0 - intermediate_output) + else: + dot = (intermediate_output * grad_scores).sum(dim=-1, keepdim=True) + grad_logits = intermediate_output * (grad_scores - dot) + return grad_logits diff --git a/transformer_engine/pytorch/_lite/softmax.py b/transformer_engine/pytorch/_lite/softmax.py new file mode 100644 index 000000000..df65caf3d --- /dev/null +++ b/transformer_engine/pytorch/_lite/softmax.py @@ -0,0 +1,78 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused softmax variants -- PyTorch-native implementation. + +Note: These are rarely hit in lite mode since SDPA/AITER handle softmax internally. +""" + +import torch +import torch.nn.functional as F + + +def _softmax_backward(grad_output, output): + """Common softmax backward: output * (grad - (output * grad).sum(dim=-1, keepdim=True)).""" + dot = (output * grad_output).sum(dim=-1, keepdim=True) + return output * (grad_output - dot) + + +def scaled_softmax_forward(input, scale_factor): + """Scaled softmax forward: softmax(input * scale).""" + return torch.softmax(input * scale_factor, dim=-1) + + +def scaled_softmax_backward(grad_output, output, scale_factor): + """Scaled softmax backward.""" + grad_input = _softmax_backward(grad_output, output) + return grad_input * scale_factor + + +def scaled_masked_softmax_forward(input, mask, scale_factor): + """Scaled masked softmax forward.""" + scaled = input * scale_factor + if mask is not None: + scaled = scaled.masked_fill(mask, float('-inf')) + return torch.softmax(scaled, dim=-1) + + +def scaled_masked_softmax_backward(grad_output, output, scale_factor): + """Scaled masked softmax backward.""" + grad_input = _softmax_backward(grad_output, output) + return grad_input * scale_factor + + +def scaled_upper_triang_masked_softmax_forward(input, scale_factor): + """Scaled upper-triangular masked softmax forward (causal mask).""" + seq_len = input.size(-1) + mask = torch.triu(torch.ones(seq_len, seq_len, device=input.device, dtype=torch.bool), diagonal=1) + scaled = input * scale_factor + scaled = scaled.masked_fill(mask, float('-inf')) + return torch.softmax(scaled, dim=-1) + + +def scaled_upper_triang_masked_softmax_backward(grad_output, output, scale_factor): + """Scaled upper-triangular masked softmax backward.""" + grad_input = _softmax_backward(grad_output, output) + return grad_input * scale_factor + + +def scaled_aligned_causal_masked_softmax_forward(input, scale_factor): + """Scaled bottom-right corner aligned causal masked softmax forward.""" + q_len = input.size(-2) + k_len = input.size(-1) + # Bottom-right aligned causal mask: position i can attend to positions <= i + (k_len - q_len) + row_idx = torch.arange(q_len, device=input.device).unsqueeze(1) + col_idx = torch.arange(k_len, device=input.device).unsqueeze(0) + offset = k_len - q_len + mask = col_idx > (row_idx + offset) + scaled = input * scale_factor + scaled = scaled.masked_fill(mask, float('-inf')) + return torch.softmax(scaled, dim=-1) + + +def scaled_aligned_causal_masked_softmax_backward(grad_output, output, scale_factor): + """Scaled aligned causal masked softmax backward.""" + grad_input = _softmax_backward(grad_output, output) + return grad_input * scale_factor diff --git a/transformer_engine/pytorch/_lite/transpose.py b/transformer_engine/pytorch/_lite/transpose.py new file mode 100644 index 000000000..0ac0e15d0 --- /dev/null +++ b/transformer_engine/pytorch/_lite/transpose.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Transpose operations -- PyTorch-native implementation.""" + +import torch + + +def fp8_transpose(input, dtype, *, out=None): + """FP8 transpose: move the last dim to the front. + + For a 2D tensor [M, K], this is equivalent to .t() → [K, M]. + For an N-D tensor [d0, d1, ..., K], produces [K, d0, d1, ...] — matching + the transpose_shape convention in the quantizer's make_empty(). + dtype is ignored since we work with PyTorch tensors directly. + """ + if input.ndim == 2: + result = input.t().contiguous() + else: + # Permute last axis to front: [..., K] -> [K, ...] + perm = [input.ndim - 1] + list(range(input.ndim - 1)) + result = input.permute(*perm).contiguous() + if out is None: + return result + out.copy_(result.reshape(out.shape) if result.shape != out.shape else result) + return out + + +def swap_first_dims(tensor, *, out): + """Swap first two dimensions of a tensor.""" + result = tensor.transpose(0, 1).contiguous() + out.copy_(result) + return out diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index 3cf15efc1..3680a5322 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -3,6 +3,8 @@ # See LICENSE for license information. """Module level PyTorch APIs""" +import os as _os + from .layernorm_linear import LayerNormLinear from .linear import Linear from .grouped_linear import GroupedLinear @@ -12,3 +14,8 @@ from .fp8_padding import Fp8Padding from .fp8_unpadding import Fp8Unpadding from .base import initialize_ub, destroy_ub, UserBufferQuantizationMode + +# In lite mode, replace the full-build fused modules with lite-native versions +if _os.environ.get("NVTE_LITE", "0") == "1": + from .._lite.fused_layernorm_linear import LayerNormLinear # noqa: F811 + from .._lite.fused_layernorm_mlp import LayerNormMLP # noqa: F811 diff --git a/transformer_engine/pytorch/triton/fused_router.py b/transformer_engine/pytorch/triton/fused_router.py new file mode 100644 index 000000000..67089438e --- /dev/null +++ b/transformer_engine/pytorch/triton/fused_router.py @@ -0,0 +1,183 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""PyTorch wrapper functions for fused MoE router Triton kernels.""" + +import torch +import triton + +from transformer_engine.common.triton.fused_router import ( + _fused_topk_score_fwd_kernel, + _fused_topk_score_bwd_kernel, + _fused_score_aux_loss_fwd_kernel, + _fused_score_aux_loss_bwd_kernel, +) + +_SCORE_FN_MAP = {"sigmoid": 0, "softmax": 1} + + +def fused_topk_with_score_function_fwd( + logits: torch.Tensor, + topk: int, + use_pre_softmax: bool, + scaling_factor: float, + score_function: str, + expert_bias: torch.Tensor | None, + num_groups: int | None = None, + group_topk: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fused score-function + top-k forward via Triton. + + Returns (probs, routing_map, intermediate_output). + """ + num_tokens, num_experts = logits.shape + block_e = triton.next_power_of_2(num_experts) + score_fn = _SCORE_FN_MAP[score_function] + has_bias = expert_bias is not None + sf = scaling_factor if scaling_factor is not None and scaling_factor > 0 else 1.0 + + use_group_topk = ( + group_topk is not None and group_topk > 0 + and num_groups is not None and num_groups > 0 + ) + ng = num_groups if use_group_topk else 1 + gt = group_topk if use_group_topk else 1 + gs = num_experts // ng if use_group_topk else num_experts + block_g = triton.next_power_of_2(ng) + + probs = torch.empty_like(logits) + routing_map_i8 = torch.empty( + num_tokens, num_experts, dtype=torch.int8, device=logits.device, + ) + intermediate = torch.empty_like(logits) + + grid = (num_tokens,) + _fused_topk_score_fwd_kernel[grid]( + logits, + probs, + routing_map_i8, + intermediate, + expert_bias if has_bias else logits, # dummy ptr when unused + num_tokens, + sf, + NUM_EXPERTS=num_experts, + TOPK=topk, + SCORE_FN=score_fn, + USE_PRE_SOFTMAX=use_pre_softmax, + HAS_BIAS=has_bias, + USE_GROUP_TOPK=use_group_topk, + NUM_GROUPS=ng, + GROUP_TOPK=gt, + GROUP_SIZE=gs, + BLOCK_E=block_e, + BLOCK_G=block_g, + ) + + routing_map = routing_map_i8.to(torch.bool) + return probs, routing_map, intermediate + + +def fused_topk_with_score_function_bwd( + num_tokens: int, + num_experts: int, + routing_map: torch.Tensor, + intermediate_output: torch.Tensor, + grad_probs: torch.Tensor, + topk: int, + use_pre_softmax: bool, + scaling_factor: float, + score_function: str, +) -> torch.Tensor: + """Fused score-function + top-k backward via Triton.""" + block_e = triton.next_power_of_2(num_experts) + score_fn = _SCORE_FN_MAP[score_function] + sf = scaling_factor if scaling_factor is not None and scaling_factor > 0 else 1.0 + + routing_map_i8 = routing_map.to(torch.int8) + grad_logits = torch.empty( + num_tokens, num_experts, + dtype=intermediate_output.dtype, device=grad_probs.device, + ) + + grid = (num_tokens,) + _fused_topk_score_bwd_kernel[grid]( + routing_map_i8, + intermediate_output, + grad_probs, + grad_logits, + num_tokens, + sf, + NUM_EXPERTS=num_experts, + TOPK=topk, + SCORE_FN=score_fn, + USE_PRE_SOFTMAX=use_pre_softmax, + BLOCK_E=block_e, + ) + + return grad_logits + + +def fused_score_for_moe_aux_loss_fwd( + logits: torch.Tensor, + topk: int, + score_function: str, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fused score computation for aux loss via Triton.""" + num_tokens, num_experts = logits.shape + block_e = triton.next_power_of_2(num_experts) + score_fn = _SCORE_FN_MAP[score_function] + + scores = torch.empty_like(logits) + routing_map_i8 = torch.empty( + num_tokens, num_experts, dtype=torch.int8, device=logits.device, + ) + intermediate = torch.empty_like(logits) + + grid = (num_tokens,) + _fused_score_aux_loss_fwd_kernel[grid]( + logits, + scores, + routing_map_i8, + intermediate, + num_tokens, + NUM_EXPERTS=num_experts, + TOPK=topk, + SCORE_FN=score_fn, + BLOCK_E=block_e, + ) + + routing_map = routing_map_i8.to(torch.bool) + return scores, routing_map, intermediate + + +def fused_score_for_moe_aux_loss_bwd( + num_tokens: int, + num_experts: int, + intermediate_output: torch.Tensor, + grad_scores: torch.Tensor, + topk: int, + score_function: str, +) -> torch.Tensor: + """Fused score backward for aux loss via Triton.""" + block_e = triton.next_power_of_2(num_experts) + score_fn = _SCORE_FN_MAP[score_function] + + grad_logits = torch.empty( + num_tokens, num_experts, + dtype=intermediate_output.dtype, device=grad_scores.device, + ) + + grid = (num_tokens,) + _fused_score_aux_loss_bwd_kernel[grid]( + intermediate_output, + grad_scores, + grad_logits, + num_tokens, + NUM_EXPERTS=num_experts, + SCORE_FN=score_fn, + BLOCK_E=block_e, + ) + + return grad_logits diff --git a/transformer_engine/pytorch/triton_kernels/grouped_gemm.py b/transformer_engine/pytorch/triton_kernels/grouped_gemm.py index 84cac374c..802ee50b4 100644 --- a/transformer_engine/pytorch/triton_kernels/grouped_gemm.py +++ b/transformer_engine/pytorch/triton_kernels/grouped_gemm.py @@ -76,7 +76,14 @@ def general_grouped_gemm_triton( # A=inputs (list of (m_i, in_features)), B=grad_outputs (list of (m_i, out_features)) A_tensor = A[0] if len(A) == 1 else torch.cat(A, dim=0) # (M, in_features) B_tensor = B[0] if len(B) == 1 else torch.cat(B, dim=0) # (M, out_features) - out_tensor_3d = out # (G, out_features, in_features) + # out is a list of per-expert grad tensors; stack to 3D for ptgmm + if isinstance(out, list): + if len(out) == 1 and out[0].ndim == 3: + out_tensor_3d = out[0] + else: + out_tensor_3d = torch.stack(out, dim=0) # (G, out_features, in_features) + else: + out_tensor_3d = out # Allocate bias_grad OUTPUT buffer if needed (kernel writes to this) bias_grad_tensor = None