Skip to content

feat(jax): add 1g1p intranode and internode deepep for jax#344

Open
llying-001 wants to merge 20 commits into
mainfrom
dev/llying/jax-deepep-internode-1g1p
Open

feat(jax): add 1g1p intranode and internode deepep for jax#344
llying-001 wants to merge 20 commits into
mainfrom
dev/llying/jax-deepep-internode-1g1p

Conversation

@llying-001
Copy link
Copy Markdown
Collaborator

Summary

This PR adds JAX DeepEP internode (multi-node) support plus a new 1-GPU-per-process launch mode for both intranode and internode DeepEP dispatch/combine, on top of the existing in-process (pmap) path.

Previously the JAX DeepEP bindings only supported the in-process (single-process, multi-GPU via pmap) execution model and were intranode-only. To run DeepEP at MoE training scale on ROCm we need:

  1. A per-process execution model so each rank owns exactly one GPU and synchronises via IPC handles + rocSHMEM (matching the torch deep_ep.Buffer(group=ProcessGroup, ...) topology).
  2. Internode dispatch/combine kernels exposed to JAX so EP can span multiple nodes over RDMA.
  3. A way to pin the EP communication domain to a subset of JAX processes (e.g. one mesh axis carries EP while others carry FSDP/DP).

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • New JAX DeepEP runtime layer (primus_turbo/jax/deep_ep/runtime.py, deep_ep/__init__.py):
    • Introduces LaunchMode with MODE_INPROC and MODE_PER_PROCESS, selected via the PRIMUS_TURBO_JAX_DEEPEP_MODE env var.
    • Adds set_ep_group(ep_ranks) / reset_runtime() to decouple the EP rank set from the full JAX world; warns once when EP falls back to jax.process_count().
    • ensure_deepep_runtime() bootstraps the C++ buffer once per process; IPC handles and the rocSHMEM unique id are exchanged via a JAX kv-store all-gather instead of jax.lax.all_gather.
  • New JAX primitives & lax wrappers for internode DeepEP (primitive/moe/moe_dispatch.py, primitive/moe/moe_combine.py, lax/moe/moe_dispatch_combine.py):
    • moe_internode_dispatch_p, moe_internode_cached_dispatch_p, moe_internode_combine_p and their per_process FFI lowerings.
    • Existing moe_dispatch / moe_combine now thread ep_size and launch_mode attrs down to the FFI layer and validate them against the visible device count.
    • Adds a warmup(hidden_bytes, ...) helper to eagerly initialise the runtime outside jax.jit.
  • C++ side (csrc/jax/bindings_jax.cpp, csrc/jax/deep_ep/deep_ep.{h,cpp}, csrc/jax/deep_ep/handler.cpp, csrc/jax/extensions.h):
    • New FFI handlers: moe_dispatch_inproc / moe_dispatch_per_process (and cached/combine variants) plus moe_internode_{dispatch,cached_dispatch,combine}_per_process.
    • New Buffer::InternodeDispatch / Buffer::InternodeCombine plumbing wired to the existing rocSHMEM kernels.
    • New pybind submodule primus_turbo.jax._C.deep_ep exposing create_per_process_buffer, sync_per_process_buffer, destroy_per_process_buffer, is_per_process_buffer_ready, per_process_buffer_nvl_bytes, and (when rocSHMEM is enabled) get_unique_id.
    • Buffer lifecycle hardening: explicit Destroy() with safe fallback in the destructor, NVL peer barrier before closing IPC handles, IPC-handle close only when ipc_synced_, rocSHMEM finalize guarded by DISABLE_ROCSHMEM, plus null-pointer checks.
  • Bug fixes:
    • Fix combine hang caused by wrong handle-matrix order in cached dispatch path.
    • Remove the device-side spin on moe_recv_counter_mapped in notify_dispatch when num_worst_tokens > 0 (XLA static-shape path) — the host skips its CPU sync there, so back-to-back launches would otherwise spin forever on a stale counter (csrc/kernels/deep_ep/internode.cu).
    • Fix compile error introduced during the internode rewiring.
  • Build (tools/build_utils.py):
    • Auto-discover linkable system libraries (e.g. libionic) via LD_LIBRARY_PATH/LIBRARY_PATH/multiarch defaults when wiring librocshmem.a.
  • Tests & infra:
    • New tests/jax/lax/test_mp_dispatch_combine.py covering multi-process (1 GPU/proc) dispatch/combine via the per_process mode, with a new JaxMultiProcessTestCase helper in tests/jax/test_utils.py.
    • tests/conftest.py recognises JaxMultiProcessTestCase under --dist-only.
    • Remove the obsolete internode UT test.
  • Docs: updated benchmark/README.md with the multi-node DeepEP launch instructions.

Checklist:

  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copilot AI review requested due to automatic review settings May 20, 2026 05:17
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends Primus-Turbo’s JAX DeepEP integration to support a 1-GPU-per-process execution model and adds internode (RDMA/rocSHMEM) dispatch/combine paths, enabling multi-process/multi-node Expert Parallel (EP) operation beyond the existing in-process (pmap) intranode mode.

Changes:

  • Introduces a new JAX DeepEP runtime layer to manage launch mode selection, EP-group sizing/pinning, and per-process IPC/RDMA buffer bootstrap.
  • Adds new JAX primitives/lax wrappers for per-process and internode DeepEP dispatch/combine, plus a warmup() helper for eager runtime initialization.
  • Implements corresponding C++ FFI handlers and buffer lifecycle hardening, adds multi-process JAX tests, and improves build-time library discovery for rocSHMEM linking.

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
tools/build_utils.py Adds multiarch/env-based library search and optional linking of libionic to improve rocSHMEM static link wiring.
tests/jax/test_utils.py Adds a spawn-based 1-GPU-per-process JAX distributed test harness.
tests/jax/lax/test_mp_dispatch_combine.py New multi-process DeepEP dispatch/combine test coverage (BF16/FP8, fwd/bwd, eval_shape).
tests/conftest.py Marks JaxMultiProcessTestCase as distributed under --dist-only.
primus_turbo/jax/primitive/moe/moe_dispatch.py Adds internode dispatch primitives and routes lowering targets based on DeepEP launch mode.
primus_turbo/jax/primitive/moe/moe_combine.py Adds internode combine primitive and mode-aware lowering for combine.
primus_turbo/jax/lax/moe/moe_dispatch_combine.py Threads ep_size/launch_mode, adds internode dispatch/combine paths, and introduces warmup().
primus_turbo/jax/lax/moe/init.py Re-exports warmup from the MoE lax API.
primus_turbo/jax/deep_ep/runtime.py New runtime manager for mode selection, EP-group pinning, KV-store allgather bootstrap, and per-process buffer lifecycle.
primus_turbo/jax/deep_ep/init.py New package init re-exporting selected DeepEP runtime helpers.
primus_turbo/jax/init.py Defers heavy imports into initialize() and makes initialization idempotent.
csrc/kernels/deep_ep/internode.cu Fixes a potential device-side spin/hang in static-shape internode dispatch (num_worst_tokens > 0).
csrc/jax/extensions.h Declares new DeepEP per-process and internode handler symbols.
csrc/jax/deep_ep/handler.cpp Adds per-process/internode FFI handlers and validates mode/ep_size consistency.
csrc/jax/deep_ep/deep_ep.h Extends Buffer API for IPC handle sync and adds internode dispatch/combine plumbing declarations.
csrc/jax/deep_ep/deep_ep.cpp Implements per-process buffer singleton, IPC sync, internode dispatch/combine, and hardens destroy/finalize behavior.
csrc/jax/bindings_jax.cpp Registers new FFI targets and exposes _C.deep_ep pybind submodule for per-process buffer management and rocSHMEM helpers.
benchmark/README.md Updates benchmark instructions to use test_intranode.py / test_internode.py.
Comments suppressed due to low confidence (1)

tests/jax/test_utils.py:217

  • If any worker reports an exception, the code raises immediately without ensuring remaining child processes are terminated. This can leave live processes running (or hanging in jax.distributed) and destabilize the rest of the test session. Ensure processes are killed/joined in a finally before raising, and include timeouts/exitcode checks even when errors is non-empty.
        if errors:
            msg = "\n".join(f"--- Rank {r} ---\n{tb}" for r, tb in errors)
            raise AssertionError(f"Multi-process test failed:\n{msg}")


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/jax/test_utils.py Outdated
Comment thread primus_turbo/jax/__init__.py
Comment thread primus_turbo/jax/deep_ep/__init__.py
Comment thread primus_turbo/jax/lax/moe/moe_dispatch_combine.py Outdated
Copilot AI review requested due to automatic review settings May 20, 2026 09:01
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 18 out of 18 changed files in this pull request and generated 2 comments.

Comments suppressed due to low confidence (1)

primus_turbo/jax/lax/moe/moe_dispatch_combine.py:519

  • Same issue as dispatch: get_ep_size(lock=True) / get_launch_mode(lock=True) are called even when x is a Tracer (so ensure_deepep_runtime() is skipped), which can lock the runtime mode to the env var default (inproc) before auto_detect_mode() has a chance to set per_process in 1-GPU-per-process setups. Consider calling deep_ep_runtime.auto_detect_mode() before locking, or only locking when not tracing.
    config = get_combine_config() if config is None else config
    if not isinstance(x, jax.core.Tracer):
        deep_ep_runtime.ensure_deepep_runtime(hidden_bytes=_get_hidden_bytes(x), config=config)
    ep_size = deep_ep_runtime.get_ep_size(lock=True)
    launch_mode = deep_ep_runtime.get_launch_mode(lock=True)
    internode = deep_ep_runtime.is_internode(lock=True)

Comment on lines 206 to +213
config = get_dispatch_config() if config is None else config
if not isinstance(x, jax.core.Tracer):
deep_ep_runtime.ensure_deepep_runtime(hidden_bytes=_get_hidden_bytes(x), config=config)
ep_size = deep_ep_runtime.get_ep_size(lock=True)
launch_mode = deep_ep_runtime.get_launch_mode(lock=True)
num_worst_tokens = num_tokens * ep_size
internode = deep_ep_runtime.is_internode(lock=True)

Comment on lines 151 to 154
Buffer::~Buffer() noexcept(false) {
if (not explicitly_destroy_) {
Destroy();
} else if (not destroyed_) {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants