Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions configs/dflash_qwen3_8b_repro.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# DFlash reproduction config for Qwen3-8B on 4x B200
#
# GPU allocation (4x B200):
# - 1 GPU for inference (SGLang tp=1)
# - 3 GPUs for training (REPLICATE)
# - global_batch = micro_batch_size × dp_size × accum = 1 × 3 × 2 = 6

model:
target_model_path: Qwen/Qwen3-8B
trust_remote_code: true
draft_model_config: torchspec/config/dflash_draft_config.json

dataset:
train_data_path: data/perfectblend_800k.jsonl
eval_data_path: null
eval_interval: 500
chat_template: qwen
prompt_key: conversations
min_loss_tokens: 32

training:
attention_backend: flex_attention
micro_batch_size: 1
draft_accumulation_steps: 2
learning_rate: 6e-4
min_lr: 6e-5
weight_decay: 0.01
max_concurrent_batches: 1
max_grad_norm: 1.0
max_seq_length: 2048
num_epochs: 3
seed: 42
training_num_gpus_per_node: 3
training_num_nodes: 1
Comment on lines +33 to +34
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Fix GPU allocation values in repro config

This repro config claims an 8x B200 layout with 6 training GPUs and 2 inference GPUs (global_batch = 12), but the actual values configure only 3 training GPUs (training_num_gpus_per_node: 3, training_num_nodes: 1) and 1 inference GPU (inference_num_gpus: 1, tp_size: 1), so the run uses 4 GPUs total and a global batch of 6. That mismatch will underutilize hardware and makes the published reproduction numbers non-reproducible with this file as written.

Useful? React with 👍 / 👎.

ttt_length: 7
fsdp_strategy: REPLICATE
fsdp_reduce_dtype: bfloat16
prefetch_depth: 8
save_interval: 2000
save_per_epoch: true
max_checkpoints: 3
warmup_ratio: 0.04

dflash_block_size: 16
dflash_num_anchors: 512
dflash_loss_decay_gamma: 7.0
dflash_num_target_layers: 5

inference:
inference_engine_type: sgl
store_last_hidden_states: false
inference_num_gpus: 1
inference_num_gpus_per_engine: 1
inference_num_gpus_per_node: 4
max_sample_pool_size: 128
inference_buffer_threshold: 64
inference_batch_size: 8
sglang:
tp_size: 1
mem_fraction_static: 0.7

mooncake:
master_server_address: null
metadata_server: null
protocol: tcp
global_segment_size: 16GB
local_buffer_size: 4GB

output_dir: ./outputs/dflash_repro
cache_dir: ./cache/dflash_repro
9 changes: 5 additions & 4 deletions configs/sglang_qwen3_8b_dflash.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
# - weight_decay=0.01: AdamW regularization for generalization
# - min_lr=6e-5: prevent LR death in later epochs (10% of peak)
#
# GPU allocation (4x H100):
# - 1 GPU for inference (SGLang engine, tp_size=1)
# - 3 GPUs for training (FSDP FULL_SHARD)
# - global_batch = 1 × 4 × 3 = 12
# GPU allocation (8x GPU):
# - 4 GPUs for inference (SGLang engine, tp_size=1, duplicate mode)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leave original config, but use correct comments

# - 4 GPUs for training (FSDP FULL_SHARD)
# - global_batch = micro_batch_size × dp_size × accum = 1 × 4 × 2 = 8
#
# Usage:
# python -m torchspec.train_entry --config configs/sglang_qwen3_8b_dflash.yaml \
Expand Down Expand Up @@ -58,6 +58,7 @@ training:

inference:
inference_engine_type: sgl
store_last_hidden_states: false
inference_num_gpus: 4
inference_num_gpus_per_engine: 1
inference_num_gpus_per_node: 4
Expand Down
30 changes: 16 additions & 14 deletions patches/sglang/v0.5.10.post1/sglang.patch
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@ torchspec sglang patch (base: 94f03a39db)
python/sglang/srt/entrypoints/http_server.py | 17 +++
python/sglang/srt/layers/logits_processor.py | 54 ++++++++
python/sglang/srt/managers/detokenizer_manager.py | 3 +
python/sglang/srt/managers/io_struct.py | 48 ++++++++
python/sglang/srt/managers/io_struct.py | 48 +++++++
python/sglang/srt/managers/schedule_batch.py | 54 +++++++-
python/sglang/srt/managers/scheduler.py | 57 ++++++++-
.../managers/scheduler_output_processor_mixin.py | 137 ++++++++++++++++++---
.../managers/scheduler_output_processor_mixin.py | 138 ++++++++++++++++++---
python/sglang/srt/managers/tokenizer_manager.py | 16 +++
python/sglang/srt/managers/utils.py | 5 +-
.../srt/model_executor/forward_batch_info.py | 4 +
python/sglang/srt/model_executor/model_runner.py | 16 +++
python/sglang/srt/models/qwen3_next.py | 8 ++
python/sglang/srt/models/qwen3_next_mtp.py | 3 +
python/sglang/srt/server_args.py | 23 ++++
python/sglang/srt/server_args.py | 24 ++++
.../sglang/srt/speculative/spec_training_info.py | 50 ++++++++
16 files changed, 481 insertions(+), 22 deletions(-)
16 files changed, 483 insertions(+), 22 deletions(-)

diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py
index d864e4aba..b74514924 100644
Expand Down Expand Up @@ -544,7 +544,7 @@ index 67af2d0de..36d6e539c 100644
# Release the closure and large GPU tensors that are no longer needed.
# The delay_sample_func closure captures forward_batch (which holds
diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py
index 496cd9665..25ccb527b 100644
index 496cd9665..45ceb0e20 100644
--- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py
+++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py
@@ -218,21 +218,22 @@ class SchedulerOutputProcessorMixin:
Expand Down Expand Up @@ -584,7 +584,7 @@ index 496cd9665..25ccb527b 100644

if req.grammar is not None:
# FIXME: this try-except block is for handling unexpected xgrammar issue.
@@ -893,6 +894,74 @@ class SchedulerOutputProcessorMixin:
@@ -893,6 +894,75 @@ class SchedulerOutputProcessorMixin:
if req.input_token_ids_logprobs_idx is None:
req.input_token_ids_logprobs_idx = []

Expand Down Expand Up @@ -638,7 +638,8 @@ index 496cd9665..25ccb527b 100644
+ )
+
+ last_hidden_states = None
+ if logits_output.last_hidden_states is not None:
+ store_lhs = getattr(self.server_args, "spec_training_store_last_hidden_states", True)
+ if store_lhs and logits_output.last_hidden_states is not None:
+ last_hidden_states = logits_output.last_hidden_states[
+ hidden_state_offset : hidden_state_offset + seq_len
+ ]
Expand All @@ -659,7 +660,7 @@ index 496cd9665..25ccb527b 100644
def stream_output(
self: Scheduler,
reqs: List[Req],
@@ -954,6 +1023,18 @@ class SchedulerOutputProcessorMixin:
@@ -954,6 +1024,18 @@ class SchedulerOutputProcessorMixin:
routed_experts = None
customized_info = {}

Expand All @@ -678,7 +679,7 @@ index 496cd9665..25ccb527b 100644
time_stats = []

if return_logprob:
@@ -1058,6 +1139,13 @@ class SchedulerOutputProcessorMixin:
@@ -1058,6 +1140,13 @@ class SchedulerOutputProcessorMixin:
spec_accepted_tokens.append(req.spec_accepted_tokens)
spec_acceptance_histogram.append(req.spec_acceptance_histogram)

Expand All @@ -692,7 +693,7 @@ index 496cd9665..25ccb527b 100644
if return_logprob:
if (
req.return_logprob
@@ -1128,9 +1216,15 @@ class SchedulerOutputProcessorMixin:
@@ -1128,9 +1217,15 @@ class SchedulerOutputProcessorMixin:
output_token_ids_logprobs_idx.append([])

if req.return_hidden_states:
Expand All @@ -711,7 +712,7 @@ index 496cd9665..25ccb527b 100644
if req.return_routed_experts:
if routed_experts is None:
routed_experts = []
@@ -1197,6 +1291,15 @@ class SchedulerOutputProcessorMixin:
@@ -1197,6 +1292,15 @@ class SchedulerOutputProcessorMixin:
retraction_counts=retraction_counts,
load=load,
dp_ranks=dp_ranks,
Expand Down Expand Up @@ -895,22 +896,23 @@ index b2bdbbbe8..833c055ab 100644
def forward(
self,
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index d91ced805..b9bb901f1 100644
index d91ced805..2708630cc 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -514,6 +514,11 @@ class ServerArgs:
@@ -514,6 +514,12 @@ class ServerArgs:
speculative_ngram_capacity: int = 10 * 1000 * 1000
enable_multi_layer_eagle: bool = False

+ # Spec training (for speculative decoding model training)
+ enable_spec_training_mooncake: bool = False
+ enable_aux_hidden_states: bool = False
+ aux_hidden_state_layer_ids: Optional[List[int]] = None
+ spec_training_store_last_hidden_states: bool = True
+
# Expert parallelism
ep_size: int = 1
moe_a2a_backend: Literal[
@@ -4896,6 +4901,24 @@ class ServerArgs:
@@ -4896,6 +4902,24 @@ class ServerArgs:
help="Enable multi-layer Eagle speculative decoding.",
)

Expand Down
56 changes: 53 additions & 3 deletions tests/test_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def test_loss_decreases_over_steps(self):

# Loss should generally decrease (allow some noise, check last < first)
self.assertLess(losses[-1], losses[0], "Loss did not decrease over 10 training steps")
self.assertTrue(all(math.isfinite(loss) for loss in losses), "Non-finite loss encountered")
self.assertTrue(all(math.isfinite(v) for v in losses), "Non-finite loss encountered")

def test_gradient_accumulation(self):
"""Two half-LR steps with accumulated gradients should produce finite grads."""
Expand Down Expand Up @@ -914,7 +914,7 @@ def test_longer_sequence(self):
model, *data = self._make_model_and_data(seq_len=128, num_anchors=8)
losses, _ = self._train_steps(model, *data, steps=20)
self.assertLess(losses[-1], losses[0])
self.assertTrue(all(math.isfinite(loss) for loss in losses))
self.assertTrue(all(math.isfinite(v) for v in losses))

def test_large_block_size(self):
"""Block size 8 should still converge."""
Expand Down Expand Up @@ -986,7 +986,7 @@ def test_loss_mask_with_padding(self):
lm_w,
steps=15,
)
self.assertTrue(all(math.isfinite(loss) for loss in losses))
self.assertTrue(all(math.isfinite(v) for v in losses))
self.assertLess(losses[-1], losses[0])


Expand Down Expand Up @@ -1133,5 +1133,55 @@ def test_dflash_draft_config_json_loads(self):
self.assertEqual(config.target_num_hidden_layers, 36)


class TestDFlashHotfixes(unittest.TestCase):
"""Tests for incremental correctness fixes (tensor schema, collator, validation)."""

def test_collator_accepts_dflash_batch(self):
"""DataCollator should accept hidden_states without target/last_hidden_states."""
from torchspec.data.utils import DataCollatorWithPadding

collator = DataCollatorWithPadding()
samples = []
for seq_len in [32, 48]:
samples.append(
{
"input_ids": torch.randint(0, 1000, (1, seq_len)),
"hidden_states": torch.randn(1, seq_len, 384),
"loss_mask": torch.ones(1, seq_len),
}
)

# Should not raise (DFlash does not provide target/last_hidden_states)
batch = collator(samples)
self.assertIsNotNone(batch["hidden_states"])
self.assertIsNone(batch["target"])
self.assertIsNone(batch["last_hidden_states"])

def test_min_loss_tokens_validation(self):
"""min_loss_tokens < 2 * block_size should be caught."""
from argparse import Namespace

# Simulate the validation logic from train_entry
args = Namespace(
dflash_block_size=16,
min_loss_tokens=10, # < 2 * 16 = 32
)
block_size = getattr(args, "dflash_block_size", 16)
min_loss = getattr(args, "min_loss_tokens", 0)

with self.assertRaises(ValueError):
if min_loss < 2 * block_size:
raise ValueError(
f"DFlash requires dataset.min_loss_tokens >= 2 * training.dflash_block_size "
f"({min_loss} < {2 * block_size})."
)

def test_min_loss_tokens_validation_passes(self):
"""min_loss_tokens >= 2 * block_size should pass."""
block_size = 16
min_loss = 32 # == 2 * 16
self.assertGreaterEqual(min_loss, 2 * block_size)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions torchspec/config/inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class InferenceConfig:
inference_num_gpus_per_node: int = 8
last_hidden_states_prenorm: Optional[bool] = None
max_sample_pool_size: int = 0
store_last_hidden_states: bool = True
sglang: SGLangConfig = field(default_factory=SGLangConfig)
vllm: VllmConfig = field(default_factory=VllmConfig)

Expand Down
23 changes: 13 additions & 10 deletions torchspec/controller/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def _cleanup_old_checkpoints(checkpoint_dir: str | None, max_checkpoints: int) -
to_delete = iter_dirs[: len(iter_dirs) - max_checkpoints]
for old_dir in to_delete:
logger.info(f"Removing old checkpoint: {old_dir}")
shutil.rmtree(old_dir, ignore_errors=True)
try:
shutil.rmtree(old_dir)
except OSError as e:
logger.warning(f"Failed to remove old checkpoint {old_dir}: {e}")


def _safe_training_cleanup(
Expand Down Expand Up @@ -373,16 +376,16 @@ def training_loop(
progress.update(1)

if _is_save_interval_step(completed_steps, args.save_interval):
max_ckpts = getattr(args, "max_checkpoints", 0)
if max_ckpts > 0:
_cleanup_old_checkpoints(args.checkpoint_dir, max_ckpts)
eval_metrics = run_eval(completed_steps, train_group, eval_enabled)
logger.info(f"Saving checkpoint at step {completed_steps}...")
train_group.save_model(completed_steps)
last_saved_step = completed_steps
best_eval_score = update_checkpoint_eval_meta(
args.checkpoint_dir, completed_steps, eval_metrics, best_eval_score
)
max_ckpts = getattr(args, "max_checkpoints", 0)
if max_ckpts > 0:
_cleanup_old_checkpoints(args.checkpoint_dir, max_ckpts)

_maybe_sync_draft_weights(args, completed_steps, train_group, inference_engines)

Expand All @@ -398,9 +401,6 @@ def training_loop(
and args.checkpoint_dir
and last_saved_step != completed_steps
):
max_ckpts = getattr(args, "max_checkpoints", 0)
if max_ckpts > 0:
_cleanup_old_checkpoints(args.checkpoint_dir, max_ckpts)
eval_metrics = run_eval(completed_steps, train_group, eval_enabled)
logger.info(
f"Saving checkpoint at end of epoch {current_epoch} "
Expand All @@ -411,6 +411,9 @@ def training_loop(
best_eval_score = update_checkpoint_eval_meta(
args.checkpoint_dir, completed_steps, eval_metrics, best_eval_score
)
max_ckpts = getattr(args, "max_checkpoints", 0)
if max_ckpts > 0:
_cleanup_old_checkpoints(args.checkpoint_dir, max_ckpts)

if completed_steps < num_steps:
current_epoch += 1
Expand All @@ -429,15 +432,15 @@ def training_loop(

# Always save a final checkpoint unless saved.
if args.checkpoint_dir and last_saved_step != completed_steps:
max_ckpts = getattr(args, "max_checkpoints", 0)
if max_ckpts > 0:
_cleanup_old_checkpoints(args.checkpoint_dir, max_ckpts)
eval_metrics = run_eval(completed_steps, train_group, eval_enabled)
logger.info(f"Saving final checkpoint at step {completed_steps}...")
train_group.save_model(completed_steps, force_sync=True)
best_eval_score = update_checkpoint_eval_meta(
args.checkpoint_dir, completed_steps, eval_metrics, best_eval_score
)
max_ckpts = getattr(args, "max_checkpoints", 0)
if max_ckpts > 0:
_cleanup_old_checkpoints(args.checkpoint_dir, max_ckpts)

final_status = ray.get(controller.get_full_status.remote())
logger.info(
Expand Down
13 changes: 9 additions & 4 deletions torchspec/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
has_target = all(item.get("target") is not None for item in features)
has_last_hs = all(item.get("last_hidden_states") is not None for item in features)
if not has_target and not has_last_hs:
raise ValueError(
"Either 'target' or 'last_hidden_states' is required when 'hidden_states' is provided"
)
pass
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep validation for hidden-states batches in collator

Replacing the guard with pass allows batches that have hidden_states but neither target nor last_hidden_states to flow through all modes, and Eagle3 then fails later in _forward when it pads batch["last_hidden_states"] (which is None). This turns an immediate, actionable data-contract error into a delayed runtime crash, making malformed inference outputs harder to detect and debug during training.

Useful? React with 👍 / 👎.

if has_target:
batch["target"] = torch.cat(
[self.paddingtensor(item["target"], max_length) for item in features]
Expand Down Expand Up @@ -459,7 +457,14 @@ def load_hf_dataset(data_path: str):
if drop_cols:
ds = ds.remove_columns(drop_cols)
return ds
except Exception:
except (ValueError, TypeError, ArithmeticError, KeyError) as e:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Catch Arrow/datasets errors when falling back to JSON loader

The new exception filter only catches built-in types, so HF schema-inference failures from load_dataset (commonly pyarrow/datasets exceptions for mixed-type columns) now bypass the fallback and abort loading instead of using _load_hub_json_files. This regresses the mixed-schema hub-dataset path this function is explicitly trying to support, and can break training runs that previously recovered by switching to raw JSON streaming.

Useful? React with 👍 / 👎.

# Schema inference failures (e.g., mixed-type columns in Arrow/Parquet).
# Fall back to manual JSON download.
import logging

logging.getLogger(__name__).info(
f"load_dataset failed for '{data_path}' ({e}), falling back to JSON download"
)
return IterableDataset.from_generator(
_load_hub_json_files, gen_kwargs={"data_path": data_path}
)
Loading
Loading