Skip to content

Improve DFlash training: performance and config-driven inference#74

Merged
cicirori merged 1 commit intomainfrom
feat/dflash-fixes
Apr 13, 2026
Merged

Improve DFlash training: performance and config-driven inference#74
cicirori merged 1 commit intomainfrom
feat/dflash-fixes

Conversation

@cicirori
Copy link
Copy Markdown
Collaborator

@cicirori cicirori commented Apr 13, 2026

Summary

Follow-up improvements for DFlash training (#64). Enhances correctness, training performance, and engine compatibility based on 8xB200 validation.

Highlights

Config-driven inference optimization: Added inference.store_last_hidden_states config field (default true). DFlash auto-sets it to false for SGLang, skipping unnecessary last_hidden_states storage in Mooncake. vLLM and other engines remain unaffected — no engine-specific code needed.

Faster training: Removed torch.compile graph break in anchor sampling (.item() elimination), use Python scalars in torch.where to avoid per-step CUDA tensor allocation. Combined with REPLICATE FSDP strategy, training reaches ~4.7 step/s on 8xB200.

Earlier error detection: DFlash config validation (backend, defer_tokenization, min_loss_tokens) now runs before dataset loading, catching misconfigurations before any async work begins.

Changes

Inference

  • Add inference.store_last_hidden_states config field — controls whether last_hidden_states is included in Mooncake tensor schema and SGLang storage
  • SGLang patch: add spec_training_store_last_hidden_states flag to ServerArgs, guard _lhs storage path
  • DFlash + SGLang auto-configures store_last_hidden_states=false

Training

  • Extract _validate_and_configure_dflash() helper for early config validation
  • Move draft config parsing before dataset loading for fail-fast
  • Remove torch.compile graph break in _sample_anchor_positions
  • Use Python scalars in torch.where (avoid per-step CUDA tensor creation)

Robustness

  • Checkpoint cleanup moved after save (prevents data loss on save failure)
  • shutil.rmtree(ignore_errors=True) → try/except with logged warning
  • DFlashTrainer.init_model raises TypeError for unsupported config types
  • set_vocab_buffers raises AttributeError if called on incompatible draft model
  • PrefetchedDataFetcher preserves original traceback on re-raise
  • load_hf_dataset narrows bare except Exception to specific schema errors
  • torch.load calls add weights_only=True

Cleanup

  • Remove dead code: _apply_rotary_pos_emb, get_default_dflash_aux_layer_ids, unused norm_weight/norm_eps params
  • DataCollator: replace ValueError with comment-only pass for missing target/last_hidden_states (DFlash doesn't need them; Eagle3 catches at trainer level)
  • Fix sglang_qwen3_8b_dflash.yaml comments (correct GPU count and global batch size)
  • Add configs/dflash_qwen3_8b_repro.yaml for 4xB200 reproduction

Validation

  • 62 unit tests passing (B200 GPU)
  • E2E smoke test: 20 steps on 4xB200, checkpoint saved successfully
  • Full training: 760K PerfectBlend, 3 epochs, 18.8h on 8xB200
  • Benchmarks: gsm8k τ=5.55 (accuracy 94%), math500 τ=3.25

🤖 Generated with Claude Code

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 9ee6964f4b

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +33 to +34
training_num_gpus_per_node: 3
training_num_nodes: 1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Fix GPU allocation values in repro config

This repro config claims an 8x B200 layout with 6 training GPUs and 2 inference GPUs (global_batch = 12), but the actual values configure only 3 training GPUs (training_num_gpus_per_node: 3, training_num_nodes: 1) and 1 inference GPU (inference_num_gpus: 1, tp_size: 1), so the run uses 4 GPUs total and a global batch of 6. That mismatch will underutilize hardware and makes the published reproduction numbers non-reproducible with this file as written.

Useful? React with 👍 / 👎.

if not has_target and not has_last_hs:
raise ValueError(
"Either 'target' or 'last_hidden_states' is required when 'hidden_states' is provided"
# DFlash uses CE loss from input_ids — no target/last_hs needed.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change to warning here?

pass
except ImportError as _e:
import logging as _logging

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use repo level logger system

except ImportError:
pass
except ImportError as _e:
import logging as _logging
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

pytorch_model_path = os.path.join(model_path, "pytorch_model.bin")
if os.path.exists(pytorch_model_path):
state_dict = torch.load(pytorch_model_path, map_location="cpu")
state_dict = torch.load(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why here?

):
from torchspec.models.draft.dflash import build_target_layer_ids
if isinstance(draft_model_config, DFlashConfig):
# Fail fast on unsupported configurations
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we extract this to a helper?

# - 3 GPUs for training (FSDP FULL_SHARD)
# - global_batch = 1 × 4 × 3 = 12
# GPU allocation (8x GPU):
# - 4 GPUs for inference (SGLang engine, tp_size=1, duplicate mode)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leave original config, but use correct comments

def set_vocab_buffers(self, d2t, t2d) -> None:
self._trainer.draft_model.set_vocab_buffers(d2t, t2d)
if hasattr(self._trainer, "draft_model") and hasattr(
self._trainer.draft_model, "set_vocab_buffers"
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will there be silent failures here? dflash doesn't have this method, but if you want to support this kind of training, it'll just silently error out.

@cicirori cicirori force-pushed the feat/dflash-fixes branch 2 times, most recently from 3975159 to 55637d0 Compare April 13, 2026 18:04
@cicirori cicirori changed the title fix: DFlash correctness, performance, and code quality improvements style: DFlash code style alignment Apr 13, 2026
@cicirori cicirori force-pushed the feat/dflash-fixes branch from 55637d0 to b97efb3 Compare April 13, 2026 18:07
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b97efb351a

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

ds = ds.remove_columns(drop_cols)
return ds
except Exception:
except (ValueError, TypeError, ArithmeticError, KeyError) as e:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Catch Arrow/datasets errors when falling back to JSON loader

The new exception filter only catches built-in types, so HF schema-inference failures from load_dataset (commonly pyarrow/datasets exceptions for mixed-type columns) now bypass the fallback and abort loading instead of using _load_hub_json_files. This regresses the mixed-schema hub-dataset path this function is explicitly trying to support, and can break training runs that previously recovered by switching to raw JSON streaming.

Useful? React with 👍 / 👎.

# provides target or last_hidden_states. Eagle3 requires at
# least one — its trainer accesses batch["last_hidden_states"]
# directly, which will raise a clear KeyError there.
pass
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep validation for hidden-states batches in collator

Replacing the guard with pass allows batches that have hidden_states but neither target nor last_hidden_states to flow through all modes, and Eagle3 then fails later in _forward when it pads batch["last_hidden_states"] (which is None). This turns an immediate, actionable data-contract error into a delayed runtime crash, making malformed inference outputs harder to detect and debug during training.

Useful? React with 👍 / 👎.

@cicirori cicirori force-pushed the feat/dflash-fixes branch from b97efb3 to 35b2d06 Compare April 13, 2026 18:17
@cicirori cicirori changed the title style: DFlash code style alignment Improve DFlash training: correctness, performance, and config-driven inference Apr 13, 2026
@cicirori cicirori changed the title Improve DFlash training: correctness, performance, and config-driven inference Improve DFlash training: performance and config-driven inference Apr 13, 2026
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 35b2d06748

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +222 to +223
if not hasattr(args, "store_last_hidden_states"):
args.store_last_hidden_states = False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Disable Mooncake LHS storage for DFlash by default

_validate_and_configure_dflash never flips args.store_last_hidden_states because it gates on hasattr, but InferenceConfig now always defines this field (default True), so the branch is dead in normal config-driven runs. That means DFlash+SGLang continues storing last_hidden_states even though this path is intended to skip them, adding unnecessary Mooncake traffic/memory pressure and undermining the hotfix/perf goal unless users manually override the config. Consider setting the flag based on value/intent rather than attribute existence.

Useful? React with 👍 / 👎.

@cicirori cicirori force-pushed the feat/dflash-fixes branch 2 times, most recently from dff5a8d to 7acea31 Compare April 13, 2026 18:29
Correctness fixes:
- SglEngine: omit last_hidden_states from DFlash tensor schema, pass
  spec_training_store_last_hidden_states=False to prevent Mooncake
  storage leak (orphaned _lhs keys)
- SGLang patch: add spec_training_store_last_hidden_states flag to
  ServerArgs + guard in _send_hidden_states_to_mooncake()
- train_entry: move draft config parsing + DFlash validation before
  dataset loading for fail-fast on defer_tokenization, backend mismatch,
  and min_loss_tokens misconfiguration
- DataCollator: replace hard ValueError with warnings.warn for missing
  target/last_hidden_states (DFlash doesn't need them)
- trainer_actor: guard set_vocab_buffers with hasattr
- DFlashTrainer: raise TypeError for unsupported config types
- Checkpoint cleanup: move after save (prevent race condition), replace
  ignore_errors with logged warning
- PrefetchedDataFetcher: preserve original traceback on re-raise
- load_hf_dataset: narrow bare except to specific schema errors
- torch.load: add weights_only=True for security

Performance:
- Remove torch.compile graph break in _sample_anchor_positions by
  eliminating .item() call
- Use Python scalars in torch.where (avoid per-step CUDA tensor alloc)
- Fix sglang_qwen3_8b_dflash.yaml: correct GPU allocation (8 GPU),
  switch to REPLICATE strategy

Cleanup:
- Remove dead code: _apply_rotary_pos_emb, get_default_dflash_aux_layer_ids,
  unused forward() params norm_weight/norm_eps
- Add debug logging to silent ImportError in engine/__init__.py
- Add configs/dflash_qwen3_8b_repro.yaml, scripts/bench_dflash_opts.py,
  scripts/eval_dflash.sh

Validated: 62 tests passing, E2E training on 8xB200 (760K PerfectBlend,
3 epochs, 18.8h), gsm8k τ=5.55, math500 τ=3.25.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@cicirori cicirori force-pushed the feat/dflash-fixes branch from 7acea31 to 47b0475 Compare April 13, 2026 18:31
@cicirori cicirori merged commit 2893c2e into main Apr 13, 2026
1 check passed
@cicirori cicirori deleted the feat/dflash-fixes branch April 13, 2026 18:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant