Refactor: moe dispatch combine autotune#312
Conversation
There was a problem hiding this comment.
Pull request overview
Refactors the MoE dispatch/combine implementation to support multi-backend autotuning, and introduces a new ROCm Mori EP backend alongside shared MoE utilities and updated distributed tests.
Changes:
- Added
MoEDispatchCombineAutoTunerwith shape-keyed caching and handle-bound reuse across dispatch/combine. - Added
MoriEPBackendandBackendType.MORI, plus ROCm-specific dispatch token-count handling. - Extracted shared helpers into
moe_utils.pyand updated dispatcher tests to use NCCL process groups.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 10 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/pytorch/modules/test_token_dispatcher.py | Adds autotune sweep coverage and switches tests to self.pg + NCCL init behavior. |
| primus_turbo/pytorch/ops/moe/moe_dispatch_combine.py | Preserves Mori’s CUDA tokens_per_expert tensor by only wrapping list outputs. |
| primus_turbo/pytorch/kernels/moe/moe_utils.py | New shared utilities for profiling, topology detection, and in-place uniqueness. |
| primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py | Main refactor: autotuner, per-backend configs, Mori backend integration, and tuning logic. |
| primus_turbo/pytorch/core/backend.py | Registers Mori availability and adds BackendType.MORI support in backend manager. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @classmethod | ||
| def register_handle(cls, handle: Any, result: EPAutoTuneResult) -> None: | ||
| """Bind ``result`` to ``id(handle)`` for paired-call reuse. | ||
|
|
||
| Args: | ||
| handle: Handle returned by the backend's ``dispatch`` call. | ||
| result: Tune result to associate with the handle. | ||
| """ | ||
| if handle is None: | ||
| return | ||
| hid = id(handle) | ||
| cache = cls._handle_cache | ||
| if hid in cache: | ||
| cache.move_to_end(hid) | ||
| cache[hid] = result | ||
| else: | ||
| cache[hid] = result | ||
| if len(cache) > cls._HANDLE_CACHE_MAX: | ||
| cache.popitem(last=False) | ||
|
|
||
| @classmethod | ||
| def lookup_handle(cls, handle: Any) -> Optional[EPAutoTuneResult]: | ||
| """Return the result bound to ``handle``, or ``None`` if unknown.""" | ||
| if handle is None: | ||
| return None | ||
| res = cls._handle_cache.get(id(handle)) | ||
| if res is not None: | ||
| cls._handle_cache.move_to_end(id(handle)) | ||
| return res |
There was a problem hiding this comment.
MoEDispatchCombineAutoTuner keys _handle_cache by id(handle) only. Python can reuse object ids after GC, so a later, unrelated handle could collide and incorrectly reuse a stale tuned result. To make this robust, consider storing a strong reference to the handle in the cache entry and verifying stored_handle is handle on lookup (bounded by the existing LRU), or using a weakref-based approach where the handle type supports it.
| if backend is None: | ||
| targets = list(_buffer_config_per_backend.keys()) or list(_DEFAULT_BUFFER_CONFIG_PER_BACKEND.keys()) | ||
| for name in targets: | ||
| _buffer_config_per_backend[name] = dataclasses.replace(new_cfg) | ||
| else: | ||
| _buffer_config_per_backend[backend] = new_cfg |
There was a problem hiding this comment.
set_buffer_global_config(backend=None) only updates _buffer_config_per_backend keys that already exist (derived from _DEFAULT_BUFFER_CONFIG_PER_BACKEND). If a new backend is registered later via register_ep_backend(), it won’t receive the global config unless callers remember to pass backend=<name>. Consider iterating over _BACKEND_REGISTRY.keys() (or taking the union) when backend is None so newly registered backends are included.
| def detect_group_topology(group: dist.ProcessGroup) -> Tuple[int, int]: | ||
| """ | ||
| Infer node topology for a process group. | ||
|
|
||
| Returns: | ||
| node_idx: compact node index within the given group. | ||
| num_nodes: number of distinct nodes spanned by the group. | ||
| """ | ||
|
|
||
| node_token = ( | ||
| os.environ.get("NODE_RANK") | ||
| or os.environ.get("GROUP_RANK") | ||
| or os.environ.get("SLURM_NODEID") | ||
| or socket.gethostname() | ||
| ) | ||
|
|
||
| world = dist.get_world_size(group) | ||
| node_tokens = [None] * world | ||
| dist.all_gather_object(node_tokens, node_token, group=group) | ||
|
|
||
| token_to_idx = {} | ||
| for token in node_tokens: | ||
| if token not in token_to_idx: | ||
| token_to_idx[token] = len(token_to_idx) | ||
|
|
||
| node_idx = token_to_idx[node_token] | ||
| num_nodes = len(token_to_idx) | ||
| return node_idx, num_nodes |
There was a problem hiding this comment.
detect_group_topology() performs an all_gather_object every time it’s called. Since group topology is static, calling this from per-forward paths (e.g., Mori init/config resolution) adds an unnecessary synchronous collective overhead. Consider caching (node_idx, num_nodes) per ProcessGroup (e.g., keyed by id(group) with a weakref) to avoid repeated collectives.
| # Parse the profiling table | ||
| assert isinstance(kernel_names, (str, tuple)) | ||
| is_tuple = isinstance(kernel_names, tuple) | ||
| prof_lines = prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=100).split("\n") | ||
| kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names | ||
| assert all([isinstance(name, str) for name in kernel_names]) | ||
| for name in kernel_names: | ||
| assert ( | ||
| sum([name in line for line in prof_lines]) == 1 | ||
| ), f"Errors of the kernel {name} in the profiling table" | ||
|
|
||
| # Save chrome traces | ||
| if trace_path is not None: | ||
| prof.export_chrome_trace(trace_path) | ||
|
|
||
| # Return average kernel durations | ||
| units = {"ms": 1e3, "us": 1e6} | ||
| kernel_durations = [] | ||
| for name in kernel_names: | ||
| for line in prof_lines: | ||
| if name in line: | ||
| time_str = line.split()[-2] | ||
| for unit, scale in units.items(): | ||
| if unit in time_str: | ||
| kernel_durations.append(float(time_str.replace(unit, "")) / scale) | ||
| break | ||
| break |
This comment was marked as low quality.
This comment was marked as low quality.
Sorry, something went wrong.
faa7ffc to
a645c55
Compare
There was a problem hiding this comment.
Stale comment
Security review (automation)
Reviewed the diff for injection, authn/authz, permission-boundary, secret-leakage, SSRF/XSS/request-forgery, path-traversal, unsafe-deserialization, and supply-chain risks.
Scope
primus_turbo/pytorch/core/backend.py(optionalmoriimport + enum)primus_turbo/pytorch/kernels/moe/moe_utils.py(profiling/topology/unique helpers)primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py(autotuner + Mori backend + Triton histogram kernel)primus_turbo/pytorch/ops/moe/moe_dispatch_combine.pytests/pytorch/modules/test_token_dispatcher.pyKey checks
- Env vars (
PRIMUS_TURBO_MORI_NUM_QP_PER_PE,NODE_RANK/GROUP_RANK/SLURM_NODEID) feed onlyint()casts and rank-id comparisons — not shell, path, SQL, or deserialization sinks.tempfile.NamedTemporaryFile(suffix=".json")uses secure defaults; the JSON is written byprof.export_chrome_tracein the same process (no attacker input).suppress_stdout_stderrmanipulates only its own file descriptors.dist.all_gather_objectindetect_group_topologypickles over the internal torch.distributed process group (trusted channel).- Triton
compute_expert_token_info_kernelclamps OOB rows/cols withtl.where+load_mask/flat_maskand clamps invalid/out-of-range expert ids beforetl.histogram/tl.atomic_add, preventing OOB reads/writes and illegal bin counts.inplace_uniquepads negatives tonum_slotsand drops the padding bin, keepingscatter_add_in-range.MoEDispatchCombineAutoTuner._handle_cacheis LRU-bounded (_HANDLE_CACHE_MAX=1024); shape cache usesTuneCache(capacity=1024).torch._C._distributed_c10d._register_process_group("mori", group)receives only internally constructed arguments.- The new
moridependency is an optional import (guarded bytry/except ImportError); no code runs unless installed.Result
No medium+ severity security issue introduced or exposed by this PR. No new inline findings.Sent by Cursor Automation: Find vulnerabilities
There was a problem hiding this comment.
Stale comment
Security review (automation)
Re-validated the diff for the latest push (
ed3ecca) covering autotuner, Mori backend, Triton histogram kernel, and shared MoE utilities.Scope
primus_turbo/pytorch/core/backend.py(optionalmoriimport +BackendType.MORI)primus_turbo/pytorch/kernels/moe/moe_utils.py(profiling / topology / unique helpers)primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py(autotuner + Mori backend + Triton histogram kernel)primus_turbo/pytorch/ops/moe/moe_dispatch_combine.pytests/pytorch/modules/test_token_dispatcher.pyKey checks
- Env vars (
PRIMUS_TURBO_MORI_NUM_QP_PER_PE,NODE_RANK/GROUP_RANK/SLURM_NODEID) feed onlyint()casts and rank-id comparisons — no shell, path, SQL, or deserialization sinks.tempfile.NamedTemporaryFile(suffix=".json")uses secure defaults; content is written by the in-processprof.export_chrome_trace(no attacker input).suppress_stdout_stderrmanipulates only its own file descriptors.dist.all_gather_objectindetect_group_topologypickles over an internaltorch.distributedgroup (trusted training peers only).- Triton
compute_expert_token_info_kernelclamps OOB rows/cols viatl.where+load_mask/flat_maskand clamps invalid/out-of-range expert ids beforetl.histogram/tl.atomic_add— no OOB reads/writes or illegal bin counts.inplace_uniquepads negatives intonum_slotsand drops the padding bin, keepingscatter_add_in-range.MoEDispatchCombineAutoTuner._handle_cacheis LRU-bounded (_HANDLE_CACHE_MAX=1024) and now identity-checks the stored handle (stored_handle is not handle) to guard againstid()reuse; shape cache usesTuneCache(capacity=1024).torch._C._distributed_c10d._register_process_group("mori", group)receives only internally constructed arguments.- The new
moridependency is an optional import (guarded bytry/except ImportError); no code runs unless installed.Result
No medium+ severity security issue introduced or exposed by this PR. No inline findings.Sent by Cursor Automation: Find vulnerabilities
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| self._bind_device() | ||
| envs = {k: v for k, v in os.environ.items()} | ||
| envs.pop("PRIMUS_TURBO_MOE_DISPATCH_COMBINE_BACKEND", None) | ||
| envs["PRIMUS_TURBO_AUTO_TUNE"] = "1" | ||
|
|
||
| with patch.dict(os.environ, envs, clear=True): | ||
| MoEDispatchCombineAutoTuner.clear() | ||
|
|
||
| # First run triggers autotune and registers the handle mapping. | ||
| _run_dispatch_combine(self.rank, self.pg) | ||
| first = MoEDispatchCombineAutoTuner.current() | ||
| assert first is not None, "autotune should have populated a result" | ||
| assert ( | ||
| len(MoEDispatchCombineAutoTuner._handle_cache) > 0 | ||
| ), "moe_dispatch should have bound the tuned result to at least one handle" | ||
|
|
||
| # Second run hits the shape cache (no extra measurements) and | ||
| # yields the same backend selection. | ||
| _run_dispatch_combine(self.rank, self.pg) | ||
| second = MoEDispatchCombineAutoTuner.current() | ||
| assert second is first or second.backend_name == first.backend_name | ||
|
|
||
|
|
There was a problem hiding this comment.
test_autotune_sweep enables PRIMUS_TURBO_AUTO_TUNE=1 and then runs two full dispatch+combine passes at the default test shape (4096×4096 with topk=8). With the new autotuner this triggers a full config sweep using torch.profiler/Kineto inside the test, which is likely to be very slow and can make CI flaky/time out. Please consider reducing the problem size for this test (smaller NUM_TOKENS/HIDDEN_SIZE), lowering the sweep iteration counts for test mode, or marking/gating the test as “slow” so it doesn’t run in the default unit test shard.
| self._bind_device() | |
| envs = {k: v for k, v in os.environ.items()} | |
| envs.pop("PRIMUS_TURBO_MOE_DISPATCH_COMBINE_BACKEND", None) | |
| envs["PRIMUS_TURBO_AUTO_TUNE"] = "1" | |
| with patch.dict(os.environ, envs, clear=True): | |
| MoEDispatchCombineAutoTuner.clear() | |
| # First run triggers autotune and registers the handle mapping. | |
| _run_dispatch_combine(self.rank, self.pg) | |
| first = MoEDispatchCombineAutoTuner.current() | |
| assert first is not None, "autotune should have populated a result" | |
| assert ( | |
| len(MoEDispatchCombineAutoTuner._handle_cache) > 0 | |
| ), "moe_dispatch should have bound the tuned result to at least one handle" | |
| # Second run hits the shape cache (no extra measurements) and | |
| # yields the same backend selection. | |
| _run_dispatch_combine(self.rank, self.pg) | |
| second = MoEDispatchCombineAutoTuner.current() | |
| assert second is first or second.backend_name == first.backend_name | |
| global NUM_TOKENS, HIDDEN_SIZE | |
| self._bind_device() | |
| envs = {k: v for k, v in os.environ.items()} | |
| envs.pop("PRIMUS_TURBO_MOE_DISPATCH_COMBINE_BACKEND", None) | |
| envs["PRIMUS_TURBO_AUTO_TUNE"] = "1" | |
| old_num_tokens = NUM_TOKENS | |
| old_hidden_size = HIDDEN_SIZE | |
| with patch.dict(os.environ, envs, clear=True): | |
| try: | |
| # Keep the autotune path covered, but shrink the problem size so | |
| # the profiler-backed config sweep remains practical in CI. | |
| NUM_TOKENS = 512 | |
| HIDDEN_SIZE = 512 | |
| MoEDispatchCombineAutoTuner.clear() | |
| # First run triggers autotune and registers the handle mapping. | |
| _run_dispatch_combine(self.rank, self.pg) | |
| first = MoEDispatchCombineAutoTuner.current() | |
| assert first is not None, "autotune should have populated a result" | |
| assert ( | |
| len(MoEDispatchCombineAutoTuner._handle_cache) > 0 | |
| ), "moe_dispatch should have bound the tuned result to at least one handle" | |
| # Second run hits the shape cache (no extra measurements) and | |
| # yields the same backend selection. | |
| _run_dispatch_combine(self.rank, self.pg) | |
| second = MoEDispatchCombineAutoTuner.current() | |
| assert second is first or second.backend_name == first.backend_name | |
| finally: | |
| NUM_TOKENS = old_num_tokens | |
| HIDDEN_SIZE = old_hidden_size |
| def _get_backend_name() -> str: | ||
| """Return the user-selected backend name, or ``TURBO`` by default.""" | ||
| bt = GlobalBackendManager.get_moe_dispatch_combine_backend(PrecisionType.BF16_FP16_FP32) | ||
| return bt.name if bt is not None else _DEFAULT_BACKEND_NAME |
There was a problem hiding this comment.
_get_backend_name() now only consults GlobalBackendManager.get_moe_dispatch_combine_backend(). That helper intentionally returns None when PRIMUS_TURBO_MOE_DISPATCH_COMBINE_BACKEND is set to a custom (non-BackendType) backend name, so this change forces a silent fallback to TURBO and breaks the documented “custom names like UCCL_EP” behavior. Please restore the previous behavior by reading ENV_MOE_DISPATCH_COMBINE_BACKEND directly when GlobalBackendManager returns None, and using the raw env value (uppercased/stripped) as the backend registry lookup key.
| element_size = 1 if fp8_dispatch else 2 | ||
| hidden_bytes = hidden_size * max(element_size, 2) | ||
|
|
There was a problem hiding this comment.
_DeepEPLikeBackend.init_buffer() computes hidden_bytes as hidden_size * 2 for all non-FP8 inputs (element_size is hard-coded to 2). This can under-allocate buffers when x is float32 (4 bytes/elem) and conflicts with the previous logic that sized by max(inp.element_size(), 2). Either (a) size by the actual input element size (pass dtype/element_size into init_buffer), or (b) explicitly restrict dispatch/combine to BF16/FP16 and raise a clear error for FP32 to avoid silent buffer overruns.
| if handle is None: | ||
| return | ||
| hid = id(handle) | ||
| cache = cls._handle_cache | ||
| if hid in cache: | ||
| cache.move_to_end(hid) | ||
| cache[hid] = (handle, result) | ||
| if len(cache) > cls._HANDLE_CACHE_MAX: | ||
| cache.popitem(last=False) |
There was a problem hiding this comment.
MoEDispatchCombineAutoTuner._handle_cache stores strong references to the full handle object (cache[hid] = (handle, result)). If a backend’s handle captures large tensors or GPU resources (e.g., MoriEPBackend’s handle currently includes recv_topk_idx_i32), this cache can retain significant memory until eviction, which is risky in long-running training. Consider storing only the tuned result keyed by id(handle) (accepting the small risk of id reuse), or redesigning handles to be lightweight/opaque IDs so the cache doesn’t keep heavyweight objects alive.
There was a problem hiding this comment.
Stale comment
Security review (automation)
Re-validated after the latest push (
8bd040b). The incremental change since the last assessment is a mechanical rename (PRIMUS_TURBO_MOE_DISPATCH_COMBINE_BACKEND→PRIMUS_TURBO_EP_BACKEND,get_moe_dispatch_combine_backend→get_ep_backend) and the removal of eagerdeep_ep/moriimports inprimus_turbo/pytorch/core/backend.py(both are now lazily loaded inside the EP backend registry). No trust boundary, permission check, or supply-chain surface changes.Scope
primus_turbo/common/constants.py,docs/examples.md,tests/pytorch/core/test_global_backend_manager.py,tests/pytorch/modules/test_token_dispatcher.py(env-var rename)primus_turbo/pytorch/core/backend.py(dropped eagerHAVE_DEEP_EP/HAVE_MORIguards; lazy imports remain insideMoriEPBackend.is_available()/_get_module())primus_turbo/pytorch/kernels/moe/moe_utils.py(profiling / topology / unique helpers)primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py(autotuner + Mori backend + Triton histogram kernel)primus_turbo/pytorch/ops/moe/moe_dispatch_combine.pyKey checks (still green)
- Env vars (
PRIMUS_TURBO_MORI_NUM_QP_PER_PE,NODE_RANK/GROUP_RANK/SLURM_NODEID,PRIMUS_TURBO_EP_BACKEND) feed onlyint()casts, bounded registry lookups (_BACKEND_REGISTRY), and rank/host token equality — no shell, path, SQL, or deserialization sinks.tempfile.NamedTemporaryFile(suffix=".json")uses secure defaults; content produced by the in-processprof.export_chrome_trace(no attacker input).suppress_stdout_stderrmanipulates only its own file descriptors.dist.all_gather_objectindetect_group_topologyruns over the trustedtorch.distributedgroup.- Triton
compute_expert_token_info_kernelclamps OOB rows/cols viatl.where+load_mask/flat_maskand clamps invalid/out-of-range expert ids beforetl.histogram/tl.atomic_add— no OOB reads/writes or illegal bins.inplace_uniquepads negatives intonum_slotsand drops the padding bin, keepingscatter_add_in-range.MoEDispatchCombineAutoTuner._handle_cacheis LRU-bounded (_HANDLE_CACHE_MAX=1024) and identity-checks the stored handle (stored_handle is not handle) to defend againstid()reuse; shape cache usesTuneCache(capacity=1024).torch._C._distributed_c10d._register_process_group("mori", group)receives only internally constructed arguments.mori/deep_epremain optional dependencies (guarded bytry/except ImportErrorin each backend'sis_available()/_get_module()).- No secret, token, payload, or model-input values are logged.
Result
No medium+ severity security issue introduced or exposed by this PR. No inline findings.Sent by Cursor Automation: Find vulnerabilities
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 60 out of 68 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # loosen timeout for this test to avoid timeout failures | ||
| common_distributed.TIMEOUT_DEFAULT = 600 | ||
|
|
| // Controlled by ``PRIMUS_TURBO_EP_FORCE_CURRENT_STREAM`` (default ``0``). | ||
| // When true, all dispatch/combine kernels run on the caller's current CUDA | ||
| // stream instead of ``comm_stream``. | ||
| bool force_current_stream = true; |
| batch_size, seq_len_q, num_heads_q, _ = q.shape | ||
| _, _, _, head_dim_v = v.shape | ||
| out = torch.empty((batch_size, seq_len_q, num_heads_q, head_dim_v), dtype=q.dtype, device=q.device) | ||
| softmax_lse = torch.empty((batch_size, num_heads_q, seq_len_q), dtype=torch.float32, device=q.device) | ||
| S_dmask = torch.empty((0,), dtype=q.dtype, device=q.device) | ||
| rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) | ||
| return out, softmax_lse, S_dmask, rng_state |
| if not deepep_use_comm_stream and os.environ.get(ENV_EP_FORCE_CURRENT_STREAM) != "1": | ||
| clear_backend_instances() | ||
| os.environ[ENV_EP_FORCE_CURRENT_STREAM] = "1" | ||
|
|
There was a problem hiding this comment.
Stale comment
Security review (automation)
Re-validated after the latest push (
b96afca). The incremental changes since the last assessment are: (1) CK GEMM/Grouped GEMM directory restructure (file moves only); (2) numerical correctness fixes incsrc/kernels/quantization/quantization_mxfp8.cu(cast toint64_tto avoid int32 offset overflow),csrc/kernels/reduce/reduce_row.cuh(partial-tile indexing), andcsrc/kernels/gemm/turbo/turbo_gemm_mxfp8_kernel.h(__builtin_amdgcn_s_barrier()WAR fence); (3) attention BHSD layout support inattention_aiter_impl.py/flash_attn_interface.py/flash_attn_usp_interface.py/attention_ring.py, withtorch.library.custom_opwrappers and matchingregister_fakeshape implementations; (4) a newdeepep_use_comm_stream=Falseknob inDeepEPTokenDispatcherthat callsclear_backend_instances()and sets the constantPRIMUS_TURBO_EP_FORCE_CURRENT_STREAMenv var to"1"; (5) a_format_kwargs(...)helper that renders tensorshape/dtype(not data) andrepr()of scalars/enums into backend mismatch error messages; (6) thenitscommit droppingdispatch_config/combine_configfrom a debug log line.Scope
csrc/kernels/quantization/quantization_mxfp8.cu,csrc/kernels/reduce/reduce_row.cuh,csrc/kernels/gemm/turbo/turbo_gemm_mxfp8_kernel.h,csrc/include/primus_turbo/deep_ep/configs.h,csrc/pytorch/deep_ep/deep_ep.hppprimus_turbo/common/constants.py(newENV_EP_FORCE_CURRENT_STREAM)primus_turbo/pytorch/core/backend.py(_format_kwargs)primus_turbo/pytorch/deep_ep/buffer.py,primus_turbo/pytorch/modules/moe/token_dispatcher.py(env-var toggle,clear_backend_instances())primus_turbo/pytorch/kernels/attention/attention_aiter_impl.py,primus_turbo/pytorch/ops/attention/{attention_utils.py,flash_attn_interface.py,flash_attn_usp_interface.py,usp/attention_ring.py}primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py(addedclear_backend_instances; nits-only logger change)tests/pytorch/{ops/test_attention.py,ops/test_attention_with_cp.py,ops/test_quantization.py,modules/test_token_dispatcher.py}Key checks (still green)
_format_kwargsonly emitsTensor(shape=..., dtype=...)for tensors andrepr()for scalar/enum kwargs — no tensor contents, weights, gradients, or attacker payload values are logged. Error messages stay free of secrets/PII.os.environ[ENV_EP_FORCE_CURRENT_STREAM] = "1"writes a hard-coded constant; no user-controlled string flows intoos.environor any shell. The C++ side reads the same env var viastd::getenvand passes it throughstd::stoi, with the value bounded to abool.clear_backend_instances()only mutates the module-private_backend_instancesdict — no FS, network, or trust-boundary surface.torch.library.custom_op(..., mutates_args=("dq","dk","dv"), device_types="cuda")wrappers receive only Tensors / scalars from in-process callers;register_fakeimpls allocate empty tensors and don't dereference user data.- BHSD layout handling is restricted to
qkv_format in ("bshd","sbhd","bhsd")with explicitassert/raise ValueErroron unknown values; no path/format string is interpolated into a sink.- Sink-attention fast-path constraints (
head_dim_qk == head_dim_v, power-of-2) moved intocan_handle()— same-or-stronger validation as before.- mxfp8
int64_tcast and reduce_row partial-tile fix close potential signed-int32 offset overflows that could have produced OOB stores on extreme shapes; this is a security/safety improvement, not a regression.MoEDispatchCombineAutoTuner._handle_cacheremains LRU-bounded (_HANDLE_CACHE_MAX=1024) withid()reuse defended bystored_handle is not handleidentity check;TuneCache(capacity=1024)shape cache unchanged.mori/deep_epremain optional, lazily imported inside each backend'sis_available()/_get_module().- No new env vars feed into shell/path/SQL/deserialization sinks;
PRIMUS_TURBO_EP_FORCE_CURRENT_STREAMand other env reads (int()casts, registry lookups, rank/host token equality) stay numeric/bounded.Result
No medium+ severity security issue introduced or exposed by this PR. No inline findings.Sent by Cursor Automation: Find vulnerabilities
There was a problem hiding this comment.
Stale comment
Security review (automation)
Re-validated after the latest push (
aab8575). Incremental delta since the last assessment (b96afca):
primus_turbo/pytorch/core/backend.py: addsBackendType.UCCLenum entry only.primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py: bumpsEPBufferConfig.num_smsdefault from32to64, adds a"UCCL_EP"entry to_DEFAULT_BUFFER_CONFIG_PER_BACKEND, and adds a newUCCLEPBackend(_DeepEPLikeBackend)whoseis_available()/_get_module()lazily importuccl.ep. The new class is not wired into_BACKEND_REGISTRY, so it is only reachable if a caller invokesregister_ep_backend("UCCL_EP", UCCLEPBackend)explicitly.setup.py: adds a mandatoryinstall_requiresentryuccl @ git+https://github.com/uccl-project/uccl.git@9bae94c59229f82efbbd5b78a7f222ce19e74e86.Key checks
- Supply chain: the new
uccldependency is fetched from the upstreamuccl-project/ucclrepo (already referenced by the project's existing DeepEP fallback documentation) and is pinned to a fixed commit SHA, so installs cannot silently move to a different revision via tag/branch retargeting. No version-range or floating-ref injection risk introduced. The constantUCCL_COMMITis a hard-coded literal — no env-var or user input flows into the install URL.UCCLEPBackend._get_module()only does a lazyimport uccl.ep; no attacker-controlled string is passed toimportand the import target is fixed._make_buffer_kwargs()introspects the upstreamBufferconstructor withinspect.signature; only the booleanis_intranode(derived fromgroup.size()) is forwarded — no untrusted kwargs propagation.EPBufferConfigdefault change (num_sms 32 → 64) and newUCCL_EPdefaults are pure numeric tuning constants; no trust-boundary impact.- Previously validated surfaces (Mori SHMEM init, autotuner LRU bounds + identity-checked handle cache,
compute_expert_token_info_kernelOOB clamps,inplace_uniquebin-padding,_format_kwargsshape/dtype-only logging,PRIMUS_TURBO_EP_FORCE_CURRENT_STREAMenv-var path, mxfp8int64_toffset cast, BHSD layoutqkv_formatallowlist,register_fakestubs, custom-op wrappers,attention_utils._infer_qkv_format) remain unchanged in this delta and still hold.- No new env-var / file-path / shell / SQL / deserialization / network sinks; no secret or token logging; no permission-check or authn change.
Result
No medium+ severity security issue introduced or exposed by this PR. No inline findings.Sent by Cursor Automation: Find vulnerabilities
There was a problem hiding this comment.
Stale comment
Security review (automation)
Re-validated after the latest push (
c4906b5). Incremental delta since the last assessment (aab8575):
primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py: a single-line registry entry ("UCCL": UCCLEPBackend) added to_BACKEND_REGISTRY. TheUCCLEPBackendclass itself (added in the prior commit) follows the same optional-import +_DeepEPLikeBackendpattern already used byDeepEPBackend(lazyimport uccl.ep, no eager dependency, identical buffer-kwargs introspection).Scope re-checked
primus_turbo/pytorch/core/backend.py(BackendType.UCCLenum +get_ep_backendenv parsing, returnsNonefor unknown enum values).primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py(autotuner, EP backends incl. UCCL, Mori, DeepEP, Triton histogram kernel).primus_turbo/pytorch/kernels/moe/moe_utils.py(bench_kineto,detect_group_topology,inplace_unique).primus_turbo/pytorch/modules/moe/token_dispatcher.py(PRIMUS_TURBO_EP_FORCE_CURRENT_STREAMenv mutation).primus_turbo/pytorch/deep_ep/buffer.pyandcsrc/pytorch/deep_ep/deep_ep.hpp(force-current-stream toggle).setup.py(newucclinstall requirement pinned to commit9bae94c5...).csrc/kernels/...(CK GEMM/Grouped-GEMM directory restructure + numerical-correctness fixes; no new I/O sinks).- Attention/USP refactor and tests.
Threat checks performed
- Injection: env values (
PRIMUS_TURBO_EP_BACKEND,PRIMUS_TURBO_MORI_NUM_QP_PER_PE,ENV_EP_FORCE_CURRENT_STREAM) are parsed viaBackendType[...]/int(...)with defensive fallbacks; no shell or SQL sinks. The Mori SHMEM process-group name ("mori") is a hard-coded constant.- Authn/authz: no auth surfaces touched.
- Path traversal / unsafe deserialization:
bench_kinetowrites to atempfile.NamedTemporaryFile(random path) andjson.loadsonly the file it just wrote; no attacker-controlled path or payload.- SSRF / request forgery / XSS: no network or HTML sinks introduced.
- Secret leakage / insecure logging: log lines emit shapes and tuning metadata only; no credentials or tokens.
- Supply chain: the new
uccldependency insetup.pyis pinned to a specific commit hash ongithub.com/uccl-project/uccl, which mitigates floating-tag tampering risk and matches existing patterns (e.g. pinnedAITER_COMMIT).UCCLEPBackendonly importsuccl.eplazily and is selected via the existing EP backend registry.- Process-wide env mutation in
DeepEPTokenDispatcher.__init__(os.environ[ENV_EP_FORCE_CURRENT_STREAM]="1"+clear_backend_instances()): a usability concern previously raised by Copilot, but not a security boundary issue (caller already controls the process), so out of scope here.Conclusion: no medium-or-higher security findings. No new finding comments posted.
Sent by Cursor Automation: Find vulnerabilities
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 61 out of 69 changed files in this pull request and generated 9 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Per-backend default buffer configuration. Keys must match the names used | ||
| # in ``_BACKEND_REGISTRY`` so lookups can be done by backend name. | ||
| _DEFAULT_BUFFER_CONFIG_PER_BACKEND: Dict[str, EPBufferConfig] = { | ||
| "TURBO": EPBufferConfig( | ||
| num_sms=64, | ||
| dispatch_config=None, | ||
| combine_config=None, | ||
| ), | ||
| "DEEP_EP": EPBufferConfig( | ||
| num_sms=64, | ||
| dispatch_config=None, | ||
| combine_config=None, | ||
| ), | ||
| "UCCL_EP": EPBufferConfig( | ||
| num_sms=64, | ||
| dispatch_config=None, | ||
| combine_config=None, | ||
| ), | ||
| "MORI": EPBufferConfig(num_sms=64, dispatch_config=None, combine_config=None), | ||
| } |
| dispatch_config = config.dispatch_config or BufferClass.get_dispatch_config(group.size()) | ||
| combine_config = config.combine_config or BufferClass.get_combine_config(group.size()) | ||
|
|
||
| element_size = 1 if fp8_dispatch else 2 | ||
| hidden_bytes = hidden_size * max(element_size, 2) | ||
|
|
| def _get_backend_name() -> str: | ||
| """Return the user-selected backend name, or ``TURBO`` by default.""" | ||
| bt = GlobalBackendManager.get_ep_backend(PrecisionType.BF16_FP16_FP32) | ||
| return bt.name if bt is not None else _DEFAULT_BACKEND_NAME |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| batch_size, seq_len_q, num_heads_q, _ = q.shape | ||
| _, _, _, head_dim_v = v.shape | ||
| out = torch.empty((batch_size, seq_len_q, num_heads_q, head_dim_v), dtype=q.dtype, device=q.device) | ||
| softmax_lse = torch.empty((batch_size, num_heads_q, seq_len_q), dtype=torch.float32, device=q.device) | ||
| S_dmask = torch.empty((0,), dtype=q.dtype, device=q.device) | ||
| rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) | ||
| return out, softmax_lse, S_dmask, rng_state |
| # loosen timeout for this test to avoid timeout failures | ||
| common_distributed.TIMEOUT_DEFAULT = 600 |
| # Entry points and Install Requires | ||
| entry_points = {} | ||
| install_requires = [] | ||
| install_requires = [f"uccl @ git+https://github.com/uccl-project/uccl.git@{UCCL_COMMIT}"] | ||
|
|
| // Controlled by ``PRIMUS_TURBO_EP_FORCE_CURRENT_STREAM`` (default ``0``). | ||
| // When true, all dispatch/combine kernels run on the caller's current CUDA | ||
| // stream instead of ``comm_stream``. | ||
| bool force_current_stream = true; |
| if not deepep_use_comm_stream and os.environ.get(ENV_EP_FORCE_CURRENT_STREAM) != "1": | ||
| clear_backend_instances() | ||
| os.environ[ENV_EP_FORCE_CURRENT_STREAM] = "1" |
| def test_ep_backend_env(self, monkeypatch): | ||
| monkeypatch.setenv("PRIMUS_TURBO_EP_BACKEND", "triton") | ||
| assert GlobalBackendManager.get_ep_backend(PrecisionType.FP8) == BackendType.TRITON |
c4906b5 to
f735ba0
Compare
There was a problem hiding this comment.
Stale comment
Security review (automation)
Re-validated after the latest push (
5ba14d2). Scope of the incremental diff (vs.6650792):
setup.py: pinsucclto commit9bae94c5…viagit+https://github.com/uccl-project/uccl.git@<sha>.primus_turbo/common/constants.pyandprimus_turbo/pytorch/core/backend.py: renamePRIMUS_TURBO_MOE_DISPATCH_COMBINE_BACKEND→PRIMUS_TURBO_EP_BACKEND; addBackendType.MORI,BackendType.UCCL, andENV_MORI_NUM_QP_PER_PE.primus_turbo/pytorch/kernels/moe/_backend/(new):EPBackendprotocol, shared_DeepEPLikeBackend, plusTurboEPBackend,DeepEPBackend,UCCLEPBackend,MoriEPBackend. Includes_apply_env_with_nccl_fallback(copiesNCCL_*→MORI_*/UCCL_*when unset),_align_mori_hip_with_torch(ctypes.CDLL(<torch lib dir>/libamdhip64.so)),_register_and_init_mori_shmem, and an autotune sweep harness.primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py: per-backend defaultEPBufferConfigtable, shape-keyedMoEDispatchCombineAutoTuner, handle-id binding cache, CUDA-graph-capture warning hook.primus_turbo/pytorch/kernels/moe/moe_utils.py(new): kineto bench helper,detect_group_topology(usesdist.all_gather_objectover an env-or-gethostname()token),inplace_unique,suppress_stdout_stderr.token_dispatcher.py/moe_dispatch_combine.py: tensor-vs-list guard and CUDA-graph-capture-safe path.- Tests +
docs/examples.md: rename / parametrization updates only.Findings
No new medium or higher severity vulnerabilities are introduced. Notes on items examined and judged not to rise to that bar:
setup.pyadds a Git-URLinstall_requiresforucclbut pins to a specific commit SHA, so it is not an unpinned/floating dependency.- All env-var reads (
ENV_EP_BACKEND,ENV_MORI_NUM_QP_PER_PE,NCCL_*→MORI_*/UCCL_*) are constrained: backend name is dereferenced through an in-process registry (_BACKEND_REGISTRY[name], raises on unknown),num_qp_per_peis coerced viaint(...), and the NCCL→MORI/UCCL fallback only copies env values that are already present in the same process's environment to a different env-var name. No shell, no eval, no subprocess._align_mori_hip_with_torchderives the.sopath fromos.path.dirname(torch.__file__); not attacker-controlled.compute_expert_token_info_kernel(Triton) guards loads/stores with explicitmask=operands and clamps OOB rows/experts before address arithmetic.detect_group_topologyusesdist.all_gather_object(pickle under the hood) over a token sourced fromNODE_RANK/GROUP_RANK/SLURM_NODEID/socket.gethostname(). Peers in an EP process group share the same trust boundary as the training job, so this is not a new attack surface relative to existingtorch.distributedidioms in this codebase.bench_kinetowrites traces throughtempfile.NamedTemporaryFile; no untrusted-path / TOCTOU exposure.Prior threads from earlier automation runs
Prior security-review summaries from this automation all reached the same "no findings" verdict. They are being collapsed as outdated; this comment supersedes them.
Sent by Cursor Automation: Find vulnerabilities
| dispatch_config=None, | ||
| combine_config=None, | ||
| ), | ||
| "UCCL_EP": EPBufferConfig( |
| _cache: TuneCache = TuneCache(capacity=1024) | ||
|
|
||
| _HANDLE_CACHE_MAX: int = 1024 | ||
| _handle_cache: "OrderedDict[int, Tuple[Any, EPAutoTuneResult]]" = OrderedDict() | ||
|
|
| """Return the user-selected backend name, or ``TURBO`` by default.""" | ||
| bt = GlobalBackendManager.get_ep_backend(PrecisionType.BF16_FP16_FP32) | ||
| return bt.name if bt is not None else _DEFAULT_BACKEND_NAME | ||
|
|
||
|
|
| install_requires = [f"uccl @ git+https://github.com/uccl-project/uccl.git@{UCCL_COMMIT}"] | ||
|
|
| assert topk_idx is not None | ||
| topk_idx_i32 = topk_idx.to(torch.int32) | ||
| else: | ||
| assert topk_idx is None, token_weights is None |
|
|
||
| import mori.jit.hip_driver as _hd | ||
|
|
||
| # ensture must have _hip |
| num_recv_tokens_per_expert = num_recv_tokens_per_expert.tolist() | ||
|
|
||
| # hold token_weights for dispatch weights in backward | ||
| # it's a workaround to aviod illegal access when token_weights is None |
| monkeypatch.setenv("PRIMUS_TURBO_EP_BACKEND", "triton") | ||
| assert GlobalBackendManager.get_ep_backend(PrecisionType.FP8) == BackendType.TRITON |
There was a problem hiding this comment.
Security review (automation)
Re-validated for the latest push (d094e9f). No medium+ severity security findings.
Incremental delta since the last assessment (5ba14d2):
EPBufferConfigsplit intoprimus_turbo/pytorch/kernels/moe/_backend/_config.py(refactor, no behavior change)._backend/__init__.pyre-exports updated; no new optional-dep imports at module load time.setup.py:uccl @ git+https://github.com/uccl-project/uccl.git@<pinned-commit>moved frominstall_requirestoextras_require["uccl"]. This is a supply-chain hygiene improvement — the unsigned external GitHub source is now opt-in instead of pulled by every default install. The commit hash remains pinned (immutable reference).- Minor docstring/comment edits in
backend.py,constants.py.
Scope re-checked across the full PR diff:
primus_turbo/pytorch/kernels/moe/_backend/{base,mori,deep_ep,turbo,uccl_ep}.py— backend implementations.moe_dispatch_combine_impl.py— autotuner and registry.moe_utils.py—bench_kineto,detect_group_topology, Triton helpers.token_dispatcher.py,ops/moe/moe_dispatch_combine.py— graph-capture handling.
Checks performed and outcome:
- Env-var ingestion (
ENV_EP_BACKEND,ENV_MORI_NUM_QP_PER_PE,NCCL_*->UCCL_*/MORI_*fallback in_apply_env_with_nccl_fallback): values are only set into the current process env or parsed asint; no shell expansion, no path/command injection. ctypes.CDLL(torch_hip)in_align_mori_hip_with_torch: path is derived fromtorch.__file__(trusted Python install), gated byos.path.isfile. Not attacker-controlled.- Triton kernel
compute_expert_token_info_kerneluses bounds masks (load_mask,footprint_mask,bin_mask) and clamps invalid lanes before atomic add; no OOB write path from tensor shapes. tempfile.NamedTemporaryFile(suffix=".json")inbench_kineto: secure temp creation, no path traversal.dist.all_gather_objectindetect_group_topology: pickle is used inside an already-trusted EP process group (standard PyTorch distributed pattern, not new attack surface introduced by this PR).lru_cacheon_build_mori_opandOrderedDicthandle cache: keyed on integers /id()withisre-check; no DoS or aliasing exploit path.- Supply chain:
UCCL_COMMITandAITER_COMMITare pinned to immutable commit hashes; UCCL is now optional.
Result: No injection, authn/authz, secret-leakage, SSRF/XSS, path-traversal, or unsafe-deserialization issues introduced by this PR. No new comments to post.
Sent by Cursor Automation: Find vulnerabilities


Description
Refactor MoE dispatch/combine to support multi-backend autotuning and add a new Mori EP backend on ROCm.
Fixes # (issue)
Type of change
Changes
MoEDispatchCombineAutoTuner: shape-keyed tune + handle-bound reuse for paired dispatch/combine, enabled viaPRIMUS_TURBO_AUTO_TUNE=1.MoriEPBackend(ROCm) and registerBackendType.MORI; keep Turbo/DeepEP unchanged.moe_utils.py(bench_kineto,detect_group_topology,inplace_unique).self.pgwithnccl.Checklist: