-
Notifications
You must be signed in to change notification settings - Fork 13
Improve DFlash training: performance and config-driven inference #74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| # DFlash reproduction config for Qwen3-8B on 4x B200 | ||
| # | ||
| # GPU allocation (4x B200): | ||
| # - 1 GPU for inference (SGLang tp=1) | ||
| # - 3 GPUs for training (REPLICATE) | ||
| # - global_batch = micro_batch_size × dp_size × accum = 1 × 3 × 2 = 6 | ||
|
|
||
| model: | ||
| target_model_path: Qwen/Qwen3-8B | ||
| trust_remote_code: true | ||
| draft_model_config: torchspec/config/dflash_draft_config.json | ||
|
|
||
| dataset: | ||
| train_data_path: data/perfectblend_800k.jsonl | ||
| eval_data_path: null | ||
| eval_interval: 500 | ||
| chat_template: qwen | ||
| prompt_key: conversations | ||
| min_loss_tokens: 32 | ||
|
|
||
| training: | ||
| attention_backend: flex_attention | ||
| micro_batch_size: 1 | ||
| draft_accumulation_steps: 2 | ||
| learning_rate: 6e-4 | ||
| min_lr: 6e-5 | ||
| weight_decay: 0.01 | ||
| max_concurrent_batches: 1 | ||
| max_grad_norm: 1.0 | ||
| max_seq_length: 2048 | ||
| num_epochs: 3 | ||
| seed: 42 | ||
| training_num_gpus_per_node: 3 | ||
| training_num_nodes: 1 | ||
| ttt_length: 7 | ||
| fsdp_strategy: REPLICATE | ||
| fsdp_reduce_dtype: bfloat16 | ||
| prefetch_depth: 8 | ||
| save_interval: 2000 | ||
| save_per_epoch: true | ||
| max_checkpoints: 3 | ||
| warmup_ratio: 0.04 | ||
|
|
||
| dflash_block_size: 16 | ||
| dflash_num_anchors: 512 | ||
| dflash_loss_decay_gamma: 7.0 | ||
| dflash_num_target_layers: 5 | ||
|
|
||
| inference: | ||
| inference_engine_type: sgl | ||
| store_last_hidden_states: false | ||
| inference_num_gpus: 1 | ||
| inference_num_gpus_per_engine: 1 | ||
| inference_num_gpus_per_node: 4 | ||
| max_sample_pool_size: 128 | ||
| inference_buffer_threshold: 64 | ||
| inference_batch_size: 8 | ||
| sglang: | ||
| tp_size: 1 | ||
| mem_fraction_static: 0.7 | ||
|
|
||
| mooncake: | ||
| master_server_address: null | ||
| metadata_server: null | ||
| protocol: tcp | ||
| global_segment_size: 16GB | ||
| local_buffer_size: 4GB | ||
|
|
||
| output_dir: ./outputs/dflash_repro | ||
| cache_dir: ./cache/dflash_repro | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,10 +5,10 @@ | |
| # - weight_decay=0.01: AdamW regularization for generalization | ||
| # - min_lr=6e-5: prevent LR death in later epochs (10% of peak) | ||
| # | ||
| # GPU allocation (4x H100): | ||
| # - 1 GPU for inference (SGLang engine, tp_size=1) | ||
| # - 3 GPUs for training (FSDP FULL_SHARD) | ||
| # - global_batch = 1 × 4 × 3 = 12 | ||
| # GPU allocation (8x GPU): | ||
| # - 4 GPUs for inference (SGLang engine, tp_size=1, duplicate mode) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. leave original config, but use correct comments |
||
| # - 4 GPUs for training (FSDP FULL_SHARD) | ||
| # - global_batch = micro_batch_size × dp_size × accum = 1 × 4 × 2 = 8 | ||
| # | ||
| # Usage: | ||
| # python -m torchspec.train_entry --config configs/sglang_qwen3_8b_dflash.yaml \ | ||
|
|
@@ -58,6 +58,7 @@ training: | |
|
|
||
| inference: | ||
| inference_engine_type: sgl | ||
| store_last_hidden_states: false | ||
| inference_num_gpus: 4 | ||
| inference_num_gpus_per_engine: 1 | ||
| inference_num_gpus_per_node: 4 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -120,9 +120,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| has_target = all(item.get("target") is not None for item in features) | ||
| has_last_hs = all(item.get("last_hidden_states") is not None for item in features) | ||
| if not has_target and not has_last_hs: | ||
| raise ValueError( | ||
| "Either 'target' or 'last_hidden_states' is required when 'hidden_states' is provided" | ||
| ) | ||
| pass | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Replacing the guard with Useful? React with 👍 / 👎. |
||
| if has_target: | ||
| batch["target"] = torch.cat( | ||
| [self.paddingtensor(item["target"], max_length) for item in features] | ||
|
|
@@ -459,7 +457,14 @@ def load_hf_dataset(data_path: str): | |
| if drop_cols: | ||
| ds = ds.remove_columns(drop_cols) | ||
| return ds | ||
| except Exception: | ||
| except (ValueError, TypeError, ArithmeticError, KeyError) as e: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The new exception filter only catches built-in types, so HF schema-inference failures from Useful? React with 👍 / 👎. |
||
| # Schema inference failures (e.g., mixed-type columns in Arrow/Parquet). | ||
| # Fall back to manual JSON download. | ||
| import logging | ||
|
|
||
| logging.getLogger(__name__).info( | ||
| f"load_dataset failed for '{data_path}' ({e}), falling back to JSON download" | ||
| ) | ||
| return IterableDataset.from_generator( | ||
| _load_hub_json_files, gen_kwargs={"data_path": data_path} | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This repro config claims an 8x B200 layout with 6 training GPUs and 2 inference GPUs (
global_batch = 12), but the actual values configure only 3 training GPUs (training_num_gpus_per_node: 3,training_num_nodes: 1) and 1 inference GPU (inference_num_gpus: 1,tp_size: 1), so the run uses 4 GPUs total and a global batch of 6. That mismatch will underutilize hardware and makes the published reproduction numbers non-reproducible with this file as written.Useful? React with 👍 / 👎.