Fused, constant-memory KL divergence from hidden states for knowledge distillation. [blog post]
Nitrobrew computes KL divergence between student and teacher softmax distributions without materializing the full [B, T, V] logit tensor. For a 152k vocabulary, this avoids allocating ~2.3 GB per 8k-token batch — the dominant memory bottleneck in on-policy distillation. Supports heterogeneous student/teacher dimensions (
Standard distillation computes logits z = x @ W.T for both student and teacher, materializing two [B, T, V] tensors, then computes KL between their softmaxes. Nitrobrew fuses the matmul and KL into a single pass that iterates over the vocabulary in chunks of size chunk_V, using online softmax accumulators to maintain numerical stability.
Memory: O(B·T·chunk_V) working set instead of O(B·T·V).
Compute: Same total FLOPs, but work is now more compute-bound.
Given student hidden states
| Direction | Formula | Gradient |
|---|---|---|
| Forward KL: |
||
| Reverse KL: |
Importantly, gradients flow through
This repo is a barebones vehicle for the implementation. Requires torch >= 2.1 (for torch.compile).
import torch
from nitrobrew import forward_kl, reverse_kl
B, T, D_s, D_t, V = 1, 4096, 3584, 5120, 152064
xs = torch.randn(B, T, D_s, device="cuda", dtype=torch.bfloat16, requires_grad=True)
xt = torch.randn(B, T, D_t, device="cuda", dtype=torch.bfloat16)
ws = torch.randn(V, D_s, device="cuda", dtype=torch.bfloat16, requires_grad=True)
wt = torch.randn(V, D_t, device="cuda", dtype=torch.bfloat16)
# Forward KL: KL(teacher || student)
kl_fwd = forward_kl(xs, xt, ws, wt, temperature=1.0)
# Reverse KL: KL(student || teacher)
kl_rev = reverse_kl(xs, xt, ws, wt, temperature=1.0)
# Both support reduction and chunk_V tuning
loss = forward_kl(xs, xt, ws, wt, reduction="mean", chunk_V=2048)
loss.backward() python benchmark.py --vocab 152064 --seq-lens 1024 4096 8192 --d-models 3584 5120 --csv results.csv Note: The naive baseline will OOM at the configurations where Nitrobrew is most useful.
Apache 2.0
