feat(jax): add 1g1p intranode and internode deepep for jax#344
Open
llying-001 wants to merge 20 commits into
Open
feat(jax): add 1g1p intranode and internode deepep for jax#344llying-001 wants to merge 20 commits into
llying-001 wants to merge 20 commits into
Conversation
… dev/llying/jax-deepep-internode-1g1p
Contributor
There was a problem hiding this comment.
Pull request overview
This PR extends Primus-Turbo’s JAX DeepEP integration to support a 1-GPU-per-process execution model and adds internode (RDMA/rocSHMEM) dispatch/combine paths, enabling multi-process/multi-node Expert Parallel (EP) operation beyond the existing in-process (pmap) intranode mode.
Changes:
- Introduces a new JAX DeepEP runtime layer to manage launch mode selection, EP-group sizing/pinning, and per-process IPC/RDMA buffer bootstrap.
- Adds new JAX primitives/lax wrappers for per-process and internode DeepEP dispatch/combine, plus a
warmup()helper for eager runtime initialization. - Implements corresponding C++ FFI handlers and buffer lifecycle hardening, adds multi-process JAX tests, and improves build-time library discovery for rocSHMEM linking.
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| tools/build_utils.py | Adds multiarch/env-based library search and optional linking of libionic to improve rocSHMEM static link wiring. |
| tests/jax/test_utils.py | Adds a spawn-based 1-GPU-per-process JAX distributed test harness. |
| tests/jax/lax/test_mp_dispatch_combine.py | New multi-process DeepEP dispatch/combine test coverage (BF16/FP8, fwd/bwd, eval_shape). |
| tests/conftest.py | Marks JaxMultiProcessTestCase as distributed under --dist-only. |
| primus_turbo/jax/primitive/moe/moe_dispatch.py | Adds internode dispatch primitives and routes lowering targets based on DeepEP launch mode. |
| primus_turbo/jax/primitive/moe/moe_combine.py | Adds internode combine primitive and mode-aware lowering for combine. |
| primus_turbo/jax/lax/moe/moe_dispatch_combine.py | Threads ep_size/launch_mode, adds internode dispatch/combine paths, and introduces warmup(). |
| primus_turbo/jax/lax/moe/init.py | Re-exports warmup from the MoE lax API. |
| primus_turbo/jax/deep_ep/runtime.py | New runtime manager for mode selection, EP-group pinning, KV-store allgather bootstrap, and per-process buffer lifecycle. |
| primus_turbo/jax/deep_ep/init.py | New package init re-exporting selected DeepEP runtime helpers. |
| primus_turbo/jax/init.py | Defers heavy imports into initialize() and makes initialization idempotent. |
| csrc/kernels/deep_ep/internode.cu | Fixes a potential device-side spin/hang in static-shape internode dispatch (num_worst_tokens > 0). |
| csrc/jax/extensions.h | Declares new DeepEP per-process and internode handler symbols. |
| csrc/jax/deep_ep/handler.cpp | Adds per-process/internode FFI handlers and validates mode/ep_size consistency. |
| csrc/jax/deep_ep/deep_ep.h | Extends Buffer API for IPC handle sync and adds internode dispatch/combine plumbing declarations. |
| csrc/jax/deep_ep/deep_ep.cpp | Implements per-process buffer singleton, IPC sync, internode dispatch/combine, and hardens destroy/finalize behavior. |
| csrc/jax/bindings_jax.cpp | Registers new FFI targets and exposes _C.deep_ep pybind submodule for per-process buffer management and rocSHMEM helpers. |
| benchmark/README.md | Updates benchmark instructions to use test_intranode.py / test_internode.py. |
Comments suppressed due to low confidence (1)
tests/jax/test_utils.py:217
- If any worker reports an exception, the code raises immediately without ensuring remaining child processes are terminated. This can leave live processes running (or hanging in
jax.distributed) and destabilize the rest of the test session. Ensure processes are killed/joined in afinallybefore raising, and include timeouts/exitcode checks even whenerrorsis non-empty.
if errors:
msg = "\n".join(f"--- Rank {r} ---\n{tb}" for r, tb in errors)
raise AssertionError(f"Multi-process test failed:\n{msg}")
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Contributor
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 18 out of 18 changed files in this pull request and generated 2 comments.
Comments suppressed due to low confidence (1)
primus_turbo/jax/lax/moe/moe_dispatch_combine.py:519
- Same issue as dispatch:
get_ep_size(lock=True)/get_launch_mode(lock=True)are called even whenxis aTracer(soensure_deepep_runtime()is skipped), which can lock the runtime mode to the env var default (inproc) beforeauto_detect_mode()has a chance to setper_processin 1-GPU-per-process setups. Consider callingdeep_ep_runtime.auto_detect_mode()before locking, or only locking when not tracing.
config = get_combine_config() if config is None else config
if not isinstance(x, jax.core.Tracer):
deep_ep_runtime.ensure_deepep_runtime(hidden_bytes=_get_hidden_bytes(x), config=config)
ep_size = deep_ep_runtime.get_ep_size(lock=True)
launch_mode = deep_ep_runtime.get_launch_mode(lock=True)
internode = deep_ep_runtime.is_internode(lock=True)
Comment on lines
206
to
+213
| config = get_dispatch_config() if config is None else config | ||
| if not isinstance(x, jax.core.Tracer): | ||
| deep_ep_runtime.ensure_deepep_runtime(hidden_bytes=_get_hidden_bytes(x), config=config) | ||
| ep_size = deep_ep_runtime.get_ep_size(lock=True) | ||
| launch_mode = deep_ep_runtime.get_launch_mode(lock=True) | ||
| num_worst_tokens = num_tokens * ep_size | ||
| internode = deep_ep_runtime.is_internode(lock=True) | ||
|
|
Comment on lines
151
to
154
| Buffer::~Buffer() noexcept(false) { | ||
| if (not explicitly_destroy_) { | ||
| Destroy(); | ||
| } else if (not destroyed_) { |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds JAX DeepEP internode (multi-node) support plus a new 1-GPU-per-process launch mode for both intranode and internode DeepEP dispatch/combine, on top of the existing in-process (pmap) path.
Previously the JAX DeepEP bindings only supported the in-process (single-process, multi-GPU via
pmap) execution model and were intranode-only. To run DeepEP at MoE training scale on ROCm we need:deep_ep.Buffer(group=ProcessGroup, ...)topology).Fixes # (issue)
Type of change
Changes
primus_turbo/jax/deep_ep/runtime.py,deep_ep/__init__.py):LaunchModewithMODE_INPROCandMODE_PER_PROCESS, selected via thePRIMUS_TURBO_JAX_DEEPEP_MODEenv var.set_ep_group(ep_ranks)/reset_runtime()to decouple the EP rank set from the full JAX world; warns once when EP falls back tojax.process_count().ensure_deepep_runtime()bootstraps the C++ buffer once per process; IPC handles and the rocSHMEM unique id are exchanged via a JAXkv-storeall-gather instead ofjax.lax.all_gather.primitive/moe/moe_dispatch.py,primitive/moe/moe_combine.py,lax/moe/moe_dispatch_combine.py):moe_internode_dispatch_p,moe_internode_cached_dispatch_p,moe_internode_combine_pand theirper_processFFI lowerings.moe_dispatch/moe_combinenow threadep_sizeandlaunch_modeattrs down to the FFI layer and validate them against the visible device count.warmup(hidden_bytes, ...)helper to eagerly initialise the runtime outsidejax.jit.csrc/jax/bindings_jax.cpp,csrc/jax/deep_ep/deep_ep.{h,cpp},csrc/jax/deep_ep/handler.cpp,csrc/jax/extensions.h):moe_dispatch_inproc/moe_dispatch_per_process(and cached/combine variants) plusmoe_internode_{dispatch,cached_dispatch,combine}_per_process.Buffer::InternodeDispatch/Buffer::InternodeCombineplumbing wired to the existing rocSHMEM kernels.primus_turbo.jax._C.deep_epexposingcreate_per_process_buffer,sync_per_process_buffer,destroy_per_process_buffer,is_per_process_buffer_ready,per_process_buffer_nvl_bytes, and (when rocSHMEM is enabled)get_unique_id.Bufferlifecycle hardening: explicitDestroy()with safe fallback in the destructor, NVL peer barrier before closing IPC handles, IPC-handle close only whenipc_synced_, rocSHMEM finalize guarded byDISABLE_ROCSHMEM, plus null-pointer checks.moe_recv_counter_mappedinnotify_dispatchwhennum_worst_tokens > 0(XLA static-shape path) — the host skips its CPU sync there, so back-to-back launches would otherwise spin forever on a stale counter (csrc/kernels/deep_ep/internode.cu).tools/build_utils.py):libionic) viaLD_LIBRARY_PATH/LIBRARY_PATH/multiarch defaults when wiringlibrocshmem.a.tests/jax/lax/test_mp_dispatch_combine.pycovering multi-process (1 GPU/proc) dispatch/combine via theper_processmode, with a newJaxMultiProcessTestCasehelper intests/jax/test_utils.py.tests/conftest.pyrecognisesJaxMultiProcessTestCaseunder--dist-only.benchmark/README.mdwith the multi-node DeepEP launch instructions.Checklist: