Support (P/D) disaggregation on mooncake#690
Open
ZhangLirong-amd wants to merge 9 commits intomainfrom
Open
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR refactors the Prefill/Decode (P/D) KV disaggregation stack to make backend connectors more modular (MoRIIO split into common/engine/connector modules, plus a new Mooncake backend), and adds correctness/performance workarounds for P/D workflows (T0 injection and GPU-side RDMA coherence fencing).
Changes:
- Extend scheduler P/D transition logic (T0 injection, first-decode handling) and plumb additional remote-KV diagnostics metadata through
ScheduledBatch. - Add a GPU memory coherence “fence” path (Triton kernel) and extend the connector interface to report received blocks requiring fencing.
- Refactor MoRIIO connector into smaller modules and add a new Mooncake connector backend; update factory registration and KV transfer proxy behavior.
Reviewed changes
Copilot reviewed 12 out of 13 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| atom/model_engine/scheduler.py | Adds P/D transition handling (including T0 injection) and extra PD diagnostics metadata on decode batches. |
| atom/model_engine/model_runner.py | Introduces Triton-based GPU memory fence for RDMA-written KV blocks and hooks it into KV connector aggregation. |
| atom/kv_transfer/disaggregation/utils.py | Adds shared RDMA chunking utility for tensor registration. |
| atom/kv_transfer/disaggregation/proxy.py | Adjusts decode streaming behavior/timeouts and corrects max_tokens handling for T0 override. |
| atom/kv_transfer/disaggregation/moriio/moriio_engine.py | New MoRIIO wrapper module extracted from connector logic. |
| atom/kv_transfer/disaggregation/moriio/moriio_connector.py | Refactors MoRIIO worker/scheduler connectors to use shared/common utilities and chunked registration. |
| atom/kv_transfer/disaggregation/moriio/moriio_common.py | New MoRIIO shared constants/types + availability checks extracted from connector. |
| atom/kv_transfer/disaggregation/moriio/init.py | Exports refactored MoRIIO connector classes. |
| atom/kv_transfer/disaggregation/mooncake/mooncake_connector.py | Adds new Mooncake push-mode RDMA connector (worker + scheduler). |
| atom/kv_transfer/disaggregation/factory.py | Updates MoRIIO registration paths and registers the new Mooncake backend. |
| atom/kv_transfer/disaggregation/base.py | Extends worker connector interface with get_finished_recv_blocks() for fencing. |
| atom/kv_transfer/disaggregation/aggregator.py | Makes KV completion aggregation robust to duplicate per-worker reports by tracking worker indices. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+1794
to
+1800
| fence_blocks = connector.get_finished_recv_blocks() | ||
| if fence_blocks: | ||
| with torch.cuda.stream(self.async_execute_stream): | ||
| self._gpu_memory_fence(fence_blocks) | ||
| event = torch.cuda.Event() | ||
| event.record(self.async_execute_stream) | ||
| self._fence_event = event |
Comment on lines
389
to
+392
| except Exception as e: | ||
| logger.exception("Error handling request: %s", e) | ||
| logger.exception( | ||
| "[PROXY] Error handling request #%d id=%s: %s", request_nums, request_id, e | ||
| ) |
| len(block_notify_list) > 0 | ||
| ), "block_notify_list cannot be empty in remote allocate message" | ||
|
|
||
| with self.lock: |
Comment on lines
+184
to
+188
| self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 | ||
| ): | ||
| assert self.local_memory_registered, "You have not register local memory data!" | ||
| assert self.moriio_engine is not None, "MoRIIO engine must be set first" | ||
| transfer_status = self.sessions[sess_idx].write( |
Comment on lines
+843
to
+853
| "consumer has %d regions", | ||
| num_regions, | ||
| len(consumer_base_addrs), | ||
| ) | ||
| if len(src_block_ids) != len(dst_block_ids): | ||
| logger.error( | ||
| "[PRODUCER] BLOCK COUNT MISMATCH: src has %d blocks, " | ||
| "dst has %d blocks", | ||
| len(src_block_ids), | ||
| len(dst_block_ids), | ||
| ) |
083294a to
00d534c
Compare
|
|
||
| ## 📢 News | ||
|
|
||
| - **[2026/05]** ATOM now supports **Prefill/Decode (P/D) disaggregation** with [Mooncake](https://github.com/kvcache-ai/Mooncake) RDMA push-mode KV cache transfer. See [PD disaggregation guide](docs/pd_disaggregation_guide.md). |
Comment on lines
+81
to
+83
| Run GSM8K evaluation against the consumer endpoint: | ||
|
|
||
| ```bash |
|
|
||
| Expected accuracy: ~0.95-0.96 (matching non-PD baseline). | ||
| ``` | ||
| ewshot: None, batch_size: 1 |
Comment on lines
+183
to
+189
| def write_remote_data_single( | ||
| self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 | ||
| ): | ||
| assert self.local_memory_registered, "You have not register local memory data!" | ||
| assert self.moriio_engine is not None, "MoRIIO engine must be set first" | ||
| transfer_status = self.sessions[sess_idx].write( | ||
| local_offset, |
| len(block_notify_list) > 0 | ||
| ), "block_notify_list cannot be empty in remote allocate message" | ||
|
|
||
| with self.lock: |
Comment on lines
+839
to
+861
| num_regions = len(self.kv_caches_base_addr) | ||
| if num_regions != len(consumer_base_addrs): | ||
| logger.error( | ||
| "[PRODUCER] REGION COUNT MISMATCH: producer has %d regions, " | ||
| "consumer has %d regions", | ||
| num_regions, | ||
| len(consumer_base_addrs), | ||
| ) | ||
| if len(src_block_ids) != len(dst_block_ids): | ||
| logger.error( | ||
| "[PRODUCER] BLOCK COUNT MISMATCH: src has %d blocks, " | ||
| "dst has %d blocks", | ||
| len(src_block_ids), | ||
| len(dst_block_ids), | ||
| ) | ||
| for region_idx in range(num_regions): | ||
| src_base = self.kv_caches_base_addr[region_idx] | ||
| dst_base = consumer_base_addrs[region_idx] | ||
| bpb = self._per_block_bytes_list[region_idx] | ||
| for sb, db in zip(src_block_ids, dst_block_ids): | ||
| src_addrs.append(src_base + sb * bpb) | ||
| dst_addrs.append(dst_base + db * bpb) | ||
| sizes.append(bpb) |
Comment on lines
+130
to
136
| KVConnectorFactory.register( | ||
| "mooncake", | ||
| worker_module="atom.kv_transfer.disaggregation.mooncake.mooncake_connector", | ||
| worker_class="MooncakeConnector", | ||
| scheduler_module="atom.kv_transfer.disaggregation.mooncake.mooncake_connector", | ||
| scheduler_class="MooncakeConnectorScheduler", | ||
| ) |
|
|
||
| ## 📢 News | ||
|
|
||
| - **[2026/05]** ATOM now supports **Prefill/Decode (P/D) disaggregation** with [Mooncake](https://github.com/kvcache-ai/Mooncake) RDMA push-mode KV cache transfer. See [PD disaggregation guide](docs/pd_disaggregation_guide.md). |
Comment on lines
+81
to
+83
| Run GSM8K evaluation against the consumer endpoint: | ||
|
|
||
| ```bash |
|
|
||
| Expected accuracy: ~0.95-0.96 (matching non-PD baseline). | ||
| ``` | ||
| ewshot: None, batch_size: 1 |
Comment on lines
+70
to
+71
| block_ids_t = torch.tensor(block_ids, dtype=torch.int32, device=kv_cache.device) | ||
| kv_flat = kv_cache.view(torch.int32) |
Comment on lines
+183
to
+195
| def write_remote_data_single( | ||
| self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 | ||
| ): | ||
| assert self.local_memory_registered, "You have not register local memory data!" | ||
| assert self.moriio_engine is not None, "MoRIIO engine must be set first" | ||
| transfer_status = self.sessions[sess_idx].write( | ||
| local_offset, | ||
| remote_offset, | ||
| transfer_size_byte, | ||
| self.moriio_engine.allocate_transfer_uid(), | ||
| ) | ||
| with self.lock: | ||
| self.transfer_status.append(transfer_status) |
| len(block_notify_list) > 0 | ||
| ), "block_notify_list cannot be empty in remote allocate message" | ||
|
|
||
| with self.lock: |
Comment on lines
+663
to
+665
| self._fence_event: Optional[torch.cuda.Event] = ( | ||
| None # for cross-partition fence | ||
| ) |
Comment on lines
+1810
to
+1813
| # GPU memory fence for RDMA-written KV blocks — disabled for | ||
| # same-partition deployments. See gpu_memory_fence() in | ||
| # atom/kv_transfer/disaggregation/utils.py for cross-partition use. | ||
|
|
Comment on lines
+4
to
+24
| """ | ||
| Shared types, constants, enums, and helpers for the MoRIIO KV connector. | ||
|
|
||
| This module has no heavy dependencies (no torch at import time, no RDMA | ||
| engine instances) so it can be imported freely by the other moriio | ||
| submodules. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| import threading | ||
| from enum import Enum | ||
| from typing import Optional | ||
|
|
||
| import msgspec | ||
|
|
||
| from atom.kv_transfer.disaggregation.utils import ( # noqa: F401 | ||
| MAX_RDMA_CHUNK_BYTES, | ||
| chunk_tensor_for_rdma, | ||
| ) |
|
|
||
| ## 📢 News | ||
|
|
||
| - **[2026/05]** ATOM now supports **Prefill/Decode (P/D) disaggregation** with [Mooncake](https://github.com/kvcache-ai/Mooncake) RDMA push-mode KV cache transfer. See [PD disaggregation guide](docs/pd_disaggregation_guide.md). |
Comment on lines
+81
to
+83
| Run GSM8K evaluation against the consumer endpoint: | ||
|
|
||
| ```bash |
|
|
||
| Expected accuracy: ~0.95-0.96 (matching non-PD baseline). | ||
| ``` | ||
| ewshot: None, batch_size: 1 |
Comment on lines
+183
to
+195
| def write_remote_data_single( | ||
| self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 | ||
| ): | ||
| assert self.local_memory_registered, "You have not register local memory data!" | ||
| assert self.moriio_engine is not None, "MoRIIO engine must be set first" | ||
| transfer_status = self.sessions[sess_idx].write( | ||
| local_offset, | ||
| remote_offset, | ||
| transfer_size_byte, | ||
| self.moriio_engine.allocate_transfer_uid(), | ||
| ) | ||
| with self.lock: | ||
| self.transfer_status.append(transfer_status) |
Comment on lines
+233
to
+240
| with _zmq_ctx(zmq.ROUTER, path) as sock: | ||
| while True: | ||
| try: | ||
| identity, msg = sock.recv_multipart() | ||
| self._dispatch_message(msg) | ||
| except Exception as e: | ||
| logger.error("Error processing message: %s", e) | ||
| raise ValueError(f"Error processing message: {e}") from e |
| len(block_notify_list) > 0 | ||
| ), "block_notify_list cannot be empty in remote allocate message" | ||
|
|
||
| with self.lock: |
Comment on lines
+652
to
+675
| # ---- PD diagnostic: block collision check ---- | ||
| if remote_kv_blocks and scheduled_seqs: | ||
| all_blocks_by_seq: dict[int, list[int]] = {} | ||
| for sid, seq in scheduled_seqs.items(): | ||
| all_blocks_by_seq[sid] = list(seq.block_table) | ||
| seen: dict[int, int] = {} | ||
| for sid, blocks in all_blocks_by_seq.items(): | ||
| for blk in blocks: | ||
| if blk in seen: | ||
| logger.error( | ||
| "[PD-DIAG] BLOCK COLLISION! block %d shared by " | ||
| "seq %d and seq %d. " | ||
| "seq %d blocks=%s, seq %d blocks=%s", | ||
| blk, | ||
| seen[blk], | ||
| sid, | ||
| seen[blk], | ||
| all_blocks_by_seq[seen[blk]][:10], | ||
| sid, | ||
| blocks[:10], | ||
| ) | ||
| else: | ||
| seen[blk] = sid | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR refactors the Prefill/Decode (P/D) disaggregation stack and adds Mooncake support for P/D separation.
Goals
Goal in next PR:
Test Result
Submission Checklist