Skip to content

tilde-research/nitrobrew-release

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Nitrobrew

Fused, constant-memory KL divergence from hidden states for knowledge distillation. [blog post]

Nitrobrew computes KL divergence between student and teacher softmax distributions without materializing the full [B, T, V] logit tensor. For a 152k vocabulary, this avoids allocating ~2.3 GB per 8k-token batch — the dominant memory bottleneck in on-policy distillation. Supports heterogeneous student/teacher dimensions ($D_s \neq D_t$).

How it works

Standard distillation computes logits z = x @ W.T for both student and teacher, materializing two [B, T, V] tensors, then computes KL between their softmaxes. Nitrobrew fuses the matmul and KL into a single pass that iterates over the vocabulary in chunks of size chunk_V, using online softmax accumulators to maintain numerical stability.

Memory: O(B·T·chunk_V) working set instead of O(B·T·V).

Compute: Same total FLOPs, but work is now more compute-bound.

Math

Given student hidden states $x_s \in \mathbb{R}^{B \times T \times D_s}$, teacher hidden states $x_t \in \mathbb{R}^{B \times T \times D_t}$, student unembed $W_s \in \mathbb{R}^{V \times D_s}$, teacher unembed $W_t \in \mathbb{R}^{V \times D_t}$, and temperature $\tau$:

$$z_s = x_s W_s^\top / \tau, \quad z_t = x_t W_t^\top / \tau$$ $$p = \mathrm{softmax}(z_s), \quad q = \mathrm{softmax}(z_t)$$

Direction Formula Gradient $\partial / \partial z_s(v)$
Forward KL: $\mathrm{KL}(q | p)$ $\sum_v q(v) [\log q(v) - \log p(v)]$ $p(v) - q(v)$
Reverse KL: $\mathrm{KL}(p | q)$ $\sum_v p(v) [\log p(v) - \log q(v)]$ $p(v) [\log p(v) - \log q(v) - \mathrm{KL}]$

Importantly, gradients flow through $x_s$ and $W_s$ only. The teacher parameters are treated as fixed.

Usage

This repo is a barebones vehicle for the implementation. Requires torch >= 2.1 (for torch.compile).

import torch
from nitrobrew import forward_kl, reverse_kl

B, T, D_s, D_t, V = 1, 4096, 3584, 5120, 152064

xs = torch.randn(B, T, D_s, device="cuda", dtype=torch.bfloat16, requires_grad=True)
xt = torch.randn(B, T, D_t, device="cuda", dtype=torch.bfloat16)
ws = torch.randn(V, D_s, device="cuda", dtype=torch.bfloat16, requires_grad=True)
wt = torch.randn(V, D_t, device="cuda", dtype=torch.bfloat16)

# Forward KL: KL(teacher || student)
kl_fwd = forward_kl(xs, xt, ws, wt, temperature=1.0)  

# Reverse KL: KL(student || teacher)
kl_rev = reverse_kl(xs, xt, ws, wt, temperature=1.0) 

# Both support reduction and chunk_V tuning
loss = forward_kl(xs, xt, ws, wt, reduction="mean", chunk_V=2048)
loss.backward() 

Benchmarking

python benchmark.py --vocab 152064  --seq-lens 1024 4096 8192 --d-models 3584 5120  --csv results.csv   

Note: The naive baseline will OOM at the configurations where Nitrobrew is most useful.

License

Apache 2.0

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages