Add LenVM inference timing analysis (baseline vs LenVM-guided)#2
Open
ChangyiYang wants to merge 5 commits into
Open
Add LenVM inference timing analysis (baseline vs LenVM-guided)#2ChangyiYang wants to merge 5 commits into
ChangyiYang wants to merge 5 commits into
Conversation
Instruments Sampler.forward and LvmGuidedSampler.apply with a lightweight Python wall-clock timer that flushes one JSONL record per decoding step when SGLANG_LVM_TIMING_LOG is set, no-op otherwise. scripts/inference/lenvm_timing.sh drives two server lifecycles (baseline, then LenVM in-proc) against the same GSM8K prompt set. inference/timing/ contains the client wrapper and the analysis that emits a CSV table plus stacked-bar plots for the sampler-side and apply()-internal breakdowns. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Reviewer VVLr asked for an "inference FLOPs" comparison (paper currently shows wall-clock implications only). flops.py estimates total FLOPs from the standard 2*N_params forward-pass rule for the base model + (1+k) LenVM forward passes per generated token. analyze.py now emits a second table with GFLOPs/token, total PFLOPs, achieved TFLOPs/s, and ratios. On the 50-question GSM8K run the theoretical FLOPs ratio is 2.17x but the measured wall-clock ratio is 4.59x, so half of the LenVM slowdown comes from GPU underutilization (CPU candidate prep, separate-stream sync) rather than raw extra compute. Also fixes _summarize_responses to dedupe per-question usage (sample_eval writes the request's usage on every choice row, so summing every row inflated total tokens 16x). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the 2*N rule of thumb with a component-level accounting: * per-layer linear matmuls: Q/K/V/O projections (GQA-aware via num_key_value_heads) + SwiGLU MLP (gate, up, down). * per-layer attention compute: 2*H_q*h*seq_len for each of Q@K^T and attn@V, so attention scales with position over the generation trajectory. * lm_head: 2 * hidden_size * vocab_size, charged per token. ModelConfig.load(name_or_path) resolves the HF cache or local model dir for config.json and falls back to hardcoded Qwen2.5 dims when neither is present. Runs are split into prefill (charged once per unique prompt assuming SGLang prefix cache is on) and decode (per sample). LenVM-guided runs add one tree_value_extend + k candidate forwards through the value model per generated token. analyze.py prints three tables (timing / FLOPs headline / FLOPs by component) and emits a new flops_breakdown.png stacked-bar plot. README updated. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Author
|
@codex review |
Two corrections to the FLOPs estimator after PR review feedback: 1. LenVM checkpoints use MLP2SiLUValueHead (d->d Linear + d->1 Linear, see sglang/srt/models/qwen2_lvm.py::MLP2SiLUValueHead), not the base model's lm_head (d * vocab_size). For Qwen2.5-1.5B that's ~4.7M FLOPs per token instead of ~467M, so the previous accounting was overcharging each LenVM forward by ~15%. ModelConfig.head_type is now "lm_head" / "value_head", and ModelConfig.load(...) autodetects by checking for value_head.safetensors or LengthValueModel-style architectures in config.json. 2. lvm_extra_flops gains a candidate_cost_multiplier=1.0 knob. The default matches the current sglang-LenVM in-proc path where each candidate is a separate single-token forward sharing only the extended KV cache. A future implementation that batches k candidates into one forward and amortizes some compute can pass a value < 1.0. Also fix _find_config_json to look under ./models/<name>/ so the analyzer can resolve checkpoints downloaded with `hf download --local-dir`. Re-running on the same 50-q dataset: theoretical LenVM extra drops from 4.48 PFLOPs to 3.82 PFLOPs, ratio drops from 2.30x to 2.10x. Wall-clock ratio is unchanged at 4.55x, so the GPU-utilization gap widens slightly. README gains a TL;DR Results section + caveats covering both knobs. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
inference/timing/sweep_analyze.py reads multiple lenvm_timing.sh result
directories (one per top-k setting) and emits:
- sweep_summary.csv / sweep_summary.json: per-k row with baseline + LenVM
wall clock, achieved TFLOPs/s, theoretical FLOPs ratio, utilization gap,
and LenVM apply / forward latency means.
- topk_sweep.png: theoretical FLOPs ratio vs measured wall-clock ratio vs
achieved TFLOPs/s ratio, plotted against k.
Run on k ∈ {1,2,3,4,5} with the existing 50 q × 16 sample setup:
k | base_s | lvm_s | wall_ratio | flops_ratio | apply_ms
--+--------+-------+------------+-------------+---------
1 | 24.19 | 23.19 | 0.96 | 1.30 | — (greedy fast path, LenVM skipped)
2 | 23.22 | 80.37 | 3.46 | 1.55 | 45.93
3 | 27.21 | 81.30 | 2.99 | 1.74 | 51.15
4 | 22.30 | 83.53 | 3.75 | 1.91 | 46.23
5 | 18.19 | 80.50 | 4.43 | 2.08 | 48.37
Two takeaways from the sweep:
- top-k=1 with temperature=1.0 hits SGLang's is_all_greedy branch in
Sampler.forward and the LenVM apply() hook is never invoked. The paper
config really starts at k=2.
- From k=2 to k=5 the LenVM apply latency is flat (~46-51 ms/step) and the
LenVM forward latency is flat (~32-34 ms/step), so the measured wall-clock
is essentially constant in k while theoretical FLOPs grows linearly. The
bottleneck is the per-step sync/CPU-prep cost, not the candidate-set
matmul, so increasing k inside this range is roughly free.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds an end-to-end + per-decoding-step latency comparison between vanilla SGLang sampling and LenVM-guided sampling, motivated by the LenVM paper review's repeated ask for inference overhead numbers (VVLr / Msb9 / iyxs: wall-clock cost, candidate-set / value-scoring breakdown, FLOPs, batching behavior).
sglang-LenVM/python/sglang/srt/lvm/timing.py: env-gated (SGLANG_LVM_TIMING_LOG) per-step JSONL logger; no-op when unset.Sampler.forward(pre-LVM / LVM apply / sample sections) andLvmGuidedSampler.apply(build_pending / LVM forward / apply_guidance).scripts/inference/lenvm_timing.sh: orchestrator that runs two server lifecycles (baseline → LenVM in-proc) against the same GSM8K prompt set.inference/timing/run_timing.py: wrapssample_evaland records wall-clock + nvidia-smi samples + token-count summary.inference/timing/flops.py: layer-level theoretical FLOPs estimator that reads each model's HFconfig.jsonand counts Q/K/V/O projections (GQA-aware), SwiGLU MLP, position-dependent attention (Q@K^T + attn@V), and the top-of-stack head separately.ModelConfig.head_typedistinguishes the base model'slm_head(2*d*Vvocab projection) from a LenVM checkpoint'sMLP2SiLUValueHead(d*d + d*1); autodetected by looking forvalue_head.safetensorsorLengthValueModel-style architectures next toconfig.json. Runs are split into prefill (charged once per unique prompt; assumes SGLang prefix cache is on) and decode (per sample). LenVM-guided runs add onetree_value_extend+kcandidate forwards per generated token.candidate_cost_multiplierdefaults to1.0(current sglang-LenVM in-proc behavior, where each candidate is a separate single-token forward sharing only the extended KV cache); set lower if a future implementation batches thekcandidates.inference/timing/analyze.py: aggregates both JSONL streams + theoretical FLOPs; emitssummary.csv/summary.json, three stdout tables, and three plots.Results
Setup: 1× H100 SXM, Qwen2.5-7B-Instruct +
lvm-math-0402-a-qwen2.5-7b-instruct-b-qwen2.5-1.5b-instruct, GSM8K 50 questions × 16 samples, max_tokens=6000, T=1.0, top-p=1.0, baseline top-k=-1, LenVM top-k=5 value-scale=0 (paper config).End-to-end + sampler-side per-step
Sampler.forward(ms / step)apply()outerLenVM
apply()internal decomposition (mean ms / step):build_pending(CPU candidate + prefix prep)apply_guidance(CPU value → probs writeback)Layer-level theoretical FLOPs vs measured wall-clock (k=5)
flops.pyreads each model's HFconfig.jsonand counts matmuls layer by layer: Q/K/V/O projections, SwiGLU MLP, position-dependent attention (Q@K^T + attn@V), and head (vocab projection for the base model,MLP2SiLUValueHeadfor the LenVM checkpoint —d*d + d*1, notd*V). LenVM-extra FLOPs = LenVM prefill (per unique prompt) + 1tree_value_extend+ k = 5 candidate forwards per output token (matchingLENVM_TOP_K=5in the run config above).Component decomposition (PFLOPs):
The theoretical FLOPs ratio is 2.10× (each generated token drives k=5 LenVM candidate evaluations + 1 LenVM extend in addition to the base forward; ~83% of the LenVM-only FLOPs are candidate forwards through the 1.5B value model with a tiny
MLP2SiLUValueHead). The measured wall-clock ratio is 4.55×. Roughly half of the LenVM slowdown is genuine extra compute; the other half is GPU underutilization: the base sampler stream stalls on CPU candidate prep, the LenVM forward runs on a separate CUDA stream with sync overhead, and the LenVM forward's effective batch is smaller than the base model's. Achieved TFLOPs/s drops from ~18% of H100 bf16 peak to ~9%.Top-k ablation (k ∈ {1,2,3,4,5})
Same setup as above; only
LENVM_TOP_Kvaries.Two findings from the sweep:
top_k=1andtemperature=1.0SGLang takes theis_all_greedyfast path insideSampler.forwardand never enters the LenVM apply() hook (all 1193 recorded decode steps haveis_greedy=true, lvm_active=false). The reported "k=1" row above is therefore not LenVM with a one-element candidate set; it is the base 7B running greedy decoding while the LenVM weights sit idle on the GPU. Empirical LenVM measurements start at k ≥ 2.LvmGuidedSampler.applyis per-step sync + CPU candidate prep, not the LenVM forward batch. Practically: paper-configured k=5 is no more expensive in wall clock than k=2, so increasing k inside this range is roughly free.Aggregator:
inference/timing/sweep_analyze.pyreads multiple result dirs and emitssweep_summary.csv/json+topk_sweep.png. The cluster-side wrapper that drives the sweep (scripts/_run_timing_topk_sweep.sh) is not in this PR because it bakes in local CUDA paths.0.5B vs 1.5B LenVM ablation
Re-running the same k ∈ {1..5} sweep with the smaller
lvm-a-qwen2.5-7b-instruct-b-qwen2.5-0.5b-instructvalue model to isolate model-size vs sync-overhead contributions.Per-k LenVM wall-clock for both models (k=1 omitted; degenerate greedy path):
Findings:
build_pending+apply_guidance(~15 ms) is unchanged across model sizes, so it dominates a larger share at smaller models.Practical implication: shrinking the value model from 1.5B to 0.5B trades ~37% theoretical compute for ~20% wall-clock — not a great deal. The system is bottlenecked by per-step sync and CPU prep, so the cheapest production setup is "larger value model, modest k" if quality holds.
Reviewer questions this addresses
build_pending= 12.25 ms / step; "value-scoring overhead" →apply_guidance= 5.29 ms / step; "value-scoring frequency" → recorded per-step vialvm_activein*.timing.jsonl.Out of scope for this PR (follow-up notes)
e2e - sampler_total; instrumentingModelRunner.forward_decodeorScheduler.run_batchwould split base / LenVM directly).LENVM_TOP_Kenv knob is wired; runs are not).Test plan
bash scripts/inference/lenvm_timing.shend-to-end smoke (3 questions / n=4 / max_tokens=1500) — done on 1× H100, produces both JSONL streams + plots.2*Nrule (14.24 vs 15.22 GFLOPs/tok baseline; difference is attention compute + lm_head being separated out).d*d + d*1) verified againstsglang/srt/models/qwen2_lvm.py::MLP2SiLUValueHead.SGLANG_LVM_TIMING_LOGunset, confirm timer is a no-op and there is no regression in vanilla SGLang behavior.🤖 Generated with Claude Code