Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 71 additions & 7 deletions tests/pytorch/distributed/run_numerics_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,35 @@ def _get_mean_abs_relative_error(a, b):
error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b))
return torch.mean(error)

@staticmethod
def _shard_strided(t, tp_size, rank, dim, stride):
"""Shard tensor along ``dim`` for given partition stride.

stride == 1: naive split (existing behavior, equivalent to _shard_tensor(t, tp_size, dim)[rank]).
stride > 1: MLM-golden interleaved shard from
``megatron/core/tensor_parallel/layers.py:_initialize_affine_weight_cpu`` —
split into ``stride * tp_size`` chunks, take ``chunks[rank::tp_size]``, concat.
This produces the per-rank ``[block_0_chunk | block_1_chunk | ...]`` layout used
by gated MLPs (SwiGLU FC1 with stride=2).
"""
chunk_count = tp_size * stride
chunk = t.shape[dim] // chunk_count
chunks = list(torch.split(t, chunk, dim=dim))
return torch.cat(chunks[rank::tp_size], dim=dim)

@staticmethod
def _stamp_partition_stride(layer, partition_stride, tp_size):
"""Stamp ``partition_stride`` on layer.weight/bias (mirrors MLM-side TE extension).

TE itself only ever sets stride=1; partition_stride>1 is an MLM-side concept
that the downstream NVTE_TP_INVARIANT_MODE deinterleave path reads from
``ctx.partition_stride``. No-op when stride==1 or tp_size==1.
"""
if partition_stride > 1 and tp_size > 1:
setattr(layer.weight, "partition_stride", partition_stride)
if layer.bias is not None:
setattr(layer.bias, "partition_stride", partition_stride)

@classmethod
def run_linear_preprocess_parallel(
cls,
Expand All @@ -239,14 +268,18 @@ def run_linear_preprocess_parallel(
sequence_parallel=False,
tp_size=1,
rank=0,
stride=1,
):
if tp_size > 1:
if parallel_mode == "column":
# split w in N dim, which should be axis 0
w = cls._shard_tensor(w, tp_size, 0)[rank]
bias = cls._shard_tensor(bias, tp_size, 0)[rank] if bias is not None else None
# split gradient in N dim, which should be axis 1
gradient = cls._shard_tensor(gradient, tp_size, 1)[rank]
# split w in N dim (axis 0), gradient in N dim (axis 1); stride>1 → interleave.
w = cls._shard_strided(w, tp_size, rank, dim=0, stride=stride)
bias = (
cls._shard_strided(bias, tp_size, rank, dim=0, stride=stride)
if bias is not None
else None
)
gradient = cls._shard_strided(gradient, tp_size, rank, dim=1, stride=stride)
if sequence_parallel:
# split x in M dim, which should be axis 0
x = cls._shard_tensor(x, tp_size, 0)[rank]
Expand Down Expand Up @@ -396,10 +429,16 @@ def run_linear(
run_num_steps=1,
enable_weight_cache=False,
fuse_wgrad_accumulation=False,
partition_stride=1,
):
"""
If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with
the reference single GPU run.

``partition_stride > 1`` enables gated-MLP-style sharding (e.g., SwiGLU FC1):
the weight (and gradient, bias) are sharded into interleaved per-rank chunks
per MLM's `_initialize_affine_weight_cpu`, and ``partition_stride`` is stamped
on the constructed layer's weight (mirroring what MLM does post-construction).
"""
# clone inputs and move to current device
# w has shape [N, K], x has shape [M, K], gradient has shape [M, N]
Expand All @@ -412,7 +451,15 @@ def run_linear(

# If Model parallel: split inputs for a given rank
x, w, bias, gradient = cls.run_linear_preprocess_parallel(
x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank
x,
w,
bias,
gradient,
parallel_mode,
sequence_parallel,
tp_size,
rank,
stride=partition_stride,
)

# set data types
Expand All @@ -438,6 +485,8 @@ def run_linear(
if bias is not None:
layer.bias.copy_(bias)

cls._stamp_partition_stride(layer, partition_stride, tp_size)

if fuse_wgrad_accumulation:
assert (
run_num_steps > 1
Expand Down Expand Up @@ -607,10 +656,15 @@ def run_layernorm_linear(
enable_weight_cache=False,
LayerNormLinearClass=te.LayerNormLinear,
normalization="LayerNorm",
partition_stride=1,
):
"""
If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with
the reference single GPU run.

``partition_stride > 1`` enables gated-MLP-style sharding (SwiGLU FC1, stride=2):
weight + gradient are interleaved per-rank via MLM's golden algorithm, and
``partition_stride`` is stamped on the constructed layer's weight.
"""
# clone inputs and move to current device
# w has shape [N, K], x has shape [M, K], gradient has shape [M, N]
Expand All @@ -623,7 +677,15 @@ def run_layernorm_linear(

# If Model parallel: split inputs for a given rank
x, w, bias, gradient = cls.run_linear_preprocess_parallel(
x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank
x,
w,
bias,
gradient,
parallel_mode,
sequence_parallel,
tp_size,
rank,
stride=partition_stride,
)

# set data types
Expand Down Expand Up @@ -652,6 +714,8 @@ def run_layernorm_linear(
if bias is not None:
layer.bias.copy_(bias)

cls._stamp_partition_stride(layer, partition_stride, tp_size)

# Run one step
y_q, ln_out, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient)

Expand Down
206 changes: 206 additions & 0 deletions tests/pytorch/distributed/run_tp_invariant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
#!/usr/bin/python3

# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""NVTE_TP_INVARIANT_MODE distributed test body. Launched via test_tp_invariant.py.

One invocation runs ONE (parallel_mode, sequence_parallel, expect_bitwise)
combination per pytest parametrize axis:

expect_bitwise=True (NVTE_TP_INVARIANT_MODE=1, with_tp_invariant):
TP=N == TP=1 bit-for-bit.
expect_bitwise=False (NVTE_TP_INVARIANT_MODE=0, without_tp_invariant):
TP=N != TP=1 (stock TP isn't bitwise; guards
with_tp_invariant against trivial-pass).

Wgrad is intentionally not compared — patches gate only the forward
(row-parallel) and dgrad (column-parallel) paths.

Reuses TestDistributedLinearBase from run_numerics_exact for sharding + gather.

LayerNormLinear with partition_stride=2 (SwiGLU FC1 deinterleave) is covered
by ``_check_tp_invariance_deinterleave`` — positive only (no without_tp_invariant
variant), because stock TP dgrad uses the same products as TP=1 just in a
different accumulation order, so bit-difference is layout-dependent (flaky).
"""

import argparse
import datetime
import os
import sys

import run_numerics_exact as rne
import torch
import torch.distributed as dist
from run_numerics_exact import (
TestDistributedLayerNormLinearBase,
TestDistributedLinearBase,
dist_print,
)

BATCH, HIDDEN, OUT = 16, 256, 128
DTYPE = torch.bfloat16


def _run_linear(parallel_mode, sequence_parallel):
"""Run TP=1 reference and TP=N for given config; return relevant output pair.

For parallel_mode='row' returns (y_ref, y_tp). For 'column' returns
(dgrad_ref, dgrad_tp). Both are full (gathered) tensors regardless of
sharding, suitable for direct bitwise comparison.
"""
x, w, bias, gradient = TestDistributedLinearBase._prepare_data(
BATCH,
HIDDEN,
OUT,
use_bias=False,
seed=42,
dtype=DTYPE,
)
y_ref, dgrad_ref, _, _ = TestDistributedLinearBase.run_linear(
x,
w,
bias,
gradient,
parallel_mode=None,
sequence_parallel=False,
tp_group=None,
tp_size=1,
rank=0,
)
y_tp, dgrad_tp, _, _ = TestDistributedLinearBase.run_linear(
x,
w,
bias,
gradient,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=rne.NCCL_WORLD,
tp_size=rne.WORLD_SIZE,
rank=rne.WORLD_RANK,
)
if parallel_mode == "row":
return y_ref, y_tp
return dgrad_ref, dgrad_tp


def _check_tp_invariance(parallel_mode, sequence_parallel, expect_bitwise):
"""Run one check; assert bitwise (with_tp_invariant) or non-bitwise (without_tp_invariant)."""
os.environ["NVTE_TP_INVARIANT_MODE"] = "1" if expect_bitwise else "0"
ref, tp = _run_linear(parallel_mode, sequence_parallel)

if rne.WORLD_RANK != 0:
return

kind = "fwd" if parallel_mode == "row" else "dgrad"
label = f"{parallel_mode}-parallel {kind} sp={int(sequence_parallel)}"

if expect_bitwise:
torch.testing.assert_close(
tp,
ref,
atol=0,
rtol=0,
msg=f"{label} not bitwise under NVTE_TP_INVARIANT_MODE=1",
)
dist_print(f"[with_tp_invariant ] {label}: TP=1 ≡ TP={rne.WORLD_SIZE} bitwise")
else:
assert not torch.equal(
tp, ref
), f"without_tp_invariant: {label} unexpectedly bitwise under NVTE_TP_INVARIANT_MODE=0"
dist_print(f"[without_tp_invariant] {label}: TP=1 ≠ TP={rne.WORLD_SIZE} (as expected)")


def _check_tp_invariance_deinterleave(sequence_parallel):
"""LayerNormLinear column-parallel + partition_stride=2 (SwiGLU FC1) TP-invariance.

Uses MLM's golden stride=2 sharding (added to ``TestDistributedLayerNormLinearBase``)
to construct per-rank interleaved weight; verifies our deinterleave correctly inverts
it so TP=N dgrad bitwise matches the TP=1 reference."""
os.environ["NVTE_TP_INVARIANT_MODE"] = "1"
x, w, _, g = TestDistributedLinearBase._prepare_data(
BATCH,
HIDDEN,
OUT,
use_bias=False,
seed=42,
dtype=DTYPE,
)
_, _, dgrad_ref, _, _ = TestDistributedLayerNormLinearBase.run_layernorm_linear(
x,
w,
None,
g,
parallel_mode=None,
sequence_parallel=False,
tp_group=None,
tp_size=1,
rank=0,
partition_stride=1,
)
_, _, dgrad_tp, _, _ = TestDistributedLayerNormLinearBase.run_layernorm_linear(
x,
w,
None,
g,
parallel_mode="column",
sequence_parallel=sequence_parallel,
tp_group=rne.NCCL_WORLD,
tp_size=rne.WORLD_SIZE,
rank=rne.WORLD_RANK,
partition_stride=2,
)
if rne.WORLD_RANK != 0:
return
label = f"LN-Linear stride=2 sp={int(sequence_parallel)}"
torch.testing.assert_close(
dgrad_tp, dgrad_ref, atol=0, rtol=0, msg=f"{label}: not TP-invariant"
)
dist_print(f"{label}: TP=1 ≡ TP={rne.WORLD_SIZE} bitwise via deinterleave")


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--check-type", choices=["linear", "deinterleave"], default="linear")
parser.add_argument("--parallel-mode", choices=["row", "column"])
parser.add_argument("--sequence-parallel", action="store_true")
parser.add_argument("--expect-bitwise", type=int, choices=[0, 1])
args = parser.parse_args()

rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
assert world_size <= torch.cuda.device_count()

torch.cuda.set_device(local_rank)
dist.init_process_group(
backend="nccl",
rank=rank,
world_size=world_size,
init_method="env://",
timeout=datetime.timedelta(seconds=60),
device_id=torch.device(f"cuda:{local_rank}"),
)

rne.WORLD_RANK, rne.WORLD_SIZE = rank, world_size
rne.NCCL_WORLD = dist.new_group(backend="nccl")

if args.check_type == "linear":
assert args.parallel_mode is not None, "--parallel-mode required for linear check"
assert args.expect_bitwise is not None, "--expect-bitwise required for linear check"
_check_tp_invariance(
parallel_mode=args.parallel_mode,
sequence_parallel=args.sequence_parallel,
expect_bitwise=bool(args.expect_bitwise),
)
else: # deinterleave
_check_tp_invariance_deinterleave(sequence_parallel=args.sequence_parallel)

dist.destroy_process_group()
return 0


if __name__ == "__main__":
sys.exit(main())
Loading
Loading