diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py index 15ae2dae63..5c3175ef3b 100644 --- a/tests/pytorch/distributed/run_numerics_exact.py +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -228,6 +228,35 @@ def _get_mean_abs_relative_error(a, b): error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b)) return torch.mean(error) + @staticmethod + def _shard_strided(t, tp_size, rank, dim, stride): + """Shard tensor along ``dim`` for given partition stride. + + stride == 1: naive split (existing behavior, equivalent to _shard_tensor(t, tp_size, dim)[rank]). + stride > 1: MLM-golden interleaved shard from + ``megatron/core/tensor_parallel/layers.py:_initialize_affine_weight_cpu`` — + split into ``stride * tp_size`` chunks, take ``chunks[rank::tp_size]``, concat. + This produces the per-rank ``[block_0_chunk | block_1_chunk | ...]`` layout used + by gated MLPs (SwiGLU FC1 with stride=2). + """ + chunk_count = tp_size * stride + chunk = t.shape[dim] // chunk_count + chunks = list(torch.split(t, chunk, dim=dim)) + return torch.cat(chunks[rank::tp_size], dim=dim) + + @staticmethod + def _stamp_partition_stride(layer, partition_stride, tp_size): + """Stamp ``partition_stride`` on layer.weight/bias (mirrors MLM-side TE extension). + + TE itself only ever sets stride=1; partition_stride>1 is an MLM-side concept + that the downstream NVTE_TP_INVARIANT_MODE deinterleave path reads from + ``ctx.partition_stride``. No-op when stride==1 or tp_size==1. + """ + if partition_stride > 1 and tp_size > 1: + setattr(layer.weight, "partition_stride", partition_stride) + if layer.bias is not None: + setattr(layer.bias, "partition_stride", partition_stride) + @classmethod def run_linear_preprocess_parallel( cls, @@ -239,14 +268,18 @@ def run_linear_preprocess_parallel( sequence_parallel=False, tp_size=1, rank=0, + stride=1, ): if tp_size > 1: if parallel_mode == "column": - # split w in N dim, which should be axis 0 - w = cls._shard_tensor(w, tp_size, 0)[rank] - bias = cls._shard_tensor(bias, tp_size, 0)[rank] if bias is not None else None - # split gradient in N dim, which should be axis 1 - gradient = cls._shard_tensor(gradient, tp_size, 1)[rank] + # split w in N dim (axis 0), gradient in N dim (axis 1); stride>1 → interleave. + w = cls._shard_strided(w, tp_size, rank, dim=0, stride=stride) + bias = ( + cls._shard_strided(bias, tp_size, rank, dim=0, stride=stride) + if bias is not None + else None + ) + gradient = cls._shard_strided(gradient, tp_size, rank, dim=1, stride=stride) if sequence_parallel: # split x in M dim, which should be axis 0 x = cls._shard_tensor(x, tp_size, 0)[rank] @@ -396,10 +429,16 @@ def run_linear( run_num_steps=1, enable_weight_cache=False, fuse_wgrad_accumulation=False, + partition_stride=1, ): """ If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with the reference single GPU run. + + ``partition_stride > 1`` enables gated-MLP-style sharding (e.g., SwiGLU FC1): + the weight (and gradient, bias) are sharded into interleaved per-rank chunks + per MLM's `_initialize_affine_weight_cpu`, and ``partition_stride`` is stamped + on the constructed layer's weight (mirroring what MLM does post-construction). """ # clone inputs and move to current device # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] @@ -412,7 +451,15 @@ def run_linear( # If Model parallel: split inputs for a given rank x, w, bias, gradient = cls.run_linear_preprocess_parallel( - x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + x, + w, + bias, + gradient, + parallel_mode, + sequence_parallel, + tp_size, + rank, + stride=partition_stride, ) # set data types @@ -438,6 +485,8 @@ def run_linear( if bias is not None: layer.bias.copy_(bias) + cls._stamp_partition_stride(layer, partition_stride, tp_size) + if fuse_wgrad_accumulation: assert ( run_num_steps > 1 @@ -607,10 +656,15 @@ def run_layernorm_linear( enable_weight_cache=False, LayerNormLinearClass=te.LayerNormLinear, normalization="LayerNorm", + partition_stride=1, ): """ If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with the reference single GPU run. + + ``partition_stride > 1`` enables gated-MLP-style sharding (SwiGLU FC1, stride=2): + weight + gradient are interleaved per-rank via MLM's golden algorithm, and + ``partition_stride`` is stamped on the constructed layer's weight. """ # clone inputs and move to current device # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] @@ -623,7 +677,15 @@ def run_layernorm_linear( # If Model parallel: split inputs for a given rank x, w, bias, gradient = cls.run_linear_preprocess_parallel( - x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + x, + w, + bias, + gradient, + parallel_mode, + sequence_parallel, + tp_size, + rank, + stride=partition_stride, ) # set data types @@ -652,6 +714,8 @@ def run_layernorm_linear( if bias is not None: layer.bias.copy_(bias) + cls._stamp_partition_stride(layer, partition_stride, tp_size) + # Run one step y_q, ln_out, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) diff --git a/tests/pytorch/distributed/run_tp_invariant.py b/tests/pytorch/distributed/run_tp_invariant.py new file mode 100644 index 0000000000..56a7bf89be --- /dev/null +++ b/tests/pytorch/distributed/run_tp_invariant.py @@ -0,0 +1,206 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""NVTE_TP_INVARIANT_MODE distributed test body. Launched via test_tp_invariant.py. + +One invocation runs ONE (parallel_mode, sequence_parallel, expect_bitwise) +combination per pytest parametrize axis: + + expect_bitwise=True (NVTE_TP_INVARIANT_MODE=1, with_tp_invariant): + TP=N == TP=1 bit-for-bit. + expect_bitwise=False (NVTE_TP_INVARIANT_MODE=0, without_tp_invariant): + TP=N != TP=1 (stock TP isn't bitwise; guards + with_tp_invariant against trivial-pass). + +Wgrad is intentionally not compared — patches gate only the forward +(row-parallel) and dgrad (column-parallel) paths. + +Reuses TestDistributedLinearBase from run_numerics_exact for sharding + gather. + +LayerNormLinear with partition_stride=2 (SwiGLU FC1 deinterleave) is covered +by ``_check_tp_invariance_deinterleave`` — positive only (no without_tp_invariant +variant), because stock TP dgrad uses the same products as TP=1 just in a +different accumulation order, so bit-difference is layout-dependent (flaky). +""" + +import argparse +import datetime +import os +import sys + +import run_numerics_exact as rne +import torch +import torch.distributed as dist +from run_numerics_exact import ( + TestDistributedLayerNormLinearBase, + TestDistributedLinearBase, + dist_print, +) + +BATCH, HIDDEN, OUT = 16, 256, 128 +DTYPE = torch.bfloat16 + + +def _run_linear(parallel_mode, sequence_parallel): + """Run TP=1 reference and TP=N for given config; return relevant output pair. + + For parallel_mode='row' returns (y_ref, y_tp). For 'column' returns + (dgrad_ref, dgrad_tp). Both are full (gathered) tensors regardless of + sharding, suitable for direct bitwise comparison. + """ + x, w, bias, gradient = TestDistributedLinearBase._prepare_data( + BATCH, + HIDDEN, + OUT, + use_bias=False, + seed=42, + dtype=DTYPE, + ) + y_ref, dgrad_ref, _, _ = TestDistributedLinearBase.run_linear( + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + ) + y_tp, dgrad_tp, _, _ = TestDistributedLinearBase.run_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=rne.NCCL_WORLD, + tp_size=rne.WORLD_SIZE, + rank=rne.WORLD_RANK, + ) + if parallel_mode == "row": + return y_ref, y_tp + return dgrad_ref, dgrad_tp + + +def _check_tp_invariance(parallel_mode, sequence_parallel, expect_bitwise): + """Run one check; assert bitwise (with_tp_invariant) or non-bitwise (without_tp_invariant).""" + os.environ["NVTE_TP_INVARIANT_MODE"] = "1" if expect_bitwise else "0" + ref, tp = _run_linear(parallel_mode, sequence_parallel) + + if rne.WORLD_RANK != 0: + return + + kind = "fwd" if parallel_mode == "row" else "dgrad" + label = f"{parallel_mode}-parallel {kind} sp={int(sequence_parallel)}" + + if expect_bitwise: + torch.testing.assert_close( + tp, + ref, + atol=0, + rtol=0, + msg=f"{label} not bitwise under NVTE_TP_INVARIANT_MODE=1", + ) + dist_print(f"[with_tp_invariant ] {label}: TP=1 ≡ TP={rne.WORLD_SIZE} bitwise") + else: + assert not torch.equal( + tp, ref + ), f"without_tp_invariant: {label} unexpectedly bitwise under NVTE_TP_INVARIANT_MODE=0" + dist_print(f"[without_tp_invariant] {label}: TP=1 ≠ TP={rne.WORLD_SIZE} (as expected)") + + +def _check_tp_invariance_deinterleave(sequence_parallel): + """LayerNormLinear column-parallel + partition_stride=2 (SwiGLU FC1) TP-invariance. + + Uses MLM's golden stride=2 sharding (added to ``TestDistributedLayerNormLinearBase``) + to construct per-rank interleaved weight; verifies our deinterleave correctly inverts + it so TP=N dgrad bitwise matches the TP=1 reference.""" + os.environ["NVTE_TP_INVARIANT_MODE"] = "1" + x, w, _, g = TestDistributedLinearBase._prepare_data( + BATCH, + HIDDEN, + OUT, + use_bias=False, + seed=42, + dtype=DTYPE, + ) + _, _, dgrad_ref, _, _ = TestDistributedLayerNormLinearBase.run_layernorm_linear( + x, + w, + None, + g, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + partition_stride=1, + ) + _, _, dgrad_tp, _, _ = TestDistributedLayerNormLinearBase.run_layernorm_linear( + x, + w, + None, + g, + parallel_mode="column", + sequence_parallel=sequence_parallel, + tp_group=rne.NCCL_WORLD, + tp_size=rne.WORLD_SIZE, + rank=rne.WORLD_RANK, + partition_stride=2, + ) + if rne.WORLD_RANK != 0: + return + label = f"LN-Linear stride=2 sp={int(sequence_parallel)}" + torch.testing.assert_close( + dgrad_tp, dgrad_ref, atol=0, rtol=0, msg=f"{label}: not TP-invariant" + ) + dist_print(f"{label}: TP=1 ≡ TP={rne.WORLD_SIZE} bitwise via deinterleave") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--check-type", choices=["linear", "deinterleave"], default="linear") + parser.add_argument("--parallel-mode", choices=["row", "column"]) + parser.add_argument("--sequence-parallel", action="store_true") + parser.add_argument("--expect-bitwise", type=int, choices=[0, 1]) + args = parser.parse_args() + + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + assert world_size <= torch.cuda.device_count() + + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="nccl", + rank=rank, + world_size=world_size, + init_method="env://", + timeout=datetime.timedelta(seconds=60), + device_id=torch.device(f"cuda:{local_rank}"), + ) + + rne.WORLD_RANK, rne.WORLD_SIZE = rank, world_size + rne.NCCL_WORLD = dist.new_group(backend="nccl") + + if args.check_type == "linear": + assert args.parallel_mode is not None, "--parallel-mode required for linear check" + assert args.expect_bitwise is not None, "--expect-bitwise required for linear check" + _check_tp_invariance( + parallel_mode=args.parallel_mode, + sequence_parallel=args.sequence_parallel, + expect_bitwise=bool(args.expect_bitwise), + ) + else: # deinterleave + _check_tp_invariance_deinterleave(sequence_parallel=args.sequence_parallel) + + dist.destroy_process_group() + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/pytorch/distributed/test_tp_invariant.py b/tests/pytorch/distributed/test_tp_invariant.py new file mode 100644 index 0000000000..62603e0d9d --- /dev/null +++ b/tests/pytorch/distributed/test_tp_invariant.py @@ -0,0 +1,95 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""pytest entry for the NVTE_TP_INVARIANT_MODE distributed correctness test. + +Launches ``run_tp_invariant.py`` under torchrun with per-case CLI args and asserts +the subprocess returns 0. The actual bitwise checks live in ``run_tp_invariant.py``. + +Test matrix (per tp_size): + parallel_mode ∈ {row, column} + sequence_parallel ∈ {False, True} + expect_bitwise ∈ {True (with_tp_invariant, patches on), + False (without_tp_invariant, patches off)} +""" + +import os +import subprocess +from pathlib import Path + +import pytest +import torch + +if torch.cuda.device_count() < 2: + pytest.skip( + "TP-invariant test requires at least 2 GPUs.", + allow_module_level=True, + ) + + +TEST_ROOT = Path(__file__).parent.resolve() + + +def _tp_sizes(): + """TP sizes worth exercising on this node (2 and, if available, 4).""" + n = torch.cuda.device_count() + sizes = [2] + if n >= 4: + sizes.append(4) + return sizes + + +@pytest.mark.parametrize("tp_size", _tp_sizes()) +@pytest.mark.parametrize("parallel_mode", ["row", "column"]) +@pytest.mark.parametrize("sequence_parallel", [False, True]) +@pytest.mark.parametrize( + "expect_bitwise", [True, False], ids=["with_tp_invariant", "without_tp_invariant"] +) +def test_tp_invariant(tp_size, parallel_mode, sequence_parallel, expect_bitwise): + """One TP-invariant correctness check per parameter combination. + + expect_bitwise=True → NVTE_TP_INVARIANT_MODE=1, TP=N must equal TP=1 bit-for-bit. + expect_bitwise=False → NVTE_TP_INVARIANT_MODE=0, TP=N must DIFFER from TP=1 + (without_tp_invariant, guards against trivial-pass).""" + cmd = [ + "torchrun", + f"--nproc_per_node={tp_size}", + str(TEST_ROOT / "run_tp_invariant.py"), + "--check-type", + "linear", + "--parallel-mode", + parallel_mode, + "--expect-bitwise", + str(int(expect_bitwise)), + ] + if sequence_parallel: + cmd.append("--sequence-parallel") + + result = subprocess.run(cmd, env=os.environ, check=False) + assert result.returncode == 0, ( + f"run_tp_invariant.py failed: tp_size={tp_size}, parallel_mode={parallel_mode}, " + f"sequence_parallel={sequence_parallel}, expect_bitwise={expect_bitwise} " + f"(returncode={result.returncode})" + ) + + +@pytest.mark.parametrize("tp_size", _tp_sizes()) +@pytest.mark.parametrize("sequence_parallel", [False, True]) +def test_tp_invariant_deinterleave(tp_size, sequence_parallel): + """LayerNormLinear column-parallel + partition_stride=2 (SwiGLU FC1 layout): dgrad bitwise + matches TP=1 reference. Uses MLM's golden stride=2 sharding to construct per-rank + interleaved weight; verifies our deinterleave correctly inverts it.""" + cmd = [ + "torchrun", + f"--nproc_per_node={tp_size}", + str(TEST_ROOT / "run_tp_invariant.py"), + "--check-type", + "deinterleave", + ] + if sequence_parallel: + cmd.append("--sequence-parallel") + result = subprocess.run(cmd, env=os.environ, check=False) + assert ( + result.returncode == 0 + ), f"deinterleave failed: tp_size={tp_size}, sequence_parallel={sequence_parallel}" diff --git a/transformer_engine/pytorch/module/_tp_invariant.py b/transformer_engine/pytorch/module/_tp_invariant.py new file mode 100644 index 0000000000..c9a85eb0ce --- /dev/null +++ b/transformer_engine/pytorch/module/_tp_invariant.py @@ -0,0 +1,192 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Internal helpers for the NVTE_TP_INVARIANT_MODE code path. + +Gated on ``NVTE_TP_INVARIANT_MODE=1`` and used by ``linear.py`` and +``layernorm_linear.py``. With the env var unset (default), these helpers are +unreachable and stock GEMM paths are taken unchanged. + +The TP-invariant path performs the full-K (or full-out) GEMM after all-gathering +sharded operands across the TP group. Result: bit-identical numerics across +TP=1/2/4/... because the underlying GEMM K-dimension accumulation order is fixed +regardless of how the operands were sharded. + +``gemm_fn`` is passed in by the caller so downstream monkey-patches of +``general_gemm`` in the caller's namespace (e.g., Megatron's batch-invariant +kernels) are honored. + +Limitations: +- FP8 not supported (callers should assert ``not fp8`` before calling). +- Trades compute for invariance (gathered operands + full GEMM). Off by default. +""" + +from typing import Callable, Optional + +import torch + +from ..utils import nvtx_range_pop, nvtx_range_push + +__all__ = [ + "allgather_along_dim", + "tp_invariant_row_parallel_gemm", + "tp_invariant_column_parallel_dgrad", +] + + +def allgather_along_dim( + tensor: torch.Tensor, + group, + world_size: int, + dim: int, +) -> torch.Tensor: + """All-gather ``tensor`` from every rank in ``group`` and concat along ``dim``.""" + chunks = [torch.empty_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(chunks, tensor.contiguous(), group=group) + return torch.cat(chunks, dim=dim) + + +def tp_invariant_row_parallel_gemm( + weightmat: torch.Tensor, + inputmat_total: torch.Tensor, + bias: Optional[torch.Tensor], + tp_group, + tp_size: int, + sequence_parallel: bool, + activation_dtype: torch.dtype, + gemm_fn: Callable, + nvtx_label: str = "tp_invariant_gemm", +) -> torch.Tensor: + """Row-parallel forward GEMM with TP-invariant numerics. + + All-gathers input + weight along the contracted (K) dim, runs the full GEMM + (matching TP=1 accumulation order), then scatters along the sequence dim + when ``sequence_parallel=True``. + + Args: + weightmat: Local weight shard of shape ``[out, hidden/TP]``. + inputmat_total: Local input shard of shape ``[..., hidden/TP]``. + bias: Optional bias of shape ``[out]``. + tp_group: TP process group. + tp_size: Size of the TP group. + sequence_parallel: If True, scatter the full output along dim 0. + activation_dtype: Output dtype (matches stock GEMM behavior). + nvtx_label: NVTX range label. + + Returns: + Output of shape ``[..., out]`` (full sequence) or ``[.../TP, out]`` + (sequence-scattered when ``sequence_parallel=True``). + """ + nvtx_range_push(nvtx_label) + + inputmat_gathered = allgather_along_dim(inputmat_total, tp_group, tp_size, dim=-1) + weight_gathered = allgather_along_dim(weightmat, tp_group, tp_size, dim=-1) + + input_2d = inputmat_gathered.reshape(-1, inputmat_gathered.shape[-1]) + out = gemm_fn( + weight_gathered, + input_2d, + out_dtype=activation_dtype, + bias=bias, + ) + if isinstance(out, tuple): + out = out[0] + out = out.reshape(inputmat_gathered.shape[:-1] + (weight_gathered.shape[0],)) + + if sequence_parallel: + rank = torch.distributed.get_rank(tp_group) + out = out.chunk(tp_size, dim=0)[rank].contiguous() + + nvtx_range_pop(nvtx_label) + return out + + +def tp_invariant_column_parallel_dgrad( + weight: torch.Tensor, + grad_output: torch.Tensor, + tp_group, + tp_size: int, + sequence_parallel: bool, + activation_dtype: torch.dtype, + gemm_fn: Callable, + partition_stride: int = 1, + nvtx_label: str = "tp_invariant_dgrad", +) -> torch.Tensor: + """Column-parallel backward dgrad with TP-invariant numerics. + + All-gathers grad_output (along out dim) and weight (along out dim), runs the + full dgrad GEMM, then scatters along the sequence dim under SP. + + For gated MLPs (e.g. SwiGLU FC1 where each rank holds interleaved [gate|val] + halves), ``partition_stride > 1`` triggers a deinterleave step to recover + the TP=1 layout [gate_all | val_all] before the full GEMM: + + Per-rank layout: [g_0 | v_0 | g_1 | v_1 | ... | g_{TP-1} | v_{TP-1}] + TP=1 native layout: [g_0 | g_1 | ... | g_{TP-1} | v_0 | v_1 | ... | v_{TP-1}] + + For non-gated layers (QKV etc., partition_stride=1) the naive all-gather + already matches the TP=1 ordering. + + Args: + weight: Local weight shard of shape ``[out/TP, in]``. + grad_output: Local grad shard of shape ``[..., out/TP]``. + tp_group: TP process group. + tp_size: Size of TP group. + sequence_parallel: If True, scatter dgrad along dim 0. + activation_dtype: Output dtype. + partition_stride: >1 triggers the deinterleave for gated MLPs. + nvtx_label: NVTX range label. + + Returns: + dgrad of shape ``[..., in]`` or sequence-scattered. + """ + nvtx_range_push(nvtx_label) + + grad_output_gathered = allgather_along_dim(grad_output, tp_group, tp_size, dim=-1) + weight_gathered = allgather_along_dim(weight, tp_group, tp_size, dim=0) + + if partition_stride > 1: + # Deinterleave gated [gate|val] halves to TP=1 [gate_all | val_all]. + # Currently only the 2-way gated split (SwiGLU FC1 layout) is handled. + assert ( + partition_stride == 2 + ), f"deinterleave only supports partition_stride=2 (gated halve); got {partition_stride}" + chunk_sz = weight.shape[0] # out_features per rank + half = chunk_sz // 2 + first_w = [weight_gathered[i * chunk_sz : i * chunk_sz + half] for i in range(tp_size)] + second_w = [ + weight_gathered[i * chunk_sz + half : (i + 1) * chunk_sz] for i in range(tp_size) + ] + weight_gathered = torch.cat(first_w + second_w, dim=0) + + g_dim = grad_output_gathered.shape[-1] // tp_size + g_half = g_dim // 2 + first_g = [ + grad_output_gathered[..., i * g_dim : i * g_dim + g_half] for i in range(tp_size) + ] + second_g = [ + grad_output_gathered[..., i * g_dim + g_half : (i + 1) * g_dim] for i in range(tp_size) + ] + grad_output_gathered = torch.cat(first_g + second_g, dim=-1) + + grad_output_2d = grad_output_gathered.reshape( + -1, + grad_output_gathered.shape[-1], + ) + dgrad = gemm_fn( + weight_gathered, + grad_output_2d, + layout="NN", + grad=True, + out_dtype=activation_dtype, + ) + if isinstance(dgrad, tuple): + dgrad = dgrad[0] + + if sequence_parallel: + rank = torch.distributed.get_rank(tp_group) + dgrad = dgrad.chunk(tp_size, dim=0)[rank].contiguous() + + nvtx_range_pop(nvtx_label) + return dgrad diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 8c88f3ee82..9067b65b9c 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -28,6 +28,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) +from ._tp_invariant import tp_invariant_column_parallel_dgrad from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( assert_dim_for_fp8_exec, @@ -374,6 +375,8 @@ def forward( # ------------------------------------------------------ # Forward GEMM # Note: y = x * w^T + # No NVTE_TP_INVARIANT_MODE branch: LayerNormLinear is column-parallel only, + # so forward GEMM K=hidden_size does not vary with TP (invariant by construction). # ------------------------------------------------------ nvtx_range_push(f"{nvtx_label}.gemm") gemm_out, *_, reduce_scatter_out = general_gemm( @@ -520,6 +523,7 @@ def forward( ctx.input_quantizer = input_quantizer ctx.owns_input = inputmat is not inp ctx.weight = weight + ctx.partition_stride = getattr(weight, "partition_stride", 1) ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None @@ -798,63 +802,84 @@ def backward( # dgrad GEMM # Note: dx = dy * w - nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight - if ctx.backward_override == "dequantized": - if isinstance(weight_for_dgrad, QuantizedTensorStorage): - weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) - else: - weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) - elif ctx.backward_override == "high_precision": - weight_for_dgrad = saved_weight - if isinstance(weight_for_dgrad, QuantizedTensorStorage): - weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) - gemm_out, *_, reduce_scatter_out = general_gemm( - weight_for_dgrad, - grad_output, - layout="NN", - grad=True, - quantization_params=ctx.grad_input_quantizer, - out=gemm_out, - out_dtype=ctx.activation_dtype, - use_split_accumulator=use_split_accumulator, - ub=ub_obj_dgrad, - ub_type=ub_type_dgrad, - extra_output=reduce_scatter_out, - bulk_overlap=ctx.ub_bulk_dgrad, - ) - nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") - - # FSDP2 only handles deallocation all-gathered weights that it allocates. - # Columnwise data is derived from rowwise data after allgather for fp8 - # and 2d block-scaled weights in TE managed memory. So we need to clear - # it here. - # (Issues #2681, #2717) - if getattr(ctx, "is_fsdp2", False) and isinstance(weight, QuantizedTensorStorage): - clear_columnwise_cache(weight) - - # Prepare grad input tensor - # Note: Perform tensor-parallel communication dgrad = None dgrad_work = None - if ctx.ub_overlap_rs_dgrad: - dgrad = reduce_scatter_out - elif ctx.ub_bulk_wgrad: - dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) - elif ctx.parallel_mode == "column" and ctx.tp_size > 1: - nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") - dgrad = gemm_out - if ctx.sequence_parallel: - dgrad, dgrad_work = reduce_scatter_along_first_dim( - dgrad, - ctx.tp_group, - async_op=True, - ) - else: - dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) - nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") + + _tp_invariant_bwd = ( + os.environ.get("NVTE_TP_INVARIANT_MODE", "0") == "1" + and ctx.parallel_mode == "column" + and ctx.tp_size > 1 + ) + + if _tp_invariant_bwd: + assert not ctx.fp8, "NVTE_TP_INVARIANT_MODE does not support FP8" + dgrad = tp_invariant_column_parallel_dgrad( + weight=weight, + grad_output=grad_output, + tp_group=ctx.tp_group, + tp_size=ctx.tp_size, + sequence_parallel=ctx.sequence_parallel, + activation_dtype=ctx.activation_dtype, + gemm_fn=general_gemm, + partition_stride=ctx.partition_stride, + nvtx_label=f"{nvtx_label}.tp_invariant_dgrad", + ) else: - dgrad = gemm_out + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight + if ctx.backward_override == "dequantized": + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + else: + weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) + elif ctx.backward_override == "high_precision": + weight_for_dgrad = saved_weight + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + gemm_out, *_, reduce_scatter_out = general_gemm( + weight_for_dgrad, + grad_output, + layout="NN", + grad=True, + quantization_params=ctx.grad_input_quantizer, + out=gemm_out, + out_dtype=ctx.activation_dtype, + use_split_accumulator=use_split_accumulator, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=reduce_scatter_out, + bulk_overlap=ctx.ub_bulk_dgrad, + ) + nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") + + # FSDP2 only handles deallocation all-gathered weights that it allocates. + # Columnwise data is derived from rowwise data after allgather for fp8 + # and 2d block-scaled weights in TE managed memory. So we need to clear + # it here. + # (Issues #2681, #2717) + if getattr(ctx, "is_fsdp2", False) and isinstance(weight, QuantizedTensorStorage): + clear_columnwise_cache(weight) + + # Prepare grad input tensor + # Note: Perform tensor-parallel communication + if ctx.ub_overlap_rs_dgrad: + dgrad = reduce_scatter_out + elif ctx.ub_bulk_wgrad: + dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) + elif ctx.parallel_mode == "column" and ctx.tp_size > 1: + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") + dgrad = gemm_out + if ctx.sequence_parallel: + dgrad, dgrad_work = reduce_scatter_along_first_dim( + dgrad, + ctx.tp_group, + async_op=True, + ) + else: + dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") + else: + dgrad = gemm_out # -------------------------------------------------- # Grad input tensor has been computed... diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index dcbb9eaf93..9a8b305c7a 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Linear API""" +import os from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Tuple, Union, List from functools import reduce @@ -28,6 +29,10 @@ _2X_ACC_WGRAD, ) from ._common import noop_cat, WeightGradStore +from ._tp_invariant import ( + tp_invariant_column_parallel_dgrad, + tp_invariant_row_parallel_gemm, +) from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( cast_if_needed, @@ -474,19 +479,39 @@ def _linear_forward_impl( # Forward GEMM # Note: y = x * w^T # ------------------------------------------------------ - nvtx_range_push(f"{nvtx_label}.gemm") - gemm_out, *_, reduce_scatter_out = general_gemm( - weightmat, - inputmat_total, - quantization_params=output_quantizer, - out_dtype=activation_dtype, - bias=bias, - use_split_accumulator=use_split_accumulator, - ub=ub_obj, - ub_type=ub_type, - extra_output=reduce_scatter_out, + _tp_invariant = ( + os.environ.get("NVTE_TP_INVARIANT_MODE", "0") == "1" + and parallel_mode == "row" + and args.tp_size > 1 ) - nvtx_range_pop(f"{nvtx_label}.gemm") + + if _tp_invariant: + assert not fp8, "NVTE_TP_INVARIANT_MODE does not support FP8" + out = tp_invariant_row_parallel_gemm( + weightmat=weightmat, + inputmat_total=inputmat_total, + bias=bias, + tp_group=tp_group, + tp_size=tp_world_size, + sequence_parallel=sequence_parallel, + activation_dtype=activation_dtype, + gemm_fn=general_gemm, + nvtx_label=f"{nvtx_label}.tp_invariant_gemm", + ) + else: + nvtx_range_push(f"{nvtx_label}.gemm") + gemm_out, *_, reduce_scatter_out = general_gemm( + weightmat, + inputmat_total, + quantization_params=output_quantizer, + out_dtype=activation_dtype, + bias=bias, + use_split_accumulator=use_split_accumulator, + ub=ub_obj, + ub_type=ub_type, + extra_output=reduce_scatter_out, + ) + nvtx_range_pop(f"{nvtx_label}.gemm") # ------------------------------------------------------ # Finished forward GEMM... # ------------------------------------------------------ @@ -502,22 +527,23 @@ def _linear_forward_impl( # Prepare output tensor # Note: Perform tensor-parallel communication # ------------------------------------------------------ - out = None - if ub_overlap_rs_fprop: - out = reduce_scatter_out - elif parallel_mode == "row" and args.tp_size > 1: - nvtx_range_push(f"{nvtx_label}.row_parallel_comm") - out = gemm_out - if sequence_parallel: - out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif args.tensor_parallel: - if symmetric_ar_type is not None: - out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) - else: - out, _ = allreduce(out, tp_group) - nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") - else: - out = gemm_out + if not _tp_invariant: + out = None + if ub_overlap_rs_fprop: + out = reduce_scatter_out + elif parallel_mode == "row" and args.tp_size > 1: + nvtx_range_push(f"{nvtx_label}.row_parallel_comm") + out = gemm_out + if sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif args.tensor_parallel: + if symmetric_ar_type is not None: + out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) + else: + out, _ = allreduce(out, tp_group) + nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") + else: + out = gemm_out # ------------------------------------------------------ # Output tensor is ready to return... # ------------------------------------------------------ @@ -971,62 +997,86 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. # dgrad GEMM # Note: dx = dy * w - - nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight_fp8 - if bwd_args.backward_override == "dequantized": - if isinstance(weight_for_dgrad, QuantizedTensorStorage): - weight_for_dgrad = weight_for_dgrad.dequantize(dtype=bwd_args.activation_dtype) - else: - weight_for_dgrad = cast_if_needed(weight_for_dgrad, bwd_args.activation_dtype) - elif bwd_args.backward_override == "high_precision": - weight_for_dgrad = saved_weight - if isinstance(weight_for_dgrad, QuantizedTensorStorage): - weight_for_dgrad = weight_for_dgrad.dequantize(dtype=bwd_args.activation_dtype) - gemm_out, *_, reduce_scatter_out = general_gemm( - weight_for_dgrad, - grad_output, - layout="NN", - grad=True, - quantization_params=grad_input_quantizer, - out=gemm_out, - out_dtype=bwd_args.activation_dtype, - use_split_accumulator=use_split_accumulator, - ub=ub_obj_dgrad, - ub_type=ub_type_dgrad, - extra_output=reduce_scatter_out, - bulk_overlap=bwd_args.ub_bulk_dgrad, + _tp_invariant_bwd = ( + os.environ.get("NVTE_TP_INVARIANT_MODE", "0") == "1" + and bwd_args.parallel_mode == "column" + and bwd_args.tp_size > 1 ) - nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") - - # FSDP2 only handles deallocation all-gathered weights that it allocates. - # Columnwise data is derived from rowwise data after allgather for fp8 - # and 2d block-scaled weights in TE managed memory. So we need to clear - # it here. - # (Issues #2681, #2717) - if bwd_args.is_fsdp2 and isinstance(weight_fp8, QuantizedTensorStorage): - clear_columnwise_cache(weight_fp8) - - # Prepare grad input tensor - # Note: Perform tensor-parallel communication - if bwd_args.ub_overlap_rs_dgrad: - dgrad = reduce_scatter_out - elif bwd_args.ub_bulk_wgrad: - dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) - elif bwd_args.parallel_mode == "column" and bwd_args.tp_size > 1: - nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") - dgrad = gemm_out - if bwd_args.sequence_parallel: - dgrad, dgrad_work = reduce_scatter_along_first_dim( - dgrad, - bwd_args.tp_group, - async_op=True, - ) - else: - dgrad, dgrad_work = allreduce(dgrad, bwd_args.tp_group, async_op=True) - nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") + + if _tp_invariant_bwd: + assert not bwd_args.fp8, "NVTE_TP_INVARIANT_MODE does not support FP8" + dgrad = tp_invariant_column_parallel_dgrad( + weight=weight_fp8, + grad_output=grad_output, + tp_group=bwd_args.tp_group, + tp_size=bwd_args.tp_size, + sequence_parallel=bwd_args.sequence_parallel, + activation_dtype=bwd_args.activation_dtype, + gemm_fn=general_gemm, + nvtx_label=f"{nvtx_label}.tp_invariant_dgrad", + ) else: - dgrad = gemm_out + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight_fp8 + if bwd_args.backward_override == "dequantized": + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize( + dtype=bwd_args.activation_dtype + ) + else: + weight_for_dgrad = cast_if_needed( + weight_for_dgrad, bwd_args.activation_dtype + ) + elif bwd_args.backward_override == "high_precision": + weight_for_dgrad = saved_weight + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize( + dtype=bwd_args.activation_dtype + ) + gemm_out, *_, reduce_scatter_out = general_gemm( + weight_for_dgrad, + grad_output, + layout="NN", + grad=True, + quantization_params=grad_input_quantizer, + out=gemm_out, + out_dtype=bwd_args.activation_dtype, + use_split_accumulator=use_split_accumulator, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=reduce_scatter_out, + bulk_overlap=bwd_args.ub_bulk_dgrad, + ) + nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") + + # FSDP2 only handles deallocation all-gathered weights that it allocates. + # Columnwise data is derived from rowwise data after allgather for fp8 + # and 2d block-scaled weights in TE managed memory. So we need to clear + # it here. + # (Issues #2681, #2717) + if bwd_args.is_fsdp2 and isinstance(weight_fp8, QuantizedTensorStorage): + clear_columnwise_cache(weight_fp8) + + # Prepare grad input tensor + # Note: Perform tensor-parallel communication + if bwd_args.ub_overlap_rs_dgrad: + dgrad = reduce_scatter_out + elif bwd_args.ub_bulk_wgrad: + dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) + elif bwd_args.parallel_mode == "column" and bwd_args.tp_size > 1: + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") + dgrad = gemm_out + if bwd_args.sequence_parallel: + dgrad, dgrad_work = reduce_scatter_along_first_dim( + dgrad, + bwd_args.tp_group, + async_op=True, + ) + else: + dgrad, dgrad_work = allreduce(dgrad, bwd_args.tp_group, async_op=True) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") + else: + dgrad = gemm_out # -------------------------------------------------- # Grad input tensor has been computed...