Skip to content

Non-record: Byte-level transformer + JEPA auxiliary loss (val_bpb: 1.1903)#832

Open
jfprincz wants to merge 1 commit intoopenai:mainfrom
jfprincz:submission/byte-jepa-compression-1.1903
Open

Non-record: Byte-level transformer + JEPA auxiliary loss (val_bpb: 1.1903)#832
jfprincz wants to merge 1 commit intoopenai:mainfrom
jfprincz:submission/byte-jepa-compression-1.1903

Conversation

@jfprincz
Copy link
Contributor

Non-record: Byte-level transformer + JEPA auxiliary loss (val_bpb: 1.1903)

val_bpb: 1.1903 (sliding window, stride=512) | 14.4 MB | 8xH100 SXM, 600s

Byte-level autoregressive transformer (vocab 260, no tokenizer) with a lightweight JEPA auxiliary loss contributing ~0.1% of peak gradient signal. Beats the sp1024 baseline (1.2244) by 0.034 BPB.

Ablation: JEPA contribution

Without JEPA With JEPA Delta
Int6 sliding s512 1.2006 1.1905 -0.0101
Step time 60ms 63ms +3ms
Params 24.2M 24.6M +459K

JEPA adds 0.01 BPB improvement at 5% overhead. The improvement is consistent across seeds and evaluation methods (pre-quant, post-quant, sliding).

Architecture

13-layer byte-level autoregressive transformer (vocab=260, no BPE/SentencePiece). The primary objective is standard next-byte CE loss. A lightweight JEPA module predicts chunk-level latent representations as an auxiliary signal (λ_max=0.001), adding 0.01 BPB over pure AR. Chunk prediction inspired by LeWM.

Component Detail
Backbone 13L, dim=512, 8H/4KV GQA, MLP 2x, LeakyReLU(0.5)², U-Net skips
JEPA projector Linear(512,256) → RMSNorm → SiLU → Linear(256,256)
JEPA predictor 2-layer MLP, 256d, causal shift with learned start token
JEPA injection Linear(256,512), zero-init, adds predicted latents to residual stream
SIGReg Epps-Pulley, 256 projections, 17 knots — prevents latent collapse
Training Phased: 30% pure AR, 50% AR+JEPA ramp, 20% pure AR
Loss CE + λ(MSE_pred + 0.02·SIGReg), λ ramps 0→0.001

Carried from our sp1024 stack: Muon+WD=0.04, EMA 0.997, XSA last 4 layers, Partial RoPE 16 dims, LN Scale, SmearGate, BigramHash(4096,32), OrthoInit+muP, int6+zstd-22, FA3.

Results

Metric Value
Pre-quant val_bpb 1.2293
Int6 roundtrip val_bpb 1.2184
Int6 sliding val_bpb (s512) 1.1905
Steps completed 9,000
Step time 63ms
Model params 24,625,001
Artifact size 14,182,907 bytes

Reproducibility (3 seeds)

Seed Steps Sliding s512 Artifact
2025 9,000 1.1903 14,369,791
42 9,000 1.1905 14,182,907
7 9,000 1.1915 14,445,175

Mean: 1.1908 | Range: 0.0012 | Submitted: seed 2025

Run command

NUM_LAYERS=13 VOCAB_SIZE=260 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2.0 \
TRAIN_SEQ_LEN=4096 TRAIN_BATCH_TOKENS=393216 BIGRAM_VOCAB_SIZE=4096 BIGRAM_DIM=32 \
XSA_LAST_N=4 EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=0 \
ROPE_DIMS=16 LN_SCALE=1 MUON_WD=0.04 ADAM_WD=0.04 \
MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 MUON_MOMENTUM_WARMUP_STEPS=1500 \
WARMDOWN_ITERS=3000 ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 \
EVAL_STRIDE=512 JEPA_CHUNK_SIZE=8 JEPA_LATENT_DIM=256 JEPA_PROJ_HIDDEN=256 \
JEPA_LAMBDA_MAX=0.001 JEPA_SIGREG_WEIGHT=0.02 JEPA_LR=0.001 \
torchrun --standalone --nproc_per_node=8 train_gpt.py

Data

Uses fineweb10B_byte260 — raw UTF-8 bytes tokenized with byte_offset=4 (IDs 4-259 = byte values 0-255). Converted from sp1024 shards via lookup table decode. No SentencePiece dependency at runtime. BPB = loss / ln(2), no tokenizer correction needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant