Educational purpose: from-scratch PyTorch implementations of Qwen3.5 and
Gemma4 following HuggingFace, plus a minimal GRPO (DeepSeek-R1 style) RL training
pipeline for math reasoning. No transformers dependency — loads stock
HuggingFace safetensors checkpoints directly.
This repo is a learning artifact, not a production framework. It exists to make the full stack — attention kernels, KV cache, sampling, RL loss, reward functions — legible line by line.
- Learn by building. Treat LLM internals as first-class code you can
read, modify, and benchmark — not a black box behind
AutoModelForCausalLM.from_pretrained. - Minimal dependency tree. Only
torch + tokenizers + safetensors + huggingface_hub. Notransformers,trl,deepspeed, orvllmat runtime — every layer of the stack is visible in this repo. - GRPO for reasoning. ~600 LOC of training code that
follows the DeepSeek-R1 recipe directly, with unit tests for the parts
that are easy to get wrong (left-padded
position_ids, rollout-step → shift-label alignment, k3 KL estimator).
pip install torch safetensors tokenizers huggingface_hub pyarrow matplotlib pytest wandbCUDA 12+ PyTorch recommended (bf16 autocast). Tested on RTX A5000 (24 GB).
reasoning_model/
├── models/ Qwen3.5 + Gemma4 from-scratch modules (load HF safetensors)
├── configs/ @dataclass configs per model + GRPO config
├── engine/ Batched inference engine: generate, KV cache, ChatML, log-probs
├── training/ GRPO: rollout → reward → advantage → PPO update, + tools
├── cli/ argparse entry points
├── bench/ Benchmarking harness (baseline vs optimized subprocess runner)
├── compat/ Minimal HF-like shims (CausalLMOutputWithPast, DynamicCache)
├── tests/ pytest — unit + integration
├── dev/ Experiment reports (markdown + plots)
└── runs/ Training run artifacts (logs + plots; weights gitignored)
import torch
from reasoning_model import Qwen3_5InferenceEngine
engine = Qwen3_5InferenceEngine.from_pretrained(
"/path/to/Qwen3.5-0.8B", device="cuda", dtype=torch.bfloat16,
)
print(engine.chat([{"role": "user", "content": "Explain RoPE in one paragraph."}]))CLI:
python -m reasoning_model.cli.chat --checkpoint /path/to/Qwen3.5-0.8Bpython -m reasoning_model.cli.bench_suite --list
python -m reasoning_model.cli.bench_suite --preset standardRuns each registered baseline/optimized pair in an isolated subprocess with timeout + profiling, writes JSON + Markdown comparison.
# Smoke test (built-in 8-problem toy dataset, ~90s total on A5000):
python -m reasoning_model.cli.train_grpo \
--checkpoint /path/to/Qwen3.5-0.8B \
--dataset toy --total-steps 15 --group-size 4 --prompt-batch 2 \
--max-new-tokens 96 --lr 3e-6 --warmup-steps 3 \
--out-dir ./runs/grpo-toy
# Real GSM8K run (set expandable_segments to avoid fp32-logits OOM):
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
python -m reasoning_model.cli.train_grpo \
--checkpoint /path/to/Qwen3.5-0.8B \
--dataset gsm8k --total-steps 500 --group-size 8 --prompt-batch 4 \
--max-new-tokens 512 --inner-epochs 4 --micro-bsz 2 \
--lr 1e-6 --warmup-steps 20 --save-every 50 \
--out-dir ./runs/grpo-gsm8kSee dev/grpo_training_report.md for live numbers, training curves, and hand-built OOM triage notes.
Toy dataset, 15 steps. Per-step rewards bounce around (only 2 prompts × 4 rollouts per step, so lots of per-step variance), but the training-window average moves from steps 1–5 mean reward 0.55 to steps 11–15 mean reward 0.88. Peak per-step reward hit 1.06 (steps 11 and 13). This is essentially memorization of 8 problems, not a generalization result — it's a smoke test.
GSM8K, 50 steps (killed at first checkpoint for inspection). Held-out
eval on 50 random GSM8K problems (seed=0, T=1.0, max_new=512),
3–4 sampling-seed trials per checkpoint:
| checkpoint | accuracy (4 trials) | mean | format (4 trials) | mean |
|---|---|---|---|---|
baseline /root/qwen3.5-0.8b |
34 / 32 / 24 / 28 % | 29.5 % | 52 / 54 / 44 / 52 % | 50.5 % |
runs/grpo-gsm8k-500/step-50 |
38 / 40 / 40 / 36 % | 38.5 % | 64 / 56 / 66 / 66 % | 63.0 % |
| Δ | +9 pp | +12.5 pp |
The distributions do not overlap — every step-50 trial beat every baseline trial on accuracy — so the gain is real at this sample size despite n=50 being noisy.
Panel read: ratio_mean sits dead on 1.0, kl_mean grows linearly to
~0.005 (far under the 0.1 danger line), eos_frac rises as the model
learns to terminate, entropy falls as the policy sharpens, and the lr
panel shows the 20-step linear warmup ramp to 1e-6 then barely-started
cosine decay.
For each prompt, sample G completions from the current policy. Score each
with rule-based rewards (format: <think>…</think> + \boxed{…}; accuracy:
boxed number vs gold). Subtract the group mean and divide by group std to
get a per-sequence scalar advantage. Run μ inner PPO epochs with the
clipped surrogate -min(ratio·A, clip(ratio, 1±ε)·A) plus a k3 KL penalty
β·(r − log r − 1) to a frozen reference copy of the base model. Save as
HF-format safetensors that the inference engine can re-load unchanged. No
value network, no learned reward model, no trl.
Before kicking off any training, eyeball VRAM + wall-clock:
python -m reasoning_model.training.plan_run \
--prompt-batch 4 --group-size 8 --max-new-tokens 512 \
--inner-epochs 4 --micro-bsz 2 --total-steps 500Prints the 14×P rule VRAM budget (weights + grads + fp32 AdamW + ref),
plus the dynamic KV cache, training activations, and the often-forgotten
(B_micro, T, V) × 4 × 2 fp32 logits transient (this caused the initial
OOM and was the #1 fix in the report). Also prints tokens-per-step and a
theoretical + realistic wall-clock per step.
Core metrics logged every step — scan the report for healthy ranges:
| Family | Keys |
|---|---|
| reward | reward/mean, acc_frac, format_frac, std |
| advantage | adv/mean (≈0), adv/std (≈1) |
| PPO health | ratio_mean (≈1), ratio_clip_frac (~0.05), kl_mean (<0.05) |
| loss | pg_loss, kl_loss, total_loss |
| rollout shape | resp_len/mean, resp_len/p95, eos_frac, entropy |
| optim | grad_norm, lr |
Live — stdout (always on) + optional wandb (--wandb-project NAME).
Offline plots from a tee'd log:
python -m reasoning_model.training.plot_log runs/grpo-gsm8k/train.logQualitative spot-check at any checkpoint:
python -m reasoning_model.training.inspect_rollouts \
--checkpoint runs/grpo-gsm8k/step-100 \
--prompt "What is 7 * 8?" --gold 56 --G 4Held-out eval on a deterministic GSM8K subset:
python -m reasoning_model.training.eval_checkpoint \
--checkpoint runs/grpo-gsm8k/step-100 --n 50 --seed 0At step 1, verify:
ratio_mean ≈ 1.000— policy identical to rollout-time policy (nothing updated yet).kl_mean ≈ 0— reference identical to policy (both loaded from the same checkpoint).adv/mean ≈ 0,adv/std ≈ 1— group normalization is doing its job.
If any of those are off, there's a bug in left-padding + position_ids,
log-prob alignment, or reference loading. Test
test_compute_response_logprobs_left_pad_invariance catches the most
common one.
pytest reasoning_model/tests/ -q
# 11 passedUnit coverage: rewards edge-cases, group normalization, loss signs under ratio/advantage combinations, log-prob shape+grad, left-pad invariance, step→shift coordinate scatter.
- dev/grpo_training_report.md — GRPO pipeline, toy (15 steps) and GSM8K (50 steps) results with embedded training plots, VRAM budget, OOM triage notes, planner calibration.
- dev/inference_time_scaling_report.md — CoT / self-consistency / self-refinement on GSM8K.
Multi-GPU / DDP, LoRA/PEFT, vLLM rollouts, process reward models,
SymPy-based math verifier, async/off-policy replay, HF transformers or
trl dependency. Each is a future PR; none blocks v1.
Model implementations follow the public architectures of Qwen3.5 and Gemma4; loader consumes the official HF safetensors checkpoints unchanged. GRPO trainer written from scratch following the DeepSeek-R1 paper.

