From 313e3c0c78dcdb73313a5a8d3f5e4ef1c8a3473c Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Mon, 13 Apr 2026 22:26:17 +0000 Subject: [PATCH 1/3] Rewrite sglang integration test with tensor dumps and harden Ray worker env - Rewrite test_sglang_engine_integration.py with structured tests (short seqs, long seqs, text prompts) that dump all tensors to .pt files for cross-engine comparison against vllm. Auto-resolve aux layer IDs from model config. - Propagate TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC and TORCHINDUCTOR_FX_GRAPH_CACHE to Ray training workers to avoid NCCL watchdog kills and enable inductor FX graph caching. - Update torchspec env key allowlist: drop legacy NCCL_IB_* keys, add the two new torch/inductor env vars. --- tests/test_sglang_engine_integration.py | 404 ++++++++++++++++++++---- torchspec/ray/train_group.py | 8 + torchspec/utils/env.py | 5 +- 3 files changed, 345 insertions(+), 72 deletions(-) diff --git a/tests/test_sglang_engine_integration.py b/tests/test_sglang_engine_integration.py index c626e29a..94a6c7f6 100644 --- a/tests/test_sglang_engine_integration.py +++ b/tests/test_sglang_engine_integration.py @@ -1,104 +1,370 @@ -"""Standalone integration script that tests Mooncake hidden states collection behavior.""" +"""Integration test for sglang extract_hidden_states + MooncakeHiddenStatesConnector. +Uses the same engine setup as SglEngine: enable_aux_hidden_states with +aux_hidden_state_layer_ids, and enable_spec_training_mooncake for Mooncake +hidden states transfer. + +Dumps all tensors to local .pt files for cross-engine comparison (e.g. vs vllm). + +Tests: + 1. Short sequences via input_ids + 2. Longer sequences via input_ids + 3. Text prompts (defer tokenization path) + +Usage: + # Start mooncake master first: + # mooncake_master --port 50051 & + # etcd --listen-client-urls http://0.0.0.0:8090 --advertise-client-urls http://localhost:8090 & + # + python tests/test_sglang_engine_integration.py [--model MODEL] [--tp TP] [--dump-dir DIR] +""" + +import argparse import os +import socket +import sys +from pathlib import Path -import sglang as sgl -import torch -from transformers import AutoTokenizer +# Ensure the repo root is on sys.path so the editable install of torchspec +# isn't shadowed by /root/torchspec (a second repo clone in the home dir). +_REPO_ROOT = str(Path(__file__).resolve().parent.parent) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) -from torchspec.transfer.mooncake import EagleMooncakeStore, MooncakeConfig +import torch # noqa: E402 +from transformers import AutoConfig # noqa: E402 -os.environ["MOONCAKE_MASTER_HOST"] = "0.0.0.0" -os.environ["MOONCAKE_MASTER_PORT"] = "50051" -os.environ["MOONCAKE_METADATA_PORT"] = "8090" +# --------------------------------------------------------------------------- +# Mooncake env setup (must happen before any sglang import) +# --------------------------------------------------------------------------- +try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + LOCAL_IP = s.getsockname()[0] + s.close() +except Exception: + LOCAL_IP = "localhost" -if __name__ == "__main__": - model_path = "Qwen/Qwen3-8B" - tokenizer = AutoTokenizer.from_pretrained(model_path) - eos_token_id = tokenizer.eos_token_id +os.environ.setdefault("MOONCAKE_MASTER_HOST", LOCAL_IP) +os.environ.setdefault("MOONCAKE_MASTER_PORT", "51135") +os.environ.setdefault("MOONCAKE_METADATA_PORT", "8763") +os.environ.setdefault("MOONCAKE_LOCAL_HOSTNAME", LOCAL_IP) +os.environ.setdefault("MOONCAKE_MASTER_SERVER", f"{LOCAL_IP}:51135") - input_ids_list = [ - [1, 2345, 6789], - [100, 200, 300, 400], - [500, 600], - ] + +def get_aux_layer_ids(model_path: str) -> tuple[list[int], int, int]: + """Replicate SglEngine's aux layer resolution: default Eagle3 layers (no final layer). + + Unlike vllm, sglang captures last_hidden_states automatically from the + model's final layer output, so aux_hidden_state_layer_ids should NOT + include the final layer. + """ + cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + cfg = getattr(cfg, "text_config", cfg) + num_layers = cfg.num_hidden_layers + aux_ids = [1, num_layers // 2 - 1, num_layers - 4] + return aux_ids, cfg.hidden_size, num_layers + + +def create_engine(model_path: str, tp_size: int, aux_layer_ids: list[int]): + import sglang as sgl engine = sgl.Engine( model_path=model_path, + tp_size=tp_size, + mem_fraction_static=0.7, + trust_remote_code=True, disable_radix_cache=True, disable_cuda_graph=True, enable_return_hidden_states=True, enable_aux_hidden_states=True, - aux_hidden_state_layer_ids=[2, 4, 6], + aux_hidden_state_layer_ids=list(aux_layer_ids), enable_spec_training_mooncake=True, - log_level="info", - tp_size=4, + chunked_prefill_size=-1, + log_level="warning", + ) + return engine + + +def fetch_and_dump( + mooncake_store, + key: str, + seq_len: int, + hidden_dim: int, + last_hidden_dim: int, + dump_dir: Path, + label: str, +) -> dict[str, torch.Tensor]: + """Retrieve tensors from mooncake, verify shapes, save to disk.""" + shapes = { + "hidden_states": (seq_len, hidden_dim), + "input_ids": (seq_len,), + "last_hidden_states": (seq_len, last_hidden_dim), + } + dtypes = { + "hidden_states": torch.bfloat16, + "input_ids": torch.long, + "last_hidden_states": torch.bfloat16, + } + data = mooncake_store.get(key, shapes=shapes, dtypes=dtypes, device="cuda") + + tensors = { + "hidden_states": data.hidden_states.cpu(), + "input_ids": data.input_ids.cpu(), + "last_hidden_states": data.last_hidden_states.cpu(), + } + + assert tensors["hidden_states"].shape == (seq_len, hidden_dim), ( + f"hidden_states shape {tensors['hidden_states'].shape} != expected {(seq_len, hidden_dim)}" + ) + assert tensors["input_ids"].shape == (seq_len,) + assert tensors["last_hidden_states"].shape == (seq_len, last_hidden_dim) + + dump_path = dump_dir / f"sglang_{label}.pt" + torch.save(tensors, dump_path) + print(f" Saved: {dump_path}") + print( + f" hidden_states: {tensors['hidden_states'].shape}, dtype={tensors['hidden_states'].dtype}" + ) + print( + f" input_ids: {tensors['input_ids'].shape}, first_10={tensors['input_ids'][:10].tolist()}" + ) + print( + f" last_hidden_states: {tensors['last_hidden_states'].shape}, dtype={tensors['last_hidden_states'].dtype}" ) + hs = tensors["hidden_states"].float() + lhs = tensors["last_hidden_states"].float() + print(f" hidden_states norm={hs.norm():.4f}, mean={hs.mean():.6f}, std={hs.std():.6f}") + print( + f" last_hidden_states norm={lhs.norm():.4f}, mean={lhs.mean():.6f}, std={lhs.std():.6f}" + ) + + return tensors + + +def run_test_input_ids( + engine, + mooncake_store, + input_ids_list: list[list[int]], + data_ids: list[str], + hidden_dim: int, + last_hidden_dim: int, + dump_dir: Path, + test_name: str, +): + """Run test with pre-tokenized input_ids.""" + print(f"\n{'=' * 60}") + print(f"TEST: {test_name}") + print(f"{'=' * 60}") + results = engine.generate( input_ids=input_ids_list, - spec_training_data_id=["data_id_1", "data_id_2", "data_id_3"], - spec_training_prompt_length=[1, 2, 1], - spec_training_response_length=[5, 10, 8], - sampling_params={"max_new_tokens": 32}, + spec_training_data_id=data_ids, + sampling_params={"max_new_tokens": 1}, + return_hidden_states=True, + ) + + for i, result in enumerate(results): + meta = result["meta_info"] + store_keys = meta.get("spec_training_mooncake_store_keys", []) + seq_len = len(input_ids_list[i]) + + assert meta.get("hidden_states") is None, "hidden_states should be None when using mooncake" + assert len(store_keys) > 0, f"Request {data_ids[i]}: no mooncake store keys returned" + + key = store_keys[0] + print(f"\n Request {data_ids[i]}: seq_len={seq_len}, mooncake_key={key}") + + label = f"{test_name}_{data_ids[i]}" + fetch_and_dump( + mooncake_store, + key, + seq_len, + hidden_dim, + last_hidden_dim, + dump_dir, + label, + ) + + print(f"\n✓ {test_name} passed") + + +def run_test_text_prompts( + engine, + mooncake_store, + text_prompts: list[str], + data_ids: list[str], + hidden_dim: int, + last_hidden_dim: int, + dump_dir: Path, + test_name: str, +): + """Run test with text prompts (defer tokenization).""" + print(f"\n{'=' * 60}") + print(f"TEST: {test_name}") + print(f"{'=' * 60}") + + results = engine.generate( + prompt=text_prompts, + spec_training_data_id=data_ids, + sampling_params={"max_new_tokens": 1}, return_hidden_states=True, ) - print("=== Batch Results ===") - all_keys = [] - seq_lens = [] for i, result in enumerate(results): - output_ids = result["output_ids"] - hidden_states = result["meta_info"].get("hidden_states") - mooncake_keys = result["meta_info"].get("spec_training_mooncake_store_keys") + meta = result["meta_info"] + store_keys = meta.get("spec_training_mooncake_store_keys", []) + seq_len = meta.get("prompt_tokens") + + assert meta.get("hidden_states") is None, "hidden_states should be None when using mooncake" + assert len(store_keys) > 0, f"Request {data_ids[i]}: no mooncake store keys returned" + assert seq_len is not None, f"Request {data_ids[i]}: prompt_tokens missing from meta_info" + + key = store_keys[0] + print(f"\n Request {data_ids[i]}: seq_len={seq_len}, mooncake_key={key}") + + label = f"{test_name}_{data_ids[i]}" + fetch_and_dump( + mooncake_store, + key, + seq_len, + hidden_dim, + last_hidden_dim, + dump_dir, + label, + ) - print(f"\n--- Request {i} ---") - print(f"output_ids: {output_ids}") - print(f"num tokens generated: {len(output_ids)}") - print(f"spec_training_data_id: {result['meta_info'].get('spec_training_data_id')}") + print(f"\n✓ {test_name} passed") - print(f"\n Hidden states in meta_info: {hidden_states}") - assert hidden_states is None, "hidden_states should be None when using mooncake" - print(f"\n Mooncake store keys: {mooncake_keys}") - assert mooncake_keys and len(mooncake_keys) > 0, "mooncake_store_keys should not be empty" +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="Qwen/Qwen3-8B") + parser.add_argument("--tp", type=int, default=4) + parser.add_argument("--dump-dir", default="./tensor_dumps") + parser.add_argument( + "--aux-layers", + type=int, + nargs="*", + default=None, + help="Override aux layer IDs (training layers only; final layer is automatic)", + ) + args = parser.parse_args() + + dump_dir = Path(args.dump_dir) + dump_dir.mkdir(parents=True, exist_ok=True) + + auto_aux_ids, hidden_size, num_layers = get_aux_layer_ids(args.model) + if args.aux_layers is not None: + aux_layer_ids = list(args.aux_layers) + else: + aux_layer_ids = auto_aux_ids + + num_training_layers = len(aux_layer_ids) + hidden_dim = num_training_layers * hidden_size + last_hidden_dim = hidden_size + + print(f"Model: {args.model}") + print(f"TP size: {args.tp}") + print(f"Aux layer IDs: {aux_layer_ids} (sglang captures last_hidden_states automatically)") + print(f" training layers: {aux_layer_ids} -> hidden_dim={hidden_dim}") + print(f" last_hidden_states from final model layer -> last_hidden_dim={last_hidden_dim}") + print(f"Hidden size: {hidden_size}") + print(f"Num layers: {num_layers}") + print(f"Dump dir: {dump_dir}") - all_keys.extend(mooncake_keys) - seq_lens.append(len(input_ids_list[i])) + meta = { + "engine": "sglang", + "model": args.model, + "aux_layer_ids": aux_layer_ids, + "num_training_layers": num_training_layers, + "hidden_size": hidden_size, + "hidden_dim": hidden_dim, + "last_hidden_dim": last_hidden_dim, + } + torch.save(meta, dump_dir / "sglang_meta.pt") - print(f"\n All meta_info keys: {list(result['meta_info'].keys())}") + # Import mooncake before creating sglang engine — sglang's subprocess + # forking can interfere with the import chain through torchspec.config.__init__ + from torchspec.config.mooncake_config import MooncakeConfig + from torchspec.transfer.mooncake.eagle_store import EagleMooncakeStore + + engine = create_engine(args.model, args.tp, aux_layer_ids) - print("\n=== Fetching data from Mooncake Store ===") mooncake_config = MooncakeConfig.from_env() mooncake_store = EagleMooncakeStore(mooncake_config) mooncake_store.setup(device="cuda") - hidden_dim = 12288 - last_hidden_dim = 4096 - - for i, key in enumerate(all_keys): - seq_len = seq_lens[i] - shapes = { - "hidden_states": (seq_len, hidden_dim), - "loss_mask": (seq_len,), - "input_ids": (seq_len,), - "last_hidden_states": (seq_len, last_hidden_dim), - } - dtypes = { - "hidden_states": torch.bfloat16, - "loss_mask": torch.long, - "input_ids": torch.long, - "last_hidden_states": torch.bfloat16, - } - - data = mooncake_store.get(key, shapes=shapes, dtypes=dtypes, device="cuda") - print(f"\n Key: {key}") - print( - f" hidden_states: shape={data.hidden_states.shape}, dtype={data.hidden_states.dtype}" - ) - print(f" loss_mask: {data.loss_mask.tolist()}") - print(f" input_ids: {data.input_ids.tolist()}") - print(f" last_hidden_states: shape={data.last_hidden_states.shape}") + # ── Test 1: Short sequences (raw token IDs) ────────────────────────── + input_ids_list = [ + [1, 2345, 6789], + [100, 200, 300, 400], + [500, 600], + ] + data_ids = ["short_0", "short_1", "short_2"] + + run_test_input_ids( + engine, + mooncake_store, + input_ids_list, + data_ids, + hidden_dim, + last_hidden_dim, + dump_dir, + "short_seqs", + ) + + # ── Test 2: Longer sequences (raw token IDs) ───────────────────────── + long_input_ids = [ + list(range(1, 101)), + list(range(200, 351)), + list(range(400, 465)), + ] + long_data_ids = ["long_0", "long_1", "long_2"] + + run_test_input_ids( + engine, + mooncake_store, + long_input_ids, + long_data_ids, + hidden_dim, + last_hidden_dim, + dump_dir, + "long_seqs", + ) + + # ── Test 3: Text prompts (defer tokenization) ──────────────────────── + text_prompts = [ + "Hello, world!", + "The quick brown fox jumps over the lazy dog.", + "Once upon a time in a land far away, there lived a brave knight.", + ] + prompt_data_ids = ["prompt_0", "prompt_1", "prompt_2"] + + run_test_text_prompts( + engine, + mooncake_store, + text_prompts, + prompt_data_ids, + hidden_dim, + last_hidden_dim, + dump_dir, + "text_prompts", + ) + + # ── Summary ────────────────────────────────────────────────────────── + print(f"\n{'=' * 60}") + print("All tests passed!") + print(f"Tensor dumps saved to: {dump_dir}/") + print(f"{'=' * 60}") + + pt_files = sorted(dump_dir.glob("sglang_*.pt")) + for f in pt_files: + print(f" {f.name}") - print("\n✓ Test completed - hidden states sent to mooncake and retrieved successfully") engine.shutdown() + + +if __name__ == "__main__": + main() diff --git a/torchspec/ray/train_group.py b/torchspec/ray/train_group.py index 09474d6f..86379eb0 100644 --- a/torchspec/ray/train_group.py +++ b/torchspec/ray/train_group.py @@ -89,6 +89,14 @@ def _allocate_gpus_for_training(self, pg, num_gpus_per_actor): } if "TORCHINDUCTOR_CACHE_DIR" in os.environ: env_vars["TORCHINDUCTOR_CACHE_DIR"] = os.environ["TORCHINDUCTOR_CACHE_DIR"] + env_vars.setdefault( + "TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", + os.environ.get("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "1800"), + ) + env_vars.setdefault( + "TORCHINDUCTOR_FX_GRAPH_CACHE", + os.environ.get("TORCHINDUCTOR_FX_GRAPH_CACHE", "1"), + ) TrainRayActor = ray.remote(num_gpus=1, runtime_env={"env_vars": env_vars})( self._training_class diff --git a/torchspec/utils/env.py b/torchspec/utils/env.py index 074da6fb..b3a84191 100644 --- a/torchspec/utils/env.py +++ b/torchspec/utils/env.py @@ -12,13 +12,12 @@ "MC_LOG_LEVEL", "MODELOPT_MAX_TOKENS_PER_EXPERT", "NCCL_DEBUG", - "NCCL_IB_DISABLE", - "NCCL_IB_HCA", - "NCCL_NET_GDR_LEVEL", "NCCL_SOCKET_IFNAME", "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", "SGLANG_DISABLE_CUDNN_CHECK", "SGLANG_VLM_CACHE_SIZE_MB", + "TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", + "TORCHINDUCTOR_FX_GRAPH_CACHE", "TORCHSPEC_LOG_DIR", "TORCHSPEC_LOG_LEVEL", "TP_SOCKET_IFNAME", From 0d00601fe7bb6137991139e51c855e824f37c9f0 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Mon, 13 Apr 2026 22:34:15 +0000 Subject: [PATCH 2/3] Fix MOONCAKE_MASTER_SERVER built from hardcoded port instead of configured value MOONCAKE_MASTER_SERVER was always set to {LOCAL_IP}:51135 regardless of the resolved MOONCAKE_MASTER_PORT value. Build it from the actual port so overriding MOONCAKE_MASTER_PORT (e.g. 50051) propagates correctly. --- tests/test_sglang_engine_integration.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_sglang_engine_integration.py b/tests/test_sglang_engine_integration.py index 94a6c7f6..cad67203 100644 --- a/tests/test_sglang_engine_integration.py +++ b/tests/test_sglang_engine_integration.py @@ -49,7 +49,8 @@ os.environ.setdefault("MOONCAKE_MASTER_PORT", "51135") os.environ.setdefault("MOONCAKE_METADATA_PORT", "8763") os.environ.setdefault("MOONCAKE_LOCAL_HOSTNAME", LOCAL_IP) -os.environ.setdefault("MOONCAKE_MASTER_SERVER", f"{LOCAL_IP}:51135") +_MC_PORT = os.environ["MOONCAKE_MASTER_PORT"] +os.environ.setdefault("MOONCAKE_MASTER_SERVER", f"{LOCAL_IP}:{_MC_PORT}") def get_aux_layer_ids(model_path: str) -> tuple[list[int], int, int]: From 0a6f579c4132c0eb224d265f7d12c35d80e4bcb2 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Mon, 13 Apr 2026 22:41:56 +0000 Subject: [PATCH 3/3] Use mmap=True when loading eval cache to reduce peak memory Avoids doubling memory usage when the eval cache checkpoint is large, since mmap lazily pages in tensors instead of reading the full file. --- torchspec/training/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchspec/training/trainer.py b/torchspec/training/trainer.py index f1208d9b..ec944c40 100644 --- a/torchspec/training/trainer.py +++ b/torchspec/training/trainer.py @@ -273,7 +273,7 @@ def load_eval_cache(self, cache_dir: str) -> int: if not os.path.exists(path): return 0 try: - self._eval_cache = torch.load(path, weights_only=False) + self._eval_cache = torch.load(path, weights_only=False, mmap=True) except Exception as e: logger.warning(f"[Rank {self.dp_rank}] Corrupt eval cache at {path}, ignoring: {e}") return 0