Skip to content

feat: DFlash (block-parallel) draft model training#64

Merged
cicirori merged 5 commits intotorchspec-project:mainfrom
zhubohao911:fork/feature/dflash-training
Apr 13, 2026
Merged

feat: DFlash (block-parallel) draft model training#64
cicirori merged 5 commits intotorchspec-project:mainfrom
zhubohao911:fork/feature/dflash-training

Conversation

@zhubohao911
Copy link
Copy Markdown
Contributor

@zhubohao911 zhubohao911 commented Apr 7, 2026

Objective

Implement DFlash block-parallel draft model training in TorchSpec, enabling disaggregated online training with SGLang inference backend.

DFlash predicts 16-token blocks in parallel using dual-source KV attention, achieving ~5x fewer forward passes per training step compared to Eagle3's autoregressive approach.

Reference: z-lab/dflash, SpecForge DFlash PR

Architecture

DFlash vs Eagle3

DFlash Eagle3
Loss CE (cross-entropy) KL (forward KL)
Prediction Block-parallel (block_size=16) Autoregressive (ttt_length=7)
Target layers 5 layers 3 layers
Mask Block-causal (FlexAttention) Causal
Context Dual-source KV (W_proj) Input fusion (fc layer)
Forward passes per step 1 7 (sequential)

Files Changed

Core Implementation (new)

File Description
torchspec/models/draft/dflash.py DFlashDraftModel — 5-layer transformer with dual-source KV attention, GQA, shared embedding + LM head
torchspec/models/dflash.py DFlashModel — training wrapper with anchor sampling, block-causal mask (FlexAttention), CE loss with exponential decay
torchspec/training/dflash_trainer.py DFlashTrainer — FSDP2 integration, WSD LR scheduler, frozen LM head initialization
torchspec/config/dflash_draft_config.json Default draft model config for Qwen3-8B
configs/sglang_qwen3_8b_dflash.yaml Multi-GPU training config (SGLang + FSDP)

Tests

File Description
tests/test_dflash.py 67 tests — config, architecture, anchor sampling, block-causal mask, forward pass, loss, accuracy

Tooling

File Description
scripts/tools/extract_dflash_checkpoint.py FSDP distributed checkpoint → single .pt extraction
scripts/tools/prepare_perfectblend.py PerfectBlend dataset download and normalization
docs/dflash/README.md Architecture overview, training guide, and validation results

Modified Files

File Change
torchspec/models/target/eagle3_target_model.py Generalized set_aux_hidden_states_layers() for N layers (was hardcoded to 3)
torchspec/models/ops/flex_attention.py Increased recompile limit, inductor GEMM backend config
torchspec/training/trainer.py Async data prefetching, gradient sync for micro-batch accumulation
torchspec/training/data_fetcher.py PrefetchedDataFetcher, min_loss_tokens filtering, clone safety for batch>1
torchspec/training/optimizer.py Added min_lr parameter to LR scheduler
torchspec/training/eagle3_trainer.py Plumbed weight_decay, min_lr, WSD params
torchspec/config/train_config.py DFlash-specific config fields, min_loss_tokens, shuffle_dataset, max_checkpoints
torchspec/controller/loop.py Checkpoint rotation (max_checkpoints), verbose timing
torchspec/training/trainer_actor.py Config-based trainer dispatch (DFlashConfigDFlashTrainer)
torchspec/train_entry.py Auto-set aux layer IDs for DFlash, inductor GEMM fix
torchspec/inference/engine/__init__.py Lazy SGLang/vLLM imports (HF-only training support)

Validation Results

Best model (P2-WSD): 800K PerfectBlend, 3 epochs, WSD LR schedule, 8x H100.

Dataset Our τ z-lab τ E2E Speedup
gsm8k 3.89 3.38 2.47x
math500 4.19 4.61 2.80x
aime24 3.98 4.12 2.60x
aime25 3.69 4.07 2.42x
humaneval 4.30 2.68x
livecodebench 4.72 2.96x

Math avg τ = 3.94 (2.7% gap to z-lab's 4.05). Decode-only speedup reached 3.02x on livecodebench.

Usage

# Prepare data
python scripts/tools/prepare_perfectblend.py \
  --output data/perfectblend_50k.jsonl --sample-size 50000

# Train (4x H100: 1 inference + 3 training FSDP)
python -m torchspec.train_entry \
  --config configs/sglang_qwen3_8b_dflash.yaml \
  dataset.train_data_path=data/perfectblend_50k.jsonl \
  output_dir=./outputs/qwen3-8b-dflash

# Extract checkpoint
python scripts/tools/extract_dflash_checkpoint.py \
  --checkpoint_dir outputs/qwen3-8b-dflash/checkpoints/iter_NNNNNNN \
  --output dflash_draft.pt

References

@zhubohao911 zhubohao911 changed the title Add DFlash training feat: DFlash (block-parallel) draft model training Apr 7, 2026
@zhubohao911 zhubohao911 force-pushed the fork/feature/dflash-training branch from b7c6213 to e18b893 Compare April 7, 2026 06:56
Add DFlash block-parallel draft model training to TorchSpec, enabling
disaggregated online training with SGLang inference backend. DFlash
predicts 16-token blocks in parallel using dual-source KV attention,
achieving ~5x fewer forward passes per step compared to Eagle3.

Core implementation:
- torchspec/models/draft/dflash.py: DFlashDraftModel architecture
  (5-layer transformer, dual-source KV, GQA, shared embedding/LM head)
- torchspec/models/dflash.py: Training wrapper with anchor sampling,
  block-causal mask (FlexAttention), CE loss with exponential decay
- torchspec/training/dflash_trainer.py: FSDP2 trainer with WSD scheduler
- tests/test_dflash.py: 67 tests covering config, architecture,
  anchor sampling, block-causal mask, forward pass, loss, accuracy

Integration:
- Config-based trainer dispatch (DFlashConfig → DFlashTrainer)
- Generalized N-layer target model support (was hardcoded to 3)
- Async data prefetching, min_lr/weight_decay optimizer params
- Checkpoint rotation, lazy SGLang/vLLM imports

Validation (best model P2-WSD, 800K PerfectBlend, 3 epochs, 8x H100):
- Math avg τ=3.94 (2.7% gap to z-lab's 4.05)
- Decode-only speedup: 3.02x on livecodebench
@zhubohao911 zhubohao911 force-pushed the fork/feature/dflash-training branch from e18b893 to 26c5138 Compare April 7, 2026 06:59
- Read target_num_hidden_layers from target model config instead of
  hardcoding 36 (Qwen3-8B specific)
- Remove duplicate SglEngine/VllmEngine imports in factory.py
- Explain why accuracy uses binary mask without decay (intentional)
- Note _apply_rotary_pos_emb kept as utility matching SpecForge
- Document RoPE cache +20 buffer purpose
- Reference SpecForge PR #427 for bidirectional intra-block attention
@zhubohao911 zhubohao911 marked this pull request as ready for review April 7, 2026 07:25
@cicirori cicirori self-assigned this Apr 12, 2026
@cicirori
Copy link
Copy Markdown
Collaborator

Thanks for the thorough work! I happened to finish writing, and after publishing I came across this PR and made some changes based on it. #71

@zhubohao911
Copy link
Copy Markdown
Contributor Author

Thanks for the thorough work! I happened to finish writing, and after publishing I came across this PR and made some changes based on it. #71

Hi cicirori

Can we merge this PR first? I believe @yubofredwang is reviewing this one.

cicirori added a commit that referenced this pull request Apr 13, 2026
- Checkpoint cleanup: move _cleanup_old_checkpoints AFTER save at all 3
  call sites to prevent race condition where cleanup deletes existing
  checkpoints before new save succeeds (data loss on save failure).

- Checkpoint cleanup: replace shutil.rmtree(ignore_errors=True) with
  try/except that logs warning on failure, preventing silent disk leak.

- DFlashTrainer: raise TypeError for unsupported draft_model_config
  types instead of silently accepting anything.

- PrefetchedDataFetcher: preserve original traceback when re-raising
  prefetch thread exceptions, so error points to actual failure site.

- DFlashDraftModel.load_embedding: add weights_only=True to torch.load
  calls for security and to suppress deprecation warning.

- load_hf_dataset: narrow bare except Exception to specific schema
  inference errors (ValueError, TypeError, etc.) and log the fallback,
  so auth errors and network failures are not silently swallowed.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@cicirori
Copy link
Copy Markdown
Collaborator

Thanks for the thorough work! I happened to finish writing, and after publishing I came across this PR and made some changes based on it. #71

Hi cicirori

Can we merge this PR first? I believe @yubofredwang is reviewing this one.

sure, linter failed. can you fix that?

- ruff-format: reformat 10 DFlash files to match project style
- ruff: drop unused model vars, rename ambiguous l → loss in test_dflash
- ruff: mark conditional SGL/vLLM engine imports with noqa: F401
- end-of-file-fixer: add trailing newline to .gitignore
- check-shebang-scripts-are-executable: chmod +x two scripts/tools files

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@cicirori cicirori merged commit 760a552 into torchspec-project:main Apr 13, 2026
1 check passed
@zhubohao911
Copy link
Copy Markdown
Contributor Author

Thanks for the thorough work! I happened to finish writing, and after publishing I came across this PR and made some changes based on it. #71

Hi cicirori
Can we merge this PR first? I believe @yubofredwang is reviewing this one.

sure, linter failed. can you fix that?

Thank you so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants