feat: DFlash (block-parallel) draft model training#64
feat: DFlash (block-parallel) draft model training#64cicirori merged 5 commits intotorchspec-project:mainfrom
Conversation
b7c6213 to
e18b893
Compare
Add DFlash block-parallel draft model training to TorchSpec, enabling disaggregated online training with SGLang inference backend. DFlash predicts 16-token blocks in parallel using dual-source KV attention, achieving ~5x fewer forward passes per step compared to Eagle3. Core implementation: - torchspec/models/draft/dflash.py: DFlashDraftModel architecture (5-layer transformer, dual-source KV, GQA, shared embedding/LM head) - torchspec/models/dflash.py: Training wrapper with anchor sampling, block-causal mask (FlexAttention), CE loss with exponential decay - torchspec/training/dflash_trainer.py: FSDP2 trainer with WSD scheduler - tests/test_dflash.py: 67 tests covering config, architecture, anchor sampling, block-causal mask, forward pass, loss, accuracy Integration: - Config-based trainer dispatch (DFlashConfig → DFlashTrainer) - Generalized N-layer target model support (was hardcoded to 3) - Async data prefetching, min_lr/weight_decay optimizer params - Checkpoint rotation, lazy SGLang/vLLM imports Validation (best model P2-WSD, 800K PerfectBlend, 3 epochs, 8x H100): - Math avg τ=3.94 (2.7% gap to z-lab's 4.05) - Decode-only speedup: 3.02x on livecodebench
e18b893 to
26c5138
Compare
- Read target_num_hidden_layers from target model config instead of hardcoding 36 (Qwen3-8B specific) - Remove duplicate SglEngine/VllmEngine imports in factory.py
- Explain why accuracy uses binary mask without decay (intentional) - Note _apply_rotary_pos_emb kept as utility matching SpecForge - Document RoPE cache +20 buffer purpose - Reference SpecForge PR #427 for bidirectional intra-block attention
|
Thanks for the thorough work! I happened to finish writing, and after publishing I came across this PR and made some changes based on it. #71 |
Hi cicirori Can we merge this PR first? I believe @yubofredwang is reviewing this one. |
- Checkpoint cleanup: move _cleanup_old_checkpoints AFTER save at all 3 call sites to prevent race condition where cleanup deletes existing checkpoints before new save succeeds (data loss on save failure). - Checkpoint cleanup: replace shutil.rmtree(ignore_errors=True) with try/except that logs warning on failure, preventing silent disk leak. - DFlashTrainer: raise TypeError for unsupported draft_model_config types instead of silently accepting anything. - PrefetchedDataFetcher: preserve original traceback when re-raising prefetch thread exceptions, so error points to actual failure site. - DFlashDraftModel.load_embedding: add weights_only=True to torch.load calls for security and to suppress deprecation warning. - load_hf_dataset: narrow bare except Exception to specific schema inference errors (ValueError, TypeError, etc.) and log the fallback, so auth errors and network failures are not silently swallowed. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
sure, linter failed. can you fix that? |
- ruff-format: reformat 10 DFlash files to match project style - ruff: drop unused model vars, rename ambiguous l → loss in test_dflash - ruff: mark conditional SGL/vLLM engine imports with noqa: F401 - end-of-file-fixer: add trailing newline to .gitignore - check-shebang-scripts-are-executable: chmod +x two scripts/tools files Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Thank you so much! |
Objective
Implement DFlash block-parallel draft model training in TorchSpec, enabling disaggregated online training with SGLang inference backend.
DFlash predicts 16-token blocks in parallel using dual-source KV attention, achieving ~5x fewer forward passes per training step compared to Eagle3's autoregressive approach.
Reference: z-lab/dflash, SpecForge DFlash PR
Architecture
DFlash vs Eagle3
Files Changed
Core Implementation (new)
torchspec/models/draft/dflash.pyDFlashDraftModel— 5-layer transformer with dual-source KV attention, GQA, shared embedding + LM headtorchspec/models/dflash.pyDFlashModel— training wrapper with anchor sampling, block-causal mask (FlexAttention), CE loss with exponential decaytorchspec/training/dflash_trainer.pyDFlashTrainer— FSDP2 integration, WSD LR scheduler, frozen LM head initializationtorchspec/config/dflash_draft_config.jsonconfigs/sglang_qwen3_8b_dflash.yamlTests
tests/test_dflash.pyTooling
scripts/tools/extract_dflash_checkpoint.pyscripts/tools/prepare_perfectblend.pydocs/dflash/README.mdModified Files
torchspec/models/target/eagle3_target_model.pyset_aux_hidden_states_layers()for N layers (was hardcoded to 3)torchspec/models/ops/flex_attention.pytorchspec/training/trainer.pytorchspec/training/data_fetcher.pyPrefetchedDataFetcher,min_loss_tokensfiltering, clone safety for batch>1torchspec/training/optimizer.pymin_lrparameter to LR schedulertorchspec/training/eagle3_trainer.pyweight_decay,min_lr, WSD paramstorchspec/config/train_config.pymin_loss_tokens,shuffle_dataset,max_checkpointstorchspec/controller/loop.pymax_checkpoints), verbose timingtorchspec/training/trainer_actor.pyDFlashConfig→DFlashTrainer)torchspec/train_entry.pytorchspec/inference/engine/__init__.pyValidation Results
Best model (P2-WSD): 800K PerfectBlend, 3 epochs, WSD LR schedule, 8x H100.
Math avg τ = 3.94 (2.7% gap to z-lab's 4.05). Decode-only speedup reached 3.02x on livecodebench.
Usage
References