Skip to content

Refactor: moe dispatch combine autotune#312

Open
zhenhuang12 wants to merge 17 commits into
mainfrom
refactor/moe-dispatch-combine-autotune
Open

Refactor: moe dispatch combine autotune#312
zhenhuang12 wants to merge 17 commits into
mainfrom
refactor/moe-dispatch-combine-autotune

Conversation

@zhenhuang12
Copy link
Copy Markdown
Collaborator

@zhenhuang12 zhenhuang12 commented Apr 24, 2026

Description

Refactor MoE dispatch/combine to support multi-backend autotuning and add a new Mori EP backend on ROCm.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add MoEDispatchCombineAutoTuner: shape-keyed tune + handle-bound reuse for paired dispatch/combine, enabled via PRIMUS_TURBO_AUTO_TUNE=1.
  • Add MoriEPBackend (ROCm) and register BackendType.MORI; keep Turbo/DeepEP unchanged.
  • Extract shared utilities to moe_utils.py (bench_kineto, detect_group_topology, inplace_unique).
  • Add autotune sweep test and switch dispatcher tests to self.pg with nccl.

Checklist:

  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copilot AI review requested due to automatic review settings April 24, 2026 06:08
cursor[bot]

This comment was marked as outdated.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Refactors the MoE dispatch/combine implementation to support multi-backend autotuning, and introduces a new ROCm Mori EP backend alongside shared MoE utilities and updated distributed tests.

Changes:

  • Added MoEDispatchCombineAutoTuner with shape-keyed caching and handle-bound reuse across dispatch/combine.
  • Added MoriEPBackend and BackendType.MORI, plus ROCm-specific dispatch token-count handling.
  • Extracted shared helpers into moe_utils.py and updated dispatcher tests to use NCCL process groups.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 10 comments.

Show a summary per file
File Description
tests/pytorch/modules/test_token_dispatcher.py Adds autotune sweep coverage and switches tests to self.pg + NCCL init behavior.
primus_turbo/pytorch/ops/moe/moe_dispatch_combine.py Preserves Mori’s CUDA tokens_per_expert tensor by only wrapping list outputs.
primus_turbo/pytorch/kernels/moe/moe_utils.py New shared utilities for profiling, topology detection, and in-place uniqueness.
primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py Main refactor: autotuner, per-backend configs, Mori backend integration, and tuning logic.
primus_turbo/pytorch/core/backend.py Registers Mori availability and adds BackendType.MORI support in backend manager.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py Outdated
Comment on lines +1414 to +1442
@classmethod
def register_handle(cls, handle: Any, result: EPAutoTuneResult) -> None:
"""Bind ``result`` to ``id(handle)`` for paired-call reuse.

Args:
handle: Handle returned by the backend's ``dispatch`` call.
result: Tune result to associate with the handle.
"""
if handle is None:
return
hid = id(handle)
cache = cls._handle_cache
if hid in cache:
cache.move_to_end(hid)
cache[hid] = result
else:
cache[hid] = result
if len(cache) > cls._HANDLE_CACHE_MAX:
cache.popitem(last=False)

@classmethod
def lookup_handle(cls, handle: Any) -> Optional[EPAutoTuneResult]:
"""Return the result bound to ``handle``, or ``None`` if unknown."""
if handle is None:
return None
res = cls._handle_cache.get(id(handle))
if res is not None:
cls._handle_cache.move_to_end(id(handle))
return res
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MoEDispatchCombineAutoTuner keys _handle_cache by id(handle) only. Python can reuse object ids after GC, so a later, unrelated handle could collide and incorrectly reuse a stale tuned result. To make this robust, consider storing a strong reference to the handle in the cache entry and verifying stored_handle is handle on lookup (bounded by the existing LRU), or using a weakref-based approach where the handle type supports it.

Copilot uses AI. Check for mistakes.
Comment on lines +250 to +255
if backend is None:
targets = list(_buffer_config_per_backend.keys()) or list(_DEFAULT_BUFFER_CONFIG_PER_BACKEND.keys())
for name in targets:
_buffer_config_per_backend[name] = dataclasses.replace(new_cfg)
else:
_buffer_config_per_backend[backend] = new_cfg
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_buffer_global_config(backend=None) only updates _buffer_config_per_backend keys that already exist (derived from _DEFAULT_BUFFER_CONFIG_PER_BACKEND). If a new backend is registered later via register_ep_backend(), it won’t receive the global config unless callers remember to pass backend=<name>. Consider iterating over _BACKEND_REGISTRY.keys() (or taking the union) when backend is None so newly registered backends are included.

Copilot uses AI. Check for mistakes.
Comment on lines +139 to +166
def detect_group_topology(group: dist.ProcessGroup) -> Tuple[int, int]:
"""
Infer node topology for a process group.

Returns:
node_idx: compact node index within the given group.
num_nodes: number of distinct nodes spanned by the group.
"""

node_token = (
os.environ.get("NODE_RANK")
or os.environ.get("GROUP_RANK")
or os.environ.get("SLURM_NODEID")
or socket.gethostname()
)

world = dist.get_world_size(group)
node_tokens = [None] * world
dist.all_gather_object(node_tokens, node_token, group=group)

token_to_idx = {}
for token in node_tokens:
if token not in token_to_idx:
token_to_idx[token] = len(token_to_idx)

node_idx = token_to_idx[node_token]
num_nodes = len(token_to_idx)
return node_idx, num_nodes
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

detect_group_topology() performs an all_gather_object every time it’s called. Since group topology is static, calling this from per-forward paths (e.g., Mori init/config resolution) adds an unnecessary synchronous collective overhead. Consider caching (node_idx, num_nodes) per ProcessGroup (e.g., keyed by id(group) with a weakref) to avoid repeated collectives.

Copilot uses AI. Check for mistakes.
Comment thread tests/pytorch/modules/test_token_dispatcher.py Outdated
Comment thread tests/pytorch/modules/test_token_dispatcher.py Outdated
Comment thread primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py Outdated
Comment thread primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py
Comment on lines +90 to +116
# Parse the profiling table
assert isinstance(kernel_names, (str, tuple))
is_tuple = isinstance(kernel_names, tuple)
prof_lines = prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=100).split("\n")
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
for name in kernel_names:
assert (
sum([name in line for line in prof_lines]) == 1
), f"Errors of the kernel {name} in the profiling table"

# Save chrome traces
if trace_path is not None:
prof.export_chrome_trace(trace_path)

# Return average kernel durations
units = {"ms": 1e3, "us": 1e6}
kernel_durations = []
for name in kernel_names:
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
for unit, scale in units.items():
if unit in time_str:
kernel_durations.append(float(time_str.replace(unit, "")) / scale)
break
break

This comment was marked as low quality.

@zhenhuang12 zhenhuang12 force-pushed the refactor/moe-dispatch-combine-autotune branch from faa7ffc to a645c55 Compare April 24, 2026 06:22
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale comment

Security review (automation)

Reviewed the diff for injection, authn/authz, permission-boundary, secret-leakage, SSRF/XSS/request-forgery, path-traversal, unsafe-deserialization, and supply-chain risks.

Scope

  • primus_turbo/pytorch/core/backend.py (optional mori import + enum)
  • primus_turbo/pytorch/kernels/moe/moe_utils.py (profiling/topology/unique helpers)
  • primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py (autotuner + Mori backend + Triton histogram kernel)
  • primus_turbo/pytorch/ops/moe/moe_dispatch_combine.py
  • tests/pytorch/modules/test_token_dispatcher.py

Key checks

  • Env vars (PRIMUS_TURBO_MORI_NUM_QP_PER_PE, NODE_RANK/GROUP_RANK/SLURM_NODEID) feed only int() casts and rank-id comparisons — not shell, path, SQL, or deserialization sinks.
  • tempfile.NamedTemporaryFile(suffix=".json") uses secure defaults; the JSON is written by prof.export_chrome_trace in the same process (no attacker input).
  • suppress_stdout_stderr manipulates only its own file descriptors.
  • dist.all_gather_object in detect_group_topology pickles over the internal torch.distributed process group (trusted channel).
  • Triton compute_expert_token_info_kernel clamps OOB rows/cols with tl.where + load_mask/flat_mask and clamps invalid/out-of-range expert ids before tl.histogram / tl.atomic_add, preventing OOB reads/writes and illegal bin counts.
  • inplace_unique pads negatives to num_slots and drops the padding bin, keeping scatter_add_ in-range.
  • MoEDispatchCombineAutoTuner._handle_cache is LRU-bounded (_HANDLE_CACHE_MAX=1024); shape cache uses TuneCache(capacity=1024).
  • torch._C._distributed_c10d._register_process_group("mori", group) receives only internally constructed arguments.
  • The new mori dependency is an optional import (guarded by try/except ImportError); no code runs unless installed.

Result
No medium+ severity security issue introduced or exposed by this PR. No new inline findings.

Open in Web View Automation 

Sent by Cursor Automation: Find vulnerabilities

Copilot AI review requested due to automatic review settings April 24, 2026 07:09
@AMD-AGI AMD-AGI deleted a comment from Copilot AI Apr 24, 2026
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale comment

Security review (automation)

Re-validated the diff for the latest push (ed3ecca) covering autotuner, Mori backend, Triton histogram kernel, and shared MoE utilities.

Scope

  • primus_turbo/pytorch/core/backend.py (optional mori import + BackendType.MORI)
  • primus_turbo/pytorch/kernels/moe/moe_utils.py (profiling / topology / unique helpers)
  • primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py (autotuner + Mori backend + Triton histogram kernel)
  • primus_turbo/pytorch/ops/moe/moe_dispatch_combine.py
  • tests/pytorch/modules/test_token_dispatcher.py

Key checks

  • Env vars (PRIMUS_TURBO_MORI_NUM_QP_PER_PE, NODE_RANK/GROUP_RANK/SLURM_NODEID) feed only int() casts and rank-id comparisons — no shell, path, SQL, or deserialization sinks.
  • tempfile.NamedTemporaryFile(suffix=".json") uses secure defaults; content is written by the in-process prof.export_chrome_trace (no attacker input).
  • suppress_stdout_stderr manipulates only its own file descriptors.
  • dist.all_gather_object in detect_group_topology pickles over an internal torch.distributed group (trusted training peers only).
  • Triton compute_expert_token_info_kernel clamps OOB rows/cols via tl.where + load_mask/flat_mask and clamps invalid/out-of-range expert ids before tl.histogram/tl.atomic_add — no OOB reads/writes or illegal bin counts.
  • inplace_unique pads negatives into num_slots and drops the padding bin, keeping scatter_add_ in-range.
  • MoEDispatchCombineAutoTuner._handle_cache is LRU-bounded (_HANDLE_CACHE_MAX=1024) and now identity-checks the stored handle (stored_handle is not handle) to guard against id() reuse; shape cache uses TuneCache(capacity=1024).
  • torch._C._distributed_c10d._register_process_group("mori", group) receives only internally constructed arguments.
  • The new mori dependency is an optional import (guarded by try/except ImportError); no code runs unless installed.

Result
No medium+ severity security issue introduced or exposed by this PR. No inline findings.

Open in Web View Automation 

Sent by Cursor Automation: Find vulnerabilities

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +198 to 220
self._bind_device()
envs = {k: v for k, v in os.environ.items()}
envs.pop("PRIMUS_TURBO_MOE_DISPATCH_COMBINE_BACKEND", None)
envs["PRIMUS_TURBO_AUTO_TUNE"] = "1"

with patch.dict(os.environ, envs, clear=True):
MoEDispatchCombineAutoTuner.clear()

# First run triggers autotune and registers the handle mapping.
_run_dispatch_combine(self.rank, self.pg)
first = MoEDispatchCombineAutoTuner.current()
assert first is not None, "autotune should have populated a result"
assert (
len(MoEDispatchCombineAutoTuner._handle_cache) > 0
), "moe_dispatch should have bound the tuned result to at least one handle"

# Second run hits the shape cache (no extra measurements) and
# yields the same backend selection.
_run_dispatch_combine(self.rank, self.pg)
second = MoEDispatchCombineAutoTuner.current()
assert second is first or second.backend_name == first.backend_name


Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_autotune_sweep enables PRIMUS_TURBO_AUTO_TUNE=1 and then runs two full dispatch+combine passes at the default test shape (4096×4096 with topk=8). With the new autotuner this triggers a full config sweep using torch.profiler/Kineto inside the test, which is likely to be very slow and can make CI flaky/time out. Please consider reducing the problem size for this test (smaller NUM_TOKENS/HIDDEN_SIZE), lowering the sweep iteration counts for test mode, or marking/gating the test as “slow” so it doesn’t run in the default unit test shard.

Suggested change
self._bind_device()
envs = {k: v for k, v in os.environ.items()}
envs.pop("PRIMUS_TURBO_MOE_DISPATCH_COMBINE_BACKEND", None)
envs["PRIMUS_TURBO_AUTO_TUNE"] = "1"
with patch.dict(os.environ, envs, clear=True):
MoEDispatchCombineAutoTuner.clear()
# First run triggers autotune and registers the handle mapping.
_run_dispatch_combine(self.rank, self.pg)
first = MoEDispatchCombineAutoTuner.current()
assert first is not None, "autotune should have populated a result"
assert (
len(MoEDispatchCombineAutoTuner._handle_cache) > 0
), "moe_dispatch should have bound the tuned result to at least one handle"
# Second run hits the shape cache (no extra measurements) and
# yields the same backend selection.
_run_dispatch_combine(self.rank, self.pg)
second = MoEDispatchCombineAutoTuner.current()
assert second is first or second.backend_name == first.backend_name
global NUM_TOKENS, HIDDEN_SIZE
self._bind_device()
envs = {k: v for k, v in os.environ.items()}
envs.pop("PRIMUS_TURBO_MOE_DISPATCH_COMBINE_BACKEND", None)
envs["PRIMUS_TURBO_AUTO_TUNE"] = "1"
old_num_tokens = NUM_TOKENS
old_hidden_size = HIDDEN_SIZE
with patch.dict(os.environ, envs, clear=True):
try:
# Keep the autotune path covered, but shrink the problem size so
# the profiler-backed config sweep remains practical in CI.
NUM_TOKENS = 512
HIDDEN_SIZE = 512
MoEDispatchCombineAutoTuner.clear()
# First run triggers autotune and registers the handle mapping.
_run_dispatch_combine(self.rank, self.pg)
first = MoEDispatchCombineAutoTuner.current()
assert first is not None, "autotune should have populated a result"
assert (
len(MoEDispatchCombineAutoTuner._handle_cache) > 0
), "moe_dispatch should have bound the tuned result to at least one handle"
# Second run hits the shape cache (no extra measurements) and
# yields the same backend selection.
_run_dispatch_combine(self.rank, self.pg)
second = MoEDispatchCombineAutoTuner.current()
assert second is first or second.backend_name == first.backend_name
finally:
NUM_TOKENS = old_num_tokens
HIDDEN_SIZE = old_hidden_size

Copilot uses AI. Check for mistakes.
Comment on lines +1648 to +1651
def _get_backend_name() -> str:
"""Return the user-selected backend name, or ``TURBO`` by default."""
bt = GlobalBackendManager.get_moe_dispatch_combine_backend(PrecisionType.BF16_FP16_FP32)
return bt.name if bt is not None else _DEFAULT_BACKEND_NAME
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_backend_name() now only consults GlobalBackendManager.get_moe_dispatch_combine_backend(). That helper intentionally returns None when PRIMUS_TURBO_MOE_DISPATCH_COMBINE_BACKEND is set to a custom (non-BackendType) backend name, so this change forces a silent fallback to TURBO and breaks the documented “custom names like UCCL_EP” behavior. Please restore the previous behavior by reading ENV_MOE_DISPATCH_COMBINE_BACKEND directly when GlobalBackendManager returns None, and using the raw env value (uppercased/stripped) as the backend registry lookup key.

Copilot uses AI. Check for mistakes.
Comment on lines +449 to +451
element_size = 1 if fp8_dispatch else 2
hidden_bytes = hidden_size * max(element_size, 2)

Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_DeepEPLikeBackend.init_buffer() computes hidden_bytes as hidden_size * 2 for all non-FP8 inputs (element_size is hard-coded to 2). This can under-allocate buffers when x is float32 (4 bytes/elem) and conflicts with the previous logic that sized by max(inp.element_size(), 2). Either (a) size by the actual input element size (pass dtype/element_size into init_buffer), or (b) explicitly restrict dispatch/combine to BF16/FP16 and raise a clear error for FP32 to avoid silent buffer overruns.

Copilot uses AI. Check for mistakes.
Comment on lines +1428 to +1436
if handle is None:
return
hid = id(handle)
cache = cls._handle_cache
if hid in cache:
cache.move_to_end(hid)
cache[hid] = (handle, result)
if len(cache) > cls._HANDLE_CACHE_MAX:
cache.popitem(last=False)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MoEDispatchCombineAutoTuner._handle_cache stores strong references to the full handle object (cache[hid] = (handle, result)). If a backend’s handle captures large tensors or GPU resources (e.g., MoriEPBackend’s handle currently includes recv_topk_idx_i32), this cache can retain significant memory until eviction, which is risky in long-running training. Consider storing only the tuned result keyed by id(handle) (accepting the small risk of id reuse), or redesigning handles to be lightweight/opaque IDs so the cache doesn’t keep heavyweight objects alive.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale comment

Security review (automation)

Re-validated after the latest push (8bd040b). The incremental change since the last assessment is a mechanical rename (PRIMUS_TURBO_MOE_DISPATCH_COMBINE_BACKENDPRIMUS_TURBO_EP_BACKEND, get_moe_dispatch_combine_backendget_ep_backend) and the removal of eager deep_ep / mori imports in primus_turbo/pytorch/core/backend.py (both are now lazily loaded inside the EP backend registry). No trust boundary, permission check, or supply-chain surface changes.

Scope

  • primus_turbo/common/constants.py, docs/examples.md, tests/pytorch/core/test_global_backend_manager.py, tests/pytorch/modules/test_token_dispatcher.py (env-var rename)
  • primus_turbo/pytorch/core/backend.py (dropped eager HAVE_DEEP_EP/HAVE_MORI guards; lazy imports remain inside MoriEPBackend.is_available() / _get_module())
  • primus_turbo/pytorch/kernels/moe/moe_utils.py (profiling / topology / unique helpers)
  • primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py (autotuner + Mori backend + Triton histogram kernel)
  • primus_turbo/pytorch/ops/moe/moe_dispatch_combine.py

Key checks (still green)

  • Env vars (PRIMUS_TURBO_MORI_NUM_QP_PER_PE, NODE_RANK/GROUP_RANK/SLURM_NODEID, PRIMUS_TURBO_EP_BACKEND) feed only int() casts, bounded registry lookups (_BACKEND_REGISTRY), and rank/host token equality — no shell, path, SQL, or deserialization sinks.
  • tempfile.NamedTemporaryFile(suffix=".json") uses secure defaults; content produced by the in-process prof.export_chrome_trace (no attacker input).
  • suppress_stdout_stderr manipulates only its own file descriptors.
  • dist.all_gather_object in detect_group_topology runs over the trusted torch.distributed group.
  • Triton compute_expert_token_info_kernel clamps OOB rows/cols via tl.where + load_mask/flat_mask and clamps invalid/out-of-range expert ids before tl.histogram / tl.atomic_add — no OOB reads/writes or illegal bins.
  • inplace_unique pads negatives into num_slots and drops the padding bin, keeping scatter_add_ in-range.
  • MoEDispatchCombineAutoTuner._handle_cache is LRU-bounded (_HANDLE_CACHE_MAX=1024) and identity-checks the stored handle (stored_handle is not handle) to defend against id() reuse; shape cache uses TuneCache(capacity=1024).
  • torch._C._distributed_c10d._register_process_group("mori", group) receives only internally constructed arguments.
  • mori / deep_ep remain optional dependencies (guarded by try/except ImportError in each backend's is_available() / _get_module()).
  • No secret, token, payload, or model-input values are logged.

Result
No medium+ severity security issue introduced or exposed by this PR. No inline findings.

Open in Web View Automation 

Sent by Cursor Automation: Find vulnerabilities

Copilot AI review requested due to automatic review settings April 29, 2026 08:57
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 60 out of 68 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +27 to +29
# loosen timeout for this test to avoid timeout failures
common_distributed.TIMEOUT_DEFAULT = 600

Comment on lines +77 to 80
// Controlled by ``PRIMUS_TURBO_EP_FORCE_CURRENT_STREAM`` (default ``0``).
// When true, all dispatch/combine kernels run on the caller's current CUDA
// stream instead of ``comm_stream``.
bool force_current_stream = true;
Comment on lines +388 to +394
batch_size, seq_len_q, num_heads_q, _ = q.shape
_, _, _, head_dim_v = v.shape
out = torch.empty((batch_size, seq_len_q, num_heads_q, head_dim_v), dtype=q.dtype, device=q.device)
softmax_lse = torch.empty((batch_size, num_heads_q, seq_len_q), dtype=torch.float32, device=q.device)
S_dmask = torch.empty((0,), dtype=q.dtype, device=q.device)
rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
return out, softmax_lse, S_dmask, rng_state
Comment on lines +154 to +157
if not deepep_use_comm_stream and os.environ.get(ENV_EP_FORCE_CURRENT_STREAM) != "1":
clear_backend_instances()
os.environ[ENV_EP_FORCE_CURRENT_STREAM] = "1"

Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale comment

Security review (automation)

Re-validated after the latest push (b96afca). The incremental changes since the last assessment are: (1) CK GEMM/Grouped GEMM directory restructure (file moves only); (2) numerical correctness fixes in csrc/kernels/quantization/quantization_mxfp8.cu (cast to int64_t to avoid int32 offset overflow), csrc/kernels/reduce/reduce_row.cuh (partial-tile indexing), and csrc/kernels/gemm/turbo/turbo_gemm_mxfp8_kernel.h (__builtin_amdgcn_s_barrier() WAR fence); (3) attention BHSD layout support in attention_aiter_impl.py / flash_attn_interface.py / flash_attn_usp_interface.py / attention_ring.py, with torch.library.custom_op wrappers and matching register_fake shape implementations; (4) a new deepep_use_comm_stream=False knob in DeepEPTokenDispatcher that calls clear_backend_instances() and sets the constant PRIMUS_TURBO_EP_FORCE_CURRENT_STREAM env var to "1"; (5) a _format_kwargs(...) helper that renders tensor shape/dtype (not data) and repr() of scalars/enums into backend mismatch error messages; (6) the nits commit dropping dispatch_config/combine_config from a debug log line.

Scope

  • csrc/kernels/quantization/quantization_mxfp8.cu, csrc/kernels/reduce/reduce_row.cuh, csrc/kernels/gemm/turbo/turbo_gemm_mxfp8_kernel.h, csrc/include/primus_turbo/deep_ep/configs.h, csrc/pytorch/deep_ep/deep_ep.hpp
  • primus_turbo/common/constants.py (new ENV_EP_FORCE_CURRENT_STREAM)
  • primus_turbo/pytorch/core/backend.py (_format_kwargs)
  • primus_turbo/pytorch/deep_ep/buffer.py, primus_turbo/pytorch/modules/moe/token_dispatcher.py (env-var toggle, clear_backend_instances())
  • primus_turbo/pytorch/kernels/attention/attention_aiter_impl.py, primus_turbo/pytorch/ops/attention/{attention_utils.py,flash_attn_interface.py,flash_attn_usp_interface.py,usp/attention_ring.py}
  • primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py (added clear_backend_instances; nits-only logger change)
  • tests/pytorch/{ops/test_attention.py,ops/test_attention_with_cp.py,ops/test_quantization.py,modules/test_token_dispatcher.py}

Key checks (still green)

  • _format_kwargs only emits Tensor(shape=..., dtype=...) for tensors and repr() for scalar/enum kwargs — no tensor contents, weights, gradients, or attacker payload values are logged. Error messages stay free of secrets/PII.
  • os.environ[ENV_EP_FORCE_CURRENT_STREAM] = "1" writes a hard-coded constant; no user-controlled string flows into os.environ or any shell. The C++ side reads the same env var via std::getenv and passes it through std::stoi, with the value bounded to a bool.
  • clear_backend_instances() only mutates the module-private _backend_instances dict — no FS, network, or trust-boundary surface.
  • torch.library.custom_op(..., mutates_args=("dq","dk","dv"), device_types="cuda") wrappers receive only Tensors / scalars from in-process callers; register_fake impls allocate empty tensors and don't dereference user data.
  • BHSD layout handling is restricted to qkv_format in ("bshd","sbhd","bhsd") with explicit assert/raise ValueError on unknown values; no path/format string is interpolated into a sink.
  • Sink-attention fast-path constraints (head_dim_qk == head_dim_v, power-of-2) moved into can_handle() — same-or-stronger validation as before.
  • mxfp8 int64_t cast and reduce_row partial-tile fix close potential signed-int32 offset overflows that could have produced OOB stores on extreme shapes; this is a security/safety improvement, not a regression.
  • MoEDispatchCombineAutoTuner._handle_cache remains LRU-bounded (_HANDLE_CACHE_MAX=1024) with id() reuse defended by stored_handle is not handle identity check; TuneCache(capacity=1024) shape cache unchanged.
  • mori / deep_ep remain optional, lazily imported inside each backend's is_available() / _get_module().
  • No new env vars feed into shell/path/SQL/deserialization sinks; PRIMUS_TURBO_EP_FORCE_CURRENT_STREAM and other env reads (int() casts, registry lookups, rank/host token equality) stay numeric/bounded.

Result
No medium+ severity security issue introduced or exposed by this PR. No inline findings.

Open in Web View Automation 

Sent by Cursor Automation: Find vulnerabilities

Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale comment

Security review (automation)

Re-validated after the latest push (aab8575). Incremental delta since the last assessment (b96afca):

  • primus_turbo/pytorch/core/backend.py: adds BackendType.UCCL enum entry only.
  • primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py: bumps EPBufferConfig.num_sms default from 32 to 64, adds a "UCCL_EP" entry to _DEFAULT_BUFFER_CONFIG_PER_BACKEND, and adds a new UCCLEPBackend(_DeepEPLikeBackend) whose is_available() / _get_module() lazily import uccl.ep. The new class is not wired into _BACKEND_REGISTRY, so it is only reachable if a caller invokes register_ep_backend("UCCL_EP", UCCLEPBackend) explicitly.
  • setup.py: adds a mandatory install_requires entry uccl @ git+https://github.com/uccl-project/uccl.git@9bae94c59229f82efbbd5b78a7f222ce19e74e86.

Key checks

  • Supply chain: the new uccl dependency is fetched from the upstream uccl-project/uccl repo (already referenced by the project's existing DeepEP fallback documentation) and is pinned to a fixed commit SHA, so installs cannot silently move to a different revision via tag/branch retargeting. No version-range or floating-ref injection risk introduced. The constant UCCL_COMMIT is a hard-coded literal — no env-var or user input flows into the install URL.
  • UCCLEPBackend._get_module() only does a lazy import uccl.ep; no attacker-controlled string is passed to import and the import target is fixed.
  • _make_buffer_kwargs() introspects the upstream Buffer constructor with inspect.signature; only the boolean is_intranode (derived from group.size()) is forwarded — no untrusted kwargs propagation.
  • EPBufferConfig default change (num_sms 32 → 64) and new UCCL_EP defaults are pure numeric tuning constants; no trust-boundary impact.
  • Previously validated surfaces (Mori SHMEM init, autotuner LRU bounds + identity-checked handle cache, compute_expert_token_info_kernel OOB clamps, inplace_unique bin-padding, _format_kwargs shape/dtype-only logging, PRIMUS_TURBO_EP_FORCE_CURRENT_STREAM env-var path, mxfp8 int64_t offset cast, BHSD layout qkv_format allowlist, register_fake stubs, custom-op wrappers, attention_utils._infer_qkv_format) remain unchanged in this delta and still hold.
  • No new env-var / file-path / shell / SQL / deserialization / network sinks; no secret or token logging; no permission-check or authn change.

Result
No medium+ severity security issue introduced or exposed by this PR. No inline findings.

Open in Web View Automation 

Sent by Cursor Automation: Find vulnerabilities

Copilot AI review requested due to automatic review settings May 6, 2026 04:21
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale comment

Security review (automation)

Re-validated after the latest push (c4906b5). Incremental delta since the last assessment (aab8575):

  • primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py: a single-line registry entry ("UCCL": UCCLEPBackend) added to _BACKEND_REGISTRY. The UCCLEPBackend class itself (added in the prior commit) follows the same optional-import + _DeepEPLikeBackend pattern already used by DeepEPBackend (lazy import uccl.ep, no eager dependency, identical buffer-kwargs introspection).

Scope re-checked

  • primus_turbo/pytorch/core/backend.py (BackendType.UCCL enum + get_ep_backend env parsing, returns None for unknown enum values).
  • primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py (autotuner, EP backends incl. UCCL, Mori, DeepEP, Triton histogram kernel).
  • primus_turbo/pytorch/kernels/moe/moe_utils.py (bench_kineto, detect_group_topology, inplace_unique).
  • primus_turbo/pytorch/modules/moe/token_dispatcher.py (PRIMUS_TURBO_EP_FORCE_CURRENT_STREAM env mutation).
  • primus_turbo/pytorch/deep_ep/buffer.py and csrc/pytorch/deep_ep/deep_ep.hpp (force-current-stream toggle).
  • setup.py (new uccl install requirement pinned to commit 9bae94c5...).
  • csrc/kernels/... (CK GEMM/Grouped-GEMM directory restructure + numerical-correctness fixes; no new I/O sinks).
  • Attention/USP refactor and tests.

Threat checks performed

  • Injection: env values (PRIMUS_TURBO_EP_BACKEND, PRIMUS_TURBO_MORI_NUM_QP_PER_PE, ENV_EP_FORCE_CURRENT_STREAM) are parsed via BackendType[...]/int(...) with defensive fallbacks; no shell or SQL sinks. The Mori SHMEM process-group name ("mori") is a hard-coded constant.
  • Authn/authz: no auth surfaces touched.
  • Path traversal / unsafe deserialization: bench_kineto writes to a tempfile.NamedTemporaryFile (random path) and json.loads only the file it just wrote; no attacker-controlled path or payload.
  • SSRF / request forgery / XSS: no network or HTML sinks introduced.
  • Secret leakage / insecure logging: log lines emit shapes and tuning metadata only; no credentials or tokens.
  • Supply chain: the new uccl dependency in setup.py is pinned to a specific commit hash on github.com/uccl-project/uccl, which mitigates floating-tag tampering risk and matches existing patterns (e.g. pinned AITER_COMMIT). UCCLEPBackend only imports uccl.ep lazily and is selected via the existing EP backend registry.
  • Process-wide env mutation in DeepEPTokenDispatcher.__init__ (os.environ[ENV_EP_FORCE_CURRENT_STREAM]="1" + clear_backend_instances()): a usability concern previously raised by Copilot, but not a security boundary issue (caller already controls the process), so out of scope here.

Conclusion: no medium-or-higher security findings. No new finding comments posted.

Open in Web View Automation 

Sent by Cursor Automation: Find vulnerabilities

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 61 out of 69 changed files in this pull request and generated 9 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +202 to +221
# Per-backend default buffer configuration. Keys must match the names used
# in ``_BACKEND_REGISTRY`` so lookups can be done by backend name.
_DEFAULT_BUFFER_CONFIG_PER_BACKEND: Dict[str, EPBufferConfig] = {
"TURBO": EPBufferConfig(
num_sms=64,
dispatch_config=None,
combine_config=None,
),
"DEEP_EP": EPBufferConfig(
num_sms=64,
dispatch_config=None,
combine_config=None,
),
"UCCL_EP": EPBufferConfig(
num_sms=64,
dispatch_config=None,
combine_config=None,
),
"MORI": EPBufferConfig(num_sms=64, dispatch_config=None, combine_config=None),
}
Comment on lines +451 to +456
dispatch_config = config.dispatch_config or BufferClass.get_dispatch_config(group.size())
combine_config = config.combine_config or BufferClass.get_combine_config(group.size())

element_size = 1 if fp8_dispatch else 2
hidden_bytes = hidden_size * max(element_size, 2)

Comment on lines +1688 to +1691
def _get_backend_name() -> str:
"""Return the user-selected backend name, or ``TURBO`` by default."""
bt = GlobalBackendManager.get_ep_backend(PrecisionType.BF16_FP16_FP32)
return bt.name if bt is not None else _DEFAULT_BACKEND_NAME
Comment on lines +387 to +394
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, seq_len_q, num_heads_q, _ = q.shape
_, _, _, head_dim_v = v.shape
out = torch.empty((batch_size, seq_len_q, num_heads_q, head_dim_v), dtype=q.dtype, device=q.device)
softmax_lse = torch.empty((batch_size, num_heads_q, seq_len_q), dtype=torch.float32, device=q.device)
S_dmask = torch.empty((0,), dtype=q.dtype, device=q.device)
rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
return out, softmax_lse, S_dmask, rng_state
Comment on lines +27 to +28
# loosen timeout for this test to avoid timeout failures
common_distributed.TIMEOUT_DEFAULT = 600
Comment thread setup.py
Comment on lines 403 to 406
# Entry points and Install Requires
entry_points = {}
install_requires = []
install_requires = [f"uccl @ git+https://github.com/uccl-project/uccl.git@{UCCL_COMMIT}"]

Comment on lines +77 to 80
// Controlled by ``PRIMUS_TURBO_EP_FORCE_CURRENT_STREAM`` (default ``0``).
// When true, all dispatch/combine kernels run on the caller's current CUDA
// stream instead of ``comm_stream``.
bool force_current_stream = true;
Comment on lines +154 to +156
if not deepep_use_comm_stream and os.environ.get(ENV_EP_FORCE_CURRENT_STREAM) != "1":
clear_backend_instances()
os.environ[ENV_EP_FORCE_CURRENT_STREAM] = "1"
Comment on lines +52 to +54
def test_ep_backend_env(self, monkeypatch):
monkeypatch.setenv("PRIMUS_TURBO_EP_BACKEND", "triton")
assert GlobalBackendManager.get_ep_backend(PrecisionType.FP8) == BackendType.TRITON
@zhenhuang12 zhenhuang12 force-pushed the refactor/moe-dispatch-combine-autotune branch from c4906b5 to f735ba0 Compare May 12, 2026 02:43
Copilot AI review requested due to automatic review settings May 18, 2026 13:18
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stale comment

Security review (automation)

Re-validated after the latest push (5ba14d2). Scope of the incremental diff (vs. 6650792):

  • setup.py: pins uccl to commit 9bae94c5… via git+https://github.com/uccl-project/uccl.git@<sha>.
  • primus_turbo/common/constants.py and primus_turbo/pytorch/core/backend.py: rename PRIMUS_TURBO_MOE_DISPATCH_COMBINE_BACKENDPRIMUS_TURBO_EP_BACKEND; add BackendType.MORI, BackendType.UCCL, and ENV_MORI_NUM_QP_PER_PE.
  • primus_turbo/pytorch/kernels/moe/_backend/ (new): EPBackend protocol, shared _DeepEPLikeBackend, plus TurboEPBackend, DeepEPBackend, UCCLEPBackend, MoriEPBackend. Includes _apply_env_with_nccl_fallback (copies NCCL_*MORI_*/UCCL_* when unset), _align_mori_hip_with_torch (ctypes.CDLL(<torch lib dir>/libamdhip64.so)), _register_and_init_mori_shmem, and an autotune sweep harness.
  • primus_turbo/pytorch/kernels/moe/moe_dispatch_combine_impl.py: per-backend default EPBufferConfig table, shape-keyed MoEDispatchCombineAutoTuner, handle-id binding cache, CUDA-graph-capture warning hook.
  • primus_turbo/pytorch/kernels/moe/moe_utils.py (new): kineto bench helper, detect_group_topology (uses dist.all_gather_object over an env-or-gethostname() token), inplace_unique, suppress_stdout_stderr.
  • token_dispatcher.py / moe_dispatch_combine.py: tensor-vs-list guard and CUDA-graph-capture-safe path.
  • Tests + docs/examples.md: rename / parametrization updates only.

Findings

No new medium or higher severity vulnerabilities are introduced. Notes on items examined and judged not to rise to that bar:

  • setup.py adds a Git-URL install_requires for uccl but pins to a specific commit SHA, so it is not an unpinned/floating dependency.
  • All env-var reads (ENV_EP_BACKEND, ENV_MORI_NUM_QP_PER_PE, NCCL_*MORI_*/UCCL_*) are constrained: backend name is dereferenced through an in-process registry (_BACKEND_REGISTRY[name], raises on unknown), num_qp_per_pe is coerced via int(...), and the NCCL→MORI/UCCL fallback only copies env values that are already present in the same process's environment to a different env-var name. No shell, no eval, no subprocess.
  • _align_mori_hip_with_torch derives the .so path from os.path.dirname(torch.__file__); not attacker-controlled.
  • compute_expert_token_info_kernel (Triton) guards loads/stores with explicit mask= operands and clamps OOB rows/experts before address arithmetic.
  • detect_group_topology uses dist.all_gather_object (pickle under the hood) over a token sourced from NODE_RANK/GROUP_RANK/SLURM_NODEID/socket.gethostname(). Peers in an EP process group share the same trust boundary as the training job, so this is not a new attack surface relative to existing torch.distributed idioms in this codebase.
  • bench_kineto writes traces through tempfile.NamedTemporaryFile; no untrusted-path / TOCTOU exposure.

Prior threads from earlier automation runs

Prior security-review summaries from this automation all reached the same "no findings" verdict. They are being collapsed as outdated; this comment supersedes them.

Open in Web View Automation 

Sent by Cursor Automation: Find vulnerabilities

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 16 out of 16 changed files in this pull request and generated 8 comments.

dispatch_config=None,
combine_config=None,
),
"UCCL_EP": EPBufferConfig(
Comment on lines +212 to +216
_cache: TuneCache = TuneCache(capacity=1024)

_HANDLE_CACHE_MAX: int = 1024
_handle_cache: "OrderedDict[int, Tuple[Any, EPAutoTuneResult]]" = OrderedDict()

Comment on lines +479 to 483
"""Return the user-selected backend name, or ``TURBO`` by default."""
bt = GlobalBackendManager.get_ep_backend(PrecisionType.BF16_FP16_FP32)
return bt.name if bt is not None else _DEFAULT_BACKEND_NAME


Comment thread setup.py Outdated
Comment on lines 405 to 406
install_requires = [f"uccl @ git+https://github.com/uccl-project/uccl.git@{UCCL_COMMIT}"]

assert topk_idx is not None
topk_idx_i32 = topk_idx.to(torch.int32)
else:
assert topk_idx is None, token_weights is None

import mori.jit.hip_driver as _hd

# ensture must have _hip
num_recv_tokens_per_expert = num_recv_tokens_per_expert.tolist()

# hold token_weights for dispatch weights in backward
# it's a workaround to aviod illegal access when token_weights is None
Comment on lines +53 to +54
monkeypatch.setenv("PRIMUS_TURBO_EP_BACKEND", "triton")
assert GlobalBackendManager.get_ep_backend(PrecisionType.FP8) == BackendType.TRITON
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Security review (automation)

Re-validated for the latest push (d094e9f). No medium+ severity security findings.

Incremental delta since the last assessment (5ba14d2):

  • EPBufferConfig split into primus_turbo/pytorch/kernels/moe/_backend/_config.py (refactor, no behavior change).
  • _backend/__init__.py re-exports updated; no new optional-dep imports at module load time.
  • setup.py: uccl @ git+https://github.com/uccl-project/uccl.git@<pinned-commit> moved from install_requires to extras_require["uccl"]. This is a supply-chain hygiene improvement — the unsigned external GitHub source is now opt-in instead of pulled by every default install. The commit hash remains pinned (immutable reference).
  • Minor docstring/comment edits in backend.py, constants.py.

Scope re-checked across the full PR diff:

  • primus_turbo/pytorch/kernels/moe/_backend/{base,mori,deep_ep,turbo,uccl_ep}.py — backend implementations.
  • moe_dispatch_combine_impl.py — autotuner and registry.
  • moe_utils.pybench_kineto, detect_group_topology, Triton helpers.
  • token_dispatcher.py, ops/moe/moe_dispatch_combine.py — graph-capture handling.

Checks performed and outcome:

  • Env-var ingestion (ENV_EP_BACKEND, ENV_MORI_NUM_QP_PER_PE, NCCL_* -> UCCL_* / MORI_* fallback in _apply_env_with_nccl_fallback): values are only set into the current process env or parsed as int; no shell expansion, no path/command injection.
  • ctypes.CDLL(torch_hip) in _align_mori_hip_with_torch: path is derived from torch.__file__ (trusted Python install), gated by os.path.isfile. Not attacker-controlled.
  • Triton kernel compute_expert_token_info_kernel uses bounds masks (load_mask, footprint_mask, bin_mask) and clamps invalid lanes before atomic add; no OOB write path from tensor shapes.
  • tempfile.NamedTemporaryFile(suffix=".json") in bench_kineto: secure temp creation, no path traversal.
  • dist.all_gather_object in detect_group_topology: pickle is used inside an already-trusted EP process group (standard PyTorch distributed pattern, not new attack surface introduced by this PR).
  • lru_cache on _build_mori_op and OrderedDict handle cache: keyed on integers / id() with is re-check; no DoS or aliasing exploit path.
  • Supply chain: UCCL_COMMIT and AITER_COMMIT are pinned to immutable commit hashes; UCCL is now optional.

Result: No injection, authn/authz, secret-leakage, SSRF/XSS, path-traversal, or unsafe-deserialization issues introduced by this PR. No new comments to post.

Open in Web View Automation 

Sent by Cursor Automation: Find vulnerabilities

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants