Skip to content

andreidhoang/reasoning-model

Repository files navigation

reasoning_model

Educational purpose: from-scratch PyTorch implementations of Qwen3.5 and Gemma4 following HuggingFace, plus a minimal GRPO (DeepSeek-R1 style) RL training pipeline for math reasoning. No transformers dependency — loads stock HuggingFace safetensors checkpoints directly.

This repo is a learning artifact, not a production framework. It exists to make the full stack — attention kernels, KV cache, sampling, RL loss, reward functions — legible line by line.

Why this exists

  • Learn by building. Treat LLM internals as first-class code you can read, modify, and benchmark — not a black box behind AutoModelForCausalLM.from_pretrained.
  • Minimal dependency tree. Only torch + tokenizers + safetensors + huggingface_hub. No transformers, trl, deepspeed, or vllm at runtime — every layer of the stack is visible in this repo.
  • GRPO for reasoning. ~600 LOC of training code that follows the DeepSeek-R1 recipe directly, with unit tests for the parts that are easy to get wrong (left-padded position_ids, rollout-step → shift-label alignment, k3 KL estimator).

Install

pip install torch safetensors tokenizers huggingface_hub pyarrow matplotlib pytest wandb

CUDA 12+ PyTorch recommended (bf16 autocast). Tested on RTX A5000 (24 GB).

Repo layout

reasoning_model/
├── models/          Qwen3.5 + Gemma4 from-scratch modules (load HF safetensors)
├── configs/         @dataclass configs per model + GRPO config
├── engine/          Batched inference engine: generate, KV cache, ChatML, log-probs
├── training/        GRPO: rollout → reward → advantage → PPO update, + tools
├── cli/             argparse entry points
├── bench/           Benchmarking harness (baseline vs optimized subprocess runner)
├── compat/          Minimal HF-like shims (CausalLMOutputWithPast, DynamicCache)
├── tests/           pytest — unit + integration
├── dev/             Experiment reports (markdown + plots)
└── runs/            Training run artifacts (logs + plots; weights gitignored)

30-second tour

1. Inference

import torch
from reasoning_model import Qwen3_5InferenceEngine

engine = Qwen3_5InferenceEngine.from_pretrained(
    "/path/to/Qwen3.5-0.8B", device="cuda", dtype=torch.bfloat16,
)
print(engine.chat([{"role": "user", "content": "Explain RoPE in one paragraph."}]))

CLI:

python -m reasoning_model.cli.chat --checkpoint /path/to/Qwen3.5-0.8B

2. Benchmark baseline vs optimized kernels

python -m reasoning_model.cli.bench_suite --list
python -m reasoning_model.cli.bench_suite --preset standard

Runs each registered baseline/optimized pair in an isolated subprocess with timeout + profiling, writes JSON + Markdown comparison.

3. GRPO training (DeepSeek-R1 recipe)

# Smoke test (built-in 8-problem toy dataset, ~90s total on A5000):
python -m reasoning_model.cli.train_grpo \
  --checkpoint /path/to/Qwen3.5-0.8B \
  --dataset toy --total-steps 15 --group-size 4 --prompt-batch 2 \
  --max-new-tokens 96 --lr 3e-6 --warmup-steps 3 \
  --out-dir ./runs/grpo-toy

# Real GSM8K run (set expandable_segments to avoid fp32-logits OOM):
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
python -m reasoning_model.cli.train_grpo \
  --checkpoint /path/to/Qwen3.5-0.8B \
  --dataset gsm8k --total-steps 500 --group-size 8 --prompt-batch 4 \
  --max-new-tokens 512 --inner-epochs 4 --micro-bsz 2 \
  --lr 1e-6 --warmup-steps 20 --save-every 50 \
  --out-dir ./runs/grpo-gsm8k

See dev/grpo_training_report.md for live numbers, training curves, and hand-built OOM triage notes.

Results preview

Toy dataset, 15 steps. Per-step rewards bounce around (only 2 prompts × 4 rollouts per step, so lots of per-step variance), but the training-window average moves from steps 1–5 mean reward 0.55 to steps 11–15 mean reward 0.88. Peak per-step reward hit 1.06 (steps 11 and 13). This is essentially memorization of 8 problems, not a generalization result — it's a smoke test.

toy run training curves

GSM8K, 50 steps (killed at first checkpoint for inspection). Held-out eval on 50 random GSM8K problems (seed=0, T=1.0, max_new=512), 3–4 sampling-seed trials per checkpoint:

checkpoint accuracy (4 trials) mean format (4 trials) mean
baseline /root/qwen3.5-0.8b 34 / 32 / 24 / 28 % 29.5 % 52 / 54 / 44 / 52 % 50.5 %
runs/grpo-gsm8k-500/step-50 38 / 40 / 40 / 36 % 38.5 % 64 / 56 / 66 / 66 % 63.0 %
Δ +9 pp +12.5 pp

The distributions do not overlap — every step-50 trial beat every baseline trial on accuracy — so the gain is real at this sample size despite n=50 being noisy.

GSM8K run training curves

Panel read: ratio_mean sits dead on 1.0, kl_mean grows linearly to ~0.005 (far under the 0.1 danger line), eos_frac rises as the model learns to terminate, entropy falls as the policy sharpens, and the lr panel shows the 20-step linear warmup ramp to 1e-6 then barely-started cosine decay.

The GRPO pipeline in one paragraph

For each prompt, sample G completions from the current policy. Score each with rule-based rewards (format: <think>…</think> + \boxed{…}; accuracy: boxed number vs gold). Subtract the group mean and divide by group std to get a per-sequence scalar advantage. Run μ inner PPO epochs with the clipped surrogate -min(ratio·A, clip(ratio, 1±ε)·A) plus a k3 KL penalty β·(r − log r − 1) to a frozen reference copy of the base model. Save as HF-format safetensors that the inference engine can re-load unchanged. No value network, no learned reward model, no trl.

Planning a run

Before kicking off any training, eyeball VRAM + wall-clock:

python -m reasoning_model.training.plan_run \
  --prompt-batch 4 --group-size 8 --max-new-tokens 512 \
  --inner-epochs 4 --micro-bsz 2 --total-steps 500

Prints the 14×P rule VRAM budget (weights + grads + fp32 AdamW + ref), plus the dynamic KV cache, training activations, and the often-forgotten (B_micro, T, V) × 4 × 2 fp32 logits transient (this caused the initial OOM and was the #1 fix in the report). Also prints tokens-per-step and a theoretical + realistic wall-clock per step.

Monitoring

Core metrics logged every step — scan the report for healthy ranges:

Family Keys
reward reward/mean, acc_frac, format_frac, std
advantage adv/mean (≈0), adv/std (≈1)
PPO health ratio_mean (≈1), ratio_clip_frac (~0.05), kl_mean (<0.05)
loss pg_loss, kl_loss, total_loss
rollout shape resp_len/mean, resp_len/p95, eos_frac, entropy
optim grad_norm, lr

Live — stdout (always on) + optional wandb (--wandb-project NAME).

Offline plots from a tee'd log:

python -m reasoning_model.training.plot_log runs/grpo-gsm8k/train.log

Qualitative spot-check at any checkpoint:

python -m reasoning_model.training.inspect_rollouts \
  --checkpoint runs/grpo-gsm8k/step-100 \
  --prompt "What is 7 * 8?" --gold 56 --G 4

Held-out eval on a deterministic GSM8K subset:

python -m reasoning_model.training.eval_checkpoint \
  --checkpoint runs/grpo-gsm8k/step-100 --n 50 --seed 0

Health invariants (the single highest-value debug check)

At step 1, verify:

  • ratio_mean ≈ 1.000 — policy identical to rollout-time policy (nothing updated yet).
  • kl_mean ≈ 0 — reference identical to policy (both loaded from the same checkpoint).
  • adv/mean ≈ 0, adv/std ≈ 1 — group normalization is doing its job.

If any of those are off, there's a bug in left-padding + position_ids, log-prob alignment, or reference loading. Test test_compute_response_logprobs_left_pad_invariance catches the most common one.

Tests

pytest reasoning_model/tests/ -q
# 11 passed

Unit coverage: rewards edge-cases, group normalization, loss signs under ratio/advantage combinations, log-prob shape+grad, left-pad invariance, step→shift coordinate scatter.

Experiment reports

Non-goals (for v1 of the GRPO trainer)

Multi-GPU / DDP, LoRA/PEFT, vLLM rollouts, process reward models, SymPy-based math verifier, async/off-policy replay, HF transformers or trl dependency. Each is a future PR; none blocks v1.

License / provenance

Model implementations follow the public architectures of Qwen3.5 and Gemma4; loader consumes the official HF safetensors checkpoints unchanged. GRPO trainer written from scratch following the DeepSeek-R1 paper.

reasoning-model

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages