Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions records/track_non_record_16mb/2026-03-25-MaskedDiffusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
This is a non-record submission that replaces the autoregressive `train_gpt.py` baseline with a masked diffusion language model implemented in`train_mdlm.py`. The MDLM is from ["Simple and Effective Masked Diffusion Language Models"](https://arxiv.org/pdf/2406.07524), and the code inspired by ["that paper's repo"](https://github.com/kuleshov-group/mdlm)

The model keeps much of the original training stack from the original baseline, but swaps the causal next-token objective for a bidirectional masked denoising objective with iterative sampling. Since the addition of the conditioning weights pushes us over the 16MB limit, I adopt the common int6+int8 mixed quantization, zstd-22 compression strategy from (#) so both models have 9 layers.

## Config
- Tokenizer/data: reuses FineWeb SP-1024, one extra \[MASK\] token added for 1025 vocab size
- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2`; Identical to baseline
- Dropout 0: Arguably less important in diffusion models since they already handle a lot of noise, which acts as regularization in the same way dropout does.
- Attention: bidirectional transformer with GQA-style `NUM_KV_HEADS=4`, with adaLN conditioning
- Conditioning: timestep-conditioned denoiser with reduced internal conditioning width `cond_dim=max(model_dim//4, 64)`
- Batch/sequence defaults: `TRAIN_BATCH_TOKENS=524288 TRAIN_SEQ_LEN=256`. Lower sequence length because it's a bidirectional model
- Sampling defaults: `SAMPLER=ddpm_cache SAMPLING_SCHEDULE=linear SAMPLING_STEPS=256`. N.b. probably lots of fun to be had with the sampling schedule!
- Variational eval: `VAR_EVAL_STEPS=32`. I'm interested in whether using more val steps gives better performance, which would be a kind of test-time compute. Validation takes ~2min on 8XH100. Running at `VAR_EVAL_STEPS=128` gets our upper bound down to 1.59, but at the cost of taking 8 mins to evaluate, which is an invalid submission.

## Metrics
- `val_loss` is the continuous-time SUBS denoising objective used for training.
- `val_var_bpb` is the compression-facing metric for this folder. It is a byte-normalized variational upper bound on NLL obtained by discretizing the same absorbing-mask process at evaluation time.

### Variational BPB

The variational BPB reported here is not apples-to-apples comparable with the validation BPB from the autoregressive models, which means this is a particularly special non-record submission. The variational metric was added under duress because there is no perfect analogy to autoregressive models' losses in the diffusion regime:
- A masked diffusion model does not natively provide an exact autoregressive factorization of `p(x)` token-by-token, so the training loss is not directly convertible to an exact codelength.
- Obtaining the exact codelength for the continuous-time process would require integrating over latent corruption trajectories, which is not tractable in our eval time.
- To make compression more comparable with AR baselines, eval instead reports a discrete absorbing-mask variational bound:
- terminal KL term `KL(q(x_T | x_0) || p(x_T))`
- plus a sum of reverse-process KL terms across `VAR_EVAL_STEPS`
- This is still an upper bound rather than an exact BPB, but it is much more principled than simply converting the denoising loss into BPB units, as if it is analogous to CE.
- This also allows us to measure the impact of discretization on the model as a form of test-time compute by varying `VAR_EVAL_STEPS`, which I note anecdotally has a meaningful impact on the metric - running at `VAR_EVAL_STEPS=128` gets our upper bound down to 1.59, but at the cost of taking 8 mins to evaluate.

Command:
```bash
RUN_ID=baseline_mdlm DATA_PATH=./data/datasets/fineweb10B_sp1024/ TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model VAL_LOSS_EVERY=0 VAR_EVAL_STEPS=32 COND_DIM=128 torchrun --standalone --nproc_per_node=8 records/track_non_record_16mb/2026-03-25-MaskedDiffusion/train_mdlm.py
```

Recommended knobs to play with:
- `TRAIN_SEQ_LEN`: diffusion currently defaults to `512` rather than the AR baseline's `1024` because shorter windows improve throughput and increase the number of independent timestep samples per step, and in theory should not harm diffusion models as much.
- `SAMPLING_STEPS` / `SAMPLING_SCHEDULE`: test-time compute knobs for generation; they do not affect `val_var_bpb` but could be cool to play with and visualize
- `VAR_EVAL_STEPS`: tighter but slower variational evaluation.

Things that really didn't work:
- This is the best method of a bad crop of diffusion LM methods; I have implemented a continuous diffusion model a la DiffusionLM, I may push it later, but it sucks even more
- Don't attempt to tie the weights - the loss spikes, they're not doing a symmetric task.

This folder is a proof-of-concept diffusion adaptation rather than a final tuned submission. With enough work, this could plausibly compete with the very worst autoregressive approaches. I don't care to do that, though, because I don't really find diffusion LMs that cool.

Files in this folder:
- `train_mdlm.py` - single-file masked diffusion training/eval script
- `train.log` - training log on Hyperbolic 8xH100
- `submission.json`
- `README.md`

Metrics:
- best pre-quant `val_loss`: 2.6564 (step:13818/20000)
- best pre-quant `val_var_bpb`: 1.6259
- post-quant roundtrip `val_loss`: 2.6553
- post-quant roundtrip `val_var_bpb`: 1.6252
- final_quant_roundtrip_exact `val_loss`: 2.65530960
- final_quant_roundtrip_exact `val_var_bpb`: 1.62519980
- step time / wallclock: 600032 ms total for 13818 steps (step avg: 43.42 ms)
- compressed artifact size: 15,313,980 bytes (int6+zstd22, payload: 21,537,286, raw_torch: 21,589,365, payload_ratio: 3.91x)
- total submission size int6+zstd22: 15,379,114 bytes

Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"author": "Spruce Campbell",
"github_id": "mtybadger",
"name": "Masked Diffusion",
"blurb": "replaces the autoregressive `train_gpt.py` baseline with a masked diffusion language model implemented in`train_mdlm.py`. The MDLM is from 'Simple and Effective Masked Diffusion Language Models'",
"date": "2026-03-26T06:19:23Z",
"track": "non-record-unlimited-compute-16mb",
"val_loss": 2.65530960,
"val_var_bpb": 1.62519980,
"pre_quant_val_loss": 2.6564,
"pre_quant_val_var_bpb": 1.6259,
"step_stop": 13818,
"wallclock_seconds": 600.032,
"bytes_total": 15379114,
"bytes_model_int6_zstd22": 15313980,
"bytes_code": 65134
}
129 changes: 129 additions & 0 deletions records/track_non_record_16mb/2026-03-25-MaskedDiffusion/train.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
(parameter-golf) ubuntu@g139:~/parameter-golf$ RUN_ID=baseline_mdlm DATA_PATH=./data/datasets/fineweb10B_sp1024/ TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model VAL_LOSS_EVERY=0 VAR_EVAL_STEPS=32 COND_DIM=128 torchrun --standalone --nproc_per_node=8 records/track_non_record_16mb/2026-03-25-MaskedDiffusion/train_mdlm.py
W0326 06:16:51.355000 82012 torch/distributed/run.py:852]
W0326 06:16:51.355000 82012 torch/distributed/run.py:852] *****************************************
W0326 06:16:51.355000 82012 torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0326 06:16:51.355000 82012 torch/distributed/run.py:852] *****************************************
logs/baseline_mdlm.txt
val_var_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model var_eval_steps:32
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:21327617
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:bidirectional_gqa num_heads:8 num_kv_heads:4 cond_dim:128
base_vocab_size:1024 model_vocab_size:1025 mask_index:1024
embed_lr:0.6 head_lr:0.008 matrix_lr:0.04 scalar_lr:0.04
noise:loglinear noise_eps:0.001 sampling_eps:0.001 sampler:ddpm_cache sampling_schedule:linear sampling_steps:256 var_eval_steps:32 time_conditioning:1 antithetic:1
export:int6_weights:1 qat:1 use_zstd:1 zstd_level:22
train_batch_tokens:524288 train_seq_len:512 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:1/20000 train_loss:6.8238 train_time:94ms step_avg:93.79ms
step:2/20000 train_loss:7.3788 train_time:118ms step_avg:58.87ms
step:3/20000 train_loss:9.4127 train_time:161ms step_avg:53.62ms
step:4/20000 train_loss:6.3292 train_time:204ms step_avg:51.08ms
step:5/20000 train_loss:7.1945 train_time:248ms step_avg:49.54ms
step:6/20000 train_loss:7.1751 train_time:291ms step_avg:48.49ms
step:7/20000 train_loss:6.8156 train_time:334ms step_avg:47.76ms
step:8/20000 train_loss:6.5614 train_time:378ms step_avg:47.24ms
step:9/20000 train_loss:6.4500 train_time:421ms step_avg:46.82ms
step:10/20000 train_loss:6.4149 train_time:465ms step_avg:46.47ms
step:200/20000 train_loss:3.7389 train_time:8730ms step_avg:43.65ms
step:400/20000 train_loss:3.2580 train_time:17434ms step_avg:43.58ms
step:600/20000 train_loss:3.3925 train_time:26132ms step_avg:43.55ms
step:800/20000 train_loss:3.2095 train_time:34833ms step_avg:43.54ms
step:1000/20000 train_loss:3.1363 train_time:43529ms step_avg:43.53ms
step:1200/20000 train_loss:3.2055 train_time:52222ms step_avg:43.52ms
step:1400/20000 train_loss:3.1691 train_time:60913ms step_avg:43.51ms
step:1600/20000 train_loss:2.9069 train_time:69598ms step_avg:43.50ms
step:1800/20000 train_loss:2.9766 train_time:78278ms step_avg:43.49ms
step:2000/20000 train_loss:2.9621 train_time:86954ms step_avg:43.48ms
step:2200/20000 train_loss:2.8625 train_time:95635ms step_avg:43.47ms
step:2400/20000 train_loss:2.8688 train_time:104317ms step_avg:43.47ms
step:2600/20000 train_loss:2.9924 train_time:112998ms step_avg:43.46ms
step:2800/20000 train_loss:2.9578 train_time:121691ms step_avg:43.46ms
step:3000/20000 train_loss:2.9934 train_time:130369ms step_avg:43.46ms
step:3200/20000 train_loss:2.8874 train_time:139050ms step_avg:43.45ms
step:3400/20000 train_loss:2.8836 train_time:147731ms step_avg:43.45ms
step:3600/20000 train_loss:2.9650 train_time:156419ms step_avg:43.45ms
step:3800/20000 train_loss:2.8891 train_time:165108ms step_avg:43.45ms
step:4000/20000 train_loss:2.8482 train_time:173795ms step_avg:43.45ms
step:4200/20000 train_loss:2.8927 train_time:182530ms step_avg:43.46ms
step:4400/20000 train_loss:2.8284 train_time:191217ms step_avg:43.46ms
step:4600/20000 train_loss:2.8546 train_time:199904ms step_avg:43.46ms
step:4800/20000 train_loss:2.8326 train_time:208586ms step_avg:43.46ms
step:5000/20000 train_loss:2.8753 train_time:217280ms step_avg:43.46ms
step:5200/20000 train_loss:2.7950 train_time:225966ms step_avg:43.45ms
step:5400/20000 train_loss:2.8134 train_time:234649ms step_avg:43.45ms
step:5600/20000 train_loss:2.9530 train_time:243337ms step_avg:43.45ms
step:5800/20000 train_loss:2.8332 train_time:252023ms step_avg:43.45ms
step:6000/20000 train_loss:2.7806 train_time:260707ms step_avg:43.45ms
step:6200/20000 train_loss:2.9239 train_time:269394ms step_avg:43.45ms
step:6400/20000 train_loss:2.7349 train_time:278086ms step_avg:43.45ms
step:6600/20000 train_loss:2.8448 train_time:286782ms step_avg:43.45ms
step:6800/20000 train_loss:2.5784 train_time:295475ms step_avg:43.45ms
step:7000/20000 train_loss:2.7528 train_time:304158ms step_avg:43.45ms
step:7200/20000 train_loss:2.8779 train_time:312834ms step_avg:43.45ms
step:7400/20000 train_loss:2.6941 train_time:321509ms step_avg:43.45ms
step:7600/20000 train_loss:2.5846 train_time:330186ms step_avg:43.45ms
step:7800/20000 train_loss:2.7477 train_time:338900ms step_avg:43.45ms
step:8000/20000 train_loss:2.6850 train_time:347592ms step_avg:43.45ms
step:8200/20000 train_loss:2.7181 train_time:356279ms step_avg:43.45ms
step:8400/20000 train_loss:2.6815 train_time:365007ms step_avg:43.45ms
step:8600/20000 train_loss:2.8039 train_time:373679ms step_avg:43.45ms
step:8800/20000 train_loss:2.7745 train_time:382356ms step_avg:43.45ms
step:9000/20000 train_loss:2.6659 train_time:391032ms step_avg:43.45ms
step:9200/20000 train_loss:2.6866 train_time:399712ms step_avg:43.45ms
step:9400/20000 train_loss:2.7190 train_time:408385ms step_avg:43.45ms
step:9600/20000 train_loss:2.9681 train_time:417057ms step_avg:43.44ms
step:9800/20000 train_loss:2.7854 train_time:425727ms step_avg:43.44ms
step:10000/20000 train_loss:2.6448 train_time:434402ms step_avg:43.44ms
step:10200/20000 train_loss:2.8168 train_time:443076ms step_avg:43.44ms
step:10400/20000 train_loss:2.5484 train_time:451751ms step_avg:43.44ms
step:10600/20000 train_loss:2.6880 train_time:460430ms step_avg:43.44ms
step:10800/20000 train_loss:2.7680 train_time:469095ms step_avg:43.43ms
step:11000/20000 train_loss:2.7202 train_time:477753ms step_avg:43.43ms
step:11200/20000 train_loss:2.7286 train_time:486418ms step_avg:43.43ms
step:11400/20000 train_loss:2.6380 train_time:495095ms step_avg:43.43ms
step:11600/20000 train_loss:2.6872 train_time:503777ms step_avg:43.43ms
step:11800/20000 train_loss:2.7488 train_time:512453ms step_avg:43.43ms
step:12000/20000 train_loss:2.7631 train_time:521120ms step_avg:43.43ms
step:12200/20000 train_loss:2.7447 train_time:529792ms step_avg:43.43ms
step:12400/20000 train_loss:2.6834 train_time:538514ms step_avg:43.43ms
step:12600/20000 train_loss:2.6894 train_time:547192ms step_avg:43.43ms
step:12800/20000 train_loss:2.6447 train_time:555870ms step_avg:43.43ms
step:13000/20000 train_loss:2.7812 train_time:564563ms step_avg:43.43ms
step:13200/20000 train_loss:2.6591 train_time:573233ms step_avg:43.43ms
step:13400/20000 train_loss:2.5699 train_time:581899ms step_avg:43.43ms
step:13600/20000 train_loss:2.6995 train_time:590568ms step_avg:43.42ms
step:13800/20000 train_loss:2.5933 train_time:599237ms step_avg:43.42ms

step:13818/20000 val_loss:2.6564 val_var_bpb:1.6259 train_time:600032ms step_avg:43.42ms
stopping_early: wallclock_cap train_time:600032ms step:13818/20000
peak memory allocated: 9681 MiB reserved: 10488 MiB
Serialized model: 84298718 bytes
Code size: 65134 bytes
Total submission size: 84363852 bytes
Serialized model int6+zstd22: 15313980 bytes (payload:21537286 raw_torch:21589365 payload_ratio:3.91x)
Total submission size int6+zstd22: 15379114 bytes
final_quant_roundtrip val_loss:2.6553 val_var_bpb:1.6252 eval_time:132549ms
final_quant_roundtrip_exact val_loss:2.65530960 val_var_bpb:1.62519980
Loading