Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions tests/jax/test_distributed_moe_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Distributed tests for the experimental ``transformer_engine.jax.flax._MoEBlock``."""

import sys

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax.sharding import Mesh, PartitionSpec

from utils import assert_allclose, is_devices_enough


@pytest.fixture(autouse=True, scope="function")
def _inject_moe(request):
"""Lazy-load ``_MoEBlock`` only for tests marked ``triton``."""
if not request.node.get_closest_marker("triton"):
yield
return

from transformer_engine.jax import MeshResource, autocast

# The class is intentionally exposed as ``_MoEBlock`` (experimental);
# aliasing to ``MoEBlock`` here keeps the test bodies readable.
from transformer_engine.jax.flax import _MoEBlock as MoEBlock
from transformer_engine.jax.flax.moe import PermutationBackend

mod = sys.modules[__name__]
mod.MeshResource = MeshResource
mod.autocast = autocast
mod.MoEBlock = MoEBlock
mod.PermutationBackend = PermutationBackend
yield


DTYPE = jnp.bfloat16
# Must be divisible by ep*fsdp = 4 so the batch dim can be sharded over
# the full ('ep','fsdp') axis tuple under Experiment 3.
BATCH_SIZE = 4
SEQUENCE_LENGTH = 16
HIDDEN_SIZE = 64
INTERMEDIATE_SIZE = 128
NUM_EXPERTS = 8
NUM_EXPERTS_PER_TOK = 2


def _make_inputs(key: jax.Array) -> jax.Array:
return jax.random.normal(key, (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=DTYPE)


def _unwrap_partitioned(x):
return x.value if hasattr(x, "value") else x


@pytest.mark.triton
class TestDistributedMoEBlock:
@pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"])
def test_ep2_fsdp2_matches_single_device(self, permutation_backend):
if not is_devices_enough(4):
pytest.skip("MoE distributed test requires 4 devices for EP=2 x FSDP=2.")

permutation_backend = PermutationBackend(permutation_backend)
key = jax.random.PRNGKey(11)
init_key, data_key = jax.random.split(key)
inputs = _make_inputs(data_key)

base_kwargs = dict(
num_experts=NUM_EXPERTS,
num_experts_per_tok=NUM_EXPERTS_PER_TOK,
intermediate_size=INTERMEDIATE_SIZE,
permutation_backend=permutation_backend,
aux_loss_coeff=1e-2,
dtype=DTYPE,
)

single_block = MoEBlock(**base_kwargs)

def _make_loss_and_grad(block):
"""Build a jitted ``value_and_grad`` over ``(variables, x)``.

Capturing ``block`` in a closure (so it isn't a jit input)
sidesteps having to mark it as static -- Flax modules are
registered pytrees but they carry Python-level config that
jit treats as part of the trace.
"""

def loss_fn(variables, x):
output, aux_loss = block.apply(variables, x)
loss = jnp.mean(output.astype(jnp.float32) ** 2)
if aux_loss is not None:
loss = loss + aux_loss.astype(jnp.float32)
return loss, (output, aux_loss)

return jax.jit(jax.value_and_grad(loss_fn, has_aux=True))

with autocast(enabled=False, mesh_resource=MeshResource()):
single_variables = single_block.init(init_key, inputs)
(single_loss, (single_output, single_aux)), single_grads = _make_loss_and_grad(
single_block
)(single_variables, inputs)

devices = np.asarray(jax.devices()[:4]).reshape(2, 2)
mesh = Mesh(devices, ("ep", "fsdp"))
# FSDP-style sharding: weights are sharded on a *non-contracting*
# weight axis (gathered before the GEMM); activations stay sharded on
# the *batch* axis throughout - the same fsdp mesh axis is reused for
# both. The TE primitives' custom_partitioning rules expect activations
# FSDP-sharded on batch, so we declare ("batch", "fsdp") AND pass
# ``input_axes=("batch", None, None)`` to enforce it on the inputs to
# the block. ("embed", "fsdp") shards the weight's hidden dim, which
# is gathered inside grouped_dense's custom_partitioning before GEMM
# (no reshard of activations needed because their layout is unchanged).
logical_axis_rules = (
("exp", "ep"),
("batch", "fsdp"),
("embed", "fsdp"),
)
# ``data_parallelism_axes=("fsdp",)`` opts in to the true-FSDP
# behavior: the ``shard_map``'s in_specs/out_specs become
# ``P(("ep","fsdp"), None, None)`` for the batch dim, so each
# device owns ``B/(ep*fsdp)`` unique tokens (no redundant compute
# across fsdp peers within an ep group).
sharded_block = MoEBlock(
data_parallelism_axes=("fsdp",),
input_axes=("batch", None, None),
**base_kwargs,
)

# ``MoEBlock`` resolves the EP axis from
# ``global_mesh_resource().ep_resource`` (set via ``autocast``),
# so the ``ep`` axis on the mesh is wired in by passing
# ``ep_resource="ep"`` here -- no per-instance config needed.
with mesh, autocast(
enabled=False,
mesh_resource=MeshResource(fsdp_resource="fsdp", ep_resource="ep"),
):
with nn.logical_axis_rules(logical_axis_rules):
# ``MoEBlock`` registers params via ``with_logical_partitioning``
# which only attaches LogicallyPartitioned metadata; the
# underlying jax.Array stays single-device unless ``init``
# is run inside ``jax.jit`` with ``out_shardings``. Use the
# canonical Flax-Linen pattern (mirrors
# ``examples/jax/encoder/test_model_parallel_encoder.py``):
# 1. ``jax.eval_shape`` to trace abstract variables (keeps
# the LogicallyPartitioned wrappers; only the inner
# arrays become ShapeDtypeStruct);
# 2. ``nn.get_partition_spec`` to extract a tree of logical
# PartitionSpecs from those wrappers (treats
# LogicallyPartitioned as a leaf);
# 3. ``nn.logical_to_mesh_sharding`` to resolve those
# logical specs to NamedShardings via the active rules;
# 4. ``jax.jit(init, out_shardings=...)`` to actually
# place the params on-device with those shardings.
abstract_variables = jax.eval_shape(sharded_block.init, init_key, inputs)
logical_partition_spec = nn.get_partition_spec(abstract_variables)
out_shardings = nn.logical_to_mesh_sharding(
logical_partition_spec, mesh, logical_axis_rules
)
sharded_variables = jax.jit(sharded_block.init, out_shardings=out_shardings)(
init_key, inputs
)
(sharded_loss, (sharded_output, sharded_aux)), sharded_grads = _make_loss_and_grad(
sharded_block
)(sharded_variables, inputs)

wi_0 = _unwrap_partitioned(sharded_variables["params"]["wi_0"])
wi_1 = _unwrap_partitioned(sharded_variables["params"]["wi_1"])
wo = _unwrap_partitioned(sharded_variables["params"]["wo"])
assert wi_0.sharding.spec == PartitionSpec("ep", "fsdp", None)
assert wi_1.sharding.spec == PartitionSpec("ep", "fsdp", None)
assert wo.sharding.spec == PartitionSpec("ep", None, "fsdp")

assert_allclose(sharded_output, single_output, dtype=DTYPE, atol=5e-2, rtol=5e-2)
assert_allclose(sharded_loss, single_loss, dtype=jnp.float32, atol=5e-2, rtol=5e-2)
assert_allclose(sharded_aux, single_aux, dtype=jnp.float32, atol=5e-2, rtol=5e-2)

# The sharded path runs the same math on each ep-shard but
# accumulates gradients via psum across (ep, fsdp), which changes
# floating-point reduction order vs the single-device run. Under
# bf16 with these toy shapes the observed max-abs grad diff is on
# the order of a few units of bf16 eps (~1e-2). 5e-2 / 5e-2
# leaves headroom for accumulation jitter without masking real
# divergence; matches the cross-backend bf16 grad tolerance in
# ``tests/jax/test_moe_block.py::test_pure_jax_matches_triton``.
for name in ("gate_kernel", "wi_0", "wi_1", "wo"):
grad_single = _unwrap_partitioned(single_grads["params"][name])
grad_sharded = _unwrap_partitioned(sharded_grads["params"][name])
assert_allclose(
grad_sharded,
grad_single,
dtype=DTYPE,
atol=5e-2,
rtol=5e-2,
err_msg=f"Distributed gradient mismatch for {name}",
)
Loading
Loading