Improve DFlash training: performance and config-driven inference#74
Improve DFlash training: performance and config-driven inference#74
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 9ee6964f4b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| training_num_gpus_per_node: 3 | ||
| training_num_nodes: 1 |
There was a problem hiding this comment.
Fix GPU allocation values in repro config
This repro config claims an 8x B200 layout with 6 training GPUs and 2 inference GPUs (global_batch = 12), but the actual values configure only 3 training GPUs (training_num_gpus_per_node: 3, training_num_nodes: 1) and 1 inference GPU (inference_num_gpus: 1, tp_size: 1), so the run uses 4 GPUs total and a global batch of 6. That mismatch will underutilize hardware and makes the published reproduction numbers non-reproducible with this file as written.
Useful? React with 👍 / 👎.
torchspec/data/utils.py
Outdated
| if not has_target and not has_last_hs: | ||
| raise ValueError( | ||
| "Either 'target' or 'last_hidden_states' is required when 'hidden_states' is provided" | ||
| # DFlash uses CE loss from input_ids — no target/last_hs needed. |
There was a problem hiding this comment.
why change to warning here?
| pass | ||
| except ImportError as _e: | ||
| import logging as _logging | ||
|
|
There was a problem hiding this comment.
use repo level logger system
| except ImportError: | ||
| pass | ||
| except ImportError as _e: | ||
| import logging as _logging |
| pytorch_model_path = os.path.join(model_path, "pytorch_model.bin") | ||
| if os.path.exists(pytorch_model_path): | ||
| state_dict = torch.load(pytorch_model_path, map_location="cpu") | ||
| state_dict = torch.load( |
torchspec/train_entry.py
Outdated
| ): | ||
| from torchspec.models.draft.dflash import build_target_layer_ids | ||
| if isinstance(draft_model_config, DFlashConfig): | ||
| # Fail fast on unsupported configurations |
There was a problem hiding this comment.
can we extract this to a helper?
| # - 3 GPUs for training (FSDP FULL_SHARD) | ||
| # - global_batch = 1 × 4 × 3 = 12 | ||
| # GPU allocation (8x GPU): | ||
| # - 4 GPUs for inference (SGLang engine, tp_size=1, duplicate mode) |
There was a problem hiding this comment.
leave original config, but use correct comments
| def set_vocab_buffers(self, d2t, t2d) -> None: | ||
| self._trainer.draft_model.set_vocab_buffers(d2t, t2d) | ||
| if hasattr(self._trainer, "draft_model") and hasattr( | ||
| self._trainer.draft_model, "set_vocab_buffers" |
There was a problem hiding this comment.
Will there be silent failures here? dflash doesn't have this method, but if you want to support this kind of training, it'll just silently error out.
3975159 to
55637d0
Compare
55637d0 to
b97efb3
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b97efb351a
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| ds = ds.remove_columns(drop_cols) | ||
| return ds | ||
| except Exception: | ||
| except (ValueError, TypeError, ArithmeticError, KeyError) as e: |
There was a problem hiding this comment.
Catch Arrow/datasets errors when falling back to JSON loader
The new exception filter only catches built-in types, so HF schema-inference failures from load_dataset (commonly pyarrow/datasets exceptions for mixed-type columns) now bypass the fallback and abort loading instead of using _load_hub_json_files. This regresses the mixed-schema hub-dataset path this function is explicitly trying to support, and can break training runs that previously recovered by switching to raw JSON streaming.
Useful? React with 👍 / 👎.
| # provides target or last_hidden_states. Eagle3 requires at | ||
| # least one — its trainer accesses batch["last_hidden_states"] | ||
| # directly, which will raise a clear KeyError there. | ||
| pass |
There was a problem hiding this comment.
Keep validation for hidden-states batches in collator
Replacing the guard with pass allows batches that have hidden_states but neither target nor last_hidden_states to flow through all modes, and Eagle3 then fails later in _forward when it pads batch["last_hidden_states"] (which is None). This turns an immediate, actionable data-contract error into a delayed runtime crash, making malformed inference outputs harder to detect and debug during training.
Useful? React with 👍 / 👎.
b97efb3 to
35b2d06
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 35b2d06748
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
torchspec/train_entry.py
Outdated
| if not hasattr(args, "store_last_hidden_states"): | ||
| args.store_last_hidden_states = False |
There was a problem hiding this comment.
Disable Mooncake LHS storage for DFlash by default
_validate_and_configure_dflash never flips args.store_last_hidden_states because it gates on hasattr, but InferenceConfig now always defines this field (default True), so the branch is dead in normal config-driven runs. That means DFlash+SGLang continues storing last_hidden_states even though this path is intended to skip them, adding unnecessary Mooncake traffic/memory pressure and undermining the hotfix/perf goal unless users manually override the config. Consider setting the flag based on value/intent rather than attribute existence.
Useful? React with 👍 / 👎.
dff5a8d to
7acea31
Compare
Correctness fixes: - SglEngine: omit last_hidden_states from DFlash tensor schema, pass spec_training_store_last_hidden_states=False to prevent Mooncake storage leak (orphaned _lhs keys) - SGLang patch: add spec_training_store_last_hidden_states flag to ServerArgs + guard in _send_hidden_states_to_mooncake() - train_entry: move draft config parsing + DFlash validation before dataset loading for fail-fast on defer_tokenization, backend mismatch, and min_loss_tokens misconfiguration - DataCollator: replace hard ValueError with warnings.warn for missing target/last_hidden_states (DFlash doesn't need them) - trainer_actor: guard set_vocab_buffers with hasattr - DFlashTrainer: raise TypeError for unsupported config types - Checkpoint cleanup: move after save (prevent race condition), replace ignore_errors with logged warning - PrefetchedDataFetcher: preserve original traceback on re-raise - load_hf_dataset: narrow bare except to specific schema errors - torch.load: add weights_only=True for security Performance: - Remove torch.compile graph break in _sample_anchor_positions by eliminating .item() call - Use Python scalars in torch.where (avoid per-step CUDA tensor alloc) - Fix sglang_qwen3_8b_dflash.yaml: correct GPU allocation (8 GPU), switch to REPLICATE strategy Cleanup: - Remove dead code: _apply_rotary_pos_emb, get_default_dflash_aux_layer_ids, unused forward() params norm_weight/norm_eps - Add debug logging to silent ImportError in engine/__init__.py - Add configs/dflash_qwen3_8b_repro.yaml, scripts/bench_dflash_opts.py, scripts/eval_dflash.sh Validated: 62 tests passing, E2E training on 8xB200 (760K PerfectBlend, 3 epochs, 18.8h), gsm8k τ=5.55, math500 τ=3.25. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
7acea31 to
47b0475
Compare
Summary
Follow-up improvements for DFlash training (#64). Enhances correctness, training performance, and engine compatibility based on 8xB200 validation.
Highlights
Config-driven inference optimization: Added
inference.store_last_hidden_statesconfig field (defaulttrue). DFlash auto-sets it tofalsefor SGLang, skipping unnecessarylast_hidden_statesstorage in Mooncake. vLLM and other engines remain unaffected — no engine-specific code needed.Faster training: Removed
torch.compilegraph break in anchor sampling (.item()elimination), use Python scalars intorch.whereto avoid per-step CUDA tensor allocation. Combined withREPLICATEFSDP strategy, training reaches ~4.7 step/s on 8xB200.Earlier error detection: DFlash config validation (backend,
defer_tokenization,min_loss_tokens) now runs before dataset loading, catching misconfigurations before any async work begins.Changes
Inference
inference.store_last_hidden_statesconfig field — controls whetherlast_hidden_statesis included in Mooncake tensor schema and SGLang storagespec_training_store_last_hidden_statesflag toServerArgs, guard_lhsstorage pathstore_last_hidden_states=falseTraining
_validate_and_configure_dflash()helper for early config validationtorch.compilegraph break in_sample_anchor_positionstorch.where(avoid per-step CUDA tensor creation)Robustness
shutil.rmtree(ignore_errors=True)→ try/except with logged warningDFlashTrainer.init_modelraisesTypeErrorfor unsupported config typesset_vocab_buffersraisesAttributeErrorif called on incompatible draft modelPrefetchedDataFetcherpreserves original traceback on re-raiseload_hf_datasetnarrows bareexcept Exceptionto specific schema errorstorch.loadcalls addweights_only=TrueCleanup
_apply_rotary_pos_emb,get_default_dflash_aux_layer_ids, unusednorm_weight/norm_epsparamsValueErrorwith comment-only pass for missingtarget/last_hidden_states(DFlash doesn't need them; Eagle3 catches at trainer level)sglang_qwen3_8b_dflash.yamlcomments (correct GPU count and global batch size)configs/dflash_qwen3_8b_repro.yamlfor 4xB200 reproductionValidation
🤖 Generated with Claude Code