From f26dbfdf519bae615371872a675417210b27e8fe Mon Sep 17 00:00:00 2001 From: changyiyang Date: Fri, 22 May 2026 20:52:29 +0000 Subject: [PATCH 1/5] Add LenVM inference timing analysis Instruments Sampler.forward and LvmGuidedSampler.apply with a lightweight Python wall-clock timer that flushes one JSONL record per decoding step when SGLANG_LVM_TIMING_LOG is set, no-op otherwise. scripts/inference/lenvm_timing.sh drives two server lifecycles (baseline, then LenVM in-proc) against the same GSM8K prompt set. inference/timing/ contains the client wrapper and the analysis that emits a CSV table plus stacked-bar plots for the sampler-side and apply()-internal breakdowns. Co-Authored-By: Claude Opus 4.7 (1M context) --- inference/timing/README.md | 76 ++++++ inference/timing/__init__.py | 0 inference/timing/analyze.py | 254 ++++++++++++++++++ inference/timing/run_timing.py | 237 ++++++++++++++++ scripts/inference/lenvm_timing.sh | 166 ++++++++++++ .../python/sglang/srt/layers/sampler.py | 21 +- .../sglang/srt/lvm/lvm_guided_sampling.py | 13 + sglang-LenVM/python/sglang/srt/lvm/timing.py | 75 ++++++ 8 files changed, 841 insertions(+), 1 deletion(-) create mode 100644 inference/timing/README.md create mode 100755 inference/timing/__init__.py create mode 100644 inference/timing/analyze.py create mode 100644 inference/timing/run_timing.py create mode 100755 scripts/inference/lenvm_timing.sh create mode 100644 sglang-LenVM/python/sglang/srt/lvm/timing.py diff --git a/inference/timing/README.md b/inference/timing/README.md new file mode 100644 index 0000000..eb201db --- /dev/null +++ b/inference/timing/README.md @@ -0,0 +1,76 @@ +# LenVM inference timing analysis + +End-to-end and per-decoding-step latency decomposition for the LenVM-guided +sampling path, compared against an otherwise-identical baseline SGLang server. + +This exists to answer the question raised in the LenVM paper review: how much +wall-clock overhead does LenVM-guided decoding add on top of vanilla decoding, +and where does that overhead live inside each decoding step? + +## What gets measured + +Two server lifecycles drive the comparison, so per-step instrumentation only +captures the configuration under test: + +1. **baseline** — SGLang with `--enable-lvm-guided-sampling` off. Vanilla + chat-completion sampling against `Qwen/Qwen2.5-7B-Instruct`. +2. **lenvm** — Same base model, plus the in-process LenVM value model + (`./models/namezz/lvm-math-0402-a-qwen2.5-7b-instruct-b-qwen2.5-1.5b-instruct` + by default), with `--enable-lvm-guided-sampling --lvm-guided-inproc + --lvm-guided-fn lvm_combined_guidance`. + +Both servers point `SGLANG_LVM_TIMING_LOG` at their own JSONL file. The client +hits each server in turn with the same GSM8K prompt set (50 questions × 16 +samples by default), recording end-to-end wall clock per run. + +The instrumentation lives in `sglang-LenVM/python/sglang/srt/lvm/timing.py` and +hooks `Sampler.forward` plus `LvmGuidedSampler.apply`. Per decoding step it +emits one JSONL line with: + +- `t_sampler_total_ms` — total time inside `Sampler.forward` +- `t_pre_lvm_ms` — preprocess + temperature scaling + softmax +- `t_lvm_apply_outer_ms` — full `LvmGuidedSampler.apply` call + - `t_lvm_build_pending_ms` — gather candidates & request state + - `t_lvm_forward_ms` — LenVM extend + launch + collect + - `t_lvm_apply_guidance_ms` — apply value-based adjustment to probs +- `t_sample_ms` — sampling kernel +- `lvm_active`, `batch_size`, `is_greedy` + +## Running it + +```bash +# from repository root, with .venv-infer and .venv-eval already built +bash scripts/inference/lenvm_timing.sh +``` + +Overridable knobs (env vars; see top of the script for defaults): +`BASE_MODEL`, `LENVM_MODEL`, `MAX_QUESTIONS`, `N_SAMPLES`, `MAX_TOKENS`, +`MAX_CONCURRENCY`, `LENVM_TOP_K`, `LENVM_VALUE_SCALE`, `RESULTS_DIR`. + +The script chains three stages: + +1. Start baseline server → run `inference.timing.run_timing` → kill server +2. Start LenVM server → run `inference.timing.run_timing` → kill server +3. `inference.timing.analyze` reads both JSONL streams + meta files and writes + a `summary.csv`, `summary.json`, and two plots into `RESULTS_DIR`. + +## Outputs (under `RESULTS_DIR`) + +- `baseline.timing.jsonl`, `lenvm.timing.jsonl` — per-step records +- `baseline.meta.json`, `lenvm.meta.json` — wall-clock, token counts, cmdline +- `baseline.gpu_samples.csv`, `lenvm.gpu_samples.csv` — `nvidia-smi` 1 Hz log +- `summary.csv`, `summary.json` — aggregated table +- `per_step_breakdown.png` — stacked bar of sampler-side decomposition +- `lvm_apply_breakdown.png` — LenVM `apply()` internal breakdown + +## Caveats + +- The per-step timer adds Python-level `time.perf_counter()` calls on the + decoding hot path. They are no-ops when `SGLANG_LVM_TIMING_LOG` is unset. +- `t_lvm_apply_outer_ms` covers a few short helpers (e.g. + `_get_inproc_provider`, `clean_stale_requests`) not separately broken out; + at production batch sizes the residual is small but visible at low load. +- Single-rank only. For TP/DP > 1 the timer writes from one rank; extend the + log filename with the rank suffix if you need per-worker traces. +- Wall-clock comparisons assume both servers see the same prompts at the same + concurrency. Run them back-to-back on a quiet GPU. diff --git a/inference/timing/__init__.py b/inference/timing/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/inference/timing/analyze.py b/inference/timing/analyze.py new file mode 100644 index 0000000..2b01e32 --- /dev/null +++ b/inference/timing/analyze.py @@ -0,0 +1,254 @@ +"""Compare baseline vs LenVM-guided timing runs produced by lenvm_timing.sh. + +Inputs (from --results-dir): + baseline.timing.jsonl per-step records from the no-LenVM server + lenvm.timing.jsonl per-step records from the LenVM-enabled server + baseline.meta.json end-to-end wall clock + token counts (run_timing.py) + lenvm.meta.json same, for LenVM run + +Outputs (in --results-dir): + summary.csv one row per setting with e2e + per-step aggregates + summary.json same as CSV but JSON + per_step_breakdown.png stacked bar of per-step decomposition + +The per-step decomposition is what reviewers asked for: how much of a decoding +step is base-model forward (inferred from "step not under sampler"), how much +is sampler-side preprocess, how much is the LenVM forward + guidance overlay. +""" + +from __future__ import annotations + +import argparse +import csv +import json +import math +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + + +_PER_STEP_KEYS = [ + "t_sampler_total_ms", + "t_pre_lvm_ms", + "t_lvm_apply_outer_ms", + "t_lvm_build_pending_ms", + "t_lvm_forward_ms", + "t_lvm_apply_guidance_ms", + "t_sample_ms", +] + + +def _iter_records(path: Path) -> Iterable[Dict[str, Any]]: + if not path.exists(): + return [] + out: List[Dict[str, Any]] = [] + with path.open() as f: + for line in f: + line = line.strip() + if not line: + continue + try: + out.append(json.loads(line)) + except json.JSONDecodeError: + continue + return out + + +def _percentile(values: List[float], p: float) -> float: + if not values: + return float("nan") + values = sorted(values) + k = (len(values) - 1) * p + f = math.floor(k) + c = math.ceil(k) + if f == c: + return values[int(k)] + return values[f] + (values[c] - values[f]) * (k - f) + + +def _filter_warmup(records: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Drop the small-batch greedy warmup steps SGLang emits at server start + (POST /generate handshake with a single token, bs=1, is_greedy=True).""" + return [r for r in records if not r.get("is_greedy", False)] + + +def _agg(records: List[Dict[str, Any]]) -> Dict[str, Any]: + """Aggregate per-step records into mean/p50/p95/sum per key.""" + out: Dict[str, Any] = {"n_steps": len(records)} + for key in _PER_STEP_KEYS: + vals: List[float] = [] + for r in records: + v = r.get(key) + if isinstance(v, (int, float)): + vals.append(float(v)) + if not vals: + continue + out[f"{key}_mean"] = sum(vals) / len(vals) + out[f"{key}_p50"] = _percentile(vals, 0.5) + out[f"{key}_p95"] = _percentile(vals, 0.95) + out[f"{key}_sum"] = sum(vals) + out[f"{key}_count"] = len(vals) + batch_sizes = [int(r["batch_size"]) for r in records if isinstance(r.get("batch_size"), int)] + if batch_sizes: + out["batch_size_mean"] = sum(batch_sizes) / len(batch_sizes) + out["batch_size_p50"] = _percentile(batch_sizes, 0.5) + out["batch_size_p95"] = _percentile(batch_sizes, 0.95) + lvm_active = sum(1 for r in records if r.get("lvm_active")) + out["n_steps_with_lvm"] = lvm_active + return out + + +def _load_meta(path: Path) -> Dict[str, Any]: + if not path.exists(): + return {} + return json.loads(path.read_text()) + + +def _row_for(tag: str, meta: Dict[str, Any], agg: Dict[str, Any]) -> Dict[str, Any]: + row: Dict[str, Any] = { + "tag": tag, + "wall_clock_s": meta.get("wall_clock_s"), + "total_output_tokens": (meta.get("summary") or {}).get("total_output_tokens"), + "throughput_output_tokens_per_s": meta.get("throughput_output_tokens_per_s"), + "n_requests": (meta.get("summary") or {}).get("n_requests"), + "output_tokens_mean": (meta.get("summary") or {}).get("output_tokens_mean"), + "output_tokens_p95": (meta.get("summary") or {}).get("output_tokens_p95"), + "latency_s_mean": (meta.get("summary") or {}).get("latency_s_mean"), + "latency_s_p95": (meta.get("summary") or {}).get("latency_s_p95"), + } + row.update(agg) + return row + + +def _stacked_bar(rows: List[Dict[str, Any]], out_png: Path) -> Optional[Path]: + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except Exception as e: # pragma: no cover + print(f"matplotlib unavailable, skipping plot: {e}") + return None + + labels = [r["tag"] for r in rows] + sections = ["t_pre_lvm_ms_mean", "t_lvm_apply_outer_ms_mean", "t_sample_ms_mean"] + legend_names = ["pre-LVM (preprocess+softmax)", "LenVM apply (forward+guidance)", "sample kernel"] + values_per_section = [[float(r.get(s, 0.0) or 0.0) for r in rows] for s in sections] + + fig, ax = plt.subplots(figsize=(7, 5)) + bottoms = [0.0] * len(labels) + for sec_vals, name in zip(values_per_section, legend_names): + ax.bar(labels, sec_vals, bottom=bottoms, label=name) + bottoms = [b + v for b, v in zip(bottoms, sec_vals)] + + for i, r in enumerate(rows): + total = bottoms[i] + ax.text(i, total, f"{total:.2f} ms", ha="center", va="bottom") + + ax.set_ylabel("Mean per-step latency inside Sampler.forward (ms)") + ax.set_title("LenVM vs baseline: sampler-side per-step decomposition") + ax.legend(loc="upper left", fontsize="small") + fig.tight_layout() + fig.savefig(out_png, dpi=140) + plt.close(fig) + return out_png + + +def _lvm_sub_breakdown(rows: List[Dict[str, Any]], out_png: Path) -> Optional[Path]: + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except Exception as e: # pragma: no cover + print(f"matplotlib unavailable, skipping plot: {e}") + return None + + lenvm_row = next((r for r in rows if r.get("n_steps_with_lvm", 0) > 0), None) + if lenvm_row is None: + return None + + sub_keys = ["t_lvm_build_pending_ms_mean", "t_lvm_forward_ms_mean", "t_lvm_apply_guidance_ms_mean"] + sub_names = ["build_pending", "LenVM forward (extend+launch+collect)", "apply_guidance"] + sub_vals = [float(lenvm_row.get(k, 0.0) or 0.0) for k in sub_keys] + + fig, ax = plt.subplots(figsize=(6, 4)) + ax.bar(sub_names, sub_vals) + for i, v in enumerate(sub_vals): + ax.text(i, v, f"{v:.2f} ms", ha="center", va="bottom") + ax.set_ylabel("Mean per-step (ms)") + ax.set_title("LenVM apply() internal breakdown") + fig.tight_layout() + fig.savefig(out_png, dpi=140) + plt.close(fig) + return out_png + + +def _print_table(rows: List[Dict[str, Any]]) -> None: + cols = [ + ("tag", "tag"), + ("wall_clock_s", "e2e_s"), + ("throughput_output_tokens_per_s", "tok/s"), + ("output_tokens_mean", "out_tok_mean"), + ("n_steps", "n_steps"), + ("t_sampler_total_ms_mean", "samp_total_ms"), + ("t_pre_lvm_ms_mean", "pre_ms"), + ("t_lvm_apply_outer_ms_mean", "lvm_apply_ms"), + ("t_sample_ms_mean", "sample_ms"), + ] + widths = {k: max(len(label), max(len(_fmt(r.get(k))) for r in rows)) for k, label in cols} + header = " | ".join(f"{label:>{widths[k]}}" for k, label in cols) + print(header) + print("-+-".join("-" * widths[k] for k, _ in cols)) + for r in rows: + print(" | ".join(f"{_fmt(r.get(k)):>{widths[k]}}" for k, _ in cols)) + + +def _fmt(v: Any) -> str: + if v is None: + return "" + if isinstance(v, float): + return f"{v:.2f}" + return str(v) + + +def main() -> int: + p = argparse.ArgumentParser() + p.add_argument("--results-dir", type=Path, required=True) + p.add_argument("--baseline-tag", default="baseline") + p.add_argument("--lenvm-tag", default="lenvm") + args = p.parse_args() + + rd = args.results_dir + rows: List[Dict[str, Any]] = [] + for tag in (args.baseline_tag, args.lenvm_tag): + meta = _load_meta(rd / f"{tag}.meta.json") + records = _filter_warmup(list(_iter_records(rd / f"{tag}.timing.jsonl"))) + agg = _agg(records) + rows.append(_row_for(tag, meta, agg)) + + summary_path = rd / "summary.json" + summary_path.write_text(json.dumps(rows, indent=2)) + + csv_path = rd / "summary.csv" + if rows: + keys = sorted({k for r in rows for k in r.keys()}) + with csv_path.open("w", newline="") as f: + w = csv.DictWriter(f, fieldnames=keys) + w.writeheader() + for r in rows: + w.writerow({k: r.get(k) for k in keys}) + + plot_path = _stacked_bar(rows, rd / "per_step_breakdown.png") + sub_plot_path = _lvm_sub_breakdown(rows, rd / "lvm_apply_breakdown.png") + + print(f"summary.json -> {summary_path}") + print(f"summary.csv -> {csv_path}") + if plot_path: + print(f"plot -> {plot_path}") + if sub_plot_path: + print(f"lvm sub-plot -> {sub_plot_path}") + print() + _print_table(rows) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/inference/timing/run_timing.py b/inference/timing/run_timing.py new file mode 100644 index 0000000..311d296 --- /dev/null +++ b/inference/timing/run_timing.py @@ -0,0 +1,237 @@ +"""Run a single sample_eval pass and capture wall-clock + token counts. + +Thin wrapper over inference.tradeoff.sample_eval that adds: +- end-to-end wall-clock timing (perf_counter) +- background nvidia-smi sampling at 1 Hz +- aggregate token counts from the run's responses jsonl +- emits /.meta.json with the above + +The server-side per-step timing (SGLANG_LVM_TIMING_LOG) is written by the +sglang process itself; this script does not touch that file. +""" + +from __future__ import annotations + +import argparse +import json +import shutil +import subprocess +import sys +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + + +def _build_sample_eval_argv(args: argparse.Namespace, tag: str) -> List[str]: + cmd: List[str] = [ + sys.executable, + "-m", + "inference.tradeoff.sample_eval", + "--dataset-name", args.dataset_name, + "--server-url", args.server_url, + "--output-dir", str(args.output_dir), + "--tag", tag, + "--stage", "run", + "--max-questions", str(args.max_questions), + "--max-concurrency", str(args.max_concurrency), + "--request-timeout", str(args.request_timeout), + "--max-tokens", str(args.max_tokens), + "--temperature", str(args.temperature), + "--top-p", str(args.top_p), + "--top-k", str(args.top_k), + "--min-p", str(args.min_p), + "--n", str(args.n), + "--http-backend", args.http_backend, + ] + if args.value_scale is not None: + cmd += ["--value-scale", str(args.value_scale)] + if args.value_mode is not None: + cmd += ["--value-mode", args.value_mode] + if args.value_gamma is not None: + cmd += ["--value-gamma", str(args.value_gamma)] + return cmd + + +@dataclass +class GpuSample: + t_offset_s: float + gpu_util_pct: List[int] + mem_used_mib: List[int] + + +def _gpu_sampler(stop_path: Path, output_path: Path, period_s: float = 1.0) -> None: + """nvidia-smi sampler subprocess body (invoked via -c).""" + # Not used; we run nvidia-smi from the parent. + raise SystemExit(0) + + +def _start_gpu_sampler(samples_out: Path) -> Optional[subprocess.Popen]: + if not shutil.which("nvidia-smi"): + return None + # nvidia-smi --query streams every s and we tee to file. + # Format: timestamp,index,utilization.gpu,memory.used + fh = open(samples_out, "w") + proc = subprocess.Popen( + [ + "nvidia-smi", + "--query-gpu=timestamp,index,utilization.gpu,memory.used", + "--format=csv,nounits", + "-lms", "1000", + ], + stdout=fh, + stderr=subprocess.DEVNULL, + ) + proc._log_fh = fh # type: ignore[attr-defined] + return proc + + +def _stop_gpu_sampler(proc: Optional[subprocess.Popen]) -> None: + if proc is None: + return + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + fh = getattr(proc, "_log_fh", None) + if fh is not None: + fh.close() + + +def _summarize_responses(responses_jsonl: Path) -> Dict[str, Any]: + """Aggregate token counts and per-question latency from sample_eval output.""" + n_requests = 0 + total_output_tokens = 0 + total_prompt_tokens = 0 + output_token_counts: List[int] = [] + latencies: List[float] = [] + with responses_jsonl.open() as f: + for line in f: + try: + row = json.loads(line) + except json.JSONDecodeError: + continue + n_requests += 1 + usage = row.get("usage") or {} + out_tok = usage.get("completion_tokens") or usage.get("output_tokens") or 0 + in_tok = usage.get("prompt_tokens") or usage.get("input_tokens") or 0 + total_output_tokens += int(out_tok or 0) + total_prompt_tokens += int(in_tok or 0) + if out_tok: + output_token_counts.append(int(out_tok)) + lat = row.get("latency_s") or row.get("elapsed_s") + if isinstance(lat, (int, float)): + latencies.append(float(lat)) + summary: Dict[str, Any] = { + "n_requests": n_requests, + "total_output_tokens": total_output_tokens, + "total_prompt_tokens": total_prompt_tokens, + } + if output_token_counts: + output_token_counts.sort() + n = len(output_token_counts) + summary["output_tokens_mean"] = sum(output_token_counts) / n + summary["output_tokens_p50"] = output_token_counts[n // 2] + summary["output_tokens_p95"] = output_token_counts[int(n * 0.95)] + summary["output_tokens_max"] = output_token_counts[-1] + if latencies: + latencies.sort() + n = len(latencies) + summary["latency_s_mean"] = sum(latencies) / n + summary["latency_s_p50"] = latencies[n // 2] + summary["latency_s_p95"] = latencies[int(n * 0.95)] + return summary + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + p.add_argument("--tag", required=True, help="Output prefix (e.g. baseline, lenvm)") + p.add_argument("--server-url", required=True) + p.add_argument("--dataset-name", default="gsm8k") + p.add_argument("--output-dir", type=Path, required=True) + p.add_argument("--max-questions", type=int, default=50) + p.add_argument("--max-concurrency", type=int, default=50) + p.add_argument("--n", type=int, default=16) + p.add_argument("--max-tokens", type=int, default=6000) + p.add_argument("--temperature", type=float, default=1.0) + p.add_argument("--top-p", type=float, default=1.0) + p.add_argument("--top-k", type=int, default=-1) + p.add_argument("--min-p", type=float, default=0.01) + p.add_argument("--request-timeout", type=float, default=600000) + p.add_argument("--http-backend", default="aiohttp") + p.add_argument("--value-scale", type=float, default=None) + p.add_argument("--value-mode", default=None) + p.add_argument("--value-gamma", type=float, default=None) + return p.parse_args() + + +def main() -> int: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + # sample_eval writes responses jsonl using the tag in its own canonical + # naming; we reuse the same tag so the file is predictable. + sample_eval_tag = ( + f"{args.tag}_q{args.max_questions}_n{args.n}_p{args.top_p}_" + f"topk{args.top_k}_minp{args.min_p}" + ) + + cmd = _build_sample_eval_argv(args, tag=sample_eval_tag) + + gpu_samples_path = args.output_dir / f"{args.tag}.gpu_samples.csv" + gpu_proc = _start_gpu_sampler(gpu_samples_path) + + t_start = time.perf_counter() + t_wall_start = time.time() + try: + rc = subprocess.call(cmd) + finally: + _stop_gpu_sampler(gpu_proc) + t_end = time.perf_counter() + t_wall_end = time.time() + + if rc != 0: + print(f"sample_eval exited with rc={rc}", file=sys.stderr) + return rc + + # Match sample_eval's compute_paths: ..responses.jsonl + responses_path = args.output_dir / f"{args.dataset_name}.{sample_eval_tag}.responses.jsonl" + summary = _summarize_responses(responses_path) if responses_path.exists() else {} + + wall_clock_s = t_end - t_start + meta: Dict[str, Any] = { + "tag": args.tag, + "server_url": args.server_url, + "dataset": args.dataset_name, + "max_questions": args.max_questions, + "n_samples_per_q": args.n, + "max_concurrency": args.max_concurrency, + "max_tokens": args.max_tokens, + "top_k": args.top_k, + "value_scale": args.value_scale, + "value_mode": args.value_mode, + "value_gamma": args.value_gamma, + "wall_clock_s": wall_clock_s, + "wall_start_epoch_s": t_wall_start, + "wall_end_epoch_s": t_wall_end, + "responses_path": str(responses_path), + "summary": summary, + "cmd": cmd, + } + if summary.get("total_output_tokens"): + meta["throughput_output_tokens_per_s"] = ( + summary["total_output_tokens"] / wall_clock_s if wall_clock_s > 0 else 0 + ) + + meta_path = args.output_dir / f"{args.tag}.meta.json" + meta_path.write_text(json.dumps(meta, indent=2)) + print(f"meta -> {meta_path}") + print(f"wall_clock_s={wall_clock_s:.2f} " + f"out_tokens={summary.get('total_output_tokens', '?')} " + f"throughput={meta.get('throughput_output_tokens_per_s', 0):.1f} tok/s") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/inference/lenvm_timing.sh b/scripts/inference/lenvm_timing.sh new file mode 100755 index 0000000..861c7a1 --- /dev/null +++ b/scripts/inference/lenvm_timing.sh @@ -0,0 +1,166 @@ +#!/usr/bin/env bash +# Compare end-to-end and per-step decoding cost: baseline vs LenVM-guided. +# +# Two server lifecycles (so per-step timer only captures the configuration +# under test): +# 1. SGLang with --enable-lvm-guided-sampling OFF -> baseline timing. +# 2. SGLang with LenVM enabled + 7B base + 1.5B LenVM -> guided timing. +# Both servers point SGLANG_LVM_TIMING_LOG at distinct JSONL files. The client +# replays the same GSM8K prompt set against each, then analyze.py emits a +# CSV/JSON table and per-step decomposition plot. +# +# Run from repository root. + +set -euo pipefail + +HOST="${HOST:-0.0.0.0}" +PORT="${PORT:-10020}" +DP_SIZE="${DP_SIZE:-1}" +TP_SIZE="${TP_SIZE:-1}" +CONTEXT_LENGTH="${CONTEXT_LENGTH:-30000}" +MEM_FRACTION_STATIC="${MEM_FRACTION_STATIC:-0.4}" +LENVM_MEM_FRACTION_STATIC="${LENVM_MEM_FRACTION_STATIC:-0.4}" + +BASE_MODEL="${BASE_MODEL:-Qwen/Qwen2.5-7B-Instruct}" +LENVM_MODEL="${LENVM_MODEL:-./models/namezz/lvm-math-0402-a-qwen2.5-7b-instruct-b-qwen2.5-1.5b-instruct}" +DATASET="${DATASET:-gsm8k}" +MAX_QUESTIONS="${MAX_QUESTIONS:-50}" +MAX_CONCURRENCY="${MAX_CONCURRENCY:-50}" +N_SAMPLES="${N_SAMPLES:-16}" +MAX_TOKENS="${MAX_TOKENS:-6000}" +TEMPERATURE="${TEMPERATURE:-1.0}" +TOP_P="${TOP_P:-1.0}" +MIN_P="${MIN_P:-0.01}" +LENVM_TOP_K="${LENVM_TOP_K:-5}" +LENVM_VALUE_SCALE="${LENVM_VALUE_SCALE:-0}" +LENVM_VALUE_MODE="${LENVM_VALUE_MODE:-centered_exp}" +LENVM_VALUE_GAMMA="${LENVM_VALUE_GAMMA:-0.997}" + +RESULTS_DIR="${RESULTS_DIR:-./results/timing/$(basename "$BASE_MODEL")_vs_$(basename "$LENVM_MODEL")}" +mkdir -p "$RESULTS_DIR" + +# ---- helpers --------------------------------------------------------------- + +wait_for_server() { + local port="$1" + for _ in $(seq 1 1200); do + if curl -sf "http://127.0.0.1:${port}/v1/models" >/dev/null; then return 0; fi + sleep 2 + done + echo "Server on port ${port} failed to become ready" >&2 + return 1 +} + +kill_server() { + local pid="$1" + if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then + kill "$pid" 2>/dev/null || true + for _ in $(seq 1 30); do + kill -0 "$pid" 2>/dev/null || return 0 + sleep 1 + done + kill -9 "$pid" 2>/dev/null || true + fi +} + +# ---- stage 1: baseline (LenVM disabled) ------------------------------------ + +echo "==> Stage 1: baseline server (no LenVM)" +source .venv-infer/bin/activate + +BASELINE_TIMING_LOG="$RESULTS_DIR/baseline.timing.jsonl" +: > "$BASELINE_TIMING_LOG" + +SGLANG_LVM_TIMING_LOG="$BASELINE_TIMING_LOG" \ +python -m sglang.launch_server \ + --model-path "$BASE_MODEL" \ + --host "$HOST" \ + --port "$PORT" \ + --tp-size "$TP_SIZE" \ + --dp-size "$DP_SIZE" \ + --context-length "$CONTEXT_LENGTH" \ + --mem-fraction-static "$MEM_FRACTION_STATIC" & +SERVER_PID=$! +trap 'kill_server "$SERVER_PID"' EXIT + +wait_for_server "$PORT" +echo "Baseline server ready" + +source .venv-eval/bin/activate +python -m inference.timing.run_timing \ + --tag baseline \ + --server-url "http://127.0.0.1:$PORT" \ + --dataset-name "$DATASET" \ + --max-questions "$MAX_QUESTIONS" \ + --max-concurrency "$MAX_CONCURRENCY" \ + --n "$N_SAMPLES" \ + --max-tokens "$MAX_TOKENS" \ + --temperature "$TEMPERATURE" \ + --top-p "$TOP_P" \ + --top-k -1 \ + --min-p "$MIN_P" \ + --output-dir "$RESULTS_DIR" + +kill_server "$SERVER_PID" +trap - EXIT +SERVER_PID="" + +# ---- stage 2: LenVM (in-proc guidance) ------------------------------------- + +echo "==> Stage 2: LenVM server (7B base + LenVM in-proc)" +source .venv-infer/bin/activate + +LENVM_TIMING_LOG="$RESULTS_DIR/lenvm.timing.jsonl" +: > "$LENVM_TIMING_LOG" + +SGLANG_LVM_TIMING_LOG="$LENVM_TIMING_LOG" \ +python -m sglang.launch_server \ + --model-path "$BASE_MODEL" \ + --host "$HOST" \ + --port "$PORT" \ + --tp-size "$TP_SIZE" \ + --dp-size "$DP_SIZE" \ + --context-length "$CONTEXT_LENGTH" \ + --enable-lvm-guided-sampling \ + --lvm-guided-inproc \ + --lvm-guided-inproc-model-path "$LENVM_MODEL" \ + --lvm-guided-inproc-json-model-override-args '{"architectures":["Qwen2ForLengthValueModel"]}' \ + --disable-overlap-schedule \ + --mem-fraction-static "$MEM_FRACTION_STATIC" \ + --lvm-guided-inproc-mem-fraction-static "$LENVM_MEM_FRACTION_STATIC" \ + --lvm-guided-fn sglang.srt.lvm.lvm_guided_sampling:lvm_combined_guidance & +SERVER_PID=$! +trap 'kill_server "$SERVER_PID"' EXIT + +wait_for_server "$PORT" +echo "LenVM server ready" + +source .venv-eval/bin/activate +python -m inference.timing.run_timing \ + --tag lenvm \ + --server-url "http://127.0.0.1:$PORT" \ + --dataset-name "$DATASET" \ + --max-questions "$MAX_QUESTIONS" \ + --max-concurrency "$MAX_CONCURRENCY" \ + --n "$N_SAMPLES" \ + --max-tokens "$MAX_TOKENS" \ + --temperature "$TEMPERATURE" \ + --top-p "$TOP_P" \ + --top-k "$LENVM_TOP_K" \ + --min-p "$MIN_P" \ + --value-scale "$LENVM_VALUE_SCALE" \ + --value-mode "$LENVM_VALUE_MODE" \ + --value-gamma "$LENVM_VALUE_GAMMA" \ + --output-dir "$RESULTS_DIR" + +kill_server "$SERVER_PID" +trap - EXIT +SERVER_PID="" + +# ---- stage 3: analyze ------------------------------------------------------ + +echo "==> Stage 3: analyze" +python -m inference.timing.analyze \ + --results-dir "$RESULTS_DIR" + +echo "Done. Results in $RESULTS_DIR" diff --git a/sglang-LenVM/python/sglang/srt/layers/sampler.py b/sglang-LenVM/python/sglang/srt/layers/sampler.py index ccd4bfd..7301561 100644 --- a/sglang-LenVM/python/sglang/srt/layers/sampler.py +++ b/sglang-LenVM/python/sglang/srt/layers/sampler.py @@ -17,6 +17,7 @@ from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda, is_npu from sglang.srt.lvm.lvm_guided_sampling import LvmGuidedSampler +from sglang.srt.lvm.timing import get_timer if is_cuda(): from sgl_kernel import ( @@ -45,6 +46,7 @@ def __init__(self, model_runner=None): self.lvm_guided_sampler = LvmGuidedSampler.from_server_args( get_global_server_args(), model_runner=model_runner ) + self._timer = get_timer() if is_dp_attention_enabled(): self.tp_sync_group = get_attention_tp_group().device_group @@ -134,6 +136,10 @@ def forward( positions: The positions of the tokens in the sequence. Used for deterministic sampling to get the unique seed for each position. """ + t_total = self._timer.section_start("t_sampler_total_ms") + guided_applied = False + batch_size_meta = int(logits_output.next_token_logits.shape[0]) + logits = logits_output.next_token_logits # Preprocess logits (custom processors and NaN handling) @@ -145,6 +151,7 @@ def forward( if return_logprob: logprobs = torch.nn.functional.log_softmax(logits, dim=-1) else: + t_pre = self._timer.section_start("t_pre_lvm_ms") can_sample_directly_from_probs = ( not sampling_info.need_top_p_sampling and not sampling_info.need_top_k_sampling @@ -181,7 +188,9 @@ def forward( probs = logits del logits - guided_applied = False + self._timer.section_end("t_pre_lvm_ms", t_pre) + + t_lvm = self._timer.section_start("t_lvm_apply_outer_ms") if self.lvm_guided_sampler is not None and sampling_info.reqs is not None: guided_probs = self.lvm_guided_sampler.apply( probs, @@ -194,10 +203,12 @@ def forward( if guided_probs is not None: probs = guided_probs guided_applied = True + self._timer.section_end("t_lvm_apply_outer_ms", t_lvm) if guided_applied: can_sample_directly_from_probs = True + t_sample = self._timer.section_start("t_sample_ms") if can_sample_directly_from_probs: # when we don't need top-k, top-p, or min-p sampling, we can directly sample from the probs batch_next_token_ids = sampling_from_probs_torch( @@ -244,6 +255,7 @@ def forward( raise ValueError( f"Invalid sampling backend: {get_global_server_args().sampling_backend}" ) + self._timer.section_end("t_sample_ms", t_sample) if return_logprob: if get_global_server_args().rl_on_policy_target is not None: @@ -291,6 +303,13 @@ def forward( group=self.tp_sync_group, ) + self._timer.section_end("t_sampler_total_ms", t_total) + self._timer.set_meta( + lvm_active=bool(guided_applied), + batch_size=batch_size_meta, + is_greedy=bool(sampling_info.is_all_greedy), + ) + self._timer.flush_step() return batch_next_token_ids def compute_logprobs_only( diff --git a/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py b/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py index ea51495..56881b2 100644 --- a/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py +++ b/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py @@ -15,6 +15,7 @@ from sglang.srt.utils.common import dynamic_import from sglang.srt.server_args import get_global_server_args from sglang.srt.lvm.lvm_value_utils import force_eos_value_zero, get_eos_token_ids +from sglang.srt.lvm.timing import get_timer from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.schedule_batch import Req @@ -1765,37 +1766,49 @@ def apply( Returns the modified probs tensor, or None when no guidance is needed (caller should use the original probs). """ + timer = get_timer() inproc = self._get_inproc_provider() if inproc not in (None, False): # Free KV cache for requests that have finished or aborted inproc.clean_stale_requests(set(r.rid for r in reqs)) + t_bp = timer.section_start("t_lvm_build_pending_ms") pending = self._build_pending(probs, reqs, temperatures, top_ps, top_ks, min_ps) + timer.section_end("t_lvm_build_pending_ms", t_bp) if pending is None: return None if pending.send_batch_indices: reqs_send = [pending.req_list[i] for i in pending.send_batch_indices] + timer.set_meta(lvm_n_reqs_with_guidance=len(reqs_send)) if pending.gpu_candidates is not None: # GPU fast path (synchronous). if inproc not in (None, False): try: rids_send = [req.rid for req in reqs_send] + t_fwd = timer.section_start("t_lvm_forward_ms") inproc.tree_value_extend(rids_send, pending.prefix_ids_send, reqs_send) gpu_emb = inproc.tree_value_launch_gpu( rids_send, pending.candidate_ids_send, gpu_candidates=pending.gpu_candidates ) gpu_embeddings = inproc.tree_value_collect_gpu(gpu_emb) + timer.section_end("t_lvm_forward_ms", t_fwd) + t_ag = timer.section_start("t_lvm_apply_guidance_ms") self._apply_guidance_gpu(pending, gpu_embeddings) + timer.section_end("t_lvm_apply_guidance_ms", t_ag) return pending.guided except Exception as exc: raise RuntimeError("LenVM GPU guidance path failed") from exc + t_fwd = timer.section_start("t_lvm_forward_ms") lvm_values = self._post_tree_value( [req.rid for req in reqs_send], pending.prefix_ids_send, pending.candidate_ids_send, reqs_send ) + timer.section_end("t_lvm_forward_ms", t_fwd) if lvm_values is None: return None + t_ag = timer.section_start("t_lvm_apply_guidance_ms") self._apply_guidance(pending, lvm_values) + timer.section_end("t_lvm_apply_guidance_ms", t_ag) return pending.guided diff --git a/sglang-LenVM/python/sglang/srt/lvm/timing.py b/sglang-LenVM/python/sglang/srt/lvm/timing.py new file mode 100644 index 0000000..b6e81d3 --- /dev/null +++ b/sglang-LenVM/python/sglang/srt/lvm/timing.py @@ -0,0 +1,75 @@ +"""Lightweight per-step Python wall-clock timer for LenVM-guided decoding. + +Activated by environment variable SGLANG_LVM_TIMING_LOG=/path/to/timing.jsonl. +When unset, all timer operations are no-ops with negligible overhead. + +Captures wall-clock (time.perf_counter) for sections within Sampler.forward and +LvmGuidedSampler.apply. Each scheduler step flushes one JSONL record with the +section durations and metadata. Records are flushed line-by-line so a tail-f +during a run is safe. + +The collector lives in the GPU worker process. For DP/TP>1 it would need a +per-rank suffix on the log path; current scope is single-rank smoke testing. +""" + +from __future__ import annotations + +import json +import os +import threading +import time +from typing import Optional + + +class _Timer: + def __init__(self) -> None: + log_path = os.environ.get("SGLANG_LVM_TIMING_LOG") + self.enabled = bool(log_path) + self._log_path: Optional[str] = log_path + self._fh = None + self._step_id = 0 + self._lock = threading.Lock() + self._current_step: dict = {} + if self.enabled: + os.makedirs(os.path.dirname(self._log_path) or ".", exist_ok=True) + self._fh = open(self._log_path, "a", buffering=1) + + def section_start(self, name: str) -> Optional[float]: + if not self.enabled: + return None + return time.perf_counter() + + def section_end(self, name: str, start: Optional[float]) -> None: + if not self.enabled or start is None: + return + elapsed_ms = (time.perf_counter() - start) * 1000.0 + # Accumulate in case a section is entered multiple times per step. + prev = self._current_step.get(name, 0.0) + self._current_step[name] = prev + elapsed_ms + + def set_meta(self, **kwargs) -> None: + if not self.enabled: + return + self._current_step.update(kwargs) + + def flush_step(self) -> None: + """Write one JSONL line for the current step and reset accumulator.""" + if not self.enabled or self._fh is None: + return + if not self._current_step: + return + with self._lock: + self._step_id += 1 + record = {"step": self._step_id, **self._current_step} + self._fh.write(json.dumps(record) + "\n") + self._current_step.clear() + + +_INSTANCE: Optional[_Timer] = None + + +def get_timer() -> _Timer: + global _INSTANCE + if _INSTANCE is None: + _INSTANCE = _Timer() + return _INSTANCE From dc93a52fc1bc8971ee7f7ac8d8b54ebb73a804a5 Mon Sep 17 00:00:00 2001 From: changyiyang Date: Fri, 22 May 2026 21:10:10 +0000 Subject: [PATCH 2/5] Add theoretical FLOPs estimate to timing analysis Reviewer VVLr asked for an "inference FLOPs" comparison (paper currently shows wall-clock implications only). flops.py estimates total FLOPs from the standard 2*N_params forward-pass rule for the base model + (1+k) LenVM forward passes per generated token. analyze.py now emits a second table with GFLOPs/token, total PFLOPs, achieved TFLOPs/s, and ratios. On the 50-question GSM8K run the theoretical FLOPs ratio is 2.17x but the measured wall-clock ratio is 4.59x, so half of the LenVM slowdown comes from GPU underutilization (CPU candidate prep, separate-stream sync) rather than raw extra compute. Also fixes _summarize_responses to dedupe per-question usage (sample_eval writes the request's usage on every choice row, so summing every row inflated total tokens 16x). Co-Authored-By: Claude Opus 4.7 (1M context) --- inference/timing/README.md | 10 ++- inference/timing/analyze.py | 117 ++++++++++++++++++++++++++---- inference/timing/flops.py | 77 ++++++++++++++++++++ inference/timing/run_timing.py | 26 +++++-- scripts/inference/lenvm_timing.sh | 4 +- 5 files changed, 209 insertions(+), 25 deletions(-) create mode 100644 inference/timing/flops.py diff --git a/inference/timing/README.md b/inference/timing/README.md index eb201db..d931b0d 100644 --- a/inference/timing/README.md +++ b/inference/timing/README.md @@ -36,6 +36,14 @@ emits one JSONL line with: - `t_sample_ms` — sampling kernel - `lvm_active`, `batch_size`, `is_greedy` +Theoretical FLOPs are computed by `analyze.py` from the standard `2 * N_params` +forward-pass rule, using base model + LenVM checkpoint sizes (`flops.py` +hardcodes the Qwen2.5 family; pass `--base-model` / `--lvm-model` to swap). +For LenVM the per-output-token cost is `2 * N_base + 2 * N_lvm * (1 + k)`, where +`k` comes from `top_k` in `meta.json`. Contrasting the theoretical FLOPs ratio +with the measured wall-clock ratio shows how much of the slowdown is raw +compute increase vs GPU underutilization. + ## Running it ```bash @@ -59,7 +67,7 @@ The script chains three stages: - `baseline.timing.jsonl`, `lenvm.timing.jsonl` — per-step records - `baseline.meta.json`, `lenvm.meta.json` — wall-clock, token counts, cmdline - `baseline.gpu_samples.csv`, `lenvm.gpu_samples.csv` — `nvidia-smi` 1 Hz log -- `summary.csv`, `summary.json` — aggregated table +- `summary.csv`, `summary.json` — aggregated table (incl. theoretical FLOPs, achieved TFLOPs/s, ratio row) - `per_step_breakdown.png` — stacked bar of sampler-side decomposition - `lvm_apply_breakdown.png` — LenVM `apply()` internal breakdown diff --git a/inference/timing/analyze.py b/inference/timing/analyze.py index 2b01e32..c56bcb2 100644 --- a/inference/timing/analyze.py +++ b/inference/timing/analyze.py @@ -25,6 +25,8 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, Optional +from inference.timing.flops import baseline_run_flops, lvm_run_flops, resolve_params + _PER_STEP_KEYS = [ "t_sampler_total_ms", @@ -104,21 +106,70 @@ def _load_meta(path: Path) -> Dict[str, Any]: def _row_for(tag: str, meta: Dict[str, Any], agg: Dict[str, Any]) -> Dict[str, Any]: + summary = meta.get("summary") or {} row: Dict[str, Any] = { "tag": tag, "wall_clock_s": meta.get("wall_clock_s"), - "total_output_tokens": (meta.get("summary") or {}).get("total_output_tokens"), + "total_output_tokens": summary.get("total_output_tokens"), + "total_prompt_tokens": summary.get("total_prompt_tokens"), "throughput_output_tokens_per_s": meta.get("throughput_output_tokens_per_s"), - "n_requests": (meta.get("summary") or {}).get("n_requests"), - "output_tokens_mean": (meta.get("summary") or {}).get("output_tokens_mean"), - "output_tokens_p95": (meta.get("summary") or {}).get("output_tokens_p95"), - "latency_s_mean": (meta.get("summary") or {}).get("latency_s_mean"), - "latency_s_p95": (meta.get("summary") or {}).get("latency_s_p95"), + "n_requests": summary.get("n_requests"), + "output_tokens_mean": summary.get("output_tokens_mean"), + "output_tokens_p95": summary.get("output_tokens_p95"), + "latency_s_mean": summary.get("latency_s_mean"), + "latency_s_p95": summary.get("latency_s_p95"), + "value_scale": meta.get("value_scale"), + "top_k": meta.get("top_k"), } row.update(agg) return row +def _add_flops( + row: Dict[str, Any], + *, + base_params_b: Optional[float], + lvm_params_b: Optional[float], + is_lvm: bool, +) -> None: + """Attach theoretical FLOPs + achieved-FLOPs/sec columns to row.""" + if base_params_b is None: + return + prompt_t = int(row.get("total_prompt_tokens") or 0) + output_t = int(row.get("total_output_tokens") or 0) + if not (prompt_t or output_t): + return + if is_lvm: + if lvm_params_b is None: + return + k = row.get("top_k") + if k is None or int(k) < 1: + return + total = lvm_run_flops(prompt_t, output_t, base_params_b, lvm_params_b, int(k)) + per_token = (2 * base_params_b + 2 * lvm_params_b * (1 + int(k))) * 1e9 + else: + total = baseline_run_flops(prompt_t, output_t, base_params_b) + per_token = 2 * base_params_b * 1e9 + row["theoretical_gflops_per_output_token"] = per_token / 1e9 + row["theoretical_pflops_total"] = total / 1e15 + wall = row.get("wall_clock_s") or 0 + if wall > 0: + row["achieved_tflops_per_s"] = total / wall / 1e12 + + +def _ratio_row(rows: List[Dict[str, Any]]) -> Dict[str, Any]: + """Compute lenvm / baseline ratios for numeric columns. Assumes rows[0]=baseline, rows[1]=lenvm.""" + if len(rows) != 2: + return {} + base, lvm = rows + ratio: Dict[str, Any] = {"tag": "ratio (lvm/base)"} + for key, b_val in base.items(): + l_val = lvm.get(key) + if isinstance(b_val, (int, float)) and isinstance(l_val, (int, float)) and b_val: + ratio[key] = l_val / b_val + return ratio + + def _stacked_bar(rows: List[Dict[str, Any]], out_png: Path) -> Optional[Path]: try: import matplotlib @@ -181,8 +232,8 @@ def _lvm_sub_breakdown(rows: List[Dict[str, Any]], out_png: Path) -> Optional[Pa return out_png -def _print_table(rows: List[Dict[str, Any]]) -> None: - cols = [ +def _print_tables(rows: List[Dict[str, Any]]) -> None: + timing_cols = [ ("tag", "tag"), ("wall_clock_s", "e2e_s"), ("throughput_output_tokens_per_s", "tok/s"), @@ -193,6 +244,20 @@ def _print_table(rows: List[Dict[str, Any]]) -> None: ("t_lvm_apply_outer_ms_mean", "lvm_apply_ms"), ("t_sample_ms_mean", "sample_ms"), ] + flops_cols = [ + ("tag", "tag"), + ("theoretical_gflops_per_output_token", "GFLOPs/tok"), + ("theoretical_pflops_total", "PFLOPs(total)"), + ("achieved_tflops_per_s", "TFLOPs/s"), + ("wall_clock_s", "e2e_s"), + ] + _print_one(rows, timing_cols) + if any("theoretical_gflops_per_output_token" in r for r in rows): + print() + _print_one(rows, flops_cols) + + +def _print_one(rows: List[Dict[str, Any]], cols: List[tuple]) -> None: widths = {k: max(len(label), max(len(_fmt(r.get(k))) for r in rows)) for k, label in cols} header = " | ".join(f"{label:>{widths[k]}}" for k, label in cols) print(header) @@ -214,26 +279,48 @@ def main() -> int: p.add_argument("--results-dir", type=Path, required=True) p.add_argument("--baseline-tag", default="baseline") p.add_argument("--lenvm-tag", default="lenvm") + p.add_argument( + "--base-model", + default="Qwen/Qwen2.5-7B-Instruct", + help="Base generation model name (for theoretical FLOPs lookup).", + ) + p.add_argument( + "--lvm-model", + default="namezz/lvm-math-0402-a-qwen2.5-7b-instruct-b-qwen2.5-1.5b-instruct", + help="LenVM checkpoint name (for theoretical FLOPs lookup).", + ) args = p.parse_args() + base_params_b = resolve_params(args.base_model) + lvm_params_b = resolve_params(args.lvm_model) + if base_params_b is None: + print(f"warning: unknown base model '{args.base_model}', skipping FLOPs") + if lvm_params_b is None: + print(f"warning: unknown LenVM model '{args.lvm_model}', skipping LenVM FLOPs") + rd = args.results_dir rows: List[Dict[str, Any]] = [] - for tag in (args.baseline_tag, args.lenvm_tag): + for tag, is_lvm in ((args.baseline_tag, False), (args.lenvm_tag, True)): meta = _load_meta(rd / f"{tag}.meta.json") records = _filter_warmup(list(_iter_records(rd / f"{tag}.timing.jsonl"))) agg = _agg(records) - rows.append(_row_for(tag, meta, agg)) + row = _row_for(tag, meta, agg) + _add_flops(row, base_params_b=base_params_b, lvm_params_b=lvm_params_b, is_lvm=is_lvm) + rows.append(row) + + ratio = _ratio_row(rows) + display_rows = rows + ([ratio] if ratio else []) summary_path = rd / "summary.json" - summary_path.write_text(json.dumps(rows, indent=2)) + summary_path.write_text(json.dumps(display_rows, indent=2)) csv_path = rd / "summary.csv" - if rows: - keys = sorted({k for r in rows for k in r.keys()}) + if display_rows: + keys = sorted({k for r in display_rows for k in r.keys()}) with csv_path.open("w", newline="") as f: w = csv.DictWriter(f, fieldnames=keys) w.writeheader() - for r in rows: + for r in display_rows: w.writerow({k: r.get(k) for k in keys}) plot_path = _stacked_bar(rows, rd / "per_step_breakdown.png") @@ -246,7 +333,7 @@ def main() -> int: if sub_plot_path: print(f"lvm sub-plot -> {sub_plot_path}") print() - _print_table(rows) + _print_tables(display_rows) return 0 diff --git a/inference/timing/flops.py b/inference/timing/flops.py new file mode 100644 index 0000000..2fe244d --- /dev/null +++ b/inference/timing/flops.py @@ -0,0 +1,77 @@ +"""Theoretical inference FLOPs estimator for LenVM vs baseline. + +Uses the standard 2*N forward rule: a transformer forward pass over T tokens +costs ~2 * N_params * T FLOPs (matmul-dominated; ignores attention's O(S) +contribution at long context, fine for ballpark inference-overhead numbers). + +Contrasting theoretical FLOPs ratio against measured wall-clock ratio +isolates GPU utilization loss from raw compute increase. Reviewers in PR #2 +asked specifically for "inference FLOPs", so we surface both. +""" + +from __future__ import annotations + +from typing import Optional + + +# Well-known parameter counts (params, including embeddings/lm_head) for the +# Qwen2.5 family used in the LenVM paper. Update via --base-model-params / +# --lvm-model-params CLI args if you swap base/LenVM checkpoints. +MODEL_PARAMS_BILLIONS = { + "Qwen/Qwen2.5-0.5B-Instruct": 0.49, + "Qwen/Qwen2.5-1.5B-Instruct": 1.54, + "Qwen/Qwen2.5-3B-Instruct": 3.09, + "Qwen/Qwen2.5-7B-Instruct": 7.61, + "Qwen/Qwen2.5-14B-Instruct": 14.8, + "Qwen/Qwen2.5-32B-Instruct": 32.5, + "Qwen/Qwen2.5-72B-Instruct": 72.7, + # LenVM checkpoints are base-model + a thin value head; value head is + # negligible (~1M params), so we charge the LenVM forward at the base + # model's compute cost. + "namezz/lvm-math-0402-a-qwen2.5-7b-instruct-b-qwen2.5-1.5b-instruct": 1.54, + "namezz/lvm-instruct-0327-a-qwen2.5-7b-instruct-b-qwen2.5-1.5b-instruct": 1.54, + "namezz/lvm-a-qwen2.5-7b-instruct-b-qwen2.5-0.5b-instruct": 0.49, + "namezz/lvm-rel-a-qwen2.5-3b-instruct-b-qwen2.5-3b-instruct": 3.09, +} + + +def resolve_params(model_name_or_path: str) -> Optional[float]: + """Return param count (in billions) for a model name or local path.""" + if model_name_or_path in MODEL_PARAMS_BILLIONS: + return MODEL_PARAMS_BILLIONS[model_name_or_path] + # Local paths like ./models/namezz/...; match by basename. + for key, val in MODEL_PARAMS_BILLIONS.items(): + if model_name_or_path.endswith(key) or model_name_or_path.endswith(key.split("/")[-1]): + return val + return None + + +def forward_flops_per_token(n_params_billions: float) -> float: + """2 * N params per token (decode-only forward, matmul-dominated).""" + return 2.0 * n_params_billions * 1e9 + + +def baseline_run_flops(prompt_tokens: int, output_tokens: int, base_params_b: float) -> float: + """Vanilla decoding: one base-model forward over every prompt + generated token.""" + return forward_flops_per_token(base_params_b) * (prompt_tokens + output_tokens) + + +def lvm_run_flops( + prompt_tokens: int, + output_tokens: int, + base_params_b: float, + lvm_params_b: float, + k_candidates: int, +) -> float: + """LenVM-guided decoding total FLOPs. + + Per generated token, LenVM adds: + - 1 forward over the LenVM (extend KV by the just-accepted token) + - k forwards over the LenVM (score the top-k candidates) + + Prefill on the LenVM happens lazily as candidates are scored, so we + fold it into output_tokens for an upper-bound estimate. + """ + base = baseline_run_flops(prompt_tokens, output_tokens, base_params_b) + lvm_per_step = forward_flops_per_token(lvm_params_b) * (1 + k_candidates) + return base + lvm_per_step * output_tokens diff --git a/inference/timing/run_timing.py b/inference/timing/run_timing.py index 311d296..ead6e34 100644 --- a/inference/timing/run_timing.py +++ b/inference/timing/run_timing.py @@ -100,12 +100,19 @@ def _stop_gpu_sampler(proc: Optional[subprocess.Popen]) -> None: def _summarize_responses(responses_jsonl: Path) -> Dict[str, Any]: - """Aggregate token counts and per-question latency from sample_eval output.""" + """Aggregate token counts and per-question latency from sample_eval output. + + sample_eval writes one row per choice but duplicates the full-request + `usage` field across every choice's row, so we dedupe by question `idx` + (counting each request's usage once) before summing. `n_requests` and the + per-choice latency stats still sample every row. + """ n_requests = 0 total_output_tokens = 0 total_prompt_tokens = 0 output_token_counts: List[int] = [] latencies: List[float] = [] + seen_idx: set = set() with responses_jsonl.open() as f: for line in f: try: @@ -113,13 +120,16 @@ def _summarize_responses(responses_jsonl: Path) -> Dict[str, Any]: except json.JSONDecodeError: continue n_requests += 1 - usage = row.get("usage") or {} - out_tok = usage.get("completion_tokens") or usage.get("output_tokens") or 0 - in_tok = usage.get("prompt_tokens") or usage.get("input_tokens") or 0 - total_output_tokens += int(out_tok or 0) - total_prompt_tokens += int(in_tok or 0) - if out_tok: - output_token_counts.append(int(out_tok)) + idx = row.get("idx") + if idx is not None and idx not in seen_idx: + seen_idx.add(idx) + usage = row.get("usage") or {} + out_tok = usage.get("completion_tokens") or usage.get("output_tokens") or 0 + in_tok = usage.get("prompt_tokens") or usage.get("input_tokens") or 0 + total_output_tokens += int(out_tok or 0) + total_prompt_tokens += int(in_tok or 0) + if out_tok: + output_token_counts.append(int(out_tok)) lat = row.get("latency_s") or row.get("elapsed_s") if isinstance(lat, (int, float)): latencies.append(float(lat)) diff --git a/scripts/inference/lenvm_timing.sh b/scripts/inference/lenvm_timing.sh index 861c7a1..0bf660b 100755 --- a/scripts/inference/lenvm_timing.sh +++ b/scripts/inference/lenvm_timing.sh @@ -161,6 +161,8 @@ SERVER_PID="" echo "==> Stage 3: analyze" python -m inference.timing.analyze \ - --results-dir "$RESULTS_DIR" + --results-dir "$RESULTS_DIR" \ + --base-model "$BASE_MODEL" \ + --lvm-model "$LENVM_MODEL" echo "Done. Results in $RESULTS_DIR" From 497f75b826633cb25fd0226c95a23f8b06270afb Mon Sep 17 00:00:00 2001 From: changyiyang Date: Fri, 22 May 2026 21:28:16 +0000 Subject: [PATCH 3/5] Refactor FLOPs estimator to read HF config.json layer-by-layer Replaces the 2*N rule of thumb with a component-level accounting: * per-layer linear matmuls: Q/K/V/O projections (GQA-aware via num_key_value_heads) + SwiGLU MLP (gate, up, down). * per-layer attention compute: 2*H_q*h*seq_len for each of Q@K^T and attn@V, so attention scales with position over the generation trajectory. * lm_head: 2 * hidden_size * vocab_size, charged per token. ModelConfig.load(name_or_path) resolves the HF cache or local model dir for config.json and falls back to hardcoded Qwen2.5 dims when neither is present. Runs are split into prefill (charged once per unique prompt assuming SGLang prefix cache is on) and decode (per sample). LenVM-guided runs add one tree_value_extend + k candidate forwards through the value model per generated token. analyze.py prints three tables (timing / FLOPs headline / FLOPs by component) and emits a new flops_breakdown.png stacked-bar plot. README updated. Co-Authored-By: Claude Opus 4.7 (1M context) --- inference/timing/README.md | 28 +++- inference/timing/analyze.py | 184 ++++++++++++++++++----- inference/timing/flops.py | 288 ++++++++++++++++++++++++++++-------- 3 files changed, 399 insertions(+), 101 deletions(-) diff --git a/inference/timing/README.md b/inference/timing/README.md index d931b0d..64dd838 100644 --- a/inference/timing/README.md +++ b/inference/timing/README.md @@ -36,13 +36,25 @@ emits one JSONL line with: - `t_sample_ms` — sampling kernel - `lvm_active`, `batch_size`, `is_greedy` -Theoretical FLOPs are computed by `analyze.py` from the standard `2 * N_params` -forward-pass rule, using base model + LenVM checkpoint sizes (`flops.py` -hardcodes the Qwen2.5 family; pass `--base-model` / `--lvm-model` to swap). -For LenVM the per-output-token cost is `2 * N_base + 2 * N_lvm * (1 + k)`, where -`k` comes from `top_k` in `meta.json`. Contrasting the theoretical FLOPs ratio -with the measured wall-clock ratio shows how much of the slowdown is raw -compute increase vs GPU underutilization. +Theoretical FLOPs are computed by `analyze.py` at the layer level. `flops.py` +loads each model's `config.json` (HF cache or local dir; falls back to +hardcoded Qwen2.5 dims if missing) and counts: + +- per-layer linear matmuls: Q / K / V / O projections (GQA-aware) + SwiGLU MLP + (gate + up + down) +- per-layer attention compute: `2 * H_q * head_dim * seq_len` for each of + Q@K^T and attn@V (so attention contribution scales with position) +- `lm_head`: `2 * hidden * vocab` + +A baseline run is split into prefill (charged once per unique prompt, since +SGLang prefix caching is on by default) and decode (per sample). A LenVM run +adds, per generated token, one `tree_value_extend` forward plus `k` +candidate forwards through the value model. The analyzer reports both the +total PFLOPs and a per-component split so the linear / attention / lm_head +shares of the baseline are visible alongside the LenVM-specific overhead. +Contrasting the theoretical FLOPs ratio with the measured wall-clock ratio +shows how much of the slowdown is raw compute increase vs GPU +underutilization. ## Running it @@ -70,6 +82,8 @@ The script chains three stages: - `summary.csv`, `summary.json` — aggregated table (incl. theoretical FLOPs, achieved TFLOPs/s, ratio row) - `per_step_breakdown.png` — stacked bar of sampler-side decomposition - `lvm_apply_breakdown.png` — LenVM `apply()` internal breakdown +- `flops_breakdown.png` — stacked bar of theoretical FLOPs by component + (base linear / attention / lm_head + LenVM extend / candidates / prefill) ## Caveats diff --git a/inference/timing/analyze.py b/inference/timing/analyze.py index c56bcb2..a4df2fb 100644 --- a/inference/timing/analyze.py +++ b/inference/timing/analyze.py @@ -25,7 +25,11 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, Optional -from inference.timing.flops import baseline_run_flops, lvm_run_flops, resolve_params +from inference.timing.flops import ( + ModelConfig, + baseline_run_flops, + lvm_extra_flops, +) _PER_STEP_KEYS = [ @@ -128,30 +132,69 @@ def _row_for(tag: str, meta: Dict[str, Any], agg: Dict[str, Any]) -> Dict[str, A def _add_flops( row: Dict[str, Any], *, - base_params_b: Optional[float], - lvm_params_b: Optional[float], + meta: Dict[str, Any], + base_cfg: Optional[ModelConfig], + lvm_cfg: Optional[ModelConfig], is_lvm: bool, ) -> None: - """Attach theoretical FLOPs + achieved-FLOPs/sec columns to row.""" - if base_params_b is None: + """Attach layer-level theoretical FLOPs + achieved-FLOPs/sec columns. + + Splits prefill (charged once per unique prompt, assumes SGLang prefix cache) + from decode (charged per sample), and breaks each into linear / attention / + lm_head components. + """ + if base_cfg is None: return - prompt_t = int(row.get("total_prompt_tokens") or 0) - output_t = int(row.get("total_output_tokens") or 0) - if not (prompt_t or output_t): + unique_prompts = int(meta.get("max_questions") or 0) + samples_per = int(meta.get("n_samples_per_q") or 1) + prompt_total = int(row.get("total_prompt_tokens") or 0) + output_total = int(row.get("total_output_tokens") or 0) + if unique_prompts <= 0 or output_total <= 0: return - if is_lvm: - if lvm_params_b is None: - return + # total_prompt_tokens is summed once per unique question (dedup in summarize). + # total_output_tokens is summed across all samples per question, then summed across questions. + mean_prompt = prompt_total / unique_prompts + mean_output = output_total / (unique_prompts * samples_per) + + base = baseline_run_flops( + base_cfg, + unique_prompts=unique_prompts, + samples_per_prompt=samples_per, + mean_prompt_tokens=mean_prompt, + mean_output_tokens=mean_output, + ) + total = base["total"]["total"] + row["base_pflops_total"] = total / 1e15 + row["base_pflops_prefill"] = base["prefill"]["total"] / 1e15 + row["base_pflops_decode"] = base["decode"]["total"] / 1e15 + row["base_pflops_linear"] = base["total"]["linear"] / 1e15 + row["base_pflops_attention"] = base["total"]["attention"] / 1e15 + row["base_pflops_lm_head"] = base["total"]["lm_head"] / 1e15 + + if is_lvm and lvm_cfg is not None: k = row.get("top_k") - if k is None or int(k) < 1: - return - total = lvm_run_flops(prompt_t, output_t, base_params_b, lvm_params_b, int(k)) - per_token = (2 * base_params_b + 2 * lvm_params_b * (1 + int(k))) * 1e9 - else: - total = baseline_run_flops(prompt_t, output_t, base_params_b) - per_token = 2 * base_params_b * 1e9 - row["theoretical_gflops_per_output_token"] = per_token / 1e9 + if k is not None and int(k) >= 1: + extra = lvm_extra_flops( + lvm_cfg, + unique_prompts=unique_prompts, + samples_per_prompt=samples_per, + mean_prompt_tokens=mean_prompt, + mean_output_tokens=mean_output, + k_candidates=int(k), + ) + total += extra["total"]["total"] + row["lvm_pflops_prefill"] = extra["lenvm_prefill"]["total"] / 1e15 + row["lvm_pflops_extend"] = extra["lenvm_extend"]["total"] / 1e15 + row["lvm_pflops_candidates"] = extra["lenvm_candidates"]["total"] / 1e15 + row["lvm_pflops_total"] = extra["total"]["total"] / 1e15 + row["theoretical_pflops_total"] = total / 1e15 + # Per-decode-token cost (excluding prefill share) for a quick "GFLOPs/tok" feel + decode_total = base["decode"]["total"] + (row.get("lvm_pflops_extend", 0.0) + row.get("lvm_pflops_candidates", 0.0)) * 1e15 + n_decode_tokens = unique_prompts * samples_per * mean_output + if n_decode_tokens > 0: + row["theoretical_gflops_per_output_token"] = decode_total / n_decode_tokens / 1e9 + wall = row.get("wall_clock_s") or 0 if wall > 0: row["achieved_tflops_per_s"] = total / wall / 1e12 @@ -203,6 +246,47 @@ def _stacked_bar(rows: List[Dict[str, Any]], out_png: Path) -> Optional[Path]: return out_png +def _flops_component_bar(rows: List[Dict[str, Any]], out_png: Path) -> Optional[Path]: + """Stacked bar of theoretical PFLOPs by component (base linear/attn/lm_head + LenVM).""" + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except Exception as e: # pragma: no cover + print(f"matplotlib unavailable, skipping plot: {e}") + return None + + if not any(r.get("base_pflops_linear") for r in rows): + return None + + labels = [r["tag"] for r in rows] + sections = [ + ("base_pflops_linear", "base linear (Q/K/V/O + MLP)"), + ("base_pflops_attention", "base attention (QK^T + attn@V)"), + ("base_pflops_lm_head", "base lm_head"), + ("lvm_pflops_extend", "LenVM extend forward"), + ("lvm_pflops_candidates", "LenVM candidate forwards"), + ("lvm_pflops_prefill", "LenVM prefill"), + ] + fig, ax = plt.subplots(figsize=(8, 5)) + bottoms = [0.0] * len(labels) + for key, name in sections: + vals = [float(r.get(key) or 0.0) for r in rows] + if max(vals) <= 0: + continue + ax.bar(labels, vals, bottom=bottoms, label=name) + bottoms = [b + v for b, v in zip(bottoms, vals)] + for i, total in enumerate(bottoms): + ax.text(i, total, f"{total:.2f} PFLOPs", ha="center", va="bottom") + ax.set_ylabel("Theoretical FLOPs (PFLOPs)") + ax.set_title("Theoretical inference FLOPs by component") + ax.legend(loc="upper left", fontsize="small") + fig.tight_layout() + fig.savefig(out_png, dpi=140) + plt.close(fig) + return out_png + + def _lvm_sub_breakdown(rows: List[Dict[str, Any]], out_png: Path) -> Optional[Path]: try: import matplotlib @@ -244,17 +328,34 @@ def _print_tables(rows: List[Dict[str, Any]]) -> None: ("t_lvm_apply_outer_ms_mean", "lvm_apply_ms"), ("t_sample_ms_mean", "sample_ms"), ] - flops_cols = [ + flops_total_cols = [ ("tag", "tag"), ("theoretical_gflops_per_output_token", "GFLOPs/tok"), - ("theoretical_pflops_total", "PFLOPs(total)"), + ("base_pflops_total", "base PFLOPs"), + ("lvm_pflops_total", "lvm PFLOPs"), + ("theoretical_pflops_total", "total PFLOPs"), ("achieved_tflops_per_s", "TFLOPs/s"), ("wall_clock_s", "e2e_s"), ] + flops_component_cols = [ + ("tag", "tag"), + ("base_pflops_linear", "base.linear"), + ("base_pflops_attention", "base.attn"), + ("base_pflops_lm_head", "base.lm_head"), + ("base_pflops_prefill", "base.prefill"), + ("base_pflops_decode", "base.decode"), + ("lvm_pflops_prefill", "lvm.prefill"), + ("lvm_pflops_extend", "lvm.extend"), + ("lvm_pflops_candidates", "lvm.cands"), + ] _print_one(rows, timing_cols) - if any("theoretical_gflops_per_output_token" in r for r in rows): + if any("theoretical_pflops_total" in r for r in rows): print() - _print_one(rows, flops_cols) + print("== FLOPs headline ==") + _print_one(rows, flops_total_cols) + print() + print("== FLOPs by component (PFLOPs) ==") + _print_one(rows, flops_component_cols) def _print_one(rows: List[Dict[str, Any]], cols: List[tuple]) -> None: @@ -291,12 +392,24 @@ def main() -> int: ) args = p.parse_args() - base_params_b = resolve_params(args.base_model) - lvm_params_b = resolve_params(args.lvm_model) - if base_params_b is None: - print(f"warning: unknown base model '{args.base_model}', skipping FLOPs") - if lvm_params_b is None: - print(f"warning: unknown LenVM model '{args.lvm_model}', skipping LenVM FLOPs") + base_cfg: Optional[ModelConfig] = None + lvm_cfg: Optional[ModelConfig] = None + try: + base_cfg = ModelConfig.load(args.base_model) + except FileNotFoundError as e: + print(f"warning: {e}; skipping baseline FLOPs") + try: + lvm_cfg = ModelConfig.load(args.lvm_model) + except FileNotFoundError as e: + print(f"warning: {e}; skipping LenVM FLOPs") + if base_cfg is not None: + print(f"base config (source={base_cfg.source}): L={base_cfg.num_hidden_layers} d={base_cfg.hidden_size} " + f"Hq={base_cfg.num_attention_heads} Hkv={base_cfg.num_key_value_heads} h={base_cfg.head_dim} " + f"ff={base_cfg.intermediate_size} V={base_cfg.vocab_size}") + if lvm_cfg is not None: + print(f"lvm config (source={lvm_cfg.source}): L={lvm_cfg.num_hidden_layers} d={lvm_cfg.hidden_size} " + f"Hq={lvm_cfg.num_attention_heads} Hkv={lvm_cfg.num_key_value_heads} h={lvm_cfg.head_dim} " + f"ff={lvm_cfg.intermediate_size} V={lvm_cfg.vocab_size}") rd = args.results_dir rows: List[Dict[str, Any]] = [] @@ -305,7 +418,7 @@ def main() -> int: records = _filter_warmup(list(_iter_records(rd / f"{tag}.timing.jsonl"))) agg = _agg(records) row = _row_for(tag, meta, agg) - _add_flops(row, base_params_b=base_params_b, lvm_params_b=lvm_params_b, is_lvm=is_lvm) + _add_flops(row, meta=meta, base_cfg=base_cfg, lvm_cfg=lvm_cfg, is_lvm=is_lvm) rows.append(row) ratio = _ratio_row(rows) @@ -325,13 +438,16 @@ def main() -> int: plot_path = _stacked_bar(rows, rd / "per_step_breakdown.png") sub_plot_path = _lvm_sub_breakdown(rows, rd / "lvm_apply_breakdown.png") + flops_plot_path = _flops_component_bar(rows, rd / "flops_breakdown.png") - print(f"summary.json -> {summary_path}") - print(f"summary.csv -> {csv_path}") + print(f"summary.json -> {summary_path}") + print(f"summary.csv -> {csv_path}") if plot_path: - print(f"plot -> {plot_path}") + print(f"timing plot -> {plot_path}") if sub_plot_path: - print(f"lvm sub-plot -> {sub_plot_path}") + print(f"lvm sub-plot -> {sub_plot_path}") + if flops_plot_path: + print(f"flops plot -> {flops_plot_path}") print() _print_tables(display_rows) return 0 diff --git a/inference/timing/flops.py b/inference/timing/flops.py index 2fe244d..e606915 100644 --- a/inference/timing/flops.py +++ b/inference/timing/flops.py @@ -1,77 +1,245 @@ -"""Theoretical inference FLOPs estimator for LenVM vs baseline. - -Uses the standard 2*N forward rule: a transformer forward pass over T tokens -costs ~2 * N_params * T FLOPs (matmul-dominated; ignores attention's O(S) -contribution at long context, fine for ballpark inference-overhead numbers). - -Contrasting theoretical FLOPs ratio against measured wall-clock ratio -isolates GPU utilization loss from raw compute increase. Reviewers in PR #2 -asked specifically for "inference FLOPs", so we surface both. +"""Layer-level inference FLOPs estimator for LenVM vs baseline. + +Reads HuggingFace ``config.json`` to count weight matmuls layer-by-layer +instead of relying on the ``2 * N_params`` rule of thumb. The decomposition +matters for LenVM analysis because the value model adds (1 + k) forwards +per generated token, and the relative weight of attention-vs-linear FLOPs +shifts at long sequence lengths. + +For each transformer layer with hidden size ``d``, ``H_q`` attention heads, +``H_kv`` key/value heads, head dim ``h``, FFN dim ``ff`` (SwiGLU: gate + up ++ down), per-token FLOPs are: + +* attention projections: ``2*d*(H_q*h) + 2*2*d*(H_kv*h) + 2*(H_q*h)*d`` + (Q, K, V, output projections; Qwen2.5 uses GQA so K/V are smaller). +* attention compute at position ``p``: ``2 * H_q * h * p * 2`` (Q@K^T + and attention@V over a KV cache of length ``p``). +* SwiGLU MLP: ``2 * d * ff * 3`` (gate + up projections share input; down + projection writes back). +* LM head (final layer only): ``2 * d * V``. + +Prefill over a prompt of length ``S`` runs every token in parallel but +each token still attends to the lower-triangular prefix, so attention +FLOPs scale as ``S * (S+1) / 2`` rather than ``S``. + +LenVM-guided decoding adds, per output token: + +* one ``tree_value_extend`` forward (catch the LenVM cache up by the + newly-accepted token), +* ``k`` ``tree_value`` forwards (score the top-k candidates). + +Prefix caching: when the same prompt is shared by multiple samples (n>1), +the base/LenVM prefill is charged once per unique prompt and decode is +charged per sample. The analyzer assumes +``unique_prompts == meta['max_questions']`` and +``samples == unique_prompts * meta['n_samples_per_q']``. """ from __future__ import annotations -from typing import Optional - - -# Well-known parameter counts (params, including embeddings/lm_head) for the -# Qwen2.5 family used in the LenVM paper. Update via --base-model-params / -# --lvm-model-params CLI args if you swap base/LenVM checkpoints. -MODEL_PARAMS_BILLIONS = { - "Qwen/Qwen2.5-0.5B-Instruct": 0.49, - "Qwen/Qwen2.5-1.5B-Instruct": 1.54, - "Qwen/Qwen2.5-3B-Instruct": 3.09, - "Qwen/Qwen2.5-7B-Instruct": 7.61, - "Qwen/Qwen2.5-14B-Instruct": 14.8, - "Qwen/Qwen2.5-32B-Instruct": 32.5, - "Qwen/Qwen2.5-72B-Instruct": 72.7, - # LenVM checkpoints are base-model + a thin value head; value head is - # negligible (~1M params), so we charge the LenVM forward at the base - # model's compute cost. - "namezz/lvm-math-0402-a-qwen2.5-7b-instruct-b-qwen2.5-1.5b-instruct": 1.54, - "namezz/lvm-instruct-0327-a-qwen2.5-7b-instruct-b-qwen2.5-1.5b-instruct": 1.54, - "namezz/lvm-a-qwen2.5-7b-instruct-b-qwen2.5-0.5b-instruct": 0.49, - "namezz/lvm-rel-a-qwen2.5-3b-instruct-b-qwen2.5-3b-instruct": 3.09, +import json +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Optional + + +# Fallback model dimensions, used only when we cannot locate a config.json +# (e.g. paths that point to a non-existent model). Sourced from the Qwen2.5 +# model cards. Update if you add new base/LenVM checkpoints. +_FALLBACK_CONFIGS: Dict[str, Dict[str, int]] = { + "Qwen/Qwen2.5-0.5B-Instruct": dict(num_hidden_layers=24, hidden_size=896, num_attention_heads=14, num_key_value_heads=2, intermediate_size=4864, vocab_size=151936), + "Qwen/Qwen2.5-1.5B-Instruct": dict(num_hidden_layers=28, hidden_size=1536, num_attention_heads=12, num_key_value_heads=2, intermediate_size=8960, vocab_size=151936), + "Qwen/Qwen2.5-3B-Instruct": dict(num_hidden_layers=36, hidden_size=2048, num_attention_heads=16, num_key_value_heads=2, intermediate_size=11008, vocab_size=151936), + "Qwen/Qwen2.5-7B-Instruct": dict(num_hidden_layers=28, hidden_size=3584, num_attention_heads=28, num_key_value_heads=4, intermediate_size=18944, vocab_size=152064), + "Qwen/Qwen2.5-14B-Instruct": dict(num_hidden_layers=48, hidden_size=5120, num_attention_heads=40, num_key_value_heads=8, intermediate_size=13824, vocab_size=152064), } -def resolve_params(model_name_or_path: str) -> Optional[float]: - """Return param count (in billions) for a model name or local path.""" - if model_name_or_path in MODEL_PARAMS_BILLIONS: - return MODEL_PARAMS_BILLIONS[model_name_or_path] - # Local paths like ./models/namezz/...; match by basename. - for key, val in MODEL_PARAMS_BILLIONS.items(): - if model_name_or_path.endswith(key) or model_name_or_path.endswith(key.split("/")[-1]): - return val +@dataclass(frozen=True) +class ModelConfig: + num_hidden_layers: int + hidden_size: int + num_attention_heads: int + num_key_value_heads: int + head_dim: int + intermediate_size: int + vocab_size: int + source: str = "" + + @classmethod + def from_dict(cls, cfg: dict, source: str = "") -> "ModelConfig": + d = cfg["hidden_size"] + Hq = cfg["num_attention_heads"] + Hkv = cfg.get("num_key_value_heads", Hq) + h = cfg.get("head_dim") or (d // Hq) + return cls( + num_hidden_layers=cfg["num_hidden_layers"], + hidden_size=d, + num_attention_heads=Hq, + num_key_value_heads=Hkv, + head_dim=h, + intermediate_size=cfg["intermediate_size"], + vocab_size=cfg["vocab_size"], + source=source, + ) + + @classmethod + def load(cls, name_or_path: str) -> "ModelConfig": + """Locate config.json for a HF model name or a local path.""" + cfg_path = _find_config_json(name_or_path) + if cfg_path is not None: + return cls.from_dict(json.loads(cfg_path.read_text()), source=str(cfg_path)) + fb = _FALLBACK_CONFIGS.get(name_or_path) + if fb is None: + # Strip leading "./" and try basename matches + base = name_or_path.rstrip("/").split("/")[-1] + for key, val in _FALLBACK_CONFIGS.items(): + if key.split("/")[-1] == base: + fb = val + break + if fb is None: + raise FileNotFoundError( + f"Could not locate config.json for {name_or_path!r} and no fallback " + f"dimensions are registered. Add one to _FALLBACK_CONFIGS." + ) + return cls.from_dict(fb, source=f"fallback:{name_or_path}") + + +def _find_config_json(name_or_path: str) -> Optional[Path]: + """Find a HuggingFace-style config.json on disk.""" + p = Path(name_or_path) + if p.is_dir(): + cfg = p / "config.json" + if cfg.exists(): + return cfg + if p.is_file() and p.name == "config.json": + return p + # HF cache: $HF_HOME/hub/models----/snapshots//config.json + hf_home = os.environ.get("HF_HOME") or os.path.expanduser("~/.cache/huggingface") + cache_root = Path(hf_home) / "hub" + if cache_root.exists() and "/" in name_or_path and not name_or_path.startswith("."): + repo_dir = cache_root / ("models--" + name_or_path.replace("/", "--")) + snapshots = repo_dir / "snapshots" + if snapshots.exists(): + for snap in snapshots.iterdir(): + cfg = snap / "config.json" + if cfg.exists(): + return cfg return None -def forward_flops_per_token(n_params_billions: float) -> float: - """2 * N params per token (decode-only forward, matmul-dominated).""" - return 2.0 * n_params_billions * 1e9 +# ---- Per-token forward FLOPs decomposition --------------------------------- -def baseline_run_flops(prompt_tokens: int, output_tokens: int, base_params_b: float) -> float: - """Vanilla decoding: one base-model forward over every prompt + generated token.""" - return forward_flops_per_token(base_params_b) * (prompt_tokens + output_tokens) +def per_layer_linear_flops(cfg: ModelConfig) -> int: + """FLOPs for one token through a single layer's attention projections + MLP.""" + d = cfg.hidden_size + Hq, Hkv, h = cfg.num_attention_heads, cfg.num_key_value_heads, cfg.head_dim + ff = cfg.intermediate_size + q_proj = 2 * d * (Hq * h) + k_proj = 2 * d * (Hkv * h) + v_proj = 2 * d * (Hkv * h) + o_proj = 2 * (Hq * h) * d + # SwiGLU MLP: gate, up, down. gate and up are both d->ff, down is ff->d. + mlp = 2 * d * ff + 2 * d * ff + 2 * ff * d + return q_proj + k_proj + v_proj + o_proj + mlp -def lvm_run_flops( - prompt_tokens: int, - output_tokens: int, - base_params_b: float, - lvm_params_b: float, - k_candidates: int, -) -> float: - """LenVM-guided decoding total FLOPs. +def per_layer_attn_compute_flops(cfg: ModelConfig, seq_len: int) -> int: + """FLOPs for QK^T and attn@V for one token attending to seq_len positions.""" + # 2 * H_q * head_dim * seq_len for each of (QK^T) and (attn @ V) + return 2 * 2 * cfg.num_attention_heads * cfg.head_dim * seq_len + + +def lm_head_flops(cfg: ModelConfig) -> int: + return 2 * cfg.hidden_size * cfg.vocab_size - Per generated token, LenVM adds: - - 1 forward over the LenVM (extend KV by the just-accepted token) - - k forwards over the LenVM (score the top-k candidates) - Prefill on the LenVM happens lazily as candidates are scored, so we - fold it into output_tokens for an upper-bound estimate. +def forward_token_flops(cfg: ModelConfig, position: int) -> Dict[str, int]: + """Decompose one decode-step FLOPs at the given (1-indexed) position.""" + lin = per_layer_linear_flops(cfg) * cfg.num_hidden_layers + attn = per_layer_attn_compute_flops(cfg, position) * cfg.num_hidden_layers + lmh = lm_head_flops(cfg) + return {"linear": lin, "attention": attn, "lm_head": lmh, "total": lin + attn + lmh} + + +def prefill_flops(cfg: ModelConfig, prompt_len: int) -> Dict[str, int]: + """Decompose prefill FLOPs over a prompt of length prompt_len. + + Linear ops scale linearly with prompt_len. Attention is over the + lower-triangular causal mask, so per-layer attention compute sums to + 2 * H_q * h * (S * (S+1) / 2 * 2) = 2 * H_q * h * S * (S+1). + """ + if prompt_len <= 0: + return {"linear": 0, "attention": 0, "lm_head": 0, "total": 0} + lin = per_layer_linear_flops(cfg) * cfg.num_hidden_layers * prompt_len + attn_per_layer = 2 * 2 * cfg.num_attention_heads * cfg.head_dim * prompt_len * (prompt_len + 1) // 2 + attn = attn_per_layer * cfg.num_hidden_layers + lmh = lm_head_flops(cfg) * prompt_len + return {"linear": lin, "attention": attn, "lm_head": lmh, "total": lin + attn + lmh} + + +def decode_flops_sum(cfg: ModelConfig, prompt_len: int, output_len: int) -> Dict[str, int]: + """Sum decode-step FLOPs for output_len tokens, attending to a growing KV cache.""" + if output_len <= 0: + return {"linear": 0, "attention": 0, "lm_head": 0, "total": 0} + lin_per_token = per_layer_linear_flops(cfg) * cfg.num_hidden_layers + lmh_per_token = lm_head_flops(cfg) + # Closed-form: sum_{t=1..L} (prompt_len + t) = L*prompt_len + L*(L+1)/2 + pos_sum = output_len * prompt_len + output_len * (output_len + 1) // 2 + attn = 2 * 2 * cfg.num_attention_heads * cfg.head_dim * pos_sum * cfg.num_hidden_layers + lin = lin_per_token * output_len + lmh = lmh_per_token * output_len + return {"linear": lin, "attention": attn, "lm_head": lmh, "total": lin + attn + lmh} + + +def implied_param_count(cfg: ModelConfig) -> int: + """Approximate parameter count from config (matches 2N rule by 2*N≈linear FLOPs/token).""" + return per_layer_linear_flops(cfg) * cfg.num_hidden_layers // 2 + lm_head_flops(cfg) // 2 + + +# ---- Run-level aggregation ------------------------------------------------- + + +def baseline_run_flops( + cfg: ModelConfig, + *, + unique_prompts: int, + samples_per_prompt: int, + mean_prompt_tokens: float, + mean_output_tokens: float, +) -> Dict[str, int]: + """Total baseline FLOPs assuming prefix caching: prefill once per question, decode per sample.""" + pre = prefill_flops(cfg, int(round(mean_prompt_tokens))) + dec = decode_flops_sum(cfg, int(round(mean_prompt_tokens)), int(round(mean_output_tokens))) + out = { + "prefill": {k: v * unique_prompts for k, v in pre.items()}, + "decode": {k: v * unique_prompts * samples_per_prompt for k, v in dec.items()}, + } + out["total"] = {k: out["prefill"][k] + out["decode"][k] for k in pre.keys()} + return out + + +def lvm_extra_flops( + cfg: ModelConfig, + *, + unique_prompts: int, + samples_per_prompt: int, + mean_prompt_tokens: float, + mean_output_tokens: float, + k_candidates: int, +) -> Dict[str, int]: + """LenVM-only extra FLOPs (on top of baseline). Per generated token: + one extend (1 forward) + k candidates (k forwards). """ - base = baseline_run_flops(prompt_tokens, output_tokens, base_params_b) - lvm_per_step = forward_flops_per_token(lvm_params_b) * (1 + k_candidates) - return base + lvm_per_step * output_tokens + pre = prefill_flops(cfg, int(round(mean_prompt_tokens))) + dec = decode_flops_sum(cfg, int(round(mean_prompt_tokens)), int(round(mean_output_tokens))) + out = { + "lenvm_prefill": {k: v * unique_prompts for k, v in pre.items()}, + # Each output token triggers 1 extend at that position + k candidate forwards. + "lenvm_extend": {k: v * unique_prompts * samples_per_prompt for k, v in dec.items()}, + "lenvm_candidates": {k: v * unique_prompts * samples_per_prompt * k_candidates for k, v in dec.items()}, + } + out["total"] = {k: out["lenvm_prefill"][k] + out["lenvm_extend"][k] + out["lenvm_candidates"][k] for k in pre.keys()} + return out From b1b4ba0a64c2a62ea2a937f265fea71e016ceb22 Mon Sep 17 00:00:00 2001 From: changyiyang Date: Fri, 22 May 2026 21:57:23 +0000 Subject: [PATCH 4/5] Charge LenVM forwards a value-head, not lm_head; add candidate-cost knob Two corrections to the FLOPs estimator after PR review feedback: 1. LenVM checkpoints use MLP2SiLUValueHead (d->d Linear + d->1 Linear, see sglang/srt/models/qwen2_lvm.py::MLP2SiLUValueHead), not the base model's lm_head (d * vocab_size). For Qwen2.5-1.5B that's ~4.7M FLOPs per token instead of ~467M, so the previous accounting was overcharging each LenVM forward by ~15%. ModelConfig.head_type is now "lm_head" / "value_head", and ModelConfig.load(...) autodetects by checking for value_head.safetensors or LengthValueModel-style architectures in config.json. 2. lvm_extra_flops gains a candidate_cost_multiplier=1.0 knob. The default matches the current sglang-LenVM in-proc path where each candidate is a separate single-token forward sharing only the extended KV cache. A future implementation that batches k candidates into one forward and amortizes some compute can pass a value < 1.0. Also fix _find_config_json to look under ./models// so the analyzer can resolve checkpoints downloaded with `hf download --local-dir`. Re-running on the same 50-q dataset: theoretical LenVM extra drops from 4.48 PFLOPs to 3.82 PFLOPs, ratio drops from 2.30x to 2.10x. Wall-clock ratio is unchanged at 4.55x, so the GPU-utilization gap widens slightly. README gains a TL;DR Results section + caveats covering both knobs. Co-Authored-By: Claude Opus 4.7 (1M context) --- inference/timing/README.md | 44 +++++++++++++++ inference/timing/analyze.py | 4 +- inference/timing/flops.py | 110 +++++++++++++++++++++++++++++++----- 3 files changed, 141 insertions(+), 17 deletions(-) diff --git a/inference/timing/README.md b/inference/timing/README.md index 64dd838..b472401 100644 --- a/inference/timing/README.md +++ b/inference/timing/README.md @@ -7,6 +7,38 @@ This exists to answer the question raised in the LenVM paper review: how much wall-clock overhead does LenVM-guided decoding add on top of vanilla decoding, and where does that overhead live inside each decoding step? +## TL;DR results + +Reference configuration: 1× H100 SXM, Qwen2.5-7B-Instruct base + the +`lvm-math-0402-a-qwen2.5-7b-instruct-b-qwen2.5-1.5b-instruct` value model, +GSM8K 50 questions × 16 samples, `max_tokens=6000`, T=1.0, top-p=1.0, +baseline `top_k=-1`, LenVM `top_k=5`, value-scale 0 (paper config). + +| metric | baseline | LenVM | ratio | +| --- | ---: | ---: | ---: | +| end-to-end wall clock (s) | 19.22 | 87.44 | **4.55× slower** | +| throughput (output tok/s) | 12,366 | 2,719 | **4.55× drop** | +| mean `Sampler.forward` (ms / step) | 0.19 | 50.31 | **268×** | +| theoretical GFLOPs / output token | 14.24 | 30.24 | **2.12×** | +| achieved throughput (TFLOPs/s) | 179.9 | 83.2 | 0.46× | +| H100 bf16 peak utilization | ~18% | ~9% | — | + +`Sampler.forward` breakdown (mean ms / step): pre-LVM 0.07 → 0.14, +`LvmGuidedSampler.apply` 0.00 → 50.03 (build_pending 12.3, LenVM forward 36.6, +apply_guidance 5.3), sample kernel 0.12 → 0.17. The LenVM `apply()` call is +~99% of the per-step latency delta; pre-LVM and sample-kernel costs are +unchanged. + +FLOPs breakdown (PFLOPs over the run): baseline.linear 3.17, base.attention +0.02, base.lm_head 0.26 — same for both runs. LenVM adds extend 0.63 + +candidates 3.17 + prefill 0.01 = 3.82 PFLOPs extra. Of the extra cost, +~83% is the `top_k=5` candidate forwards through the 1.5B value model. + +**Conclusion**: LenVM nearly doubles the theoretical compute (2.10× total +PFLOPs) and slows wall clock by 4.55×. Half of the slowdown is genuine +extra compute; the other half is GPU underutilization (CPU candidate prep +stalls, separate-stream sync, smaller effective batch on the value model). + ## What gets measured Two server lifecycles drive the comparison, so per-step instrumentation only @@ -94,5 +126,17 @@ The script chains three stages: at production batch sizes the residual is small but visible at low load. - Single-rank only. For TP/DP > 1 the timer writes from one rank; extend the log filename with the rank suffix if you need per-worker traces. +- LenVM-extra FLOPs assume **independent single-token forwards** for each of + the `k` candidates per generated token. This matches the current + sglang-LenVM in-proc path (`tree_value_extend` extends KV once, then + each candidate is a separate single-token forward attending to the shared + prefix cache). If a future implementation batches all `k` candidates into + one forward and amortizes some MLP/attention cost, pass + `candidate_cost_multiplier < 1.0` to `lvm_extra_flops` to scale that term. +- LenVM head cost uses the small `MLP2SiLUValueHead` (`d*d + d*out_dim`) not + the base model's `lm_head` (`d * vocab_size`). `ModelConfig.load(...)` + auto-detects this by checking for `value_head.safetensors` next to + `config.json` or `LengthValueModel`/`ValueModel`/`ValueHead` in the + config's `architectures`. Force the choice via `head_type=...` if needed. - Wall-clock comparisons assume both servers see the same prompts at the same concurrency. Run them back-to-back on a quiet GPU. diff --git a/inference/timing/analyze.py b/inference/timing/analyze.py index a4df2fb..1ecc6ca 100644 --- a/inference/timing/analyze.py +++ b/inference/timing/analyze.py @@ -405,11 +405,11 @@ def main() -> int: if base_cfg is not None: print(f"base config (source={base_cfg.source}): L={base_cfg.num_hidden_layers} d={base_cfg.hidden_size} " f"Hq={base_cfg.num_attention_heads} Hkv={base_cfg.num_key_value_heads} h={base_cfg.head_dim} " - f"ff={base_cfg.intermediate_size} V={base_cfg.vocab_size}") + f"ff={base_cfg.intermediate_size} V={base_cfg.vocab_size} head={base_cfg.head_type}") if lvm_cfg is not None: print(f"lvm config (source={lvm_cfg.source}): L={lvm_cfg.num_hidden_layers} d={lvm_cfg.hidden_size} " f"Hq={lvm_cfg.num_attention_heads} Hkv={lvm_cfg.num_key_value_heads} h={lvm_cfg.head_dim} " - f"ff={lvm_cfg.intermediate_size} V={lvm_cfg.vocab_size}") + f"ff={lvm_cfg.intermediate_size} V={lvm_cfg.vocab_size} head={lvm_cfg.head_type}") rd = args.results_dir rows: List[Dict[str, Any]] = [] diff --git a/inference/timing/flops.py b/inference/timing/flops.py index e606915..5f4088b 100644 --- a/inference/timing/flops.py +++ b/inference/timing/flops.py @@ -65,10 +65,19 @@ class ModelConfig: head_dim: int intermediate_size: int vocab_size: int + # Head type used at the top of the stack. "lm_head" = 2*d*V vocab projection + # (standard causal LM). "value_head" = small MLP -> scalar (LenVM checkpoints + # ship a MLP2SiLUValueHead: d->d Linear + d->1 Linear, see + # sglang/srt/models/qwen2_lvm.py::MLP2SiLUValueHead). + head_type: str = "lm_head" + value_head_hidden: Optional[int] = None # MLP hidden dim if head_type == "value_head" + value_head_out_dim: int = 1 source: str = "" @classmethod - def from_dict(cls, cfg: dict, source: str = "") -> "ModelConfig": + def from_dict(cls, cfg: dict, source: str = "", *, head_type: str = "lm_head", + value_head_hidden: Optional[int] = None, + value_head_out_dim: int = 1) -> "ModelConfig": d = cfg["hidden_size"] Hq = cfg["num_attention_heads"] Hkv = cfg.get("num_key_value_heads", Hq) @@ -81,18 +90,33 @@ def from_dict(cls, cfg: dict, source: str = "") -> "ModelConfig": head_dim=h, intermediate_size=cfg["intermediate_size"], vocab_size=cfg["vocab_size"], + head_type=head_type, + value_head_hidden=value_head_hidden if value_head_hidden is not None else d, + value_head_out_dim=value_head_out_dim, source=source, ) @classmethod - def load(cls, name_or_path: str) -> "ModelConfig": - """Locate config.json for a HF model name or a local path.""" + def load(cls, name_or_path: str, *, head_type: str = "auto", + value_head_out_dim: int = 1) -> "ModelConfig": + """Locate config.json for a HF model name or a local path. + + head_type="auto" (default) inspects the directory for value_head.safetensors, + treating its presence as a LenVM-style value-head checkpoint. Pass + head_type="value_head" or "lm_head" to force the choice. + """ cfg_path = _find_config_json(name_or_path) if cfg_path is not None: - return cls.from_dict(json.loads(cfg_path.read_text()), source=str(cfg_path)) + cfg_dict = json.loads(cfg_path.read_text()) + ht = head_type if head_type != "auto" else _autodetect_head(cfg_path.parent, cfg_dict) + return cls.from_dict( + cfg_dict, + source=str(cfg_path), + head_type=ht, + value_head_out_dim=value_head_out_dim, + ) fb = _FALLBACK_CONFIGS.get(name_or_path) if fb is None: - # Strip leading "./" and try basename matches base = name_or_path.rstrip("/").split("/")[-1] for key, val in _FALLBACK_CONFIGS.items(): if key.split("/")[-1] == base: @@ -103,7 +127,21 @@ def load(cls, name_or_path: str) -> "ModelConfig": f"Could not locate config.json for {name_or_path!r} and no fallback " f"dimensions are registered. Add one to _FALLBACK_CONFIGS." ) - return cls.from_dict(fb, source=f"fallback:{name_or_path}") + ht = head_type if head_type != "auto" else "lm_head" + return cls.from_dict(fb, source=f"fallback:{name_or_path}", head_type=ht, + value_head_out_dim=value_head_out_dim) + + +def _autodetect_head(model_dir: Path, cfg_dict: dict) -> str: + """Return 'value_head' if a value_head.safetensors sits next to config.json, + or the loaded architecture is a known value-head class. Otherwise 'lm_head'. + """ + if (model_dir / "value_head.safetensors").exists(): + return "value_head" + for arch in cfg_dict.get("architectures", []) or []: + if "LengthValueModel" in arch or "ValueModel" in arch or "ValueHead" in arch: + return "value_head" + return "lm_head" def _find_config_json(name_or_path: str) -> Optional[Path]: @@ -126,6 +164,12 @@ def _find_config_json(name_or_path: str) -> Optional[Path]: cfg = snap / "config.json" if cfg.exists(): return cfg + # Local download dir convention (download_data_and_model.sh writes here): + if "/" in name_or_path and not name_or_path.startswith("."): + local_dir = Path("./models") / name_or_path + cfg = local_dir / "config.json" + if cfg.exists(): + return cfg return None @@ -153,14 +197,33 @@ def per_layer_attn_compute_flops(cfg: ModelConfig, seq_len: int) -> int: def lm_head_flops(cfg: ModelConfig) -> int: + """Vocab projection 2 * hidden * vocab_size.""" return 2 * cfg.hidden_size * cfg.vocab_size +def value_head_flops(cfg: ModelConfig) -> int: + """MLP2SiLUValueHead: hidden->hidden (fc) + hidden->out_dim (summary). + + SiLU activation is negligible vs the two matmuls. The summary projection + is tiny when out_dim=1 (the default) but kept for completeness. + """ + fc = 2 * cfg.hidden_size * (cfg.value_head_hidden or cfg.hidden_size) + summary = 2 * (cfg.value_head_hidden or cfg.hidden_size) * cfg.value_head_out_dim + return fc + summary + + +def head_flops(cfg: ModelConfig) -> int: + """FLOPs at the top of the stack, dispatched by ModelConfig.head_type.""" + if cfg.head_type == "value_head": + return value_head_flops(cfg) + return lm_head_flops(cfg) + + def forward_token_flops(cfg: ModelConfig, position: int) -> Dict[str, int]: """Decompose one decode-step FLOPs at the given (1-indexed) position.""" lin = per_layer_linear_flops(cfg) * cfg.num_hidden_layers attn = per_layer_attn_compute_flops(cfg, position) * cfg.num_hidden_layers - lmh = lm_head_flops(cfg) + lmh = head_flops(cfg) return {"linear": lin, "attention": attn, "lm_head": lmh, "total": lin + attn + lmh} @@ -176,7 +239,7 @@ def prefill_flops(cfg: ModelConfig, prompt_len: int) -> Dict[str, int]: lin = per_layer_linear_flops(cfg) * cfg.num_hidden_layers * prompt_len attn_per_layer = 2 * 2 * cfg.num_attention_heads * cfg.head_dim * prompt_len * (prompt_len + 1) // 2 attn = attn_per_layer * cfg.num_hidden_layers - lmh = lm_head_flops(cfg) * prompt_len + lmh = head_flops(cfg) * prompt_len return {"linear": lin, "attention": attn, "lm_head": lmh, "total": lin + attn + lmh} @@ -185,7 +248,7 @@ def decode_flops_sum(cfg: ModelConfig, prompt_len: int, output_len: int) -> Dict if output_len <= 0: return {"linear": 0, "attention": 0, "lm_head": 0, "total": 0} lin_per_token = per_layer_linear_flops(cfg) * cfg.num_hidden_layers - lmh_per_token = lm_head_flops(cfg) + lmh_per_token = head_flops(cfg) # Closed-form: sum_{t=1..L} (prompt_len + t) = L*prompt_len + L*(L+1)/2 pos_sum = output_len * prompt_len + output_len * (output_len + 1) // 2 attn = 2 * 2 * cfg.num_attention_heads * cfg.head_dim * pos_sum * cfg.num_hidden_layers @@ -229,17 +292,34 @@ def lvm_extra_flops( mean_prompt_tokens: float, mean_output_tokens: float, k_candidates: int, + candidate_cost_multiplier: float = 1.0, ) -> Dict[str, int]: - """LenVM-only extra FLOPs (on top of baseline). Per generated token: - one extend (1 forward) + k candidates (k forwards). + """LenVM-only extra FLOPs (on top of baseline) per output token: + + * one ``tree_value_extend`` forward (catch the value-model KV up with the + just-accepted token; a single-token decode at the current position). + * ``k_candidates`` value-model forwards (score the top-k candidate tokens; + each is a single-token decode at the position right after the extend). + + ``candidate_cost_multiplier`` is a knob for future LenVM implementations + that share work across candidates (e.g. batched single-forward scoring + that amortizes some MLP / attention cost across the k candidates). The + default of ``1.0`` matches the current sglang-LenVM in-proc path, where + each candidate is a separate single-token forward sharing only the + extended KV cache. Set to e.g. ``0.6`` if a future implementation can + batch the k candidates into one forward. + + Whether ``cfg.head_type`` is ``lm_head`` or ``value_head`` controls + whether each forward is charged the 2*d*V vocab projection (causal LM) + or the much smaller MLP2SiLUValueHead (LenVM checkpoints). """ pre = prefill_flops(cfg, int(round(mean_prompt_tokens))) dec = decode_flops_sum(cfg, int(round(mean_prompt_tokens)), int(round(mean_output_tokens))) + candidate_scale = unique_prompts * samples_per_prompt * k_candidates * candidate_cost_multiplier out = { - "lenvm_prefill": {k: v * unique_prompts for k, v in pre.items()}, - # Each output token triggers 1 extend at that position + k candidate forwards. - "lenvm_extend": {k: v * unique_prompts * samples_per_prompt for k, v in dec.items()}, - "lenvm_candidates": {k: v * unique_prompts * samples_per_prompt * k_candidates for k, v in dec.items()}, + "lenvm_prefill": {k: v * unique_prompts for k, v in pre.items()}, + "lenvm_extend": {k: v * unique_prompts * samples_per_prompt for k, v in dec.items()}, + "lenvm_candidates": {k: int(v * candidate_scale) for k, v in dec.items()}, } out["total"] = {k: out["lenvm_prefill"][k] + out["lenvm_extend"][k] + out["lenvm_candidates"][k] for k in pre.keys()} return out From 799025af78614733dd846ac05028ad50306bed64 Mon Sep 17 00:00:00 2001 From: changyiyang Date: Fri, 22 May 2026 22:24:50 +0000 Subject: [PATCH 5/5] Add top-k ablation aggregator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit inference/timing/sweep_analyze.py reads multiple lenvm_timing.sh result directories (one per top-k setting) and emits: - sweep_summary.csv / sweep_summary.json: per-k row with baseline + LenVM wall clock, achieved TFLOPs/s, theoretical FLOPs ratio, utilization gap, and LenVM apply / forward latency means. - topk_sweep.png: theoretical FLOPs ratio vs measured wall-clock ratio vs achieved TFLOPs/s ratio, plotted against k. Run on k ∈ {1,2,3,4,5} with the existing 50 q × 16 sample setup: k | base_s | lvm_s | wall_ratio | flops_ratio | apply_ms --+--------+-------+------------+-------------+--------- 1 | 24.19 | 23.19 | 0.96 | 1.30 | — (greedy fast path, LenVM skipped) 2 | 23.22 | 80.37 | 3.46 | 1.55 | 45.93 3 | 27.21 | 81.30 | 2.99 | 1.74 | 51.15 4 | 22.30 | 83.53 | 3.75 | 1.91 | 46.23 5 | 18.19 | 80.50 | 4.43 | 2.08 | 48.37 Two takeaways from the sweep: - top-k=1 with temperature=1.0 hits SGLang's is_all_greedy branch in Sampler.forward and the LenVM apply() hook is never invoked. The paper config really starts at k=2. - From k=2 to k=5 the LenVM apply latency is flat (~46-51 ms/step) and the LenVM forward latency is flat (~32-34 ms/step), so the measured wall-clock is essentially constant in k while theoretical FLOPs grows linearly. The bottleneck is the per-step sync/CPU-prep cost, not the candidate-set matmul, so increasing k inside this range is roughly free. Co-Authored-By: Claude Opus 4.7 (1M context) --- inference/timing/sweep_analyze.py | 195 ++++++++++++++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 inference/timing/sweep_analyze.py diff --git a/inference/timing/sweep_analyze.py b/inference/timing/sweep_analyze.py new file mode 100644 index 0000000..781dfe2 --- /dev/null +++ b/inference/timing/sweep_analyze.py @@ -0,0 +1,195 @@ +"""Aggregate timing results from multiple ``lenvm_timing.sh`` runs into a top-k sweep. + +Each input ``--results-dirs`` entry is a directory produced by a single +``lenvm_timing.sh`` invocation (so it contains ``summary.json`` + per-stage +``*.meta.json`` + ``*.timing.jsonl``). This script extracts ``LENVM_TOP_K`` from +the LenVM ``meta.json`` and aggregates the per-run metrics into a single CSV + +plot, plus a stdout table. Use this to answer: + +- "How does the wall-clock slowdown scale with the LenVM candidate-set size?" +- "Does the theoretical FLOPs ratio track the measured wall-clock ratio as k + grows, or does the GPU-utilization gap widen?" +""" + +from __future__ import annotations + +import argparse +import csv +import json +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +_RATIO_KEYS = [ + "wall_clock_s", + "throughput_output_tokens_per_s", + "theoretical_pflops_total", + "achieved_tflops_per_s", + "t_sampler_total_ms_mean", + "t_lvm_apply_outer_ms_mean", + "t_lvm_forward_ms_mean", +] + + +def _load_summary(rd: Path) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + """Return (baseline_row, lenvm_row, ratio_row) from a summary.json.""" + rows: List[Dict[str, Any]] = json.loads((rd / "summary.json").read_text()) + base = next(r for r in rows if r.get("tag") == "baseline") + lvm = next(r for r in rows if r.get("tag") == "lenvm") + ratio = next((r for r in rows if "ratio" in str(r.get("tag"))), {}) + return base, lvm, ratio + + +def _infer_k(rd: Path, lvm_row: Dict[str, Any]) -> Optional[int]: + k = lvm_row.get("top_k") + if isinstance(k, int) and k > 0: + return k + # Fall back to parsing the directory name (e.g. sweep_q50_n16_k5). + m = re.search(r"_k(\d+)", rd.name) + if m: + return int(m.group(1)) + return None + + +def _percent(x: Optional[float]) -> str: + if x is None: + return "" + return f"{x * 100:.1f}%" + + +def _fmt(v: Any) -> str: + if v is None: + return "" + if isinstance(v, float): + return f"{v:.2f}" + return str(v) + + +def _print_table(rows: List[Dict[str, Any]], cols: List[Tuple[str, str]]) -> None: + widths = {k: max(len(label), max(len(_fmt(r.get(k))) for r in rows)) for k, label in cols} + header = " | ".join(f"{label:>{widths[k]}}" for k, label in cols) + print(header) + print("-+-".join("-" * widths[k] for k, _ in cols)) + for r in rows: + print(" | ".join(f"{_fmt(r.get(k)):>{widths[k]}}" for k, _ in cols)) + + +def _plot(rows: List[Dict[str, Any]], out_png: Path) -> Optional[Path]: + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except Exception as e: + print(f"matplotlib unavailable, skipping plot: {e}") + return None + + rows = [r for r in rows if r.get("k") is not None] + if not rows: + return None + rows = sorted(rows, key=lambda r: r["k"]) + ks = [r["k"] for r in rows] + flops_ratio = [r.get("flops_ratio") for r in rows] + wall_ratio = [r.get("wall_ratio") for r in rows] + util_ratio = [r.get("achieved_tflops_ratio") for r in rows] + + fig, ax = plt.subplots(figsize=(8, 5)) + ax.plot(ks, flops_ratio, marker="o", label="theoretical FLOPs ratio (LenVM / baseline)") + ax.plot(ks, wall_ratio, marker="s", label="measured wall-clock ratio") + if any(v is not None for v in util_ratio): + ax.plot(ks, util_ratio, marker="^", linestyle="--", + label="achieved TFLOPs/s ratio (≤1 means utilization loss)") + ax.axhline(1.0, color="gray", linewidth=0.5, linestyle=":") + ax.set_xlabel("LenVM candidate-set size k") + ax.set_ylabel("Ratio (LenVM / baseline)") + ax.set_title("LenVM overhead vs candidate-set size") + ax.set_xticks(ks) + ax.legend(loc="upper left", fontsize="small") + ax.grid(alpha=0.3) + fig.tight_layout() + fig.savefig(out_png, dpi=140) + plt.close(fig) + return out_png + + +def main() -> int: + p = argparse.ArgumentParser() + p.add_argument("--results-dirs", type=Path, nargs="+", required=True, + help="Directories produced by lenvm_timing.sh, one per top-k value.") + p.add_argument("--output-dir", type=Path, required=True) + args = p.parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + rows: List[Dict[str, Any]] = [] + for rd in args.results_dirs: + if not (rd / "summary.json").exists(): + print(f"skipping {rd} — no summary.json") + continue + base, lvm, ratio = _load_summary(rd) + k = _infer_k(rd, lvm) + row: Dict[str, Any] = { + "k": k, + "results_dir": str(rd), + "base_wall_s": base.get("wall_clock_s"), + "lvm_wall_s": lvm.get("wall_clock_s"), + "wall_ratio": ratio.get("wall_clock_s"), + "base_tok_per_s": base.get("throughput_output_tokens_per_s"), + "lvm_tok_per_s": lvm.get("throughput_output_tokens_per_s"), + "tok_per_s_ratio": ratio.get("throughput_output_tokens_per_s"), + "base_pflops": base.get("theoretical_pflops_total"), + "lvm_pflops": lvm.get("theoretical_pflops_total"), + "flops_ratio": ratio.get("theoretical_pflops_total"), + "base_achieved_tflops_per_s": base.get("achieved_tflops_per_s"), + "lvm_achieved_tflops_per_s": lvm.get("achieved_tflops_per_s"), + "achieved_tflops_ratio": ratio.get("achieved_tflops_per_s"), + "lvm_apply_ms_mean": lvm.get("t_lvm_apply_outer_ms_mean"), + "lvm_forward_ms_mean": lvm.get("t_lvm_forward_ms_mean"), + "lvm_build_pending_ms_mean": lvm.get("t_lvm_build_pending_ms_mean"), + "lvm_apply_guidance_ms_mean": lvm.get("t_lvm_apply_guidance_ms_mean"), + } + if row["wall_ratio"] and row["flops_ratio"]: + # Utilization gap = wall-clock ratio / theoretical FLOPs ratio. + # >1 means LenVM is slower wall-clock than its raw extra compute alone would predict. + row["utilization_gap"] = row["wall_ratio"] / row["flops_ratio"] + rows.append(row) + + rows = sorted(rows, key=lambda r: (r.get("k") if r.get("k") is not None else 999)) + + summary_json = args.output_dir / "sweep_summary.json" + summary_json.write_text(json.dumps(rows, indent=2)) + + csv_path = args.output_dir / "sweep_summary.csv" + if rows: + keys = sorted({k for r in rows for k in r.keys()}) + with csv_path.open("w", newline="") as f: + w = csv.DictWriter(f, fieldnames=keys) + w.writeheader() + for r in rows: + w.writerow({k: r.get(k) for k in keys}) + + plot_path = _plot(rows, args.output_dir / "topk_sweep.png") + + print(f"summary.json -> {summary_json}") + print(f"summary.csv -> {csv_path}") + if plot_path: + print(f"plot -> {plot_path}") + print() + _print_table( + rows, + cols=[ + ("k", "k"), + ("base_wall_s", "base_s"), + ("lvm_wall_s", "lvm_s"), + ("wall_ratio", "wall_ratio"), + ("flops_ratio", "flops_ratio"), + ("utilization_gap", "util_gap"), + ("lvm_apply_ms_mean", "apply_ms"), + ("lvm_forward_ms_mean", "lvm_fwd_ms"), + ("lvm_tok_per_s", "lvm_tok/s"), + ], + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())