diff --git a/configs/sglang_kimi_k25_2node.yaml b/configs/sglang_kimi_k25_2node.yaml index 5d3db2be..b7c5a152 100644 --- a/configs/sglang_kimi_k25_2node.yaml +++ b/configs/sglang_kimi_k25_2node.yaml @@ -71,7 +71,7 @@ mooncake: enable_gpu_direct: false device_name: mlx5_7,mlx5_8,mlx5_9,mlx5_10,mlx5_11 # Please change this to correct network devices. # protocol: tcp # Switch to TCP if RDMA is not available. - # kv_lease_ttl_s: 60 # If TCP transfer is slow, increase the lease TTL. Also increase the global_segment_size accordingly. + # kv_lease_ttl_s: 60 # Mooncake master internal lease only; does not affect deletion timing. Increase global_segment_size for TCP. output_dir: ./outputs/train_kimi25_2node_h200 diff --git a/configs/sglang_minimax_m25_5node.yaml b/configs/sglang_minimax_m25_5node.yaml index 0bfd934f..87905418 100644 --- a/configs/sglang_minimax_m25_5node.yaml +++ b/configs/sglang_minimax_m25_5node.yaml @@ -67,7 +67,7 @@ mooncake: protocol: rdma device_name: mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9,mlx5_10,mlx5_11 enable_gpu_direct: false - kv_lease_ttl_s: 3.0 + kv_lease_ttl_s: 3.0 # Mooncake master internal lease only; does not affect deletion timing output_dir: ./outputs/sglang_minimax_m25_5node cache_dir: ./cache/sglang_minimax_m25_5node diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..3d9e5ead --- /dev/null +++ b/conftest.py @@ -0,0 +1,87 @@ +# Copyright (c) 2026 LightSeek Foundation +# MIT License + +"""Root conftest: stubs heavy dependencies when they are not installed. + +When torch is not available (e.g. Mac dev machine), installs: +1. A meta-path finder that mocks torch/mooncake/transformers/etc. +2. A lightweight ``torchspec`` package stub so submodule imports don't + trigger ``torchspec/__init__.py``'s eager model imports. +""" + +import importlib.abc +import importlib.machinery +import os +import sys +from types import ModuleType +from unittest.mock import MagicMock + + +def _is_available(name): + try: + __import__(name) + return True + except ImportError: + return False + + +class _MockFinder(importlib.abc.MetaPathFinder): + """Returns mock modules for missing heavy deps.""" + + PREFIXES = ( + "torch", + "mooncake", + "transformers", + "flash_attn", + "triton", + "vllm", + "sglang", + "numba", + "ray", + "omegaconf", + "openai", + "huggingface_hub", + "safetensors", + "accelerate", + "peft", + "wandb", + "datasets", + "tokenizers", + "sentencepiece", + "pyzmq", + "zmq", + ) + + def find_spec(self, fullname, path, target=None): + for prefix in self.PREFIXES: + if fullname == prefix or fullname.startswith(prefix + "."): + return importlib.machinery.ModuleSpec(fullname, self, is_package=True) + return None + + def create_module(self, spec): + class _Proxy(ModuleType): + def __getattr__(self, name): + return MagicMock() + + proxy = _Proxy(spec.name) + proxy.__path__ = [] + proxy.__package__ = spec.name + return proxy + + def exec_module(self, module): + pass + + +if not _is_available("torch"): + # 1. Install the mock finder for heavy deps + sys.meta_path.insert(0, _MockFinder()) + + # 2. Pre-seed torchspec as a namespace package so that + # ``from torchspec.config.mooncake_config import ...`` does NOT + # trigger torchspec/__init__.py (which eagerly imports models → torch). + _root = os.path.dirname(os.path.abspath(__file__)) + _pkg = ModuleType("torchspec") + _pkg.__path__ = [os.path.join(_root, "torchspec")] + _pkg.__package__ = "torchspec" + _pkg.__file__ = os.path.join(_root, "torchspec", "__init__.py") + sys.modules["torchspec"] = _pkg diff --git a/docs/code_architecture.md b/docs/code_architecture.md index ab10ad47..43d2ac9c 100644 --- a/docs/code_architecture.md +++ b/docs/code_architecture.md @@ -54,7 +54,6 @@ torchspec/ │ ├── eagle_store.py # EagleMooncakeStore │ ├── buffers.py # HostBufferPool, GPUReceiveBuffer │ ├── helpers.py # Buffer size calculation -│ ├── deferred_delete.py # Deferred key deletion │ └── utils.py # Mooncake utility helpers ├── data/ # Data pipeline │ ├── dataset.py # load_conversation_dataset() @@ -155,7 +154,7 @@ Distributed tensor transfer for multi-node training: - **`store.py`**: `MooncakeHiddenStateStore` - Base class with RDMA buffer management - **`eagle_store.py`**: `EagleMooncakeStore` - Eagle3-specific wrapper with: - Zero-copy `batch_put_from` for tensor storage - - Deferred deletion (respects 5-second lease TTL) + - Force deletion via `batch_remove(force=True)` - Lazy tensor retrieval interface - **`buffers.py`**: `HostBufferPool` (pre-allocated host buffers), `GPUReceiveBuffer` (GPU Direct RDMA) - **`helpers.py`**: Buffer size calculation and Mooncake master process management @@ -316,7 +315,7 @@ python train.py --config base.yaml --config experiment.yaml training.learning_ra | Module | Purpose | |--------|-------| -| `torchspec/transfer/mooncake/` | Mooncake tensor transfer (RDMA/TCP, buffer pools, deferred delete) | +| `torchspec/transfer/mooncake/` | Mooncake tensor transfer (RDMA/TCP, buffer pools, force delete) | | `torchspec/utils/distributed.py` | Device mesh setup, TP/DP primitives (`get_tp_group`, `get_tp_device_mesh`) | | `torchspec/utils/env.py` | Ray actor env-var forwarding (`get_torchspec_env_vars`) | | `torchspec/utils/logging.py` | Unified logger | diff --git a/pyproject.toml b/pyproject.toml index 99a20da9..d97974d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "ninja", "packaging", "pyzmq", - "mooncake-transfer-engine", + "mooncake-transfer-engine>=0.3.10.post1", "openai", "omegaconf", "ray", diff --git a/tests/test_mooncake_force_delete.py b/tests/test_mooncake_force_delete.py new file mode 100644 index 00000000..129a1de1 --- /dev/null +++ b/tests/test_mooncake_force_delete.py @@ -0,0 +1,243 @@ +# Copyright (c) 2026 LightSeek Foundation +# MIT License + +"""Tests for Mooncake Store force delete + hard pin refactoring. + +Depends on conftest.py (project root) to stub torch and mooncake when +running on environments without GPU dependencies (e.g. Mac dev machines). +""" + +import os +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from torchspec.config.mooncake_config import MooncakeConfig +from torchspec.transfer.mooncake.buffers import AsyncPutManager +from torchspec.transfer.mooncake.eagle_store import EagleMooncakeStore +from torchspec.transfer.mooncake.store import MooncakeHiddenStateStore + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def _make_eagle_store(mock_raw_store): + """Create an EagleMooncakeStore with a mocked internal _store.""" + config = MooncakeConfig() + store = EagleMooncakeStore(config) + store._store = mock_raw_store + store._initialized = True + return store + + +def _make_base_store(mock_raw_store, enable_hard_pin=False): + """Create a MooncakeHiddenStateStore subclass with a mocked internal _store.""" + config = MooncakeConfig(enable_hard_pin=enable_hard_pin) + + class ConcreteStore(MooncakeHiddenStateStore): + pass + + store = ConcreteStore(config) + store._store = mock_raw_store + return store + + +# --------------------------------------------------------------------------- +# Test 1: enable_hard_pin env roundtrip +# --------------------------------------------------------------------------- +class TestEnableHardPinConfig: + def test_enable_hard_pin_env_roundtrip(self): + config = MooncakeConfig(enable_hard_pin=True) + assert config.enable_hard_pin is True + + with patch.dict(os.environ, {}, clear=False): + config.export_env() + assert os.environ["MOONCAKE_ENABLE_HARD_PIN"] == "1" + restored = MooncakeConfig.from_env() + assert restored.enable_hard_pin is True + + def test_enable_hard_pin_default_off(self): + config = MooncakeConfig() + assert config.enable_hard_pin is False + + with patch.dict(os.environ, {}, clear=False): + config.export_env() + assert os.environ["MOONCAKE_ENABLE_HARD_PIN"] == "0" + restored = MooncakeConfig.from_env() + assert restored.enable_hard_pin is False + + +# --------------------------------------------------------------------------- +# Tests 2-3: _verify_force_delete +# --------------------------------------------------------------------------- +class TestVerifyForceDelete: + def test_missing_batch_remove_method(self): + mock_raw = MagicMock(spec=[]) # no batch_remove attr + store = _make_base_store(mock_raw) + with pytest.raises(RuntimeError, match="batch_remove.*not found"): + store._verify_force_delete() + + def test_missing_force_param(self): + mock_raw = MagicMock() + mock_raw.batch_remove = MagicMock() + mock_raw.batch_remove.__doc__ = "batch_remove(keys) -> list[int]" + store = _make_base_store(mock_raw) + # Patch importlib.metadata.version to force fallback to docstring check + with patch.dict( + sys.modules, {"importlib.metadata": MagicMock(version=MagicMock(side_effect=Exception))} + ): + with pytest.raises(RuntimeError, match="batch_remove.*force.*not supported"): + store._verify_force_delete() + + def test_valid_batch_remove(self): + mock_raw = MagicMock() + mock_raw.batch_remove = MagicMock() + mock_raw.batch_remove.__doc__ = "batch_remove(keys, force=False) -> list[int]" + store = _make_base_store(mock_raw) + store._verify_force_delete() # should not raise + + +# --------------------------------------------------------------------------- +# Tests 4-9: remove_eagle3_tensors +# --------------------------------------------------------------------------- +class TestRemoveEagle3Tensors: + def test_success(self): + mock_raw = MagicMock() + mock_raw.batch_remove.return_value = [0, 0] + store = _make_eagle_store(mock_raw) + store.remove_eagle3_tensors("k1") + mock_raw.batch_remove.assert_called_once_with(["k1_hs", "k1_ids"], force=True) + + def test_not_found_is_success(self): + mock_raw = MagicMock() + mock_raw.batch_remove.return_value = [-704, 0] + store = _make_eagle_store(mock_raw) + store.remove_eagle3_tensors("k1") + assert mock_raw.batch_remove.call_count == 1 + + def test_retry_then_succeed(self): + mock_raw = MagicMock() + mock_raw.batch_remove.side_effect = [ + [-1, 0], # first: k1_hs fails + [0], # second: k1_hs succeeds + ] + store = _make_eagle_store(mock_raw) + store.remove_eagle3_tensors("k1") + assert mock_raw.batch_remove.call_count == 2 + + def test_exhaust_retries(self): + mock_raw = MagicMock() + mock_raw.batch_remove.return_value = [-1, -1] + store = _make_eagle_store(mock_raw) + # Should not raise + store.remove_eagle3_tensors("k1") + assert mock_raw.batch_remove.call_count == 3 + + def test_exception_is_retried(self): + mock_raw = MagicMock() + mock_raw.batch_remove.side_effect = [ + RuntimeError("connection lost"), + [0, 0], + ] + store = _make_eagle_store(mock_raw) + store.remove_eagle3_tensors("k1") + assert mock_raw.batch_remove.call_count == 2 + + def test_all_exceptions(self): + mock_raw = MagicMock() + mock_raw.batch_remove.side_effect = RuntimeError("down") + store = _make_eagle_store(mock_raw) + # Should not raise + store.remove_eagle3_tensors("k1") + assert mock_raw.batch_remove.call_count == 3 + + +# --------------------------------------------------------------------------- +# Tests 10-11: cleanup does not mask put error +# --------------------------------------------------------------------------- +class TestCleanupDoesNotMaskPutError: + def test_sync_put_cleanup_does_not_mask_put_error(self): + mock_raw = MagicMock() + mock_raw.batch_put_from.return_value = [-1, 0] + mock_raw.batch_remove.side_effect = RuntimeError("cleanup failed") + store = _make_eagle_store(mock_raw) + store._replicate_config = None + + with pytest.raises(RuntimeError, match="batch_put_from failed"): + store._do_sync_batch_put(["k_hs", "k_ids"], [100, 200], [64, 32]) + + def test_async_put_cleanup_does_not_mask_put_error(self): + mock_raw = MagicMock() + mock_raw.batch_put_from.return_value = [-1, 0] + mock_raw.batch_remove.side_effect = RuntimeError("cleanup failed") + mgr = AsyncPutManager(store=mock_raw, max_workers=1) + with pytest.raises(RuntimeError, match="async batch_put_from failed"): + mgr._do_put(["k_hs", "k_ids"], [100, 200], [64, 32]) + mgr.shutdown() + + +# --------------------------------------------------------------------------- +# Tests 12-13: _build_replicate_config +# --------------------------------------------------------------------------- +class TestBuildReplicateConfig: + def test_supported(self): + mock_config_instance = MagicMock() + mock_config_instance.with_hard_pin = False + mock_store_module = MagicMock() + mock_store_module.ReplicateConfig.return_value = mock_config_instance + + with patch.dict(sys.modules, {"mooncake.store": mock_store_module}): + mock_raw = MagicMock() + store = _make_base_store(mock_raw, enable_hard_pin=True) + store._build_replicate_config() + assert store._replicate_config is not None + assert store._replicate_config.with_hard_pin is True + + def test_unsupported(self): + mock_config_instance = MagicMock(spec=[]) # no with_hard_pin attr + mock_store_module = MagicMock() + mock_store_module.ReplicateConfig.return_value = mock_config_instance + + with patch.dict(sys.modules, {"mooncake.store": mock_store_module}): + mock_raw = MagicMock() + store = _make_base_store(mock_raw, enable_hard_pin=True) + store._build_replicate_config() + assert store._replicate_config is None + + +# --------------------------------------------------------------------------- +# Tests 14-15: replicate_config passed through put paths +# --------------------------------------------------------------------------- +class TestReplicateConfigPassthrough: + def test_sync_put_passes_replicate_config(self): + mock_raw = MagicMock() + mock_raw.batch_put_from.return_value = [0, 0] + store = _make_eagle_store(mock_raw) + mock_cfg = MagicMock() + store._replicate_config = mock_cfg + + store._do_sync_batch_put(["k_hs", "k_ids"], [100, 200], [64, 32]) + mock_raw.batch_put_from.assert_called_once_with( + ["k_hs", "k_ids"], [100, 200], [64, 32], config=mock_cfg + ) + + def test_sync_put_no_config_when_none(self): + mock_raw = MagicMock() + mock_raw.batch_put_from.return_value = [0, 0] + store = _make_eagle_store(mock_raw) + store._replicate_config = None + + store._do_sync_batch_put(["k_hs", "k_ids"], [100, 200], [64, 32]) + mock_raw.batch_put_from.assert_called_once_with(["k_hs", "k_ids"], [100, 200], [64, 32]) + + def test_async_put_passes_replicate_config(self): + mock_raw = MagicMock() + mock_raw.batch_put_from.return_value = [0, 0] + mock_cfg = MagicMock() + mgr = AsyncPutManager(store=mock_raw, max_workers=1, replicate_config=mock_cfg) + mgr._do_put(["k_hs", "k_ids"], [100, 200], [64, 32]) + mock_raw.batch_put_from.assert_called_once_with( + ["k_hs", "k_ids"], [100, 200], [64, 32], config=mock_cfg + ) + mgr.shutdown() diff --git a/tools/test_transfer_paths.sh b/tools/test_transfer_paths.sh new file mode 100755 index 00000000..830c59af --- /dev/null +++ b/tools/test_transfer_paths.sh @@ -0,0 +1,115 @@ +#!/bin/bash +# Test all Mooncake transfer path combinations for the force delete refactoring. +# +# Paths tested: +# 1. TCP + host buffer async (default path) +# 2. RDMA + host buffer async +# 3. RDMA + GPU Direct sync +# 4. TCP + GPU Direct sync (GDR over TCP, uncommon but valid) +# +# Each path runs 3 training steps to verify put/get/delete work end-to-end. + +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)" +ROOT_DIR="$(cd -- "$SCRIPT_DIR/.." && pwd)" +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3} +export TORCHINDUCTOR_CACHE_DIR="$ROOT_DIR/cache/compiled_kernels" +export TORCHSPEC_LOG_LEVEL=INFO + +CONFIG="$ROOT_DIR/configs/sglang_qwen3_8b.yaml" +STEPS=${1:-30} +PASS=0 +FAIL=0 +RESULTS=() + +run_test() { + local name="$1" + shift + local logfile="/tmp/ts_path_test_${name}.log" + + echo "" + echo "========================================" + echo "TEST: $name" + echo " Args: $*" + echo "========================================" + + # Stop any leftover Ray + ray stop --force 2>/dev/null || true + sleep 2 + + set +e + python3 -m torchspec.train_entry \ + --config "$CONFIG" \ + training.training_num_gpus_per_node=2 \ + inference.inference_num_gpus=2 \ + inference.inference_num_gpus_per_engine=2 \ + inference.inference_num_gpus_per_node=4 \ + inference.sglang.tp_size=2 \ + training.num_train_steps=$STEPS \ + "$@" \ + > "$logfile" 2>&1 + local rc=$? + set -e + + # Check for training completion in Ray worker logs + local step_count + step_count=$(grep -c "step.*${STEPS}/${STEPS}" /tmp/ray/session_latest/logs/worker*.err 2>/dev/null || echo 0) + local delete_errors + delete_errors=$(grep -c "force delete abandoned" /tmp/ray/session_latest/logs/worker*.err 2>/dev/null || echo 0) + local put_errors + put_errors=$(grep -c "batch_put_from failed" /tmp/ray/session_latest/logs/worker*.err 2>/dev/null || echo 0) + + if [ "$step_count" -ge 1 ] && [ "$delete_errors" -eq 0 ] && [ "$put_errors" -eq 0 ]; then + echo " RESULT: PASS (${STEPS} steps completed, 0 delete errors, 0 put errors)" + PASS=$((PASS + 1)) + RESULTS+=("PASS: $name") + else + echo " RESULT: FAIL (steps=$step_count, delete_errors=$delete_errors, put_errors=$put_errors, exit=$rc)" + echo " Log: $logfile" + FAIL=$((FAIL + 1)) + RESULTS+=("FAIL: $name (steps=$step_count, del_err=$delete_errors, put_err=$put_errors)") + fi +} + +echo "========================================" +echo "Mooncake Transfer Path Tests" +echo "========================================" + +# Test 1: TCP + host buffer async (default) +run_test "tcp_host_async" \ + mooncake.protocol=tcp \ + mooncake.enable_gpu_direct=false + +# Test 2: RDMA + host buffer async +run_test "rdma_host_async" \ + mooncake.protocol=rdma \ + mooncake.device_name=mlx5_0 \ + mooncake.enable_gpu_direct=false + +# Test 3: RDMA + GPU Direct +run_test "rdma_gpu_direct" \ + mooncake.protocol=rdma \ + mooncake.device_name=mlx5_0 \ + mooncake.enable_gpu_direct=true + +# Test 4: TCP + GPU Direct +run_test "tcp_gpu_direct" \ + mooncake.protocol=tcp \ + mooncake.enable_gpu_direct=true + +# Summary +echo "" +echo "========================================" +echo "SUMMARY" +echo "========================================" +for r in "${RESULTS[@]}"; do + echo " $r" +done +echo "" +echo "Total: $PASS passed, $FAIL failed" +echo "========================================" + +if [ "$FAIL" -gt 0 ]; then + exit 1 +fi diff --git a/torchspec/config/mooncake_config.py b/torchspec/config/mooncake_config.py index a8bd95c4..3b937838 100644 --- a/torchspec/config/mooncake_config.py +++ b/torchspec/config/mooncake_config.py @@ -57,7 +57,7 @@ class MooncakeConfig: gpu_buffer_size: str | int | None = None enable_gpu_direct: bool = False replica_num: int = 1 - enable_soft_pin: bool = False + enable_hard_pin: bool = False host_buffer_size: str | int | None = None get_batch_size: int = 1 max_seq_len: int = 8192 @@ -70,7 +70,7 @@ class MooncakeConfig: get_retry_wait_seconds: float = 0.5 get_retry_log_interval_seconds: float = 10.0 get_retry_max_wait_seconds: float = 60.0 - kv_lease_ttl_s: float = 5.0 + kv_lease_ttl_s: float = 5.0 # Mooncake master lease TTL only; not used for deletion timing def __post_init__(self): # Coerce size fields: accept str ("4GB") or int @@ -146,6 +146,7 @@ def from_flat_args(cls, args) -> "MooncakeConfig": args, "mooncake_get_batch_size", getattr(args, "per_dp_rank_batch_size", 1) ), "kv_lease_ttl_s": getattr(args, "mooncake_kv_lease_ttl_s", 5.0), + "enable_hard_pin": getattr(args, "mooncake_enable_hard_pin", False), "max_seq_len": getattr( args, "mooncake_max_seq_len", @@ -187,6 +188,7 @@ def export_env(self) -> None: self.get_retry_log_interval_seconds ) os.environ["MOONCAKE_GET_RETRY_MAX_WAIT_SECONDS"] = str(self.get_retry_max_wait_seconds) + os.environ["MOONCAKE_ENABLE_HARD_PIN"] = "1" if self.enable_hard_pin else "0" @classmethod def from_env(cls) -> "MooncakeConfig": @@ -237,6 +239,7 @@ def from_env(cls) -> "MooncakeConfig": get_retry_wait_seconds=get_retry_wait_seconds, get_retry_log_interval_seconds=get_retry_log_interval_seconds, get_retry_max_wait_seconds=get_retry_max_wait_seconds, + enable_hard_pin=os.getenv("MOONCAKE_ENABLE_HARD_PIN", "0") == "1", ) @classmethod diff --git a/torchspec/transfer/mooncake/buffers.py b/torchspec/transfer/mooncake/buffers.py index d75ea89e..085fd56f 100644 --- a/torchspec/transfer/mooncake/buffers.py +++ b/torchspec/transfer/mooncake/buffers.py @@ -122,8 +122,9 @@ class AsyncPutManager: because ``MooncakeDistributedStore`` is not thread-safe for concurrent puts. """ - def __init__(self, store: Any, max_workers: int = 1): + def __init__(self, store: Any, max_workers: int = 1, replicate_config: Any = None): self._store = store + self._replicate_config = replicate_config self._executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="async_put") self._in_flight: Dict[int, Future] = {} self._last_error: Optional[BaseException] = None @@ -193,17 +194,22 @@ def _do_put( torch.cuda.set_device(device_index) wait_event.synchronize() with self._put_lock: - results = self._store.batch_put_from(keys, buffer_ptrs, sizes) + if self._replicate_config is not None: + results = self._store.batch_put_from( + keys, buffer_ptrs, sizes, config=self._replicate_config + ) + else: + results = self._store.batch_put_from(keys, buffer_ptrs, sizes) failures = [(k, r) for k, r in zip(keys, results) if r != 0] if failures: - for k in keys: - try: - self._store.remove(k) - except Exception: - logger.debug( - "Failed to remove partial key %s after async batch_put_from failure.", - k, - ) + try: + self._store.batch_remove(keys, force=True) + except Exception: + logger.warning( + "Failed to cleanup keys after async batch_put_from failure: %s", + keys, + exc_info=True, + ) detail = ", ".join(f"{k} (code={r})" for k, r in failures) raise RuntimeError(f"async batch_put_from failed: {detail}") diff --git a/torchspec/transfer/mooncake/eagle_store.py b/torchspec/transfer/mooncake/eagle_store.py index caeef92b..3366ae3c 100644 --- a/torchspec/transfer/mooncake/eagle_store.py +++ b/torchspec/transfer/mooncake/eagle_store.py @@ -18,14 +18,12 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import atexit import ctypes import time from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import torch -from torchspec.transfer.mooncake.deferred_delete import DeferredDeleteManager from torchspec.transfer.mooncake.helpers import _format_bytes from torchspec.transfer.mooncake.store import MooncakeHiddenStateStore from torchspec.utils.logging import logger @@ -63,57 +61,11 @@ class EagleMooncakeStore(MooncakeHiddenStateStore): - {key}_ids: input_ids - {key}_lhs: last_hidden_states (if present) - Deletions are deferred to respect Mooncake's lease TTL (config.kv_lease_ttl_s). + Deletions use ``batch_remove(force=True)`` for immediate cleanup. """ TENSOR_SUFFIXES = ["_hs", "_tgt", "_ids", "_lhs"] - def __init__(self, config): - """Initialize Eagle3 Mooncake Store with deferred deletion.""" - super().__init__(config) - self._deferred_delete_manager: Optional[DeferredDeleteManager] = None - self._cleanup_registered = False - - def setup(self, device: torch.device = None) -> None: - """Initialize the Mooncake Store client and deferred delete manager.""" - super().setup(device) - - if self._deferred_delete_manager is None: - lease_ttl_s = self.config.kv_lease_ttl_s - # Initialize deferred delete manager after store is ready - self._deferred_delete_manager = DeferredDeleteManager( - store=self._store, - ttl_buffer_seconds=0.5, # Small buffer for safety - check_interval=1.0, # Check queue every second - max_queue_size=10000, # Max pending deletions - retry_interval=2.0, # Retry failed deletes after 2s - ttl_seconds=lease_ttl_s, # Mooncake lease TTL - ) - logger.debug("Deferred delete manager initialized") - - # Register cleanup on exit - if not self._cleanup_registered: - atexit.register(self._cleanup_deferred_deletes) - self._cleanup_registered = True - - def _cleanup_deferred_deletes(self): - """Cleanup deferred delete manager on exit.""" - if self._deferred_delete_manager is not None: - logger.info("Cleaning up deferred delete manager...") - stats = self._deferred_delete_manager.get_stats() - queue_size = self._deferred_delete_manager.get_queue_size() - if queue_size > 0: - logger.warning( - " Shutting down with %d pending deletions. " - "Some Mooncake objects may not be cleaned up.", - queue_size, - ) - self._deferred_delete_manager.stop() - logger.info( - "Deferred delete final stats: %s", - stats, - ) - def put( self, key: str, @@ -242,17 +194,22 @@ def _do_sync_batch_put( ) -> None: """Synchronous batch_put_from with error handling.""" total_bytes = sum(sizes) - results = self._store.batch_put_from(keys, buffer_ptrs, sizes) + if self._replicate_config is not None: + results = self._store.batch_put_from( + keys, buffer_ptrs, sizes, config=self._replicate_config + ) + else: + results = self._store.batch_put_from(keys, buffer_ptrs, sizes) failures = [(k, r) for k, r in zip(keys, results) if r != 0] if failures: - for k in keys: - try: - self._store.remove(k) - except Exception: - logger.debug( - "Failed to remove partial key %s after batch_put_from failure.", - k, - ) + try: + self._store.batch_remove(keys, force=True) + except Exception: + logger.warning( + "Failed to cleanup keys after batch_put_from failure: %s", + keys, + exc_info=True, + ) failure_details = ", ".join(f"{k} (code={r})" for k, r in failures) config_details = ( f"total_bytes={_format_bytes(total_bytes)}, " @@ -506,72 +463,51 @@ def remove_eagle3_tensors( has_last_hidden_states: bool = False, has_target: bool = False, ) -> None: - """ - Queue deferred removal of all tensors associated with an Eagle3 output. + """Force-delete all tensors associated with an Eagle3 output. - Deletions are queued and executed after Mooncake's lease TTL expires. - This prevents deletion failures due to active leases. + Uses ``batch_remove(force=True)`` to bypass lease TTL and delete + immediately after consumption. Retries up to 3 times on failure. + Never raises — deletion is best-effort to avoid breaking the + training fetch path. Args: key: Base key used when storing has_last_hidden_states: Whether last_hidden_states was stored has_target: Whether target (logits) was stored """ - keys = [f"{key}_hs", f"{key}_ids"] if has_target: keys.append(f"{key}_tgt") if has_last_hidden_states: keys.append(f"{key}_lhs") - logger.debug( - "Queueing deferred deletion for base_key=%s, num_keys=%d", - key, - len(keys), - ) - - # Queue deletion instead of deleting immediately - if self._deferred_delete_manager is None: - logger.error( - "Deferred delete manager not initialized! Cannot delete %s", - key, - ) - return - - success = self._deferred_delete_manager.enqueue_delete( - keys=keys, - base_key=key, - max_attempts=3, - ) - - if success: - logger.debug( - "Queued deferred deletion for base_key=%s", - key, - ) - else: - logger.error( - "Failed to queue deletion for %s (queue full)", - key, - ) - - def get_deferred_delete_stats(self) -> Dict[str, int]: - """Get statistics from the deferred delete manager. - - Returns: - Dict with keys: enqueued, attempted, succeeded, failed, retried, abandoned, queue_size - """ - if self._deferred_delete_manager is None: - return { - "enqueued": 0, - "attempted": 0, - "succeeded": 0, - "failed": 0, - "retried": 0, - "abandoned": 0, - "queue_size": 0, - } - - stats = self._deferred_delete_manager.get_stats() - stats["queue_size"] = self._deferred_delete_manager.get_queue_size() - return stats + for attempt in range(1, 4): + try: + results = self._store.batch_remove(keys, force=True) + except Exception: + if attempt < 3: + logger.warning( + "batch_remove raised for %s (attempt %d/3)", key, attempt, exc_info=True + ) + time.sleep(0.5) + else: + logger.error( + "Force delete abandoned for %s after 3 exceptions", key, exc_info=True + ) + continue + failed = [(k, r) for k, r in zip(keys, results) if r not in (None, 0, -704)] + if not failed: + logger.debug("Force-deleted %s (%d keys)", key, len(results)) + return + if attempt < 3: + time.sleep(0.5) + logger.warning( + "Retrying force delete for %s: %d keys failed (attempt %d/3)", + key, + len(failed), + attempt, + ) + else: + logger.error("Force delete abandoned for %s: %s", key, failed) + return + keys = [k for k, _ in failed] diff --git a/torchspec/transfer/mooncake/store.py b/torchspec/transfer/mooncake/store.py index a87606e3..37219d98 100644 --- a/torchspec/transfer/mooncake/store.py +++ b/torchspec/transfer/mooncake/store.py @@ -20,7 +20,7 @@ import threading from abc import ABC -from typing import Dict, Optional +from typing import Any, Dict, Optional import torch from mooncake.store import MooncakeDistributedStore @@ -56,6 +56,7 @@ def __init__(self, config: MooncakeConfig): self._gpu_send_buffer: Optional[GPUSendBuffer] = None self._gpu_direct_available = False self._copy_stream: Optional[torch.cuda.Stream] = None + self._replicate_config: Any = None def setup(self, device: torch.device | int | None = None) -> None: """Initialize the Mooncake Store client.""" @@ -94,6 +95,9 @@ def setup(self, device: torch.device | int | None = None) -> None: f"and metadata server is available at {self.config.metadata_server}" ) + self._verify_force_delete() + self._build_replicate_config() + pool_size = self.config.async_put_pool_size if pool_size > 0: self._host_buffer_pool = HostBufferPool( @@ -105,7 +109,9 @@ def setup(self, device: torch.device | int | None = None) -> None: for buf in self._host_buffer_pool._buffers: self._register_buffer(buf.ptr, buf.size) - self._async_put_manager = AsyncPutManager(store=self._store, max_workers=pool_size) + self._async_put_manager = AsyncPutManager( + store=self._store, max_workers=pool_size, replicate_config=self._replicate_config + ) logger.info("Async put manager created (pool_size=%d)", pool_size) if self.config.enable_gpu_direct and torch.cuda.is_available(): @@ -223,14 +229,9 @@ def warmup_rdma(self) -> None: buf = self._host_buffer_pool.get_buffer() size = 4096 self._store.batch_put_from([key], [buf.ptr], [size]) - self._store.remove(key) + self._store.batch_remove([key], force=True) logger.info("RDMA warmup complete") - def remove(self, key: str) -> None: - """Remove data from Mooncake Store.""" - self._store.remove(key) - logger.debug("Removed data with key: %s", key) - def exists(self, key: str) -> bool: """Check if a key exists in the store (metadata-only, no data download).""" try: @@ -239,6 +240,57 @@ def exists(self, key: str) -> bool: except Exception: return False + def _verify_force_delete(self) -> None: + """Fail-fast if Mooncake doesn't support batch_remove(force=True). + + Requires mooncake-transfer-engine >= 0.3.10.post1. + Primary check uses package version metadata; falls back to docstring + heuristic for non-pip installs. + """ + batch_remove = getattr(self._store, "batch_remove", None) + if batch_remove is None: + raise RuntimeError( + "Mooncake version too old: batch_remove() not found. " + "Requires mooncake-transfer-engine >= 0.3.10.post1." + ) + try: + from importlib.metadata import version + + from packaging.version import Version + + installed = Version(version("mooncake-transfer-engine")) + if installed >= Version("0.3.10.post1"): + return + except Exception: + pass + doc = getattr(batch_remove, "__doc__", "") or "" + if "force" not in doc: + raise RuntimeError( + "Mooncake version too old: batch_remove(force=True) not supported. " + "Requires mooncake-transfer-engine >= 0.3.10.post1." + ) + + def _build_replicate_config(self) -> None: + """Build ReplicateConfig for batch_put_from if hard_pin is enabled and supported.""" + self._replicate_config = None + if not self.config.enable_hard_pin: + return + try: + from mooncake.store import ReplicateConfig + + cfg = ReplicateConfig() + if hasattr(cfg, "with_hard_pin"): + cfg.with_hard_pin = True + self._replicate_config = cfg + logger.info("Hard pin enabled for batch_put_from") + else: + logger.warning( + "enable_hard_pin=True but ReplicateConfig lacks with_hard_pin attr " + "(needs unreleased Mooncake)" + ) + except ImportError: + logger.warning("enable_hard_pin=True but ReplicateConfig not importable") + def close(self) -> None: """Close the Mooncake Store client.""" if self._async_put_manager is not None: