Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions configs/dflash_qwen3_8b_repro.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# DFlash reproduction config for Qwen3-8B on 8x B200
#
# GPU allocation (8x B200):
# - 2 GPUs for inference (SGLang tp=2)
# - 6 GPUs for training (REPLICATE)
# - global_batch = 1 × 2 × 6 = 12

model:
target_model_path: Qwen/Qwen3-8B
trust_remote_code: true
draft_model_config: torchspec/config/dflash_draft_config.json

dataset:
train_data_path: data/perfectblend_800k.jsonl
eval_data_path: null
eval_interval: 500
chat_template: qwen
prompt_key: conversations
min_loss_tokens: 32

training:
attention_backend: flex_attention
micro_batch_size: 1
draft_accumulation_steps: 2
learning_rate: 6e-4
min_lr: 6e-5
weight_decay: 0.01
max_concurrent_batches: 1
max_grad_norm: 1.0
max_seq_length: 2048
num_epochs: 3
seed: 42
training_num_gpus_per_node: 3
training_num_nodes: 1
ttt_length: 7
fsdp_strategy: REPLICATE
fsdp_reduce_dtype: bfloat16
prefetch_depth: 8
save_interval: 2000
save_per_epoch: true
max_checkpoints: 3
warmup_ratio: 0.04

dflash_block_size: 16
dflash_num_anchors: 512
dflash_loss_decay_gamma: 7.0
dflash_num_target_layers: 5

inference:
inference_engine_type: sgl
inference_num_gpus: 1
inference_num_gpus_per_engine: 1
inference_num_gpus_per_node: 4
max_sample_pool_size: 128
inference_buffer_threshold: 64
inference_batch_size: 8
sglang:
tp_size: 1
mem_fraction_static: 0.7

mooncake:
master_server_address: null
metadata_server: null
protocol: tcp
global_segment_size: 16GB
local_buffer_size: 4GB

output_dir: ./outputs/dflash_repro
cache_dir: ./cache/dflash_repro
84 changes: 84 additions & 0 deletions configs/sglang_qwen3_8b_dflash.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# DFlash training config for Qwen3-8B target model
#
# Convergence-optimized config with WSD schedule recommendations:
# - accum=2: 2x more optimizer steps for better convergence
# - weight_decay=0.01: AdamW regularization for generalization
# - min_lr=6e-5: prevent LR death in later epochs (10% of peak)
#
# GPU allocation (8x GPU):
# - 4 GPUs for inference (SGLang engine, tp_size=1, duplicate mode)
# - 4 GPUs for training (FSDP REPLICATE)
# - global_batch = micro_batch_size × dp_size × accum = 1 × 4 × 2 = 8
#
# Usage:
# python -m torchspec.train_entry --config configs/sglang_qwen3_8b_dflash.yaml \
# dataset.train_data_path=/path/to/data.jsonl output_dir=./outputs/dflash

model:
target_model_path: Qwen/Qwen3-8B
trust_remote_code: true
draft_model_config: torchspec/config/dflash_draft_config.json

dataset:
train_data_path: ../examples/data/sample_conversations.jsonl
eval_data_path: null
eval_interval: 100
chat_template: qwen
prompt_key: conversations
min_loss_tokens: 32

training:
attention_backend: flex_attention
micro_batch_size: 1
draft_accumulation_steps: 2 # was 4 → 2x more optimizer steps
learning_rate: 6e-4
min_lr: 6e-5 # 10% of peak — prevents LR death in later epochs
weight_decay: 0.01 # AdamW regularization for better generalization
max_concurrent_batches: 1
max_grad_norm: 1.0
max_seq_length: 2048
num_epochs: 3
seed: 42
training_num_gpus_per_node: 4
training_num_nodes: 1
ttt_length: 7
fsdp_strategy: REPLICATE
prefetch_depth: 8
save_interval: 1000
save_per_epoch: true
max_checkpoints: 2
warmup_ratio: 0.04

# DFlash-specific parameters
dflash_block_size: 16
dflash_num_anchors: 512
dflash_loss_decay_gamma: 7.0
dflash_num_target_layers: 5

inference:
inference_engine_type: sgl
inference_num_gpus: 4
inference_num_gpus_per_engine: 1
inference_num_gpus_per_node: 4
max_sample_pool_size: 64
inference_buffer_threshold: 32
inference_batch_size: 8
sglang:
tp_size: 1
mem_fraction_static: 0.7

mooncake:
master_server_address: null
metadata_server: null
protocol: tcp
global_segment_size: 16GB
local_buffer_size: 4GB

output_dir: ./outputs/qwen3-8b-dflash
cache_dir: ./cache/qwen3-8b-dflash
model_download_dir: null

debug:
save_debug_train_data: null
debug_train_only: false
debug_inference_only: false
1 change: 1 addition & 0 deletions docker/justfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ build:
release-primary:
ARG_TAG_POSTFIX="${ARG_TAG_POSTFIX:-""}" ARG_BUILD_EXTRA_ARGS="" just _release-raw


_build-only:
#!/bin/bash
set -euxo pipefail
Expand Down
Loading
Loading