Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/sglang_kimi_k25_2node.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ mooncake:
enable_gpu_direct: false
device_name: mlx5_7,mlx5_8,mlx5_9,mlx5_10,mlx5_11 # Please change this to correct network devices.
# protocol: tcp # Switch to TCP if RDMA is not available.
# kv_lease_ttl_s: 60 # If TCP transfer is slow, increase the lease TTL. Also increase the global_segment_size accordingly.
# kv_lease_ttl_s: 60 # Mooncake master internal lease only; does not affect deletion timing. Increase global_segment_size for TCP.


output_dir: ./outputs/train_kimi25_2node_h200
Expand Down
2 changes: 1 addition & 1 deletion configs/sglang_minimax_m25_5node.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ mooncake:
protocol: rdma
device_name: mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9,mlx5_10,mlx5_11
enable_gpu_direct: false
kv_lease_ttl_s: 3.0
kv_lease_ttl_s: 3.0 # Mooncake master internal lease only; does not affect deletion timing

output_dir: ./outputs/sglang_minimax_m25_5node
cache_dir: ./cache/sglang_minimax_m25_5node
Expand Down
87 changes: 87 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2026 LightSeek Foundation
# MIT License

"""Root conftest: stubs heavy dependencies when they are not installed.

When torch is not available (e.g. Mac dev machine), installs:
1. A meta-path finder that mocks torch/mooncake/transformers/etc.
2. A lightweight ``torchspec`` package stub so submodule imports don't
trigger ``torchspec/__init__.py``'s eager model imports.
"""

import importlib.abc
import importlib.machinery
import os
import sys
from types import ModuleType
from unittest.mock import MagicMock


def _is_available(name):
try:
__import__(name)
return True
except ImportError:
return False


class _MockFinder(importlib.abc.MetaPathFinder):
"""Returns mock modules for missing heavy deps."""

PREFIXES = (
"torch",
"mooncake",
"transformers",
"flash_attn",
"triton",
"vllm",
"sglang",
"numba",
"ray",
"omegaconf",
"openai",
"huggingface_hub",
"safetensors",
"accelerate",
"peft",
"wandb",
"datasets",
"tokenizers",
"sentencepiece",
"pyzmq",
"zmq",
)

def find_spec(self, fullname, path, target=None):
for prefix in self.PREFIXES:
if fullname == prefix or fullname.startswith(prefix + "."):
return importlib.machinery.ModuleSpec(fullname, self, is_package=True)
return None

def create_module(self, spec):
class _Proxy(ModuleType):
def __getattr__(self, name):
return MagicMock()

proxy = _Proxy(spec.name)
proxy.__path__ = []
proxy.__package__ = spec.name
return proxy

def exec_module(self, module):
pass


if not _is_available("torch"):
# 1. Install the mock finder for heavy deps
sys.meta_path.insert(0, _MockFinder())

# 2. Pre-seed torchspec as a namespace package so that
# ``from torchspec.config.mooncake_config import ...`` does NOT
# trigger torchspec/__init__.py (which eagerly imports models → torch).
_root = os.path.dirname(os.path.abspath(__file__))
_pkg = ModuleType("torchspec")
_pkg.__path__ = [os.path.join(_root, "torchspec")]
_pkg.__package__ = "torchspec"
_pkg.__file__ = os.path.join(_root, "torchspec", "__init__.py")
sys.modules["torchspec"] = _pkg
5 changes: 2 additions & 3 deletions docs/code_architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ torchspec/
│ ├── eagle_store.py # EagleMooncakeStore
│ ├── buffers.py # HostBufferPool, GPUReceiveBuffer
│ ├── helpers.py # Buffer size calculation
│ ├── deferred_delete.py # Deferred key deletion
│ └── utils.py # Mooncake utility helpers
├── data/ # Data pipeline
│ ├── dataset.py # load_conversation_dataset()
Expand Down Expand Up @@ -155,7 +154,7 @@ Distributed tensor transfer for multi-node training:
- **`store.py`**: `MooncakeHiddenStateStore` - Base class with RDMA buffer management
- **`eagle_store.py`**: `EagleMooncakeStore` - Eagle3-specific wrapper with:
- Zero-copy `batch_put_from` for tensor storage
- Deferred deletion (respects 5-second lease TTL)
- Force deletion via `batch_remove(force=True)`
- Lazy tensor retrieval interface
- **`buffers.py`**: `HostBufferPool` (pre-allocated host buffers), `GPUReceiveBuffer` (GPU Direct RDMA)
- **`helpers.py`**: Buffer size calculation and Mooncake master process management
Expand Down Expand Up @@ -316,7 +315,7 @@ python train.py --config base.yaml --config experiment.yaml training.learning_ra

| Module | Purpose |
|--------|-------|
| `torchspec/transfer/mooncake/` | Mooncake tensor transfer (RDMA/TCP, buffer pools, deferred delete) |
| `torchspec/transfer/mooncake/` | Mooncake tensor transfer (RDMA/TCP, buffer pools, force delete) |
| `torchspec/utils/distributed.py` | Device mesh setup, TP/DP primitives (`get_tp_group`, `get_tp_device_mesh`) |
| `torchspec/utils/env.py` | Ray actor env-var forwarding (`get_torchspec_env_vars`) |
| `torchspec/utils/logging.py` | Unified logger |
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = [
"ninja",
"packaging",
"pyzmq",
"mooncake-transfer-engine",
"mooncake-transfer-engine>=0.3.10.post1",
"openai",
"omegaconf",
"ray",
Expand Down
243 changes: 243 additions & 0 deletions tests/test_mooncake_force_delete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
# Copyright (c) 2026 LightSeek Foundation
# MIT License

"""Tests for Mooncake Store force delete + hard pin refactoring.

Depends on conftest.py (project root) to stub torch and mooncake when
running on environments without GPU dependencies (e.g. Mac dev machines).
"""

import os
import sys
from unittest.mock import MagicMock, patch

import pytest

from torchspec.config.mooncake_config import MooncakeConfig
from torchspec.transfer.mooncake.buffers import AsyncPutManager
from torchspec.transfer.mooncake.eagle_store import EagleMooncakeStore
from torchspec.transfer.mooncake.store import MooncakeHiddenStateStore


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_eagle_store(mock_raw_store):
"""Create an EagleMooncakeStore with a mocked internal _store."""
config = MooncakeConfig()
store = EagleMooncakeStore(config)
store._store = mock_raw_store
store._initialized = True
return store


def _make_base_store(mock_raw_store, enable_hard_pin=False):
"""Create a MooncakeHiddenStateStore subclass with a mocked internal _store."""
config = MooncakeConfig(enable_hard_pin=enable_hard_pin)

class ConcreteStore(MooncakeHiddenStateStore):
pass

store = ConcreteStore(config)
store._store = mock_raw_store
return store


# ---------------------------------------------------------------------------
# Test 1: enable_hard_pin env roundtrip
# ---------------------------------------------------------------------------
class TestEnableHardPinConfig:
def test_enable_hard_pin_env_roundtrip(self):
config = MooncakeConfig(enable_hard_pin=True)
assert config.enable_hard_pin is True

with patch.dict(os.environ, {}, clear=False):
config.export_env()
assert os.environ["MOONCAKE_ENABLE_HARD_PIN"] == "1"
restored = MooncakeConfig.from_env()
assert restored.enable_hard_pin is True

def test_enable_hard_pin_default_off(self):
config = MooncakeConfig()
assert config.enable_hard_pin is False

with patch.dict(os.environ, {}, clear=False):
config.export_env()
assert os.environ["MOONCAKE_ENABLE_HARD_PIN"] == "0"
restored = MooncakeConfig.from_env()
assert restored.enable_hard_pin is False


# ---------------------------------------------------------------------------
# Tests 2-3: _verify_force_delete
# ---------------------------------------------------------------------------
class TestVerifyForceDelete:
def test_missing_batch_remove_method(self):
mock_raw = MagicMock(spec=[]) # no batch_remove attr
store = _make_base_store(mock_raw)
with pytest.raises(RuntimeError, match="batch_remove.*not found"):
store._verify_force_delete()

def test_missing_force_param(self):
mock_raw = MagicMock()
mock_raw.batch_remove = MagicMock()
mock_raw.batch_remove.__doc__ = "batch_remove(keys) -> list[int]"
store = _make_base_store(mock_raw)
# Patch importlib.metadata.version to force fallback to docstring check
with patch.dict(
sys.modules, {"importlib.metadata": MagicMock(version=MagicMock(side_effect=Exception))}
):
with pytest.raises(RuntimeError, match="batch_remove.*force.*not supported"):
store._verify_force_delete()

def test_valid_batch_remove(self):
mock_raw = MagicMock()
mock_raw.batch_remove = MagicMock()
mock_raw.batch_remove.__doc__ = "batch_remove(keys, force=False) -> list[int]"
store = _make_base_store(mock_raw)
store._verify_force_delete() # should not raise


# ---------------------------------------------------------------------------
# Tests 4-9: remove_eagle3_tensors
# ---------------------------------------------------------------------------
class TestRemoveEagle3Tensors:
def test_success(self):
mock_raw = MagicMock()
mock_raw.batch_remove.return_value = [0, 0]
store = _make_eagle_store(mock_raw)
store.remove_eagle3_tensors("k1")
mock_raw.batch_remove.assert_called_once_with(["k1_hs", "k1_ids"], force=True)

def test_not_found_is_success(self):
mock_raw = MagicMock()
mock_raw.batch_remove.return_value = [-704, 0]
store = _make_eagle_store(mock_raw)
store.remove_eagle3_tensors("k1")
assert mock_raw.batch_remove.call_count == 1

def test_retry_then_succeed(self):
mock_raw = MagicMock()
mock_raw.batch_remove.side_effect = [
[-1, 0], # first: k1_hs fails
[0], # second: k1_hs succeeds
]
store = _make_eagle_store(mock_raw)
store.remove_eagle3_tensors("k1")
assert mock_raw.batch_remove.call_count == 2

def test_exhaust_retries(self):
mock_raw = MagicMock()
mock_raw.batch_remove.return_value = [-1, -1]
store = _make_eagle_store(mock_raw)
# Should not raise
store.remove_eagle3_tensors("k1")
assert mock_raw.batch_remove.call_count == 3

def test_exception_is_retried(self):
mock_raw = MagicMock()
mock_raw.batch_remove.side_effect = [
RuntimeError("connection lost"),
[0, 0],
]
store = _make_eagle_store(mock_raw)
store.remove_eagle3_tensors("k1")
assert mock_raw.batch_remove.call_count == 2

def test_all_exceptions(self):
mock_raw = MagicMock()
mock_raw.batch_remove.side_effect = RuntimeError("down")
store = _make_eagle_store(mock_raw)
# Should not raise
store.remove_eagle3_tensors("k1")
assert mock_raw.batch_remove.call_count == 3


# ---------------------------------------------------------------------------
# Tests 10-11: cleanup does not mask put error
# ---------------------------------------------------------------------------
class TestCleanupDoesNotMaskPutError:
def test_sync_put_cleanup_does_not_mask_put_error(self):
mock_raw = MagicMock()
mock_raw.batch_put_from.return_value = [-1, 0]
mock_raw.batch_remove.side_effect = RuntimeError("cleanup failed")
store = _make_eagle_store(mock_raw)
store._replicate_config = None

with pytest.raises(RuntimeError, match="batch_put_from failed"):
store._do_sync_batch_put(["k_hs", "k_ids"], [100, 200], [64, 32])

def test_async_put_cleanup_does_not_mask_put_error(self):
mock_raw = MagicMock()
mock_raw.batch_put_from.return_value = [-1, 0]
mock_raw.batch_remove.side_effect = RuntimeError("cleanup failed")
mgr = AsyncPutManager(store=mock_raw, max_workers=1)
with pytest.raises(RuntimeError, match="async batch_put_from failed"):
mgr._do_put(["k_hs", "k_ids"], [100, 200], [64, 32])
mgr.shutdown()


# ---------------------------------------------------------------------------
# Tests 12-13: _build_replicate_config
# ---------------------------------------------------------------------------
class TestBuildReplicateConfig:
def test_supported(self):
mock_config_instance = MagicMock()
mock_config_instance.with_hard_pin = False
mock_store_module = MagicMock()
mock_store_module.ReplicateConfig.return_value = mock_config_instance

with patch.dict(sys.modules, {"mooncake.store": mock_store_module}):
mock_raw = MagicMock()
store = _make_base_store(mock_raw, enable_hard_pin=True)
store._build_replicate_config()
assert store._replicate_config is not None
assert store._replicate_config.with_hard_pin is True

def test_unsupported(self):
mock_config_instance = MagicMock(spec=[]) # no with_hard_pin attr
mock_store_module = MagicMock()
mock_store_module.ReplicateConfig.return_value = mock_config_instance

with patch.dict(sys.modules, {"mooncake.store": mock_store_module}):
mock_raw = MagicMock()
store = _make_base_store(mock_raw, enable_hard_pin=True)
store._build_replicate_config()
assert store._replicate_config is None


# ---------------------------------------------------------------------------
# Tests 14-15: replicate_config passed through put paths
# ---------------------------------------------------------------------------
class TestReplicateConfigPassthrough:
def test_sync_put_passes_replicate_config(self):
mock_raw = MagicMock()
mock_raw.batch_put_from.return_value = [0, 0]
store = _make_eagle_store(mock_raw)
mock_cfg = MagicMock()
store._replicate_config = mock_cfg

store._do_sync_batch_put(["k_hs", "k_ids"], [100, 200], [64, 32])
mock_raw.batch_put_from.assert_called_once_with(
["k_hs", "k_ids"], [100, 200], [64, 32], config=mock_cfg
)

def test_sync_put_no_config_when_none(self):
mock_raw = MagicMock()
mock_raw.batch_put_from.return_value = [0, 0]
store = _make_eagle_store(mock_raw)
store._replicate_config = None

store._do_sync_batch_put(["k_hs", "k_ids"], [100, 200], [64, 32])
mock_raw.batch_put_from.assert_called_once_with(["k_hs", "k_ids"], [100, 200], [64, 32])

def test_async_put_passes_replicate_config(self):
mock_raw = MagicMock()
mock_raw.batch_put_from.return_value = [0, 0]
mock_cfg = MagicMock()
mgr = AsyncPutManager(store=mock_raw, max_workers=1, replicate_config=mock_cfg)
mgr._do_put(["k_hs", "k_ids"], [100, 200], [64, 32])
mock_raw.batch_put_from.assert_called_once_with(
["k_hs", "k_ids"], [100, 200], [64, 32], config=mock_cfg
)
mgr.shutdown()
Loading
Loading