From 52a29c162c7213ed3673ff06eab1232ee5af9ef0 Mon Sep 17 00:00:00 2001 From: Jinze Xue Date: Mon, 11 May 2026 16:50:17 -0700 Subject: [PATCH 1/4] Add TP-invariant GEMM for bitwise-identical training across TP degrees Gated on NVTE_TP_INVARIANT_MODE=1 (default off; stock paths unchanged). - module/linear.py: row-parallel FWD + BWD full GEMM matching TP=1 K-dim accumulation. - module/layernorm_linear.py: column-parallel BWD dgrad full GEMM with gated deinterleave for SwiGLU FC1 (partition_stride > 1). Companion Megatron-LM PR (gates this code path via env var): https://github.com/NVIDIA/Megatron-LM/pull/4740. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Jinze Xue --- .../pytorch/module/layernorm_linear.py | 182 ++++++++---- transformer_engine/pytorch/module/linear.py | 262 ++++++++++++------ 2 files changed, 311 insertions(+), 133 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 8c88f3ee82..78bae7dea8 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -374,7 +374,15 @@ def forward( # ------------------------------------------------------ # Forward GEMM # Note: y = x * w^T + # No _tp_invariant_fwd needed here: LayerNormLinear is column-parallel, + # so forward GEMM K=hidden_size is constant across TP (invariant by construction). + # Only row-parallel forward (linear.py) needs the invariant path (K=in/TP varies). # ------------------------------------------------------ + if os.environ.get("NVTE_TP_INVARIANT_MODE", "0") == "1" and parallel_mode == "row": + assert False, ( + "NVTE_TP_INVARIANT_MODE row-parallel forward is not implemented in " + "layernorm_linear.py. Use linear.py for row-parallel layers." + ) nvtx_range_push(f"{nvtx_label}.gemm") gemm_out, *_, reduce_scatter_out = general_gemm( weightmat, @@ -520,6 +528,7 @@ def forward( ctx.input_quantizer = input_quantizer ctx.owns_input = inputmat is not inp ctx.weight = weight + ctx.partition_stride = getattr(weight, 'partition_stride', 1) ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None @@ -798,63 +807,130 @@ def backward( # dgrad GEMM # Note: dx = dy * w - nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight - if ctx.backward_override == "dequantized": - if isinstance(weight_for_dgrad, QuantizedTensorStorage): - weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) - else: - weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) - elif ctx.backward_override == "high_precision": - weight_for_dgrad = saved_weight - if isinstance(weight_for_dgrad, QuantizedTensorStorage): - weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) - gemm_out, *_, reduce_scatter_out = general_gemm( - weight_for_dgrad, - grad_output, - layout="NN", - grad=True, - quantization_params=ctx.grad_input_quantizer, - out=gemm_out, - out_dtype=ctx.activation_dtype, - use_split_accumulator=use_split_accumulator, - ub=ub_obj_dgrad, - ub_type=ub_type_dgrad, - extra_output=reduce_scatter_out, - bulk_overlap=ctx.ub_bulk_dgrad, - ) - nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") - - # FSDP2 only handles deallocation all-gathered weights that it allocates. - # Columnwise data is derived from rowwise data after allgather for fp8 - # and 2d block-scaled weights in TE managed memory. So we need to clear - # it here. - # (Issues #2681, #2717) - if getattr(ctx, "is_fsdp2", False) and isinstance(weight, QuantizedTensorStorage): - clear_columnwise_cache(weight) - - # Prepare grad input tensor - # Note: Perform tensor-parallel communication dgrad = None dgrad_work = None - if ctx.ub_overlap_rs_dgrad: - dgrad = reduce_scatter_out - elif ctx.ub_bulk_wgrad: - dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) - elif ctx.parallel_mode == "column" and ctx.tp_size > 1: - nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") - dgrad = gemm_out - if ctx.sequence_parallel: - dgrad, dgrad_work = reduce_scatter_along_first_dim( - dgrad, - ctx.tp_group, - async_op=True, + + _tp_invariant_bwd = ( + os.environ.get("NVTE_TP_INVARIANT_MODE", "0") == "1" + and ctx.parallel_mode == "column" + and ctx.tp_size > 1 + ) + + if _tp_invariant_bwd: + # TP-invariant diagnostic: full dgrad GEMM matching TP=1 accumulation. + assert not ctx.fp8, "NVTE_TP_INVARIANT_MODE does not support FP8" + nvtx_range_push(f"{nvtx_label}.tp_invariant_dgrad") + + def allgather_along_dim(tensor, group, world_size, dim): + chunks = [torch.empty_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather( + chunks, tensor.contiguous(), group=group, ) - else: - dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) - nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") + return torch.cat(chunks, dim=dim) + + grad_output_gathered = allgather_along_dim( + grad_output, ctx.tp_group, ctx.tp_size, dim=-1, + ) # [tokens, out/TP] -> [tokens, out] + weight_gathered = allgather_along_dim( + weight, ctx.tp_group, ctx.tp_size, dim=0, + ) # [out/TP, in] -> [out, in] + + # Deinterleave gathered tensors to match TP=1 K-dimension ordering. + # Only needed for gated linear units (partition_stride > 1, e.g. SwiGLU FC1) + # where each rank stores interleaved [gate_i | value_i]. + # After all-gather: [gate_0|val_0 | gate_1|val_1 | ...]. + # TP=1 layout is [gate_all | val_all]. Reorder to match. + # + # For non-gated layers (partition_stride == 1, e.g. QKV), each rank + # stores contiguous GQA groups, and the naive all-gather already + # produces the correct TP=1 ordering. Deinterleaving would corrupt it. + if ctx.partition_stride > 1: + chunk_sz = weight.shape[0] # out_features per rank + half = chunk_sz // 2 + first_w = [weight_gathered[i * chunk_sz : i * chunk_sz + half] for i in range(ctx.tp_size)] + second_w = [weight_gathered[i * chunk_sz + half : (i + 1) * chunk_sz] for i in range(ctx.tp_size)] + weight_gathered = torch.cat(first_w + second_w, dim=0) + + g_dim = grad_output_gathered.shape[-1] // ctx.tp_size + g_half = g_dim // 2 + first_g = [grad_output_gathered[..., i * g_dim : i * g_dim + g_half] for i in range(ctx.tp_size)] + second_g = [grad_output_gathered[..., i * g_dim + g_half : (i + 1) * g_dim] for i in range(ctx.tp_size)] + grad_output_gathered = torch.cat(first_g + second_g, dim=-1) + + grad_output_2d = grad_output_gathered.reshape( + -1, grad_output_gathered.shape[-1], + ) + dgrad = general_gemm( + weight_gathered, grad_output_2d, + layout="NN", grad=True, + out_dtype=ctx.activation_dtype, + ) + if isinstance(dgrad, tuple): + dgrad = dgrad[0] + + # SP: scatter to per-rank chunk along sequence dim. + if ctx.sequence_parallel: + rank = torch.distributed.get_rank(ctx.tp_group) + dgrad = dgrad.chunk(ctx.tp_size, dim=0)[rank].contiguous() + + del grad_output_gathered, weight_gathered, grad_output_2d + nvtx_range_pop(f"{nvtx_label}.tp_invariant_dgrad") else: - dgrad = gemm_out + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight + if ctx.backward_override == "dequantized": + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + else: + weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) + elif ctx.backward_override == "high_precision": + weight_for_dgrad = saved_weight + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + gemm_out, *_, reduce_scatter_out = general_gemm( + weight_for_dgrad, + grad_output, + layout="NN", + grad=True, + quantization_params=ctx.grad_input_quantizer, + out=gemm_out, + out_dtype=ctx.activation_dtype, + use_split_accumulator=use_split_accumulator, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=reduce_scatter_out, + bulk_overlap=ctx.ub_bulk_dgrad, + ) + nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") + + # FSDP2 only handles deallocation all-gathered weights that it allocates. + # Columnwise data is derived from rowwise data after allgather for fp8 + # and 2d block-scaled weights in TE managed memory. So we need to clear + # it here. + # (Issues #2681, #2717) + if getattr(ctx, "is_fsdp2", False) and isinstance(weight, QuantizedTensorStorage): + clear_columnwise_cache(weight) + + # Prepare grad input tensor + # Note: Perform tensor-parallel communication + if ctx.ub_overlap_rs_dgrad: + dgrad = reduce_scatter_out + elif ctx.ub_bulk_wgrad: + dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) + elif ctx.parallel_mode == "column" and ctx.tp_size > 1: + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") + dgrad = gemm_out + if ctx.sequence_parallel: + dgrad, dgrad_work = reduce_scatter_along_first_dim( + dgrad, + ctx.tp_group, + async_op=True, + ) + else: + dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") + else: + dgrad = gemm_out # -------------------------------------------------- # Grad input tensor has been computed... diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index dcbb9eaf93..f20a2cd605 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Linear API""" +import os from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Tuple, Union, List from functools import reduce @@ -474,19 +475,67 @@ def _linear_forward_impl( # Forward GEMM # Note: y = x * w^T # ------------------------------------------------------ - nvtx_range_push(f"{nvtx_label}.gemm") - gemm_out, *_, reduce_scatter_out = general_gemm( - weightmat, - inputmat_total, - quantization_params=output_quantizer, - out_dtype=activation_dtype, - bias=bias, - use_split_accumulator=use_split_accumulator, - ub=ub_obj, - ub_type=ub_type, - extra_output=reduce_scatter_out, + _fp32_tp_reduce = ( + os.environ.get("NVTE_FP32_TP_REDUCE", "0") == "1" + and parallel_mode == "row" + and args.tp_size > 1 ) - nvtx_range_pop(f"{nvtx_label}.gemm") + _gemm_out_dtype = torch.float32 if _fp32_tp_reduce else activation_dtype + + _tp_invariant = ( + os.environ.get("NVTE_TP_INVARIANT_MODE", "0") == "1" + and parallel_mode == "row" + and args.tp_size > 1 + ) + + if _tp_invariant: + # TP-invariant diagnostic: full GEMM matching TP=1 accumulation order. + assert not fp8, "NVTE_TP_INVARIANT_MODE does not support FP8" + nvtx_range_push(f"{nvtx_label}.tp_invariant_gemm") + + def allgather_along_dim(tensor, group, world_size, dim): + chunks = [torch.empty_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(chunks, tensor.contiguous(), group=group) + return torch.cat(chunks, dim=dim) + + inputmat_gathered = allgather_along_dim( + inputmat_total, tp_group, tp_world_size, dim=-1, + ) # [tokens, hidden/TP] -> [tokens, hidden] + weight_gathered = allgather_along_dim( + weightmat, tp_group, tp_world_size, dim=-1, + ) # [out, hidden/TP] -> [out, hidden] + + input_2d = inputmat_gathered.reshape(-1, inputmat_gathered.shape[-1]) + out = general_gemm( + weight_gathered, input_2d, + out_dtype=activation_dtype, + bias=bias, + ) + if isinstance(out, tuple): + out = out[0] + out = out.reshape(inputmat_gathered.shape[:-1] + (weight_gathered.shape[0],)) + + # SP: scatter to per-rank chunk along sequence dim. + if sequence_parallel: + rank = torch.distributed.get_rank(tp_group) + out = out.chunk(tp_world_size, dim=0)[rank].contiguous() + + del inputmat_gathered, weight_gathered, input_2d + nvtx_range_pop(f"{nvtx_label}.tp_invariant_gemm") + else: + nvtx_range_push(f"{nvtx_label}.gemm") + gemm_out, *_, reduce_scatter_out = general_gemm( + weightmat, + inputmat_total, + quantization_params=output_quantizer, + out_dtype=_gemm_out_dtype, + bias=bias, + use_split_accumulator=use_split_accumulator, + ub=ub_obj, + ub_type=ub_type, + extra_output=reduce_scatter_out, + ) + nvtx_range_pop(f"{nvtx_label}.gemm") # ------------------------------------------------------ # Finished forward GEMM... # ------------------------------------------------------ @@ -502,22 +551,25 @@ def _linear_forward_impl( # Prepare output tensor # Note: Perform tensor-parallel communication # ------------------------------------------------------ - out = None - if ub_overlap_rs_fprop: - out = reduce_scatter_out - elif parallel_mode == "row" and args.tp_size > 1: - nvtx_range_push(f"{nvtx_label}.row_parallel_comm") - out = gemm_out - if sequence_parallel: - out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif args.tensor_parallel: - if symmetric_ar_type is not None: - out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) - else: - out, _ = allreduce(out, tp_group) - nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") - else: - out = gemm_out + if not _tp_invariant: + out = None + if ub_overlap_rs_fprop: + out = reduce_scatter_out + elif parallel_mode == "row" and args.tp_size > 1: + nvtx_range_push(f"{nvtx_label}.row_parallel_comm") + out = gemm_out + if sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif args.tensor_parallel: + if symmetric_ar_type is not None: + out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) + else: + out, _ = allreduce(out, tp_group) + if _fp32_tp_reduce and out.dtype != activation_dtype: + out = out.to(activation_dtype) + nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") + else: + out = gemm_out # ------------------------------------------------------ # Output tensor is ready to return... # ------------------------------------------------------ @@ -971,62 +1023,112 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. # dgrad GEMM # Note: dx = dy * w + _fp32_tp_reduce_bwd = ( + os.environ.get("NVTE_FP32_TP_REDUCE", "0") == "1" + and bwd_args.parallel_mode == "column" + and bwd_args.tp_size > 1 + ) + _dgrad_out_dtype = torch.float32 if _fp32_tp_reduce_bwd else bwd_args.activation_dtype - nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight_fp8 - if bwd_args.backward_override == "dequantized": - if isinstance(weight_for_dgrad, QuantizedTensorStorage): - weight_for_dgrad = weight_for_dgrad.dequantize(dtype=bwd_args.activation_dtype) - else: - weight_for_dgrad = cast_if_needed(weight_for_dgrad, bwd_args.activation_dtype) - elif bwd_args.backward_override == "high_precision": - weight_for_dgrad = saved_weight - if isinstance(weight_for_dgrad, QuantizedTensorStorage): - weight_for_dgrad = weight_for_dgrad.dequantize(dtype=bwd_args.activation_dtype) - gemm_out, *_, reduce_scatter_out = general_gemm( - weight_for_dgrad, - grad_output, - layout="NN", - grad=True, - quantization_params=grad_input_quantizer, - out=gemm_out, - out_dtype=bwd_args.activation_dtype, - use_split_accumulator=use_split_accumulator, - ub=ub_obj_dgrad, - ub_type=ub_type_dgrad, - extra_output=reduce_scatter_out, - bulk_overlap=bwd_args.ub_bulk_dgrad, + _tp_invariant_bwd = ( + os.environ.get("NVTE_TP_INVARIANT_MODE", "0") == "1" + and bwd_args.parallel_mode == "column" + and bwd_args.tp_size > 1 ) - nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") - - # FSDP2 only handles deallocation all-gathered weights that it allocates. - # Columnwise data is derived from rowwise data after allgather for fp8 - # and 2d block-scaled weights in TE managed memory. So we need to clear - # it here. - # (Issues #2681, #2717) - if bwd_args.is_fsdp2 and isinstance(weight_fp8, QuantizedTensorStorage): - clear_columnwise_cache(weight_fp8) - - # Prepare grad input tensor - # Note: Perform tensor-parallel communication - if bwd_args.ub_overlap_rs_dgrad: - dgrad = reduce_scatter_out - elif bwd_args.ub_bulk_wgrad: - dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) - elif bwd_args.parallel_mode == "column" and bwd_args.tp_size > 1: - nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") - dgrad = gemm_out - if bwd_args.sequence_parallel: - dgrad, dgrad_work = reduce_scatter_along_first_dim( - dgrad, - bwd_args.tp_group, - async_op=True, + + if _tp_invariant_bwd: + # TP-invariant diagnostic: full dgrad GEMM matching TP=1 accumulation. + assert not bwd_args.fp8, "NVTE_TP_INVARIANT_MODE does not support FP8" + nvtx_range_push(f"{nvtx_label}.tp_invariant_dgrad") + + def allgather_along_dim(tensor, group, world_size, dim): + chunks = [torch.empty_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather( + chunks, tensor.contiguous(), group=group, ) - else: - dgrad, dgrad_work = allreduce(dgrad, bwd_args.tp_group, async_op=True) - nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") + return torch.cat(chunks, dim=dim) + + grad_output_gathered = allgather_along_dim( + grad_output, bwd_args.tp_group, bwd_args.tp_size, dim=-1, + ) # [tokens, out/TP] -> [tokens, out] + weight_gathered = allgather_along_dim( + weight_fp8, bwd_args.tp_group, bwd_args.tp_size, dim=0, + ) # [out/TP, in] -> [out, in] + + grad_output_2d = grad_output_gathered.reshape( + -1, grad_output_gathered.shape[-1], + ) + dgrad = general_gemm( + weight_gathered, grad_output_2d, + layout="NN", grad=True, + out_dtype=bwd_args.activation_dtype, + ) + if isinstance(dgrad, tuple): + dgrad = dgrad[0] + + # SP: scatter to per-rank chunk along sequence dim. + if bwd_args.sequence_parallel: + rank = torch.distributed.get_rank(bwd_args.tp_group) + dgrad = dgrad.chunk(bwd_args.tp_size, dim=0)[rank].contiguous() + + del grad_output_gathered, weight_gathered, grad_output_2d + nvtx_range_pop(f"{nvtx_label}.tp_invariant_dgrad") else: - dgrad = gemm_out + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight_fp8 + if bwd_args.backward_override == "dequantized": + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=bwd_args.activation_dtype) + else: + weight_for_dgrad = cast_if_needed(weight_for_dgrad, bwd_args.activation_dtype) + elif bwd_args.backward_override == "high_precision": + weight_for_dgrad = saved_weight + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=bwd_args.activation_dtype) + gemm_out, *_, reduce_scatter_out = general_gemm( + weight_for_dgrad, + grad_output, + layout="NN", + grad=True, + quantization_params=grad_input_quantizer, + out=gemm_out, + out_dtype=_dgrad_out_dtype, + use_split_accumulator=use_split_accumulator, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=reduce_scatter_out, + bulk_overlap=bwd_args.ub_bulk_dgrad, + ) + nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") + + # FSDP2 only handles deallocation all-gathered weights that it allocates. + # Columnwise data is derived from rowwise data after allgather for fp8 + # and 2d block-scaled weights in TE managed memory. So we need to clear + # it here. + # (Issues #2681, #2717) + if bwd_args.is_fsdp2 and isinstance(weight_fp8, QuantizedTensorStorage): + clear_columnwise_cache(weight_fp8) + + # Prepare grad input tensor + # Note: Perform tensor-parallel communication + if bwd_args.ub_overlap_rs_dgrad: + dgrad = reduce_scatter_out + elif bwd_args.ub_bulk_wgrad: + dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) + elif bwd_args.parallel_mode == "column" and bwd_args.tp_size > 1: + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") + dgrad = gemm_out + if bwd_args.sequence_parallel: + dgrad, dgrad_work = reduce_scatter_along_first_dim( + dgrad, + bwd_args.tp_group, + async_op=True, + ) + else: + dgrad, dgrad_work = allreduce(dgrad, bwd_args.tp_group, async_op=True) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") + else: + dgrad = gemm_out # -------------------------------------------------- # Grad input tensor has been computed... From da8c6a525b9d48e3e7ba5e49cf5967edbb0ee8be Mon Sep 17 00:00:00 2001 From: Jinze Xue Date: Tue, 12 May 2026 20:10:47 -0700 Subject: [PATCH 2/4] Modularize TP-invariant helpers + add distributed unit test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract row-parallel fwd / column-parallel dgrad GEMMs to module/_tp_invariant.py (mirrors _common.py precedent); main-file call sites become helper calls. - Drop NVTE_FP32_TP_REDUCE (separable feature; ~15 lines removed). - Add tests/pytorch/distributed/{test,run}_tp_invariant.py: 20 cases covering Linear with/without NVTE_TP_INVARIANT_MODE × parallel_mode × sp × tp_size, plus LayerNormLinear partition_stride=2 (SwiGLU FC1 deinterleave). Reuses TestDistributedLinearBase from run_numerics_exact.py (extended with partition_stride support). Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Jinze Xue --- .../pytorch/distributed/run_numerics_exact.py | 60 +++++- tests/pytorch/distributed/run_tp_invariant.py | 162 +++++++++++++++ .../pytorch/distributed/test_tp_invariant.py | 90 +++++++++ .../pytorch/module/_tp_invariant.py | 189 ++++++++++++++++++ .../pytorch/module/layernorm_linear.py | 76 ++----- transformer_engine/pytorch/module/linear.py | 104 ++-------- 6 files changed, 528 insertions(+), 153 deletions(-) create mode 100644 tests/pytorch/distributed/run_tp_invariant.py create mode 100644 tests/pytorch/distributed/test_tp_invariant.py create mode 100644 transformer_engine/pytorch/module/_tp_invariant.py diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py index 15ae2dae63..b99e1f9ac9 100644 --- a/tests/pytorch/distributed/run_numerics_exact.py +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -228,6 +228,35 @@ def _get_mean_abs_relative_error(a, b): error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b)) return torch.mean(error) + @staticmethod + def _shard_strided(t, tp_size, rank, dim, stride): + """Shard tensor along ``dim`` for given partition stride. + + stride == 1: naive split (existing behavior, equivalent to _shard_tensor(t, tp_size, dim)[rank]). + stride > 1: MLM-golden interleaved shard from + ``megatron/core/tensor_parallel/layers.py:_initialize_affine_weight_cpu`` — + split into ``stride * tp_size`` chunks, take ``chunks[rank::tp_size]``, concat. + This produces the per-rank ``[block_0_chunk | block_1_chunk | ...]`` layout used + by gated MLPs (SwiGLU FC1 with stride=2). + """ + chunk_count = tp_size * stride + chunk = t.shape[dim] // chunk_count + chunks = list(torch.split(t, chunk, dim=dim)) + return torch.cat(chunks[rank::tp_size], dim=dim) + + @staticmethod + def _stamp_partition_stride(layer, partition_stride, tp_size): + """Stamp ``partition_stride`` on layer.weight/bias (mirrors MLM-side TE extension). + + TE itself only ever sets stride=1; partition_stride>1 is an MLM-side concept + that the downstream NVTE_TP_INVARIANT_MODE deinterleave path reads from + ``ctx.partition_stride``. No-op when stride==1 or tp_size==1. + """ + if partition_stride > 1 and tp_size > 1: + setattr(layer.weight, 'partition_stride', partition_stride) + if layer.bias is not None: + setattr(layer.bias, 'partition_stride', partition_stride) + @classmethod def run_linear_preprocess_parallel( cls, @@ -239,14 +268,14 @@ def run_linear_preprocess_parallel( sequence_parallel=False, tp_size=1, rank=0, + stride=1, ): if tp_size > 1: if parallel_mode == "column": - # split w in N dim, which should be axis 0 - w = cls._shard_tensor(w, tp_size, 0)[rank] - bias = cls._shard_tensor(bias, tp_size, 0)[rank] if bias is not None else None - # split gradient in N dim, which should be axis 1 - gradient = cls._shard_tensor(gradient, tp_size, 1)[rank] + # split w in N dim (axis 0), gradient in N dim (axis 1); stride>1 → interleave. + w = cls._shard_strided(w, tp_size, rank, dim=0, stride=stride) + bias = cls._shard_strided(bias, tp_size, rank, dim=0, stride=stride) if bias is not None else None + gradient = cls._shard_strided(gradient, tp_size, rank, dim=1, stride=stride) if sequence_parallel: # split x in M dim, which should be axis 0 x = cls._shard_tensor(x, tp_size, 0)[rank] @@ -396,10 +425,16 @@ def run_linear( run_num_steps=1, enable_weight_cache=False, fuse_wgrad_accumulation=False, + partition_stride=1, ): """ If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with the reference single GPU run. + + ``partition_stride > 1`` enables gated-MLP-style sharding (e.g., SwiGLU FC1): + the weight (and gradient, bias) are sharded into interleaved per-rank chunks + per MLM's `_initialize_affine_weight_cpu`, and ``partition_stride`` is stamped + on the constructed layer's weight (mirroring what MLM does post-construction). """ # clone inputs and move to current device # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] @@ -412,7 +447,8 @@ def run_linear( # If Model parallel: split inputs for a given rank x, w, bias, gradient = cls.run_linear_preprocess_parallel( - x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank, + stride=partition_stride, ) # set data types @@ -438,6 +474,8 @@ def run_linear( if bias is not None: layer.bias.copy_(bias) + cls._stamp_partition_stride(layer, partition_stride, tp_size) + if fuse_wgrad_accumulation: assert ( run_num_steps > 1 @@ -607,10 +645,15 @@ def run_layernorm_linear( enable_weight_cache=False, LayerNormLinearClass=te.LayerNormLinear, normalization="LayerNorm", + partition_stride=1, ): """ If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with the reference single GPU run. + + ``partition_stride > 1`` enables gated-MLP-style sharding (SwiGLU FC1, stride=2): + weight + gradient are interleaved per-rank via MLM's golden algorithm, and + ``partition_stride`` is stamped on the constructed layer's weight. """ # clone inputs and move to current device # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] @@ -623,7 +666,8 @@ def run_layernorm_linear( # If Model parallel: split inputs for a given rank x, w, bias, gradient = cls.run_linear_preprocess_parallel( - x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank, + stride=partition_stride, ) # set data types @@ -652,6 +696,8 @@ def run_layernorm_linear( if bias is not None: layer.bias.copy_(bias) + cls._stamp_partition_stride(layer, partition_stride, tp_size) + # Run one step y_q, ln_out, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) diff --git a/tests/pytorch/distributed/run_tp_invariant.py b/tests/pytorch/distributed/run_tp_invariant.py new file mode 100644 index 0000000000..ed3898a4b1 --- /dev/null +++ b/tests/pytorch/distributed/run_tp_invariant.py @@ -0,0 +1,162 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""NVTE_TP_INVARIANT_MODE distributed test body. Launched via test_tp_invariant.py. + +One invocation runs ONE (parallel_mode, sequence_parallel, expect_bitwise) +combination per pytest parametrize axis: + + expect_bitwise=True (NVTE_TP_INVARIANT_MODE=1, with_tp_invariant): + TP=N == TP=1 bit-for-bit. + expect_bitwise=False (NVTE_TP_INVARIANT_MODE=0, without_tp_invariant): + TP=N != TP=1 (stock TP isn't bitwise; guards + with_tp_invariant against trivial-pass). + +Wgrad is intentionally not compared — patches gate only the forward +(row-parallel) and dgrad (column-parallel) paths. + +Reuses TestDistributedLinearBase from run_numerics_exact for sharding + gather. + +LayerNormLinear with partition_stride=2 (SwiGLU FC1 deinterleave) is covered +by ``_check_tp_invariance_deinterleave`` — positive only (no without_tp_invariant +variant), because stock TP dgrad uses the same products as TP=1 just in a +different accumulation order, so bit-difference is layout-dependent (flaky). +""" + +import argparse +import datetime +import os +import sys + +import run_numerics_exact as rne +import torch +import torch.distributed as dist +from run_numerics_exact import ( + TestDistributedLayerNormLinearBase, + TestDistributedLinearBase, + dist_print, +) + +BATCH, HIDDEN, OUT = 16, 256, 128 +DTYPE = torch.bfloat16 + + +def _run_linear(parallel_mode, sequence_parallel): + """Run TP=1 reference and TP=N for given config; return relevant output pair. + + For parallel_mode='row' returns (y_ref, y_tp). For 'column' returns + (dgrad_ref, dgrad_tp). Both are full (gathered) tensors regardless of + sharding, suitable for direct bitwise comparison. + """ + x, w, bias, gradient = TestDistributedLinearBase._prepare_data( + BATCH, HIDDEN, OUT, use_bias=False, seed=42, dtype=DTYPE, + ) + y_ref, dgrad_ref, _, _ = TestDistributedLinearBase.run_linear( + x, w, bias, gradient, + parallel_mode=None, sequence_parallel=False, + tp_group=None, tp_size=1, rank=0, + ) + y_tp, dgrad_tp, _, _ = TestDistributedLinearBase.run_linear( + x, w, bias, gradient, + parallel_mode=parallel_mode, sequence_parallel=sequence_parallel, + tp_group=rne.NCCL_WORLD, tp_size=rne.WORLD_SIZE, rank=rne.WORLD_RANK, + ) + if parallel_mode == "row": + return y_ref, y_tp + return dgrad_ref, dgrad_tp + + +def _check_tp_invariance(parallel_mode, sequence_parallel, expect_bitwise): + """Run one check; assert bitwise (with_tp_invariant) or non-bitwise (without_tp_invariant).""" + os.environ["NVTE_TP_INVARIANT_MODE"] = "1" if expect_bitwise else "0" + ref, tp = _run_linear(parallel_mode, sequence_parallel) + + if rne.WORLD_RANK != 0: + return + + kind = "fwd" if parallel_mode == "row" else "dgrad" + label = f"{parallel_mode}-parallel {kind} sp={int(sequence_parallel)}" + + if expect_bitwise: + torch.testing.assert_close( + tp, ref, atol=0, rtol=0, + msg=f"{label} not bitwise under NVTE_TP_INVARIANT_MODE=1", + ) + dist_print(f"[with_tp_invariant ] {label}: TP=1 ≡ TP={rne.WORLD_SIZE} bitwise") + else: + assert not torch.equal(tp, ref), ( + f"without_tp_invariant: {label} unexpectedly bitwise under NVTE_TP_INVARIANT_MODE=0" + ) + dist_print(f"[without_tp_invariant] {label}: TP=1 ≠ TP={rne.WORLD_SIZE} (as expected)") + + +def _check_tp_invariance_deinterleave(sequence_parallel): + """LayerNormLinear column-parallel + partition_stride=2 (SwiGLU FC1) TP-invariance. + + Uses MLM's golden stride=2 sharding (added to ``TestDistributedLayerNormLinearBase``) + to construct per-rank interleaved weight; verifies our deinterleave correctly inverts + it so TP=N dgrad bitwise matches the TP=1 reference.""" + os.environ["NVTE_TP_INVARIANT_MODE"] = "1" + x, w, _, g = TestDistributedLinearBase._prepare_data( + BATCH, HIDDEN, OUT, use_bias=False, seed=42, dtype=DTYPE, + ) + _, _, dgrad_ref, _, _ = TestDistributedLayerNormLinearBase.run_layernorm_linear( + x, w, None, g, parallel_mode=None, sequence_parallel=False, + tp_group=None, tp_size=1, rank=0, partition_stride=1, + ) + _, _, dgrad_tp, _, _ = TestDistributedLayerNormLinearBase.run_layernorm_linear( + x, w, None, g, parallel_mode="column", sequence_parallel=sequence_parallel, + tp_group=rne.NCCL_WORLD, tp_size=rne.WORLD_SIZE, rank=rne.WORLD_RANK, + partition_stride=2, + ) + if rne.WORLD_RANK != 0: + return + label = f"LN-Linear stride=2 sp={int(sequence_parallel)}" + torch.testing.assert_close(dgrad_tp, dgrad_ref, atol=0, rtol=0, + msg=f"{label}: not TP-invariant") + dist_print(f"{label}: TP=1 ≡ TP={rne.WORLD_SIZE} bitwise via deinterleave") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--check-type", choices=["linear", "deinterleave"], default="linear") + parser.add_argument("--parallel-mode", choices=["row", "column"]) + parser.add_argument("--sequence-parallel", action="store_true") + parser.add_argument("--expect-bitwise", type=int, choices=[0, 1]) + args = parser.parse_args() + + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + assert world_size <= torch.cuda.device_count() + + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="nccl", rank=rank, world_size=world_size, init_method="env://", + timeout=datetime.timedelta(seconds=60), + device_id=torch.device(f"cuda:{local_rank}"), + ) + + rne.WORLD_RANK, rne.WORLD_SIZE = rank, world_size + rne.NCCL_WORLD = dist.new_group(backend="nccl") + + if args.check_type == "linear": + assert args.parallel_mode is not None, "--parallel-mode required for linear check" + assert args.expect_bitwise is not None, "--expect-bitwise required for linear check" + _check_tp_invariance( + parallel_mode=args.parallel_mode, + sequence_parallel=args.sequence_parallel, + expect_bitwise=bool(args.expect_bitwise), + ) + else: # deinterleave + _check_tp_invariance_deinterleave(sequence_parallel=args.sequence_parallel) + + dist.destroy_process_group() + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/pytorch/distributed/test_tp_invariant.py b/tests/pytorch/distributed/test_tp_invariant.py new file mode 100644 index 0000000000..7d0e6d928b --- /dev/null +++ b/tests/pytorch/distributed/test_tp_invariant.py @@ -0,0 +1,90 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""pytest entry for the NVTE_TP_INVARIANT_MODE distributed correctness test. + +Launches ``run_tp_invariant.py`` under torchrun with per-case CLI args and asserts +the subprocess returns 0. The actual bitwise checks live in ``run_tp_invariant.py``. + +Test matrix (per tp_size): + parallel_mode ∈ {row, column} + sequence_parallel ∈ {False, True} + expect_bitwise ∈ {True (with_tp_invariant, patches on), + False (without_tp_invariant, patches off)} +""" + +import os +import subprocess +from pathlib import Path + +import pytest +import torch + +if torch.cuda.device_count() < 2: + pytest.skip( + "TP-invariant test requires at least 2 GPUs.", + allow_module_level=True, + ) + + +TEST_ROOT = Path(__file__).parent.resolve() + + +def _tp_sizes(): + """TP sizes worth exercising on this node (2 and, if available, 4).""" + n = torch.cuda.device_count() + sizes = [2] + if n >= 4: + sizes.append(4) + return sizes + + +@pytest.mark.parametrize("tp_size", _tp_sizes()) +@pytest.mark.parametrize("parallel_mode", ["row", "column"]) +@pytest.mark.parametrize("sequence_parallel", [False, True]) +@pytest.mark.parametrize("expect_bitwise", [True, False], + ids=["with_tp_invariant", "without_tp_invariant"]) +def test_tp_invariant(tp_size, parallel_mode, sequence_parallel, expect_bitwise): + """One TP-invariant correctness check per parameter combination. + + expect_bitwise=True → NVTE_TP_INVARIANT_MODE=1, TP=N must equal TP=1 bit-for-bit. + expect_bitwise=False → NVTE_TP_INVARIANT_MODE=0, TP=N must DIFFER from TP=1 + (without_tp_invariant, guards against trivial-pass).""" + cmd = [ + "torchrun", + f"--nproc_per_node={tp_size}", + str(TEST_ROOT / "run_tp_invariant.py"), + "--check-type", "linear", + "--parallel-mode", parallel_mode, + "--expect-bitwise", str(int(expect_bitwise)), + ] + if sequence_parallel: + cmd.append("--sequence-parallel") + + result = subprocess.run(cmd, env=os.environ, check=False) + assert result.returncode == 0, ( + f"run_tp_invariant.py failed: tp_size={tp_size}, parallel_mode={parallel_mode}, " + f"sequence_parallel={sequence_parallel}, expect_bitwise={expect_bitwise} " + f"(returncode={result.returncode})" + ) + + +@pytest.mark.parametrize("tp_size", _tp_sizes()) +@pytest.mark.parametrize("sequence_parallel", [False, True]) +def test_tp_invariant_deinterleave(tp_size, sequence_parallel): + """LayerNormLinear column-parallel + partition_stride=2 (SwiGLU FC1 layout): dgrad bitwise + matches TP=1 reference. Uses MLM's golden stride=2 sharding to construct per-rank + interleaved weight; verifies our deinterleave correctly inverts it.""" + cmd = [ + "torchrun", + f"--nproc_per_node={tp_size}", + str(TEST_ROOT / "run_tp_invariant.py"), + "--check-type", "deinterleave", + ] + if sequence_parallel: + cmd.append("--sequence-parallel") + result = subprocess.run(cmd, env=os.environ, check=False) + assert result.returncode == 0, ( + f"deinterleave failed: tp_size={tp_size}, sequence_parallel={sequence_parallel}" + ) diff --git a/transformer_engine/pytorch/module/_tp_invariant.py b/transformer_engine/pytorch/module/_tp_invariant.py new file mode 100644 index 0000000000..f4481531b3 --- /dev/null +++ b/transformer_engine/pytorch/module/_tp_invariant.py @@ -0,0 +1,189 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Internal helpers for the NVTE_TP_INVARIANT_MODE code path. + +Gated on ``NVTE_TP_INVARIANT_MODE=1`` and used by ``linear.py`` and +``layernorm_linear.py``. With the env var unset (default), these helpers are +unreachable and stock GEMM paths are taken unchanged. + +The TP-invariant path performs the full-K (or full-out) GEMM after all-gathering +sharded operands across the TP group. Result: bit-identical numerics across +TP=1/2/4/... because the underlying GEMM K-dimension accumulation order is fixed +regardless of how the operands were sharded. + +Limitations: +- FP8 not supported (callers should assert ``not fp8`` before calling). +- Trades compute for invariance (gathered operands + full GEMM). Off by default. +""" + +from typing import Optional + +import torch + +from ..cpp_extensions import general_gemm +from ..utils import nvtx_range_pop, nvtx_range_push + +__all__ = [ + "allgather_along_dim", + "tp_invariant_row_parallel_gemm", + "tp_invariant_column_parallel_dgrad", +] + + +def allgather_along_dim( + tensor: torch.Tensor, + group, + world_size: int, + dim: int, +) -> torch.Tensor: + """All-gather ``tensor`` from every rank in ``group`` and concat along ``dim``.""" + chunks = [torch.empty_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(chunks, tensor.contiguous(), group=group) + return torch.cat(chunks, dim=dim) + + +def tp_invariant_row_parallel_gemm( + weightmat: torch.Tensor, + inputmat_total: torch.Tensor, + bias: Optional[torch.Tensor], + tp_group, + tp_size: int, + sequence_parallel: bool, + activation_dtype: torch.dtype, + nvtx_label: str = "tp_invariant_gemm", +) -> torch.Tensor: + """Row-parallel forward GEMM with TP-invariant numerics. + + All-gathers input + weight along the contracted (K) dim, runs the full GEMM + (matching TP=1 accumulation order), then scatters along the sequence dim + when ``sequence_parallel=True``. + + Args: + weightmat: Local weight shard of shape ``[out, hidden/TP]``. + inputmat_total: Local input shard of shape ``[..., hidden/TP]``. + bias: Optional bias of shape ``[out]``. + tp_group: TP process group. + tp_size: Size of the TP group. + sequence_parallel: If True, scatter the full output along dim 0. + activation_dtype: Output dtype (matches stock GEMM behavior). + nvtx_label: NVTX range label. + + Returns: + Output of shape ``[..., out]`` (full sequence) or ``[.../TP, out]`` + (sequence-scattered when ``sequence_parallel=True``). + """ + nvtx_range_push(nvtx_label) + + inputmat_gathered = allgather_along_dim(inputmat_total, tp_group, tp_size, dim=-1) + weight_gathered = allgather_along_dim(weightmat, tp_group, tp_size, dim=-1) + + input_2d = inputmat_gathered.reshape(-1, inputmat_gathered.shape[-1]) + out = general_gemm( + weight_gathered, input_2d, + out_dtype=activation_dtype, + bias=bias, + ) + if isinstance(out, tuple): + out = out[0] + out = out.reshape(inputmat_gathered.shape[:-1] + (weight_gathered.shape[0],)) + + if sequence_parallel: + rank = torch.distributed.get_rank(tp_group) + out = out.chunk(tp_size, dim=0)[rank].contiguous() + + nvtx_range_pop(nvtx_label) + return out + + +def tp_invariant_column_parallel_dgrad( + weight: torch.Tensor, + grad_output: torch.Tensor, + tp_group, + tp_size: int, + sequence_parallel: bool, + activation_dtype: torch.dtype, + partition_stride: int = 1, + nvtx_label: str = "tp_invariant_dgrad", +) -> torch.Tensor: + """Column-parallel backward dgrad with TP-invariant numerics. + + All-gathers grad_output (along out dim) and weight (along out dim), runs the + full dgrad GEMM, then scatters along the sequence dim under SP. + + For gated MLPs (e.g. SwiGLU FC1 where each rank holds interleaved [gate|val] + halves), ``partition_stride > 1`` triggers a deinterleave step to recover + the TP=1 layout [gate_all | val_all] before the full GEMM: + + Per-rank layout: [g_0 | v_0 | g_1 | v_1 | ... | g_{TP-1} | v_{TP-1}] + TP=1 native layout: [g_0 | g_1 | ... | g_{TP-1} | v_0 | v_1 | ... | v_{TP-1}] + + For non-gated layers (QKV etc., partition_stride=1) the naive all-gather + already matches the TP=1 ordering. + + Args: + weight: Local weight shard of shape ``[out/TP, in]``. + grad_output: Local grad shard of shape ``[..., out/TP]``. + tp_group: TP process group. + tp_size: Size of TP group. + sequence_parallel: If True, scatter dgrad along dim 0. + activation_dtype: Output dtype. + partition_stride: >1 triggers the deinterleave for gated MLPs. + nvtx_label: NVTX range label. + + Returns: + dgrad of shape ``[..., in]`` or sequence-scattered. + """ + nvtx_range_push(nvtx_label) + + grad_output_gathered = allgather_along_dim(grad_output, tp_group, tp_size, dim=-1) + weight_gathered = allgather_along_dim(weight, tp_group, tp_size, dim=0) + + if partition_stride > 1: + # Deinterleave gated [gate|val] halves to TP=1 [gate_all | val_all]. + # Currently only the 2-way gated split (SwiGLU FC1 layout) is handled. + assert partition_stride == 2, ( + f"deinterleave only supports partition_stride=2 (gated halve); got {partition_stride}" + ) + chunk_sz = weight.shape[0] # out_features per rank + half = chunk_sz // 2 + first_w = [ + weight_gathered[i * chunk_sz : i * chunk_sz + half] + for i in range(tp_size) + ] + second_w = [ + weight_gathered[i * chunk_sz + half : (i + 1) * chunk_sz] + for i in range(tp_size) + ] + weight_gathered = torch.cat(first_w + second_w, dim=0) + + g_dim = grad_output_gathered.shape[-1] // tp_size + g_half = g_dim // 2 + first_g = [ + grad_output_gathered[..., i * g_dim : i * g_dim + g_half] + for i in range(tp_size) + ] + second_g = [ + grad_output_gathered[..., i * g_dim + g_half : (i + 1) * g_dim] + for i in range(tp_size) + ] + grad_output_gathered = torch.cat(first_g + second_g, dim=-1) + + grad_output_2d = grad_output_gathered.reshape( + -1, grad_output_gathered.shape[-1], + ) + dgrad = general_gemm( + weight_gathered, grad_output_2d, + layout="NN", grad=True, + out_dtype=activation_dtype, + ) + if isinstance(dgrad, tuple): + dgrad = dgrad[0] + + if sequence_parallel: + rank = torch.distributed.get_rank(tp_group) + dgrad = dgrad.chunk(tp_size, dim=0)[rank].contiguous() + + nvtx_range_pop(nvtx_label) + return dgrad diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 78bae7dea8..6e57a065b1 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -28,6 +28,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) +from ._tp_invariant import tp_invariant_column_parallel_dgrad from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( assert_dim_for_fp8_exec, @@ -374,15 +375,9 @@ def forward( # ------------------------------------------------------ # Forward GEMM # Note: y = x * w^T - # No _tp_invariant_fwd needed here: LayerNormLinear is column-parallel, - # so forward GEMM K=hidden_size is constant across TP (invariant by construction). - # Only row-parallel forward (linear.py) needs the invariant path (K=in/TP varies). + # No NVTE_TP_INVARIANT_MODE branch: LayerNormLinear is column-parallel only, + # so forward GEMM K=hidden_size does not vary with TP (invariant by construction). # ------------------------------------------------------ - if os.environ.get("NVTE_TP_INVARIANT_MODE", "0") == "1" and parallel_mode == "row": - assert False, ( - "NVTE_TP_INVARIANT_MODE row-parallel forward is not implemented in " - "layernorm_linear.py. Use linear.py for row-parallel layers." - ) nvtx_range_push(f"{nvtx_label}.gemm") gemm_out, *_, reduce_scatter_out = general_gemm( weightmat, @@ -817,64 +812,17 @@ def backward( ) if _tp_invariant_bwd: - # TP-invariant diagnostic: full dgrad GEMM matching TP=1 accumulation. assert not ctx.fp8, "NVTE_TP_INVARIANT_MODE does not support FP8" - nvtx_range_push(f"{nvtx_label}.tp_invariant_dgrad") - - def allgather_along_dim(tensor, group, world_size, dim): - chunks = [torch.empty_like(tensor) for _ in range(world_size)] - torch.distributed.all_gather( - chunks, tensor.contiguous(), group=group, - ) - return torch.cat(chunks, dim=dim) - - grad_output_gathered = allgather_along_dim( - grad_output, ctx.tp_group, ctx.tp_size, dim=-1, - ) # [tokens, out/TP] -> [tokens, out] - weight_gathered = allgather_along_dim( - weight, ctx.tp_group, ctx.tp_size, dim=0, - ) # [out/TP, in] -> [out, in] - - # Deinterleave gathered tensors to match TP=1 K-dimension ordering. - # Only needed for gated linear units (partition_stride > 1, e.g. SwiGLU FC1) - # where each rank stores interleaved [gate_i | value_i]. - # After all-gather: [gate_0|val_0 | gate_1|val_1 | ...]. - # TP=1 layout is [gate_all | val_all]. Reorder to match. - # - # For non-gated layers (partition_stride == 1, e.g. QKV), each rank - # stores contiguous GQA groups, and the naive all-gather already - # produces the correct TP=1 ordering. Deinterleaving would corrupt it. - if ctx.partition_stride > 1: - chunk_sz = weight.shape[0] # out_features per rank - half = chunk_sz // 2 - first_w = [weight_gathered[i * chunk_sz : i * chunk_sz + half] for i in range(ctx.tp_size)] - second_w = [weight_gathered[i * chunk_sz + half : (i + 1) * chunk_sz] for i in range(ctx.tp_size)] - weight_gathered = torch.cat(first_w + second_w, dim=0) - - g_dim = grad_output_gathered.shape[-1] // ctx.tp_size - g_half = g_dim // 2 - first_g = [grad_output_gathered[..., i * g_dim : i * g_dim + g_half] for i in range(ctx.tp_size)] - second_g = [grad_output_gathered[..., i * g_dim + g_half : (i + 1) * g_dim] for i in range(ctx.tp_size)] - grad_output_gathered = torch.cat(first_g + second_g, dim=-1) - - grad_output_2d = grad_output_gathered.reshape( - -1, grad_output_gathered.shape[-1], + dgrad = tp_invariant_column_parallel_dgrad( + weight=weight, + grad_output=grad_output, + tp_group=ctx.tp_group, + tp_size=ctx.tp_size, + sequence_parallel=ctx.sequence_parallel, + activation_dtype=ctx.activation_dtype, + partition_stride=ctx.partition_stride, + nvtx_label=f"{nvtx_label}.tp_invariant_dgrad", ) - dgrad = general_gemm( - weight_gathered, grad_output_2d, - layout="NN", grad=True, - out_dtype=ctx.activation_dtype, - ) - if isinstance(dgrad, tuple): - dgrad = dgrad[0] - - # SP: scatter to per-rank chunk along sequence dim. - if ctx.sequence_parallel: - rank = torch.distributed.get_rank(ctx.tp_group) - dgrad = dgrad.chunk(ctx.tp_size, dim=0)[rank].contiguous() - - del grad_output_gathered, weight_gathered, grad_output_2d - nvtx_range_pop(f"{nvtx_label}.tp_invariant_dgrad") else: nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f20a2cd605..c2c2aff697 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -29,6 +29,10 @@ _2X_ACC_WGRAD, ) from ._common import noop_cat, WeightGradStore +from ._tp_invariant import ( + tp_invariant_column_parallel_dgrad, + tp_invariant_row_parallel_gemm, +) from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( cast_if_needed, @@ -475,13 +479,6 @@ def _linear_forward_impl( # Forward GEMM # Note: y = x * w^T # ------------------------------------------------------ - _fp32_tp_reduce = ( - os.environ.get("NVTE_FP32_TP_REDUCE", "0") == "1" - and parallel_mode == "row" - and args.tp_size > 1 - ) - _gemm_out_dtype = torch.float32 if _fp32_tp_reduce else activation_dtype - _tp_invariant = ( os.environ.get("NVTE_TP_INVARIANT_MODE", "0") == "1" and parallel_mode == "row" @@ -489,46 +486,24 @@ def _linear_forward_impl( ) if _tp_invariant: - # TP-invariant diagnostic: full GEMM matching TP=1 accumulation order. assert not fp8, "NVTE_TP_INVARIANT_MODE does not support FP8" - nvtx_range_push(f"{nvtx_label}.tp_invariant_gemm") - - def allgather_along_dim(tensor, group, world_size, dim): - chunks = [torch.empty_like(tensor) for _ in range(world_size)] - torch.distributed.all_gather(chunks, tensor.contiguous(), group=group) - return torch.cat(chunks, dim=dim) - - inputmat_gathered = allgather_along_dim( - inputmat_total, tp_group, tp_world_size, dim=-1, - ) # [tokens, hidden/TP] -> [tokens, hidden] - weight_gathered = allgather_along_dim( - weightmat, tp_group, tp_world_size, dim=-1, - ) # [out, hidden/TP] -> [out, hidden] - - input_2d = inputmat_gathered.reshape(-1, inputmat_gathered.shape[-1]) - out = general_gemm( - weight_gathered, input_2d, - out_dtype=activation_dtype, + out = tp_invariant_row_parallel_gemm( + weightmat=weightmat, + inputmat_total=inputmat_total, bias=bias, + tp_group=tp_group, + tp_size=tp_world_size, + sequence_parallel=sequence_parallel, + activation_dtype=activation_dtype, + nvtx_label=f"{nvtx_label}.tp_invariant_gemm", ) - if isinstance(out, tuple): - out = out[0] - out = out.reshape(inputmat_gathered.shape[:-1] + (weight_gathered.shape[0],)) - - # SP: scatter to per-rank chunk along sequence dim. - if sequence_parallel: - rank = torch.distributed.get_rank(tp_group) - out = out.chunk(tp_world_size, dim=0)[rank].contiguous() - - del inputmat_gathered, weight_gathered, input_2d - nvtx_range_pop(f"{nvtx_label}.tp_invariant_gemm") else: nvtx_range_push(f"{nvtx_label}.gemm") gemm_out, *_, reduce_scatter_out = general_gemm( weightmat, inputmat_total, quantization_params=output_quantizer, - out_dtype=_gemm_out_dtype, + out_dtype=activation_dtype, bias=bias, use_split_accumulator=use_split_accumulator, ub=ub_obj, @@ -565,8 +540,6 @@ def allgather_along_dim(tensor, group, world_size, dim): out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) else: out, _ = allreduce(out, tp_group) - if _fp32_tp_reduce and out.dtype != activation_dtype: - out = out.to(activation_dtype) nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") else: out = gemm_out @@ -1023,13 +996,6 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. # dgrad GEMM # Note: dx = dy * w - _fp32_tp_reduce_bwd = ( - os.environ.get("NVTE_FP32_TP_REDUCE", "0") == "1" - and bwd_args.parallel_mode == "column" - and bwd_args.tp_size > 1 - ) - _dgrad_out_dtype = torch.float32 if _fp32_tp_reduce_bwd else bwd_args.activation_dtype - _tp_invariant_bwd = ( os.environ.get("NVTE_TP_INVARIANT_MODE", "0") == "1" and bwd_args.parallel_mode == "column" @@ -1037,42 +1003,16 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. ) if _tp_invariant_bwd: - # TP-invariant diagnostic: full dgrad GEMM matching TP=1 accumulation. assert not bwd_args.fp8, "NVTE_TP_INVARIANT_MODE does not support FP8" - nvtx_range_push(f"{nvtx_label}.tp_invariant_dgrad") - - def allgather_along_dim(tensor, group, world_size, dim): - chunks = [torch.empty_like(tensor) for _ in range(world_size)] - torch.distributed.all_gather( - chunks, tensor.contiguous(), group=group, - ) - return torch.cat(chunks, dim=dim) - - grad_output_gathered = allgather_along_dim( - grad_output, bwd_args.tp_group, bwd_args.tp_size, dim=-1, - ) # [tokens, out/TP] -> [tokens, out] - weight_gathered = allgather_along_dim( - weight_fp8, bwd_args.tp_group, bwd_args.tp_size, dim=0, - ) # [out/TP, in] -> [out, in] - - grad_output_2d = grad_output_gathered.reshape( - -1, grad_output_gathered.shape[-1], + dgrad = tp_invariant_column_parallel_dgrad( + weight=weight_fp8, + grad_output=grad_output, + tp_group=bwd_args.tp_group, + tp_size=bwd_args.tp_size, + sequence_parallel=bwd_args.sequence_parallel, + activation_dtype=bwd_args.activation_dtype, + nvtx_label=f"{nvtx_label}.tp_invariant_dgrad", ) - dgrad = general_gemm( - weight_gathered, grad_output_2d, - layout="NN", grad=True, - out_dtype=bwd_args.activation_dtype, - ) - if isinstance(dgrad, tuple): - dgrad = dgrad[0] - - # SP: scatter to per-rank chunk along sequence dim. - if bwd_args.sequence_parallel: - rank = torch.distributed.get_rank(bwd_args.tp_group) - dgrad = dgrad.chunk(bwd_args.tp_size, dim=0)[rank].contiguous() - - del grad_output_gathered, weight_gathered, grad_output_2d - nvtx_range_pop(f"{nvtx_label}.tp_invariant_dgrad") else: nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight_fp8 @@ -1092,7 +1032,7 @@ def allgather_along_dim(tensor, group, world_size, dim): grad=True, quantization_params=grad_input_quantizer, out=gemm_out, - out_dtype=_dgrad_out_dtype, + out_dtype=bwd_args.activation_dtype, use_split_accumulator=use_split_accumulator, ub=ub_obj_dgrad, ub_type=ub_type_dgrad, From 580a369423a83b90ba898b838d9f66ea3f7cba17 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 May 2026 03:21:32 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/distributed/run_numerics_exact.py | 28 +++++-- tests/pytorch/distributed/run_tp_invariant.py | 82 ++++++++++++++----- .../pytorch/distributed/test_tp_invariant.py | 23 ++++-- .../pytorch/module/_tp_invariant.py | 32 ++++---- .../pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/linear.py | 12 ++- 6 files changed, 125 insertions(+), 54 deletions(-) diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py index b99e1f9ac9..5c3175ef3b 100644 --- a/tests/pytorch/distributed/run_numerics_exact.py +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -253,9 +253,9 @@ def _stamp_partition_stride(layer, partition_stride, tp_size): ``ctx.partition_stride``. No-op when stride==1 or tp_size==1. """ if partition_stride > 1 and tp_size > 1: - setattr(layer.weight, 'partition_stride', partition_stride) + setattr(layer.weight, "partition_stride", partition_stride) if layer.bias is not None: - setattr(layer.bias, 'partition_stride', partition_stride) + setattr(layer.bias, "partition_stride", partition_stride) @classmethod def run_linear_preprocess_parallel( @@ -274,7 +274,11 @@ def run_linear_preprocess_parallel( if parallel_mode == "column": # split w in N dim (axis 0), gradient in N dim (axis 1); stride>1 → interleave. w = cls._shard_strided(w, tp_size, rank, dim=0, stride=stride) - bias = cls._shard_strided(bias, tp_size, rank, dim=0, stride=stride) if bias is not None else None + bias = ( + cls._shard_strided(bias, tp_size, rank, dim=0, stride=stride) + if bias is not None + else None + ) gradient = cls._shard_strided(gradient, tp_size, rank, dim=1, stride=stride) if sequence_parallel: # split x in M dim, which should be axis 0 @@ -447,7 +451,14 @@ def run_linear( # If Model parallel: split inputs for a given rank x, w, bias, gradient = cls.run_linear_preprocess_parallel( - x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank, + x, + w, + bias, + gradient, + parallel_mode, + sequence_parallel, + tp_size, + rank, stride=partition_stride, ) @@ -666,7 +677,14 @@ def run_layernorm_linear( # If Model parallel: split inputs for a given rank x, w, bias, gradient = cls.run_linear_preprocess_parallel( - x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank, + x, + w, + bias, + gradient, + parallel_mode, + sequence_parallel, + tp_size, + rank, stride=partition_stride, ) diff --git a/tests/pytorch/distributed/run_tp_invariant.py b/tests/pytorch/distributed/run_tp_invariant.py index ed3898a4b1..56a7bf89be 100644 --- a/tests/pytorch/distributed/run_tp_invariant.py +++ b/tests/pytorch/distributed/run_tp_invariant.py @@ -52,17 +52,34 @@ def _run_linear(parallel_mode, sequence_parallel): sharding, suitable for direct bitwise comparison. """ x, w, bias, gradient = TestDistributedLinearBase._prepare_data( - BATCH, HIDDEN, OUT, use_bias=False, seed=42, dtype=DTYPE, + BATCH, + HIDDEN, + OUT, + use_bias=False, + seed=42, + dtype=DTYPE, ) y_ref, dgrad_ref, _, _ = TestDistributedLinearBase.run_linear( - x, w, bias, gradient, - parallel_mode=None, sequence_parallel=False, - tp_group=None, tp_size=1, rank=0, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, ) y_tp, dgrad_tp, _, _ = TestDistributedLinearBase.run_linear( - x, w, bias, gradient, - parallel_mode=parallel_mode, sequence_parallel=sequence_parallel, - tp_group=rne.NCCL_WORLD, tp_size=rne.WORLD_SIZE, rank=rne.WORLD_RANK, + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=rne.NCCL_WORLD, + tp_size=rne.WORLD_SIZE, + rank=rne.WORLD_RANK, ) if parallel_mode == "row": return y_ref, y_tp @@ -82,14 +99,17 @@ def _check_tp_invariance(parallel_mode, sequence_parallel, expect_bitwise): if expect_bitwise: torch.testing.assert_close( - tp, ref, atol=0, rtol=0, + tp, + ref, + atol=0, + rtol=0, msg=f"{label} not bitwise under NVTE_TP_INVARIANT_MODE=1", ) dist_print(f"[with_tp_invariant ] {label}: TP=1 ≡ TP={rne.WORLD_SIZE} bitwise") else: - assert not torch.equal(tp, ref), ( - f"without_tp_invariant: {label} unexpectedly bitwise under NVTE_TP_INVARIANT_MODE=0" - ) + assert not torch.equal( + tp, ref + ), f"without_tp_invariant: {label} unexpectedly bitwise under NVTE_TP_INVARIANT_MODE=0" dist_print(f"[without_tp_invariant] {label}: TP=1 ≠ TP={rne.WORLD_SIZE} (as expected)") @@ -101,22 +121,43 @@ def _check_tp_invariance_deinterleave(sequence_parallel): it so TP=N dgrad bitwise matches the TP=1 reference.""" os.environ["NVTE_TP_INVARIANT_MODE"] = "1" x, w, _, g = TestDistributedLinearBase._prepare_data( - BATCH, HIDDEN, OUT, use_bias=False, seed=42, dtype=DTYPE, + BATCH, + HIDDEN, + OUT, + use_bias=False, + seed=42, + dtype=DTYPE, ) _, _, dgrad_ref, _, _ = TestDistributedLayerNormLinearBase.run_layernorm_linear( - x, w, None, g, parallel_mode=None, sequence_parallel=False, - tp_group=None, tp_size=1, rank=0, partition_stride=1, + x, + w, + None, + g, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + partition_stride=1, ) _, _, dgrad_tp, _, _ = TestDistributedLayerNormLinearBase.run_layernorm_linear( - x, w, None, g, parallel_mode="column", sequence_parallel=sequence_parallel, - tp_group=rne.NCCL_WORLD, tp_size=rne.WORLD_SIZE, rank=rne.WORLD_RANK, + x, + w, + None, + g, + parallel_mode="column", + sequence_parallel=sequence_parallel, + tp_group=rne.NCCL_WORLD, + tp_size=rne.WORLD_SIZE, + rank=rne.WORLD_RANK, partition_stride=2, ) if rne.WORLD_RANK != 0: return label = f"LN-Linear stride=2 sp={int(sequence_parallel)}" - torch.testing.assert_close(dgrad_tp, dgrad_ref, atol=0, rtol=0, - msg=f"{label}: not TP-invariant") + torch.testing.assert_close( + dgrad_tp, dgrad_ref, atol=0, rtol=0, msg=f"{label}: not TP-invariant" + ) dist_print(f"{label}: TP=1 ≡ TP={rne.WORLD_SIZE} bitwise via deinterleave") @@ -135,7 +176,10 @@ def main(): torch.cuda.set_device(local_rank) dist.init_process_group( - backend="nccl", rank=rank, world_size=world_size, init_method="env://", + backend="nccl", + rank=rank, + world_size=world_size, + init_method="env://", timeout=datetime.timedelta(seconds=60), device_id=torch.device(f"cuda:{local_rank}"), ) diff --git a/tests/pytorch/distributed/test_tp_invariant.py b/tests/pytorch/distributed/test_tp_invariant.py index 7d0e6d928b..62603e0d9d 100644 --- a/tests/pytorch/distributed/test_tp_invariant.py +++ b/tests/pytorch/distributed/test_tp_invariant.py @@ -43,8 +43,9 @@ def _tp_sizes(): @pytest.mark.parametrize("tp_size", _tp_sizes()) @pytest.mark.parametrize("parallel_mode", ["row", "column"]) @pytest.mark.parametrize("sequence_parallel", [False, True]) -@pytest.mark.parametrize("expect_bitwise", [True, False], - ids=["with_tp_invariant", "without_tp_invariant"]) +@pytest.mark.parametrize( + "expect_bitwise", [True, False], ids=["with_tp_invariant", "without_tp_invariant"] +) def test_tp_invariant(tp_size, parallel_mode, sequence_parallel, expect_bitwise): """One TP-invariant correctness check per parameter combination. @@ -55,9 +56,12 @@ def test_tp_invariant(tp_size, parallel_mode, sequence_parallel, expect_bitwise) "torchrun", f"--nproc_per_node={tp_size}", str(TEST_ROOT / "run_tp_invariant.py"), - "--check-type", "linear", - "--parallel-mode", parallel_mode, - "--expect-bitwise", str(int(expect_bitwise)), + "--check-type", + "linear", + "--parallel-mode", + parallel_mode, + "--expect-bitwise", + str(int(expect_bitwise)), ] if sequence_parallel: cmd.append("--sequence-parallel") @@ -80,11 +84,12 @@ def test_tp_invariant_deinterleave(tp_size, sequence_parallel): "torchrun", f"--nproc_per_node={tp_size}", str(TEST_ROOT / "run_tp_invariant.py"), - "--check-type", "deinterleave", + "--check-type", + "deinterleave", ] if sequence_parallel: cmd.append("--sequence-parallel") result = subprocess.run(cmd, env=os.environ, check=False) - assert result.returncode == 0, ( - f"deinterleave failed: tp_size={tp_size}, sequence_parallel={sequence_parallel}" - ) + assert ( + result.returncode == 0 + ), f"deinterleave failed: tp_size={tp_size}, sequence_parallel={sequence_parallel}" diff --git a/transformer_engine/pytorch/module/_tp_invariant.py b/transformer_engine/pytorch/module/_tp_invariant.py index f4481531b3..f9964480f1 100644 --- a/transformer_engine/pytorch/module/_tp_invariant.py +++ b/transformer_engine/pytorch/module/_tp_invariant.py @@ -81,7 +81,8 @@ def tp_invariant_row_parallel_gemm( input_2d = inputmat_gathered.reshape(-1, inputmat_gathered.shape[-1]) out = general_gemm( - weight_gathered, input_2d, + weight_gathered, + input_2d, out_dtype=activation_dtype, bias=bias, ) @@ -143,39 +144,36 @@ def tp_invariant_column_parallel_dgrad( if partition_stride > 1: # Deinterleave gated [gate|val] halves to TP=1 [gate_all | val_all]. # Currently only the 2-way gated split (SwiGLU FC1 layout) is handled. - assert partition_stride == 2, ( - f"deinterleave only supports partition_stride=2 (gated halve); got {partition_stride}" - ) + assert ( + partition_stride == 2 + ), f"deinterleave only supports partition_stride=2 (gated halve); got {partition_stride}" chunk_sz = weight.shape[0] # out_features per rank half = chunk_sz // 2 - first_w = [ - weight_gathered[i * chunk_sz : i * chunk_sz + half] - for i in range(tp_size) - ] + first_w = [weight_gathered[i * chunk_sz : i * chunk_sz + half] for i in range(tp_size)] second_w = [ - weight_gathered[i * chunk_sz + half : (i + 1) * chunk_sz] - for i in range(tp_size) + weight_gathered[i * chunk_sz + half : (i + 1) * chunk_sz] for i in range(tp_size) ] weight_gathered = torch.cat(first_w + second_w, dim=0) g_dim = grad_output_gathered.shape[-1] // tp_size g_half = g_dim // 2 first_g = [ - grad_output_gathered[..., i * g_dim : i * g_dim + g_half] - for i in range(tp_size) + grad_output_gathered[..., i * g_dim : i * g_dim + g_half] for i in range(tp_size) ] second_g = [ - grad_output_gathered[..., i * g_dim + g_half : (i + 1) * g_dim] - for i in range(tp_size) + grad_output_gathered[..., i * g_dim + g_half : (i + 1) * g_dim] for i in range(tp_size) ] grad_output_gathered = torch.cat(first_g + second_g, dim=-1) grad_output_2d = grad_output_gathered.reshape( - -1, grad_output_gathered.shape[-1], + -1, + grad_output_gathered.shape[-1], ) dgrad = general_gemm( - weight_gathered, grad_output_2d, - layout="NN", grad=True, + weight_gathered, + grad_output_2d, + layout="NN", + grad=True, out_dtype=activation_dtype, ) if isinstance(dgrad, tuple): diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 6e57a065b1..bdf96a6630 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -523,7 +523,7 @@ def forward( ctx.input_quantizer = input_quantizer ctx.owns_input = inputmat is not inp ctx.weight = weight - ctx.partition_stride = getattr(weight, 'partition_stride', 1) + ctx.partition_stride = getattr(weight, "partition_stride", 1) ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index c2c2aff697..5e9988b42c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1018,13 +1018,19 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. weight_for_dgrad = weight_fp8 if bwd_args.backward_override == "dequantized": if isinstance(weight_for_dgrad, QuantizedTensorStorage): - weight_for_dgrad = weight_for_dgrad.dequantize(dtype=bwd_args.activation_dtype) + weight_for_dgrad = weight_for_dgrad.dequantize( + dtype=bwd_args.activation_dtype + ) else: - weight_for_dgrad = cast_if_needed(weight_for_dgrad, bwd_args.activation_dtype) + weight_for_dgrad = cast_if_needed( + weight_for_dgrad, bwd_args.activation_dtype + ) elif bwd_args.backward_override == "high_precision": weight_for_dgrad = saved_weight if isinstance(weight_for_dgrad, QuantizedTensorStorage): - weight_for_dgrad = weight_for_dgrad.dequantize(dtype=bwd_args.activation_dtype) + weight_for_dgrad = weight_for_dgrad.dequantize( + dtype=bwd_args.activation_dtype + ) gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, From ba9f89269fc06767cb67da418e77e2acd1ff3cb0 Mon Sep 17 00:00:00 2001 From: Jinze Xue Date: Tue, 12 May 2026 23:37:25 -0700 Subject: [PATCH 4/4] _tp_invariant: take gemm_fn from caller (honor downstream patches) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Importing ``general_gemm`` directly into ``_tp_invariant.py`` bypasses downstream monkey-patches that rebind it in caller modules — e.g., Megatron-LM's batch-invariant kernels patch the ``general_gemm`` symbol inside ``module.linear`` and ``module.layernorm_linear`` (their hardcoded target list), but not our new ``module._tp_invariant``. Result: the helper silently called the unpatched cuBLAS path, producing different bits than the BIK Triton path used elsewhere. Pass ``gemm_fn`` from the caller so the helper uses whichever ``general_gemm`` binding the caller's namespace holds. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Jinze Xue --- transformer_engine/pytorch/module/_tp_invariant.py | 13 +++++++++---- .../pytorch/module/layernorm_linear.py | 1 + transformer_engine/pytorch/module/linear.py | 2 ++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/_tp_invariant.py b/transformer_engine/pytorch/module/_tp_invariant.py index f9964480f1..c9a85eb0ce 100644 --- a/transformer_engine/pytorch/module/_tp_invariant.py +++ b/transformer_engine/pytorch/module/_tp_invariant.py @@ -13,16 +13,19 @@ TP=1/2/4/... because the underlying GEMM K-dimension accumulation order is fixed regardless of how the operands were sharded. +``gemm_fn`` is passed in by the caller so downstream monkey-patches of +``general_gemm`` in the caller's namespace (e.g., Megatron's batch-invariant +kernels) are honored. + Limitations: - FP8 not supported (callers should assert ``not fp8`` before calling). - Trades compute for invariance (gathered operands + full GEMM). Off by default. """ -from typing import Optional +from typing import Callable, Optional import torch -from ..cpp_extensions import general_gemm from ..utils import nvtx_range_pop, nvtx_range_push __all__ = [ @@ -52,6 +55,7 @@ def tp_invariant_row_parallel_gemm( tp_size: int, sequence_parallel: bool, activation_dtype: torch.dtype, + gemm_fn: Callable, nvtx_label: str = "tp_invariant_gemm", ) -> torch.Tensor: """Row-parallel forward GEMM with TP-invariant numerics. @@ -80,7 +84,7 @@ def tp_invariant_row_parallel_gemm( weight_gathered = allgather_along_dim(weightmat, tp_group, tp_size, dim=-1) input_2d = inputmat_gathered.reshape(-1, inputmat_gathered.shape[-1]) - out = general_gemm( + out = gemm_fn( weight_gathered, input_2d, out_dtype=activation_dtype, @@ -105,6 +109,7 @@ def tp_invariant_column_parallel_dgrad( tp_size: int, sequence_parallel: bool, activation_dtype: torch.dtype, + gemm_fn: Callable, partition_stride: int = 1, nvtx_label: str = "tp_invariant_dgrad", ) -> torch.Tensor: @@ -169,7 +174,7 @@ def tp_invariant_column_parallel_dgrad( -1, grad_output_gathered.shape[-1], ) - dgrad = general_gemm( + dgrad = gemm_fn( weight_gathered, grad_output_2d, layout="NN", diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index bdf96a6630..9067b65b9c 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -820,6 +820,7 @@ def backward( tp_size=ctx.tp_size, sequence_parallel=ctx.sequence_parallel, activation_dtype=ctx.activation_dtype, + gemm_fn=general_gemm, partition_stride=ctx.partition_stride, nvtx_label=f"{nvtx_label}.tp_invariant_dgrad", ) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 5e9988b42c..9a8b305c7a 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -495,6 +495,7 @@ def _linear_forward_impl( tp_size=tp_world_size, sequence_parallel=sequence_parallel, activation_dtype=activation_dtype, + gemm_fn=general_gemm, nvtx_label=f"{nvtx_label}.tp_invariant_gemm", ) else: @@ -1011,6 +1012,7 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. tp_size=bwd_args.tp_size, sequence_parallel=bwd_args.sequence_parallel, activation_dtype=bwd_args.activation_dtype, + gemm_fn=general_gemm, nvtx_label=f"{nvtx_label}.tp_invariant_dgrad", ) else: