diff --git a/configs/dflash_qwen3_8b_repro.yaml b/configs/dflash_qwen3_8b_repro.yaml new file mode 100644 index 0000000..5a06c93 --- /dev/null +++ b/configs/dflash_qwen3_8b_repro.yaml @@ -0,0 +1,70 @@ +# DFlash reproduction config for Qwen3-8B on 4x B200 +# +# GPU allocation (4x B200): +# - 1 GPU for inference (SGLang tp=1) +# - 3 GPUs for training (REPLICATE) +# - global_batch = micro_batch_size × dp_size × accum = 1 × 3 × 2 = 6 + +model: + target_model_path: Qwen/Qwen3-8B + trust_remote_code: true + draft_model_config: torchspec/config/dflash_draft_config.json + +dataset: + train_data_path: data/perfectblend_800k.jsonl + eval_data_path: null + eval_interval: 500 + chat_template: qwen + prompt_key: conversations + min_loss_tokens: 32 + +training: + attention_backend: flex_attention + micro_batch_size: 1 + draft_accumulation_steps: 2 + learning_rate: 6e-4 + min_lr: 6e-5 + weight_decay: 0.01 + max_concurrent_batches: 1 + max_grad_norm: 1.0 + max_seq_length: 2048 + num_epochs: 3 + seed: 42 + training_num_gpus_per_node: 3 + training_num_nodes: 1 + ttt_length: 7 + fsdp_strategy: REPLICATE + fsdp_reduce_dtype: bfloat16 + prefetch_depth: 8 + save_interval: 2000 + save_per_epoch: true + max_checkpoints: 3 + warmup_ratio: 0.04 + + dflash_block_size: 16 + dflash_num_anchors: 512 + dflash_loss_decay_gamma: 7.0 + dflash_num_target_layers: 5 + +inference: + inference_engine_type: sgl + store_last_hidden_states: false + inference_num_gpus: 1 + inference_num_gpus_per_engine: 1 + inference_num_gpus_per_node: 4 + max_sample_pool_size: 128 + inference_buffer_threshold: 64 + inference_batch_size: 8 + sglang: + tp_size: 1 + mem_fraction_static: 0.7 + +mooncake: + master_server_address: null + metadata_server: null + protocol: tcp + global_segment_size: 16GB + local_buffer_size: 4GB + +output_dir: ./outputs/dflash_repro +cache_dir: ./cache/dflash_repro diff --git a/configs/sglang_qwen3_8b_dflash.yaml b/configs/sglang_qwen3_8b_dflash.yaml index bc97676..fe48ec6 100644 --- a/configs/sglang_qwen3_8b_dflash.yaml +++ b/configs/sglang_qwen3_8b_dflash.yaml @@ -5,10 +5,10 @@ # - weight_decay=0.01: AdamW regularization for generalization # - min_lr=6e-5: prevent LR death in later epochs (10% of peak) # -# GPU allocation (4x H100): -# - 1 GPU for inference (SGLang engine, tp_size=1) -# - 3 GPUs for training (FSDP FULL_SHARD) -# - global_batch = 1 × 4 × 3 = 12 +# GPU allocation (8x GPU): +# - 4 GPUs for inference (SGLang engine, tp_size=1, duplicate mode) +# - 4 GPUs for training (FSDP FULL_SHARD) +# - global_batch = micro_batch_size × dp_size × accum = 1 × 4 × 2 = 8 # # Usage: # python -m torchspec.train_entry --config configs/sglang_qwen3_8b_dflash.yaml \ @@ -58,6 +58,7 @@ training: inference: inference_engine_type: sgl + store_last_hidden_states: false inference_num_gpus: 4 inference_num_gpus_per_engine: 1 inference_num_gpus_per_node: 4 diff --git a/patches/sglang/v0.5.10.post1/sglang.patch b/patches/sglang/v0.5.10.post1/sglang.patch index 0fe3718..b98c36d 100644 --- a/patches/sglang/v0.5.10.post1/sglang.patch +++ b/patches/sglang/v0.5.10.post1/sglang.patch @@ -4,19 +4,19 @@ torchspec sglang patch (base: 94f03a39db) python/sglang/srt/entrypoints/http_server.py | 17 +++ python/sglang/srt/layers/logits_processor.py | 54 ++++++++ python/sglang/srt/managers/detokenizer_manager.py | 3 + - python/sglang/srt/managers/io_struct.py | 48 ++++++++ + python/sglang/srt/managers/io_struct.py | 48 +++++++ python/sglang/srt/managers/schedule_batch.py | 54 +++++++- python/sglang/srt/managers/scheduler.py | 57 ++++++++- - .../managers/scheduler_output_processor_mixin.py | 137 ++++++++++++++++++--- + .../managers/scheduler_output_processor_mixin.py | 138 ++++++++++++++++++--- python/sglang/srt/managers/tokenizer_manager.py | 16 +++ python/sglang/srt/managers/utils.py | 5 +- .../srt/model_executor/forward_batch_info.py | 4 + python/sglang/srt/model_executor/model_runner.py | 16 +++ python/sglang/srt/models/qwen3_next.py | 8 ++ python/sglang/srt/models/qwen3_next_mtp.py | 3 + - python/sglang/srt/server_args.py | 23 ++++ + python/sglang/srt/server_args.py | 24 ++++ .../sglang/srt/speculative/spec_training_info.py | 50 ++++++++ - 16 files changed, 481 insertions(+), 22 deletions(-) + 16 files changed, 483 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index d864e4aba..b74514924 100644 @@ -544,7 +544,7 @@ index 67af2d0de..36d6e539c 100644 # Release the closure and large GPU tensors that are no longer needed. # The delay_sample_func closure captures forward_batch (which holds diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py -index 496cd9665..25ccb527b 100644 +index 496cd9665..45ceb0e20 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -218,21 +218,22 @@ class SchedulerOutputProcessorMixin: @@ -584,7 +584,7 @@ index 496cd9665..25ccb527b 100644 if req.grammar is not None: # FIXME: this try-except block is for handling unexpected xgrammar issue. -@@ -893,6 +894,74 @@ class SchedulerOutputProcessorMixin: +@@ -893,6 +894,75 @@ class SchedulerOutputProcessorMixin: if req.input_token_ids_logprobs_idx is None: req.input_token_ids_logprobs_idx = [] @@ -638,7 +638,8 @@ index 496cd9665..25ccb527b 100644 + ) + + last_hidden_states = None -+ if logits_output.last_hidden_states is not None: ++ store_lhs = getattr(self.server_args, "spec_training_store_last_hidden_states", True) ++ if store_lhs and logits_output.last_hidden_states is not None: + last_hidden_states = logits_output.last_hidden_states[ + hidden_state_offset : hidden_state_offset + seq_len + ] @@ -659,7 +660,7 @@ index 496cd9665..25ccb527b 100644 def stream_output( self: Scheduler, reqs: List[Req], -@@ -954,6 +1023,18 @@ class SchedulerOutputProcessorMixin: +@@ -954,6 +1024,18 @@ class SchedulerOutputProcessorMixin: routed_experts = None customized_info = {} @@ -678,7 +679,7 @@ index 496cd9665..25ccb527b 100644 time_stats = [] if return_logprob: -@@ -1058,6 +1139,13 @@ class SchedulerOutputProcessorMixin: +@@ -1058,6 +1140,13 @@ class SchedulerOutputProcessorMixin: spec_accepted_tokens.append(req.spec_accepted_tokens) spec_acceptance_histogram.append(req.spec_acceptance_histogram) @@ -692,7 +693,7 @@ index 496cd9665..25ccb527b 100644 if return_logprob: if ( req.return_logprob -@@ -1128,9 +1216,15 @@ class SchedulerOutputProcessorMixin: +@@ -1128,9 +1217,15 @@ class SchedulerOutputProcessorMixin: output_token_ids_logprobs_idx.append([]) if req.return_hidden_states: @@ -711,7 +712,7 @@ index 496cd9665..25ccb527b 100644 if req.return_routed_experts: if routed_experts is None: routed_experts = [] -@@ -1197,6 +1291,15 @@ class SchedulerOutputProcessorMixin: +@@ -1197,6 +1292,15 @@ class SchedulerOutputProcessorMixin: retraction_counts=retraction_counts, load=load, dp_ranks=dp_ranks, @@ -895,10 +896,10 @@ index b2bdbbbe8..833c055ab 100644 def forward( self, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py -index d91ced805..b9bb901f1 100644 +index d91ced805..2708630cc 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py -@@ -514,6 +514,11 @@ class ServerArgs: +@@ -514,6 +514,12 @@ class ServerArgs: speculative_ngram_capacity: int = 10 * 1000 * 1000 enable_multi_layer_eagle: bool = False @@ -906,11 +907,12 @@ index d91ced805..b9bb901f1 100644 + enable_spec_training_mooncake: bool = False + enable_aux_hidden_states: bool = False + aux_hidden_state_layer_ids: Optional[List[int]] = None ++ spec_training_store_last_hidden_states: bool = True + # Expert parallelism ep_size: int = 1 moe_a2a_backend: Literal[ -@@ -4896,6 +4901,24 @@ class ServerArgs: +@@ -4896,6 +4902,24 @@ class ServerArgs: help="Enable multi-layer Eagle speculative decoding.", ) diff --git a/tests/test_dflash.py b/tests/test_dflash.py index dd4e6de..e9561a0 100644 --- a/tests/test_dflash.py +++ b/tests/test_dflash.py @@ -575,7 +575,7 @@ def test_loss_decreases_over_steps(self): # Loss should generally decrease (allow some noise, check last < first) self.assertLess(losses[-1], losses[0], "Loss did not decrease over 10 training steps") - self.assertTrue(all(math.isfinite(loss) for loss in losses), "Non-finite loss encountered") + self.assertTrue(all(math.isfinite(v) for v in losses), "Non-finite loss encountered") def test_gradient_accumulation(self): """Two half-LR steps with accumulated gradients should produce finite grads.""" @@ -914,7 +914,7 @@ def test_longer_sequence(self): model, *data = self._make_model_and_data(seq_len=128, num_anchors=8) losses, _ = self._train_steps(model, *data, steps=20) self.assertLess(losses[-1], losses[0]) - self.assertTrue(all(math.isfinite(loss) for loss in losses)) + self.assertTrue(all(math.isfinite(v) for v in losses)) def test_large_block_size(self): """Block size 8 should still converge.""" @@ -986,7 +986,7 @@ def test_loss_mask_with_padding(self): lm_w, steps=15, ) - self.assertTrue(all(math.isfinite(loss) for loss in losses)) + self.assertTrue(all(math.isfinite(v) for v in losses)) self.assertLess(losses[-1], losses[0]) @@ -1133,5 +1133,55 @@ def test_dflash_draft_config_json_loads(self): self.assertEqual(config.target_num_hidden_layers, 36) +class TestDFlashHotfixes(unittest.TestCase): + """Tests for incremental correctness fixes (tensor schema, collator, validation).""" + + def test_collator_accepts_dflash_batch(self): + """DataCollator should accept hidden_states without target/last_hidden_states.""" + from torchspec.data.utils import DataCollatorWithPadding + + collator = DataCollatorWithPadding() + samples = [] + for seq_len in [32, 48]: + samples.append( + { + "input_ids": torch.randint(0, 1000, (1, seq_len)), + "hidden_states": torch.randn(1, seq_len, 384), + "loss_mask": torch.ones(1, seq_len), + } + ) + + # Should not raise (DFlash does not provide target/last_hidden_states) + batch = collator(samples) + self.assertIsNotNone(batch["hidden_states"]) + self.assertIsNone(batch["target"]) + self.assertIsNone(batch["last_hidden_states"]) + + def test_min_loss_tokens_validation(self): + """min_loss_tokens < 2 * block_size should be caught.""" + from argparse import Namespace + + # Simulate the validation logic from train_entry + args = Namespace( + dflash_block_size=16, + min_loss_tokens=10, # < 2 * 16 = 32 + ) + block_size = getattr(args, "dflash_block_size", 16) + min_loss = getattr(args, "min_loss_tokens", 0) + + with self.assertRaises(ValueError): + if min_loss < 2 * block_size: + raise ValueError( + f"DFlash requires dataset.min_loss_tokens >= 2 * training.dflash_block_size " + f"({min_loss} < {2 * block_size})." + ) + + def test_min_loss_tokens_validation_passes(self): + """min_loss_tokens >= 2 * block_size should pass.""" + block_size = 16 + min_loss = 32 # == 2 * 16 + self.assertGreaterEqual(min_loss, 2 * block_size) + + if __name__ == "__main__": unittest.main() diff --git a/torchspec/config/inference_config.py b/torchspec/config/inference_config.py index 704ee9a..33a440d 100644 --- a/torchspec/config/inference_config.py +++ b/torchspec/config/inference_config.py @@ -117,6 +117,7 @@ class InferenceConfig: inference_num_gpus_per_node: int = 8 last_hidden_states_prenorm: Optional[bool] = None max_sample_pool_size: int = 0 + store_last_hidden_states: bool = True sglang: SGLangConfig = field(default_factory=SGLangConfig) vllm: VllmConfig = field(default_factory=VllmConfig) diff --git a/torchspec/controller/loop.py b/torchspec/controller/loop.py index ab23f36..b4278b0 100644 --- a/torchspec/controller/loop.py +++ b/torchspec/controller/loop.py @@ -103,7 +103,10 @@ def _cleanup_old_checkpoints(checkpoint_dir: str | None, max_checkpoints: int) - to_delete = iter_dirs[: len(iter_dirs) - max_checkpoints] for old_dir in to_delete: logger.info(f"Removing old checkpoint: {old_dir}") - shutil.rmtree(old_dir, ignore_errors=True) + try: + shutil.rmtree(old_dir) + except OSError as e: + logger.warning(f"Failed to remove old checkpoint {old_dir}: {e}") def _safe_training_cleanup( @@ -373,9 +376,6 @@ def training_loop( progress.update(1) if _is_save_interval_step(completed_steps, args.save_interval): - max_ckpts = getattr(args, "max_checkpoints", 0) - if max_ckpts > 0: - _cleanup_old_checkpoints(args.checkpoint_dir, max_ckpts) eval_metrics = run_eval(completed_steps, train_group, eval_enabled) logger.info(f"Saving checkpoint at step {completed_steps}...") train_group.save_model(completed_steps) @@ -383,6 +383,9 @@ def training_loop( best_eval_score = update_checkpoint_eval_meta( args.checkpoint_dir, completed_steps, eval_metrics, best_eval_score ) + max_ckpts = getattr(args, "max_checkpoints", 0) + if max_ckpts > 0: + _cleanup_old_checkpoints(args.checkpoint_dir, max_ckpts) _maybe_sync_draft_weights(args, completed_steps, train_group, inference_engines) @@ -398,9 +401,6 @@ def training_loop( and args.checkpoint_dir and last_saved_step != completed_steps ): - max_ckpts = getattr(args, "max_checkpoints", 0) - if max_ckpts > 0: - _cleanup_old_checkpoints(args.checkpoint_dir, max_ckpts) eval_metrics = run_eval(completed_steps, train_group, eval_enabled) logger.info( f"Saving checkpoint at end of epoch {current_epoch} " @@ -411,6 +411,9 @@ def training_loop( best_eval_score = update_checkpoint_eval_meta( args.checkpoint_dir, completed_steps, eval_metrics, best_eval_score ) + max_ckpts = getattr(args, "max_checkpoints", 0) + if max_ckpts > 0: + _cleanup_old_checkpoints(args.checkpoint_dir, max_ckpts) if completed_steps < num_steps: current_epoch += 1 @@ -429,15 +432,15 @@ def training_loop( # Always save a final checkpoint unless saved. if args.checkpoint_dir and last_saved_step != completed_steps: - max_ckpts = getattr(args, "max_checkpoints", 0) - if max_ckpts > 0: - _cleanup_old_checkpoints(args.checkpoint_dir, max_ckpts) eval_metrics = run_eval(completed_steps, train_group, eval_enabled) logger.info(f"Saving final checkpoint at step {completed_steps}...") train_group.save_model(completed_steps, force_sync=True) best_eval_score = update_checkpoint_eval_meta( args.checkpoint_dir, completed_steps, eval_metrics, best_eval_score ) + max_ckpts = getattr(args, "max_checkpoints", 0) + if max_ckpts > 0: + _cleanup_old_checkpoints(args.checkpoint_dir, max_ckpts) final_status = ray.get(controller.get_full_status.remote()) logger.info( diff --git a/torchspec/data/utils.py b/torchspec/data/utils.py index 4469e46..c2416f9 100644 --- a/torchspec/data/utils.py +++ b/torchspec/data/utils.py @@ -120,9 +120,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: has_target = all(item.get("target") is not None for item in features) has_last_hs = all(item.get("last_hidden_states") is not None for item in features) if not has_target and not has_last_hs: - raise ValueError( - "Either 'target' or 'last_hidden_states' is required when 'hidden_states' is provided" - ) + pass if has_target: batch["target"] = torch.cat( [self.paddingtensor(item["target"], max_length) for item in features] @@ -459,7 +457,14 @@ def load_hf_dataset(data_path: str): if drop_cols: ds = ds.remove_columns(drop_cols) return ds - except Exception: + except (ValueError, TypeError, ArithmeticError, KeyError) as e: + # Schema inference failures (e.g., mixed-type columns in Arrow/Parquet). + # Fall back to manual JSON download. + import logging + + logging.getLogger(__name__).info( + f"load_dataset failed for '{data_path}' ({e}), falling back to JSON download" + ) return IterableDataset.from_generator( _load_hub_json_files, gen_kwargs={"data_path": data_path} ) diff --git a/torchspec/inference/engine/__init__.py b/torchspec/inference/engine/__init__.py index cf7c395..bf478b4 100644 --- a/torchspec/inference/engine/__init__.py +++ b/torchspec/inference/engine/__init__.py @@ -35,12 +35,16 @@ from torchspec.inference.engine.sgl_engine import SglEngine # noqa: F401 __all__.append("SglEngine") -except ImportError: - pass +except ImportError as _e: + from torchspec.utils.logging import logger as _logger + + _logger.debug("SglEngine not available: %s", _e) try: from torchspec.inference.engine.vllm_engine import VllmEngine # noqa: F401 __all__.append("VllmEngine") -except ImportError: - pass +except ImportError as _e: + from torchspec.utils.logging import logger as _logger + + _logger.debug("VllmEngine not available: %s", _e) diff --git a/torchspec/inference/engine/sgl_engine.py b/torchspec/inference/engine/sgl_engine.py index 07a466a..39abe1c 100644 --- a/torchspec/inference/engine/sgl_engine.py +++ b/torchspec/inference/engine/sgl_engine.py @@ -101,6 +101,7 @@ def __init__( self._mooncake_config = None self._mooncake_store = None self._hidden_size = None + self._store_last_hidden_states = True self.local_gpu_id = None setup_file_logging("inference", self.rank, group=engine_group) @@ -163,6 +164,8 @@ def init( mooncake_config.metadata_server, ) + self._store_last_hidden_states = getattr(self.args, "store_last_hidden_states", True) + # Get configuration mem_fraction = getattr(self.args, "sglang_mem_fraction_static", 0.8) pp_size = getattr(self.args, "sglang_pp_size", 1) @@ -243,6 +246,11 @@ def init( "chunked_prefill_size": -1, "allow_auto_truncate": True, **({"context_length": max_seq_length} if max_seq_length else {}), + **( + {"spec_training_store_last_hidden_states": False} + if not self._store_last_hidden_states + else {} + ), } ) @@ -532,16 +540,20 @@ def _get_tensor_shapes(self, seq_len: int) -> dict: # IMPORTANT: Sglang stores tensors WITHOUT batch dimension in mooncake # We must request the SAME shapes that sglang stored, otherwise we get size mismatch # The collator will add batch dimension when needed - return { + shapes = { "hidden_states": (seq_len, concat_hidden_size), # 2D without batch dim "input_ids": (seq_len,), # 1D without batch dim - "last_hidden_states": (seq_len, hidden_size), # 2D without batch dim } + if self._store_last_hidden_states: + shapes["last_hidden_states"] = (seq_len, hidden_size) + return shapes def _get_tensor_dtypes(self) -> dict: """Get tensor dtypes for mooncake metadata.""" - return { + dtypes = { "hidden_states": HIDDEN_STATES_STORAGE_DTYPE, "input_ids": torch.long, - "last_hidden_states": HIDDEN_STATES_STORAGE_DTYPE, } + if self._store_last_hidden_states: + dtypes["last_hidden_states"] = HIDDEN_STATES_STORAGE_DTYPE + return dtypes diff --git a/torchspec/models/dflash.py b/torchspec/models/dflash.py index 53378de..886ce4b 100644 --- a/torchspec/models/dflash.py +++ b/torchspec/models/dflash.py @@ -26,7 +26,7 @@ Matches SpecForge's OnlineDFlashModel (specforge/core/dflash.py). """ -from typing import List, Optional, Tuple +from typing import List, Tuple import torch import torch.nn as nn @@ -137,22 +137,11 @@ def _sample_anchor_positions( valid = loss_mask[:, : max_anchor + 1] > 0.5 valid_counts = valid.sum(dim=1) - if int(valid_counts.max().item()) == 0: - logger.warning( - f"No valid anchor positions in batch (max_anchor={max_anchor}, " - f"block_size={bs}). Returning dummy anchors with " - f"keep_mask=False so loss is zero. Consider setting " - f"dataset.min_loss_tokens >= 2*block_size." - ) - anchors = torch.zeros(bsz, max_n, dtype=torch.long, device=device) - keep_mask = torch.zeros(bsz, max_n, dtype=torch.bool, device=device) - return anchors, keep_mask - indices = torch.arange(max_anchor + 1, device=device).unsqueeze(0).expand(bsz, -1) - masked_indices = torch.where(valid, indices, torch.tensor(seq_len + 1, device=device)) + masked_indices = torch.where(valid, indices, seq_len + 1) random_vals = torch.rand(bsz, max_anchor + 1, device=device) - random_vals = torch.where(valid, random_vals, torch.tensor(2.0, device=device)) + random_vals = torch.where(valid, random_vals, 2.0) _, sorted_idx = random_vals.sort(dim=1) gathered = torch.gather(masked_indices, 1, sorted_idx) @@ -168,7 +157,7 @@ def _sample_anchor_positions( keep_mask = torch.arange(max_n, device=device).unsqueeze(0) < valid_counts.unsqueeze( 1 ).clamp(max=max_n) - anchors = torch.where(keep_mask, anchors, torch.tensor(0, dtype=torch.long, device=device)) + anchors = torch.where(keep_mask, anchors, 0) return anchors, keep_mask @@ -226,8 +215,6 @@ def forward( hidden_states_list: List[torch.Tensor], loss_mask: torch.Tensor, lm_head_weight: torch.Tensor, - norm_weight: Optional[torch.Tensor] = None, - norm_eps: float = 1e-6, ) -> Tuple[torch.Tensor, torch.Tensor]: """Full DFlash training forward pass. diff --git a/torchspec/models/draft/dflash.py b/torchspec/models/draft/dflash.py index 918a104..9b769a7 100644 --- a/torchspec/models/draft/dflash.py +++ b/torchspec/models/draft/dflash.py @@ -464,7 +464,9 @@ def load_embedding( return pytorch_model_path = os.path.join(model_path, "pytorch_model.bin") if os.path.exists(pytorch_model_path): - state_dict = torch.load(pytorch_model_path, map_location="cpu") + state_dict = torch.load( + pytorch_model_path, map_location="cpu", weights_only=True + ) self.embed_tokens.weight.copy_(state_dict[embedding_key]) return raise FileNotFoundError( @@ -478,7 +480,9 @@ def load_embedding( with safe_open(os.path.join(model_path, ckpt_file), framework="pt") as f: self.embed_tokens.weight.copy_(f.get_tensor(embedding_key)) else: - state_dict = torch.load(os.path.join(model_path, ckpt_file)) + state_dict = torch.load( + os.path.join(model_path, ckpt_file), map_location="cpu", weights_only=True + ) self.embed_tokens.weight.copy_(state_dict[embedding_key]) else: local_cache_path = snapshot_download(repo_id=model_path) diff --git a/torchspec/train_entry.py b/torchspec/train_entry.py index 12de62a..8c0d5ca 100644 --- a/torchspec/train_entry.py +++ b/torchspec/train_entry.py @@ -195,6 +195,41 @@ def _get_draft_model_config(args): return AutoDraftModelConfig.from_dict(config_dict) +def _validate_and_configure_dflash(args, draft_model_config) -> None: + """Validate DFlash-specific config and auto-set aux layer IDs. + + Called before dataset loading to fail fast on misconfigurations. + """ + from torchspec.models.draft.dflash import DFlashConfig + + if not isinstance(draft_model_config, DFlashConfig): + return + + if getattr(args, "inference_engine_type", "hf") != "sgl": + raise NotImplementedError("DFlash currently supports only inference_engine_type='sgl'.") + if getattr(args, "defer_tokenization", False): + raise NotImplementedError("DFlash does not support defer_tokenization=True.") + block_size = getattr(args, "dflash_block_size", 16) + min_loss = getattr(args, "min_loss_tokens", 0) + if min_loss < 2 * block_size: + raise ValueError( + f"DFlash requires dataset.min_loss_tokens >= 2 * training.dflash_block_size " + f"({min_loss} < {2 * block_size}). Set dataset.min_loss_tokens={2 * block_size}." + ) + + # Auto-set aux layer IDs from draft config if not explicitly provided + if not getattr(args, "aux_hidden_states_layers", None): + from torchspec.models.draft.dflash import build_target_layer_ids + + target_layer_ids = getattr(draft_model_config, "target_layer_ids", None) + if target_layer_ids is None: + num_target = getattr(draft_model_config, "num_target_layers", 5) + target_num_hidden = getattr(draft_model_config, "target_num_hidden_layers", 36) + target_layer_ids = build_target_layer_ids(num_target, target_num_hidden) + args.aux_hidden_states_layers = target_layer_ids + logger.info(f"DFlash: set aux_hidden_states_layers = {target_layer_ids}") + + def train_async_no_generation(args): """Entry point for Eagle3 online training. @@ -219,6 +254,13 @@ def train_async_no_generation(args): scheduling_strategy=NodeAffinitySchedulingStrategy(node_id=driver_node_id, soft=False), ).remote(args, args.dp_size) + # [1.5] Parse draft config + DFlash validation (before any async work) + with timer.phase("Parse draft model config"): + draft_model_config = _get_draft_model_config(args) + args.draft_model_config_obj = draft_model_config + + _validate_and_configure_dflash(args, draft_model_config) + # [2] Kick off dataset loading on controller (async — runs on actor while driver continues) timer.begin_async("Dataset loading") dataset_size_ref = controller.load_dataset.remote(args) @@ -226,25 +268,6 @@ def train_async_no_generation(args): # [3] Do initialization that doesn't depend on dataset in parallel with timer.phase("Driver-side init"): - draft_model_config = _get_draft_model_config(args) - args.draft_model_config_obj = draft_model_config - - # Auto-set aux layer IDs for DFlash (5 layers) if not explicitly provided - from torchspec.models.draft.dflash import DFlashConfig - - if isinstance(draft_model_config, DFlashConfig) and not getattr( - args, "aux_hidden_states_layers", None - ): - from torchspec.models.draft.dflash import build_target_layer_ids - - target_layer_ids = getattr(draft_model_config, "target_layer_ids", None) - if target_layer_ids is None: - num_target = getattr(draft_model_config, "num_target_layers", 5) - target_num_hidden = getattr(draft_model_config, "target_num_hidden_layers", 36) - target_layer_ids = build_target_layer_ids(num_target, target_num_hidden) - args.aux_hidden_states_layers = target_layer_ids - logger.info(f"DFlash: set aux_hidden_states_layers = {target_layer_ids}") - pgs = create_placement_groups(args) launch_mooncake_master(args) mooncake_config = build_mooncake_config(args) diff --git a/torchspec/training/data_fetcher.py b/torchspec/training/data_fetcher.py index e40a48f..4fd06cc 100644 --- a/torchspec/training/data_fetcher.py +++ b/torchspec/training/data_fetcher.py @@ -378,7 +378,11 @@ def _prefetch_loop(self) -> None: for batch in self.inner: self._queue.put(batch) except Exception as e: - self._error = e + # Preserve the original traceback so re-raise in __next__ + # points to the actual failure site, not to __next__ itself. + import sys + + self._error = e.with_traceback(sys.exc_info()[2]) finally: self._queue.put(self._SENTINEL) diff --git a/torchspec/training/dflash_trainer.py b/torchspec/training/dflash_trainer.py index 0f36b27..3c76054 100644 --- a/torchspec/training/dflash_trainer.py +++ b/torchspec/training/dflash_trainer.py @@ -78,7 +78,10 @@ def init_model( elif isinstance(draft_model_config, DFlashConfig): config = draft_model_config else: - config = draft_model_config + raise TypeError( + f"Unsupported draft_model_config type: {type(draft_model_config).__name__}. " + f"Expected str, dict, or DFlashConfig." + ) if not hasattr(config, "num_target_layers") or config.num_target_layers is None: config.num_target_layers = self.num_target_layers diff --git a/torchspec/training/trainer_actor.py b/torchspec/training/trainer_actor.py index 1e9f5d0..524f47e 100644 --- a/torchspec/training/trainer_actor.py +++ b/torchspec/training/trainer_actor.py @@ -106,7 +106,15 @@ def save_draft_model_for_serving(self, output_dir: str) -> None: self._trainer.save_draft_model_for_serving(output_dir) def set_vocab_buffers(self, d2t, t2d) -> None: - self._trainer.draft_model.set_vocab_buffers(d2t, t2d) + if hasattr(self._trainer, "draft_model") and hasattr( + self._trainer.draft_model, "set_vocab_buffers" + ): + self._trainer.draft_model.set_vocab_buffers(d2t, t2d) + else: + raise AttributeError( + "set_vocab_buffers called but draft model does not support vocab pruning. " + "DFlash training should not use vocab pruning — check train_entry config." + ) def set_eval_queue(self, queue, mooncake_config=None, per_dp_rank_batch_size: int = 1): return self._trainer.set_eval_queue( diff --git a/torchspec/utils/misc.py b/torchspec/utils/misc.py index 040d420..b7c1b8b 100644 --- a/torchspec/utils/misc.py +++ b/torchspec/utils/misc.py @@ -86,22 +86,3 @@ def get_default_eagle3_aux_layer_ids(model_path: str) -> List[int]: config = getattr(config, "text_config", config) num_layers = config.num_hidden_layers return [1, num_layers // 2 - 1, num_layers - 4] - - -def get_default_dflash_aux_layer_ids(model_path: str, num_target_layers: int = 5) -> List[int]: - """Get default auxiliary hidden state layer IDs for DFlash. - - Uses the same uniform spacing algorithm as DFlashDraftModel.build_target_layer_ids(). - - Args: - model_path: Path to the HuggingFace model checkpoint. - num_target_layers: Number of target layers to capture (default: 5). - - Returns: - List of uniformly spaced layer IDs. - """ - from torchspec.models.draft.dflash import build_target_layer_ids - - config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - config = getattr(config, "text_config", config) - return build_target_layer_ids(num_target_layers, config.num_hidden_layers)