diff --git a/inference/timing/README.md b/inference/timing/README.md new file mode 100644 index 0000000..b472401 --- /dev/null +++ b/inference/timing/README.md @@ -0,0 +1,142 @@ +# LenVM inference timing analysis + +End-to-end and per-decoding-step latency decomposition for the LenVM-guided +sampling path, compared against an otherwise-identical baseline SGLang server. + +This exists to answer the question raised in the LenVM paper review: how much +wall-clock overhead does LenVM-guided decoding add on top of vanilla decoding, +and where does that overhead live inside each decoding step? + +## TL;DR results + +Reference configuration: 1× H100 SXM, Qwen2.5-7B-Instruct base + the +`lvm-math-0402-a-qwen2.5-7b-instruct-b-qwen2.5-1.5b-instruct` value model, +GSM8K 50 questions × 16 samples, `max_tokens=6000`, T=1.0, top-p=1.0, +baseline `top_k=-1`, LenVM `top_k=5`, value-scale 0 (paper config). + +| metric | baseline | LenVM | ratio | +| --- | ---: | ---: | ---: | +| end-to-end wall clock (s) | 19.22 | 87.44 | **4.55× slower** | +| throughput (output tok/s) | 12,366 | 2,719 | **4.55× drop** | +| mean `Sampler.forward` (ms / step) | 0.19 | 50.31 | **268×** | +| theoretical GFLOPs / output token | 14.24 | 30.24 | **2.12×** | +| achieved throughput (TFLOPs/s) | 179.9 | 83.2 | 0.46× | +| H100 bf16 peak utilization | ~18% | ~9% | — | + +`Sampler.forward` breakdown (mean ms / step): pre-LVM 0.07 → 0.14, +`LvmGuidedSampler.apply` 0.00 → 50.03 (build_pending 12.3, LenVM forward 36.6, +apply_guidance 5.3), sample kernel 0.12 → 0.17. The LenVM `apply()` call is +~99% of the per-step latency delta; pre-LVM and sample-kernel costs are +unchanged. + +FLOPs breakdown (PFLOPs over the run): baseline.linear 3.17, base.attention +0.02, base.lm_head 0.26 — same for both runs. LenVM adds extend 0.63 + +candidates 3.17 + prefill 0.01 = 3.82 PFLOPs extra. Of the extra cost, +~83% is the `top_k=5` candidate forwards through the 1.5B value model. + +**Conclusion**: LenVM nearly doubles the theoretical compute (2.10× total +PFLOPs) and slows wall clock by 4.55×. Half of the slowdown is genuine +extra compute; the other half is GPU underutilization (CPU candidate prep +stalls, separate-stream sync, smaller effective batch on the value model). + +## What gets measured + +Two server lifecycles drive the comparison, so per-step instrumentation only +captures the configuration under test: + +1. **baseline** — SGLang with `--enable-lvm-guided-sampling` off. Vanilla + chat-completion sampling against `Qwen/Qwen2.5-7B-Instruct`. +2. **lenvm** — Same base model, plus the in-process LenVM value model + (`./models/namezz/lvm-math-0402-a-qwen2.5-7b-instruct-b-qwen2.5-1.5b-instruct` + by default), with `--enable-lvm-guided-sampling --lvm-guided-inproc + --lvm-guided-fn lvm_combined_guidance`. + +Both servers point `SGLANG_LVM_TIMING_LOG` at their own JSONL file. The client +hits each server in turn with the same GSM8K prompt set (50 questions × 16 +samples by default), recording end-to-end wall clock per run. + +The instrumentation lives in `sglang-LenVM/python/sglang/srt/lvm/timing.py` and +hooks `Sampler.forward` plus `LvmGuidedSampler.apply`. Per decoding step it +emits one JSONL line with: + +- `t_sampler_total_ms` — total time inside `Sampler.forward` +- `t_pre_lvm_ms` — preprocess + temperature scaling + softmax +- `t_lvm_apply_outer_ms` — full `LvmGuidedSampler.apply` call + - `t_lvm_build_pending_ms` — gather candidates & request state + - `t_lvm_forward_ms` — LenVM extend + launch + collect + - `t_lvm_apply_guidance_ms` — apply value-based adjustment to probs +- `t_sample_ms` — sampling kernel +- `lvm_active`, `batch_size`, `is_greedy` + +Theoretical FLOPs are computed by `analyze.py` at the layer level. `flops.py` +loads each model's `config.json` (HF cache or local dir; falls back to +hardcoded Qwen2.5 dims if missing) and counts: + +- per-layer linear matmuls: Q / K / V / O projections (GQA-aware) + SwiGLU MLP + (gate + up + down) +- per-layer attention compute: `2 * H_q * head_dim * seq_len` for each of + Q@K^T and attn@V (so attention contribution scales with position) +- `lm_head`: `2 * hidden * vocab` + +A baseline run is split into prefill (charged once per unique prompt, since +SGLang prefix caching is on by default) and decode (per sample). A LenVM run +adds, per generated token, one `tree_value_extend` forward plus `k` +candidate forwards through the value model. The analyzer reports both the +total PFLOPs and a per-component split so the linear / attention / lm_head +shares of the baseline are visible alongside the LenVM-specific overhead. +Contrasting the theoretical FLOPs ratio with the measured wall-clock ratio +shows how much of the slowdown is raw compute increase vs GPU +underutilization. + +## Running it + +```bash +# from repository root, with .venv-infer and .venv-eval already built +bash scripts/inference/lenvm_timing.sh +``` + +Overridable knobs (env vars; see top of the script for defaults): +`BASE_MODEL`, `LENVM_MODEL`, `MAX_QUESTIONS`, `N_SAMPLES`, `MAX_TOKENS`, +`MAX_CONCURRENCY`, `LENVM_TOP_K`, `LENVM_VALUE_SCALE`, `RESULTS_DIR`. + +The script chains three stages: + +1. Start baseline server → run `inference.timing.run_timing` → kill server +2. Start LenVM server → run `inference.timing.run_timing` → kill server +3. `inference.timing.analyze` reads both JSONL streams + meta files and writes + a `summary.csv`, `summary.json`, and two plots into `RESULTS_DIR`. + +## Outputs (under `RESULTS_DIR`) + +- `baseline.timing.jsonl`, `lenvm.timing.jsonl` — per-step records +- `baseline.meta.json`, `lenvm.meta.json` — wall-clock, token counts, cmdline +- `baseline.gpu_samples.csv`, `lenvm.gpu_samples.csv` — `nvidia-smi` 1 Hz log +- `summary.csv`, `summary.json` — aggregated table (incl. theoretical FLOPs, achieved TFLOPs/s, ratio row) +- `per_step_breakdown.png` — stacked bar of sampler-side decomposition +- `lvm_apply_breakdown.png` — LenVM `apply()` internal breakdown +- `flops_breakdown.png` — stacked bar of theoretical FLOPs by component + (base linear / attention / lm_head + LenVM extend / candidates / prefill) + +## Caveats + +- The per-step timer adds Python-level `time.perf_counter()` calls on the + decoding hot path. They are no-ops when `SGLANG_LVM_TIMING_LOG` is unset. +- `t_lvm_apply_outer_ms` covers a few short helpers (e.g. + `_get_inproc_provider`, `clean_stale_requests`) not separately broken out; + at production batch sizes the residual is small but visible at low load. +- Single-rank only. For TP/DP > 1 the timer writes from one rank; extend the + log filename with the rank suffix if you need per-worker traces. +- LenVM-extra FLOPs assume **independent single-token forwards** for each of + the `k` candidates per generated token. This matches the current + sglang-LenVM in-proc path (`tree_value_extend` extends KV once, then + each candidate is a separate single-token forward attending to the shared + prefix cache). If a future implementation batches all `k` candidates into + one forward and amortizes some MLP/attention cost, pass + `candidate_cost_multiplier < 1.0` to `lvm_extra_flops` to scale that term. +- LenVM head cost uses the small `MLP2SiLUValueHead` (`d*d + d*out_dim`) not + the base model's `lm_head` (`d * vocab_size`). `ModelConfig.load(...)` + auto-detects this by checking for `value_head.safetensors` next to + `config.json` or `LengthValueModel`/`ValueModel`/`ValueHead` in the + config's `architectures`. Force the choice via `head_type=...` if needed. +- Wall-clock comparisons assume both servers see the same prompts at the same + concurrency. Run them back-to-back on a quiet GPU. diff --git a/inference/timing/__init__.py b/inference/timing/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/inference/timing/analyze.py b/inference/timing/analyze.py new file mode 100644 index 0000000..1ecc6ca --- /dev/null +++ b/inference/timing/analyze.py @@ -0,0 +1,457 @@ +"""Compare baseline vs LenVM-guided timing runs produced by lenvm_timing.sh. + +Inputs (from --results-dir): + baseline.timing.jsonl per-step records from the no-LenVM server + lenvm.timing.jsonl per-step records from the LenVM-enabled server + baseline.meta.json end-to-end wall clock + token counts (run_timing.py) + lenvm.meta.json same, for LenVM run + +Outputs (in --results-dir): + summary.csv one row per setting with e2e + per-step aggregates + summary.json same as CSV but JSON + per_step_breakdown.png stacked bar of per-step decomposition + +The per-step decomposition is what reviewers asked for: how much of a decoding +step is base-model forward (inferred from "step not under sampler"), how much +is sampler-side preprocess, how much is the LenVM forward + guidance overlay. +""" + +from __future__ import annotations + +import argparse +import csv +import json +import math +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + +from inference.timing.flops import ( + ModelConfig, + baseline_run_flops, + lvm_extra_flops, +) + + +_PER_STEP_KEYS = [ + "t_sampler_total_ms", + "t_pre_lvm_ms", + "t_lvm_apply_outer_ms", + "t_lvm_build_pending_ms", + "t_lvm_forward_ms", + "t_lvm_apply_guidance_ms", + "t_sample_ms", +] + + +def _iter_records(path: Path) -> Iterable[Dict[str, Any]]: + if not path.exists(): + return [] + out: List[Dict[str, Any]] = [] + with path.open() as f: + for line in f: + line = line.strip() + if not line: + continue + try: + out.append(json.loads(line)) + except json.JSONDecodeError: + continue + return out + + +def _percentile(values: List[float], p: float) -> float: + if not values: + return float("nan") + values = sorted(values) + k = (len(values) - 1) * p + f = math.floor(k) + c = math.ceil(k) + if f == c: + return values[int(k)] + return values[f] + (values[c] - values[f]) * (k - f) + + +def _filter_warmup(records: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Drop the small-batch greedy warmup steps SGLang emits at server start + (POST /generate handshake with a single token, bs=1, is_greedy=True).""" + return [r for r in records if not r.get("is_greedy", False)] + + +def _agg(records: List[Dict[str, Any]]) -> Dict[str, Any]: + """Aggregate per-step records into mean/p50/p95/sum per key.""" + out: Dict[str, Any] = {"n_steps": len(records)} + for key in _PER_STEP_KEYS: + vals: List[float] = [] + for r in records: + v = r.get(key) + if isinstance(v, (int, float)): + vals.append(float(v)) + if not vals: + continue + out[f"{key}_mean"] = sum(vals) / len(vals) + out[f"{key}_p50"] = _percentile(vals, 0.5) + out[f"{key}_p95"] = _percentile(vals, 0.95) + out[f"{key}_sum"] = sum(vals) + out[f"{key}_count"] = len(vals) + batch_sizes = [int(r["batch_size"]) for r in records if isinstance(r.get("batch_size"), int)] + if batch_sizes: + out["batch_size_mean"] = sum(batch_sizes) / len(batch_sizes) + out["batch_size_p50"] = _percentile(batch_sizes, 0.5) + out["batch_size_p95"] = _percentile(batch_sizes, 0.95) + lvm_active = sum(1 for r in records if r.get("lvm_active")) + out["n_steps_with_lvm"] = lvm_active + return out + + +def _load_meta(path: Path) -> Dict[str, Any]: + if not path.exists(): + return {} + return json.loads(path.read_text()) + + +def _row_for(tag: str, meta: Dict[str, Any], agg: Dict[str, Any]) -> Dict[str, Any]: + summary = meta.get("summary") or {} + row: Dict[str, Any] = { + "tag": tag, + "wall_clock_s": meta.get("wall_clock_s"), + "total_output_tokens": summary.get("total_output_tokens"), + "total_prompt_tokens": summary.get("total_prompt_tokens"), + "throughput_output_tokens_per_s": meta.get("throughput_output_tokens_per_s"), + "n_requests": summary.get("n_requests"), + "output_tokens_mean": summary.get("output_tokens_mean"), + "output_tokens_p95": summary.get("output_tokens_p95"), + "latency_s_mean": summary.get("latency_s_mean"), + "latency_s_p95": summary.get("latency_s_p95"), + "value_scale": meta.get("value_scale"), + "top_k": meta.get("top_k"), + } + row.update(agg) + return row + + +def _add_flops( + row: Dict[str, Any], + *, + meta: Dict[str, Any], + base_cfg: Optional[ModelConfig], + lvm_cfg: Optional[ModelConfig], + is_lvm: bool, +) -> None: + """Attach layer-level theoretical FLOPs + achieved-FLOPs/sec columns. + + Splits prefill (charged once per unique prompt, assumes SGLang prefix cache) + from decode (charged per sample), and breaks each into linear / attention / + lm_head components. + """ + if base_cfg is None: + return + unique_prompts = int(meta.get("max_questions") or 0) + samples_per = int(meta.get("n_samples_per_q") or 1) + prompt_total = int(row.get("total_prompt_tokens") or 0) + output_total = int(row.get("total_output_tokens") or 0) + if unique_prompts <= 0 or output_total <= 0: + return + # total_prompt_tokens is summed once per unique question (dedup in summarize). + # total_output_tokens is summed across all samples per question, then summed across questions. + mean_prompt = prompt_total / unique_prompts + mean_output = output_total / (unique_prompts * samples_per) + + base = baseline_run_flops( + base_cfg, + unique_prompts=unique_prompts, + samples_per_prompt=samples_per, + mean_prompt_tokens=mean_prompt, + mean_output_tokens=mean_output, + ) + total = base["total"]["total"] + row["base_pflops_total"] = total / 1e15 + row["base_pflops_prefill"] = base["prefill"]["total"] / 1e15 + row["base_pflops_decode"] = base["decode"]["total"] / 1e15 + row["base_pflops_linear"] = base["total"]["linear"] / 1e15 + row["base_pflops_attention"] = base["total"]["attention"] / 1e15 + row["base_pflops_lm_head"] = base["total"]["lm_head"] / 1e15 + + if is_lvm and lvm_cfg is not None: + k = row.get("top_k") + if k is not None and int(k) >= 1: + extra = lvm_extra_flops( + lvm_cfg, + unique_prompts=unique_prompts, + samples_per_prompt=samples_per, + mean_prompt_tokens=mean_prompt, + mean_output_tokens=mean_output, + k_candidates=int(k), + ) + total += extra["total"]["total"] + row["lvm_pflops_prefill"] = extra["lenvm_prefill"]["total"] / 1e15 + row["lvm_pflops_extend"] = extra["lenvm_extend"]["total"] / 1e15 + row["lvm_pflops_candidates"] = extra["lenvm_candidates"]["total"] / 1e15 + row["lvm_pflops_total"] = extra["total"]["total"] / 1e15 + + row["theoretical_pflops_total"] = total / 1e15 + # Per-decode-token cost (excluding prefill share) for a quick "GFLOPs/tok" feel + decode_total = base["decode"]["total"] + (row.get("lvm_pflops_extend", 0.0) + row.get("lvm_pflops_candidates", 0.0)) * 1e15 + n_decode_tokens = unique_prompts * samples_per * mean_output + if n_decode_tokens > 0: + row["theoretical_gflops_per_output_token"] = decode_total / n_decode_tokens / 1e9 + + wall = row.get("wall_clock_s") or 0 + if wall > 0: + row["achieved_tflops_per_s"] = total / wall / 1e12 + + +def _ratio_row(rows: List[Dict[str, Any]]) -> Dict[str, Any]: + """Compute lenvm / baseline ratios for numeric columns. Assumes rows[0]=baseline, rows[1]=lenvm.""" + if len(rows) != 2: + return {} + base, lvm = rows + ratio: Dict[str, Any] = {"tag": "ratio (lvm/base)"} + for key, b_val in base.items(): + l_val = lvm.get(key) + if isinstance(b_val, (int, float)) and isinstance(l_val, (int, float)) and b_val: + ratio[key] = l_val / b_val + return ratio + + +def _stacked_bar(rows: List[Dict[str, Any]], out_png: Path) -> Optional[Path]: + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except Exception as e: # pragma: no cover + print(f"matplotlib unavailable, skipping plot: {e}") + return None + + labels = [r["tag"] for r in rows] + sections = ["t_pre_lvm_ms_mean", "t_lvm_apply_outer_ms_mean", "t_sample_ms_mean"] + legend_names = ["pre-LVM (preprocess+softmax)", "LenVM apply (forward+guidance)", "sample kernel"] + values_per_section = [[float(r.get(s, 0.0) or 0.0) for r in rows] for s in sections] + + fig, ax = plt.subplots(figsize=(7, 5)) + bottoms = [0.0] * len(labels) + for sec_vals, name in zip(values_per_section, legend_names): + ax.bar(labels, sec_vals, bottom=bottoms, label=name) + bottoms = [b + v for b, v in zip(bottoms, sec_vals)] + + for i, r in enumerate(rows): + total = bottoms[i] + ax.text(i, total, f"{total:.2f} ms", ha="center", va="bottom") + + ax.set_ylabel("Mean per-step latency inside Sampler.forward (ms)") + ax.set_title("LenVM vs baseline: sampler-side per-step decomposition") + ax.legend(loc="upper left", fontsize="small") + fig.tight_layout() + fig.savefig(out_png, dpi=140) + plt.close(fig) + return out_png + + +def _flops_component_bar(rows: List[Dict[str, Any]], out_png: Path) -> Optional[Path]: + """Stacked bar of theoretical PFLOPs by component (base linear/attn/lm_head + LenVM).""" + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except Exception as e: # pragma: no cover + print(f"matplotlib unavailable, skipping plot: {e}") + return None + + if not any(r.get("base_pflops_linear") for r in rows): + return None + + labels = [r["tag"] for r in rows] + sections = [ + ("base_pflops_linear", "base linear (Q/K/V/O + MLP)"), + ("base_pflops_attention", "base attention (QK^T + attn@V)"), + ("base_pflops_lm_head", "base lm_head"), + ("lvm_pflops_extend", "LenVM extend forward"), + ("lvm_pflops_candidates", "LenVM candidate forwards"), + ("lvm_pflops_prefill", "LenVM prefill"), + ] + fig, ax = plt.subplots(figsize=(8, 5)) + bottoms = [0.0] * len(labels) + for key, name in sections: + vals = [float(r.get(key) or 0.0) for r in rows] + if max(vals) <= 0: + continue + ax.bar(labels, vals, bottom=bottoms, label=name) + bottoms = [b + v for b, v in zip(bottoms, vals)] + for i, total in enumerate(bottoms): + ax.text(i, total, f"{total:.2f} PFLOPs", ha="center", va="bottom") + ax.set_ylabel("Theoretical FLOPs (PFLOPs)") + ax.set_title("Theoretical inference FLOPs by component") + ax.legend(loc="upper left", fontsize="small") + fig.tight_layout() + fig.savefig(out_png, dpi=140) + plt.close(fig) + return out_png + + +def _lvm_sub_breakdown(rows: List[Dict[str, Any]], out_png: Path) -> Optional[Path]: + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except Exception as e: # pragma: no cover + print(f"matplotlib unavailable, skipping plot: {e}") + return None + + lenvm_row = next((r for r in rows if r.get("n_steps_with_lvm", 0) > 0), None) + if lenvm_row is None: + return None + + sub_keys = ["t_lvm_build_pending_ms_mean", "t_lvm_forward_ms_mean", "t_lvm_apply_guidance_ms_mean"] + sub_names = ["build_pending", "LenVM forward (extend+launch+collect)", "apply_guidance"] + sub_vals = [float(lenvm_row.get(k, 0.0) or 0.0) for k in sub_keys] + + fig, ax = plt.subplots(figsize=(6, 4)) + ax.bar(sub_names, sub_vals) + for i, v in enumerate(sub_vals): + ax.text(i, v, f"{v:.2f} ms", ha="center", va="bottom") + ax.set_ylabel("Mean per-step (ms)") + ax.set_title("LenVM apply() internal breakdown") + fig.tight_layout() + fig.savefig(out_png, dpi=140) + plt.close(fig) + return out_png + + +def _print_tables(rows: List[Dict[str, Any]]) -> None: + timing_cols = [ + ("tag", "tag"), + ("wall_clock_s", "e2e_s"), + ("throughput_output_tokens_per_s", "tok/s"), + ("output_tokens_mean", "out_tok_mean"), + ("n_steps", "n_steps"), + ("t_sampler_total_ms_mean", "samp_total_ms"), + ("t_pre_lvm_ms_mean", "pre_ms"), + ("t_lvm_apply_outer_ms_mean", "lvm_apply_ms"), + ("t_sample_ms_mean", "sample_ms"), + ] + flops_total_cols = [ + ("tag", "tag"), + ("theoretical_gflops_per_output_token", "GFLOPs/tok"), + ("base_pflops_total", "base PFLOPs"), + ("lvm_pflops_total", "lvm PFLOPs"), + ("theoretical_pflops_total", "total PFLOPs"), + ("achieved_tflops_per_s", "TFLOPs/s"), + ("wall_clock_s", "e2e_s"), + ] + flops_component_cols = [ + ("tag", "tag"), + ("base_pflops_linear", "base.linear"), + ("base_pflops_attention", "base.attn"), + ("base_pflops_lm_head", "base.lm_head"), + ("base_pflops_prefill", "base.prefill"), + ("base_pflops_decode", "base.decode"), + ("lvm_pflops_prefill", "lvm.prefill"), + ("lvm_pflops_extend", "lvm.extend"), + ("lvm_pflops_candidates", "lvm.cands"), + ] + _print_one(rows, timing_cols) + if any("theoretical_pflops_total" in r for r in rows): + print() + print("== FLOPs headline ==") + _print_one(rows, flops_total_cols) + print() + print("== FLOPs by component (PFLOPs) ==") + _print_one(rows, flops_component_cols) + + +def _print_one(rows: List[Dict[str, Any]], cols: List[tuple]) -> None: + widths = {k: max(len(label), max(len(_fmt(r.get(k))) for r in rows)) for k, label in cols} + header = " | ".join(f"{label:>{widths[k]}}" for k, label in cols) + print(header) + print("-+-".join("-" * widths[k] for k, _ in cols)) + for r in rows: + print(" | ".join(f"{_fmt(r.get(k)):>{widths[k]}}" for k, _ in cols)) + + +def _fmt(v: Any) -> str: + if v is None: + return "" + if isinstance(v, float): + return f"{v:.2f}" + return str(v) + + +def main() -> int: + p = argparse.ArgumentParser() + p.add_argument("--results-dir", type=Path, required=True) + p.add_argument("--baseline-tag", default="baseline") + p.add_argument("--lenvm-tag", default="lenvm") + p.add_argument( + "--base-model", + default="Qwen/Qwen2.5-7B-Instruct", + help="Base generation model name (for theoretical FLOPs lookup).", + ) + p.add_argument( + "--lvm-model", + default="namezz/lvm-math-0402-a-qwen2.5-7b-instruct-b-qwen2.5-1.5b-instruct", + help="LenVM checkpoint name (for theoretical FLOPs lookup).", + ) + args = p.parse_args() + + base_cfg: Optional[ModelConfig] = None + lvm_cfg: Optional[ModelConfig] = None + try: + base_cfg = ModelConfig.load(args.base_model) + except FileNotFoundError as e: + print(f"warning: {e}; skipping baseline FLOPs") + try: + lvm_cfg = ModelConfig.load(args.lvm_model) + except FileNotFoundError as e: + print(f"warning: {e}; skipping LenVM FLOPs") + if base_cfg is not None: + print(f"base config (source={base_cfg.source}): L={base_cfg.num_hidden_layers} d={base_cfg.hidden_size} " + f"Hq={base_cfg.num_attention_heads} Hkv={base_cfg.num_key_value_heads} h={base_cfg.head_dim} " + f"ff={base_cfg.intermediate_size} V={base_cfg.vocab_size} head={base_cfg.head_type}") + if lvm_cfg is not None: + print(f"lvm config (source={lvm_cfg.source}): L={lvm_cfg.num_hidden_layers} d={lvm_cfg.hidden_size} " + f"Hq={lvm_cfg.num_attention_heads} Hkv={lvm_cfg.num_key_value_heads} h={lvm_cfg.head_dim} " + f"ff={lvm_cfg.intermediate_size} V={lvm_cfg.vocab_size} head={lvm_cfg.head_type}") + + rd = args.results_dir + rows: List[Dict[str, Any]] = [] + for tag, is_lvm in ((args.baseline_tag, False), (args.lenvm_tag, True)): + meta = _load_meta(rd / f"{tag}.meta.json") + records = _filter_warmup(list(_iter_records(rd / f"{tag}.timing.jsonl"))) + agg = _agg(records) + row = _row_for(tag, meta, agg) + _add_flops(row, meta=meta, base_cfg=base_cfg, lvm_cfg=lvm_cfg, is_lvm=is_lvm) + rows.append(row) + + ratio = _ratio_row(rows) + display_rows = rows + ([ratio] if ratio else []) + + summary_path = rd / "summary.json" + summary_path.write_text(json.dumps(display_rows, indent=2)) + + csv_path = rd / "summary.csv" + if display_rows: + keys = sorted({k for r in display_rows for k in r.keys()}) + with csv_path.open("w", newline="") as f: + w = csv.DictWriter(f, fieldnames=keys) + w.writeheader() + for r in display_rows: + w.writerow({k: r.get(k) for k in keys}) + + plot_path = _stacked_bar(rows, rd / "per_step_breakdown.png") + sub_plot_path = _lvm_sub_breakdown(rows, rd / "lvm_apply_breakdown.png") + flops_plot_path = _flops_component_bar(rows, rd / "flops_breakdown.png") + + print(f"summary.json -> {summary_path}") + print(f"summary.csv -> {csv_path}") + if plot_path: + print(f"timing plot -> {plot_path}") + if sub_plot_path: + print(f"lvm sub-plot -> {sub_plot_path}") + if flops_plot_path: + print(f"flops plot -> {flops_plot_path}") + print() + _print_tables(display_rows) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/inference/timing/flops.py b/inference/timing/flops.py new file mode 100644 index 0000000..5f4088b --- /dev/null +++ b/inference/timing/flops.py @@ -0,0 +1,325 @@ +"""Layer-level inference FLOPs estimator for LenVM vs baseline. + +Reads HuggingFace ``config.json`` to count weight matmuls layer-by-layer +instead of relying on the ``2 * N_params`` rule of thumb. The decomposition +matters for LenVM analysis because the value model adds (1 + k) forwards +per generated token, and the relative weight of attention-vs-linear FLOPs +shifts at long sequence lengths. + +For each transformer layer with hidden size ``d``, ``H_q`` attention heads, +``H_kv`` key/value heads, head dim ``h``, FFN dim ``ff`` (SwiGLU: gate + up ++ down), per-token FLOPs are: + +* attention projections: ``2*d*(H_q*h) + 2*2*d*(H_kv*h) + 2*(H_q*h)*d`` + (Q, K, V, output projections; Qwen2.5 uses GQA so K/V are smaller). +* attention compute at position ``p``: ``2 * H_q * h * p * 2`` (Q@K^T + and attention@V over a KV cache of length ``p``). +* SwiGLU MLP: ``2 * d * ff * 3`` (gate + up projections share input; down + projection writes back). +* LM head (final layer only): ``2 * d * V``. + +Prefill over a prompt of length ``S`` runs every token in parallel but +each token still attends to the lower-triangular prefix, so attention +FLOPs scale as ``S * (S+1) / 2`` rather than ``S``. + +LenVM-guided decoding adds, per output token: + +* one ``tree_value_extend`` forward (catch the LenVM cache up by the + newly-accepted token), +* ``k`` ``tree_value`` forwards (score the top-k candidates). + +Prefix caching: when the same prompt is shared by multiple samples (n>1), +the base/LenVM prefill is charged once per unique prompt and decode is +charged per sample. The analyzer assumes +``unique_prompts == meta['max_questions']`` and +``samples == unique_prompts * meta['n_samples_per_q']``. +""" + +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Optional + + +# Fallback model dimensions, used only when we cannot locate a config.json +# (e.g. paths that point to a non-existent model). Sourced from the Qwen2.5 +# model cards. Update if you add new base/LenVM checkpoints. +_FALLBACK_CONFIGS: Dict[str, Dict[str, int]] = { + "Qwen/Qwen2.5-0.5B-Instruct": dict(num_hidden_layers=24, hidden_size=896, num_attention_heads=14, num_key_value_heads=2, intermediate_size=4864, vocab_size=151936), + "Qwen/Qwen2.5-1.5B-Instruct": dict(num_hidden_layers=28, hidden_size=1536, num_attention_heads=12, num_key_value_heads=2, intermediate_size=8960, vocab_size=151936), + "Qwen/Qwen2.5-3B-Instruct": dict(num_hidden_layers=36, hidden_size=2048, num_attention_heads=16, num_key_value_heads=2, intermediate_size=11008, vocab_size=151936), + "Qwen/Qwen2.5-7B-Instruct": dict(num_hidden_layers=28, hidden_size=3584, num_attention_heads=28, num_key_value_heads=4, intermediate_size=18944, vocab_size=152064), + "Qwen/Qwen2.5-14B-Instruct": dict(num_hidden_layers=48, hidden_size=5120, num_attention_heads=40, num_key_value_heads=8, intermediate_size=13824, vocab_size=152064), +} + + +@dataclass(frozen=True) +class ModelConfig: + num_hidden_layers: int + hidden_size: int + num_attention_heads: int + num_key_value_heads: int + head_dim: int + intermediate_size: int + vocab_size: int + # Head type used at the top of the stack. "lm_head" = 2*d*V vocab projection + # (standard causal LM). "value_head" = small MLP -> scalar (LenVM checkpoints + # ship a MLP2SiLUValueHead: d->d Linear + d->1 Linear, see + # sglang/srt/models/qwen2_lvm.py::MLP2SiLUValueHead). + head_type: str = "lm_head" + value_head_hidden: Optional[int] = None # MLP hidden dim if head_type == "value_head" + value_head_out_dim: int = 1 + source: str = "" + + @classmethod + def from_dict(cls, cfg: dict, source: str = "", *, head_type: str = "lm_head", + value_head_hidden: Optional[int] = None, + value_head_out_dim: int = 1) -> "ModelConfig": + d = cfg["hidden_size"] + Hq = cfg["num_attention_heads"] + Hkv = cfg.get("num_key_value_heads", Hq) + h = cfg.get("head_dim") or (d // Hq) + return cls( + num_hidden_layers=cfg["num_hidden_layers"], + hidden_size=d, + num_attention_heads=Hq, + num_key_value_heads=Hkv, + head_dim=h, + intermediate_size=cfg["intermediate_size"], + vocab_size=cfg["vocab_size"], + head_type=head_type, + value_head_hidden=value_head_hidden if value_head_hidden is not None else d, + value_head_out_dim=value_head_out_dim, + source=source, + ) + + @classmethod + def load(cls, name_or_path: str, *, head_type: str = "auto", + value_head_out_dim: int = 1) -> "ModelConfig": + """Locate config.json for a HF model name or a local path. + + head_type="auto" (default) inspects the directory for value_head.safetensors, + treating its presence as a LenVM-style value-head checkpoint. Pass + head_type="value_head" or "lm_head" to force the choice. + """ + cfg_path = _find_config_json(name_or_path) + if cfg_path is not None: + cfg_dict = json.loads(cfg_path.read_text()) + ht = head_type if head_type != "auto" else _autodetect_head(cfg_path.parent, cfg_dict) + return cls.from_dict( + cfg_dict, + source=str(cfg_path), + head_type=ht, + value_head_out_dim=value_head_out_dim, + ) + fb = _FALLBACK_CONFIGS.get(name_or_path) + if fb is None: + base = name_or_path.rstrip("/").split("/")[-1] + for key, val in _FALLBACK_CONFIGS.items(): + if key.split("/")[-1] == base: + fb = val + break + if fb is None: + raise FileNotFoundError( + f"Could not locate config.json for {name_or_path!r} and no fallback " + f"dimensions are registered. Add one to _FALLBACK_CONFIGS." + ) + ht = head_type if head_type != "auto" else "lm_head" + return cls.from_dict(fb, source=f"fallback:{name_or_path}", head_type=ht, + value_head_out_dim=value_head_out_dim) + + +def _autodetect_head(model_dir: Path, cfg_dict: dict) -> str: + """Return 'value_head' if a value_head.safetensors sits next to config.json, + or the loaded architecture is a known value-head class. Otherwise 'lm_head'. + """ + if (model_dir / "value_head.safetensors").exists(): + return "value_head" + for arch in cfg_dict.get("architectures", []) or []: + if "LengthValueModel" in arch or "ValueModel" in arch or "ValueHead" in arch: + return "value_head" + return "lm_head" + + +def _find_config_json(name_or_path: str) -> Optional[Path]: + """Find a HuggingFace-style config.json on disk.""" + p = Path(name_or_path) + if p.is_dir(): + cfg = p / "config.json" + if cfg.exists(): + return cfg + if p.is_file() and p.name == "config.json": + return p + # HF cache: $HF_HOME/hub/models----/snapshots//config.json + hf_home = os.environ.get("HF_HOME") or os.path.expanduser("~/.cache/huggingface") + cache_root = Path(hf_home) / "hub" + if cache_root.exists() and "/" in name_or_path and not name_or_path.startswith("."): + repo_dir = cache_root / ("models--" + name_or_path.replace("/", "--")) + snapshots = repo_dir / "snapshots" + if snapshots.exists(): + for snap in snapshots.iterdir(): + cfg = snap / "config.json" + if cfg.exists(): + return cfg + # Local download dir convention (download_data_and_model.sh writes here): + if "/" in name_or_path and not name_or_path.startswith("."): + local_dir = Path("./models") / name_or_path + cfg = local_dir / "config.json" + if cfg.exists(): + return cfg + return None + + +# ---- Per-token forward FLOPs decomposition --------------------------------- + + +def per_layer_linear_flops(cfg: ModelConfig) -> int: + """FLOPs for one token through a single layer's attention projections + MLP.""" + d = cfg.hidden_size + Hq, Hkv, h = cfg.num_attention_heads, cfg.num_key_value_heads, cfg.head_dim + ff = cfg.intermediate_size + q_proj = 2 * d * (Hq * h) + k_proj = 2 * d * (Hkv * h) + v_proj = 2 * d * (Hkv * h) + o_proj = 2 * (Hq * h) * d + # SwiGLU MLP: gate, up, down. gate and up are both d->ff, down is ff->d. + mlp = 2 * d * ff + 2 * d * ff + 2 * ff * d + return q_proj + k_proj + v_proj + o_proj + mlp + + +def per_layer_attn_compute_flops(cfg: ModelConfig, seq_len: int) -> int: + """FLOPs for QK^T and attn@V for one token attending to seq_len positions.""" + # 2 * H_q * head_dim * seq_len for each of (QK^T) and (attn @ V) + return 2 * 2 * cfg.num_attention_heads * cfg.head_dim * seq_len + + +def lm_head_flops(cfg: ModelConfig) -> int: + """Vocab projection 2 * hidden * vocab_size.""" + return 2 * cfg.hidden_size * cfg.vocab_size + + +def value_head_flops(cfg: ModelConfig) -> int: + """MLP2SiLUValueHead: hidden->hidden (fc) + hidden->out_dim (summary). + + SiLU activation is negligible vs the two matmuls. The summary projection + is tiny when out_dim=1 (the default) but kept for completeness. + """ + fc = 2 * cfg.hidden_size * (cfg.value_head_hidden or cfg.hidden_size) + summary = 2 * (cfg.value_head_hidden or cfg.hidden_size) * cfg.value_head_out_dim + return fc + summary + + +def head_flops(cfg: ModelConfig) -> int: + """FLOPs at the top of the stack, dispatched by ModelConfig.head_type.""" + if cfg.head_type == "value_head": + return value_head_flops(cfg) + return lm_head_flops(cfg) + + +def forward_token_flops(cfg: ModelConfig, position: int) -> Dict[str, int]: + """Decompose one decode-step FLOPs at the given (1-indexed) position.""" + lin = per_layer_linear_flops(cfg) * cfg.num_hidden_layers + attn = per_layer_attn_compute_flops(cfg, position) * cfg.num_hidden_layers + lmh = head_flops(cfg) + return {"linear": lin, "attention": attn, "lm_head": lmh, "total": lin + attn + lmh} + + +def prefill_flops(cfg: ModelConfig, prompt_len: int) -> Dict[str, int]: + """Decompose prefill FLOPs over a prompt of length prompt_len. + + Linear ops scale linearly with prompt_len. Attention is over the + lower-triangular causal mask, so per-layer attention compute sums to + 2 * H_q * h * (S * (S+1) / 2 * 2) = 2 * H_q * h * S * (S+1). + """ + if prompt_len <= 0: + return {"linear": 0, "attention": 0, "lm_head": 0, "total": 0} + lin = per_layer_linear_flops(cfg) * cfg.num_hidden_layers * prompt_len + attn_per_layer = 2 * 2 * cfg.num_attention_heads * cfg.head_dim * prompt_len * (prompt_len + 1) // 2 + attn = attn_per_layer * cfg.num_hidden_layers + lmh = head_flops(cfg) * prompt_len + return {"linear": lin, "attention": attn, "lm_head": lmh, "total": lin + attn + lmh} + + +def decode_flops_sum(cfg: ModelConfig, prompt_len: int, output_len: int) -> Dict[str, int]: + """Sum decode-step FLOPs for output_len tokens, attending to a growing KV cache.""" + if output_len <= 0: + return {"linear": 0, "attention": 0, "lm_head": 0, "total": 0} + lin_per_token = per_layer_linear_flops(cfg) * cfg.num_hidden_layers + lmh_per_token = head_flops(cfg) + # Closed-form: sum_{t=1..L} (prompt_len + t) = L*prompt_len + L*(L+1)/2 + pos_sum = output_len * prompt_len + output_len * (output_len + 1) // 2 + attn = 2 * 2 * cfg.num_attention_heads * cfg.head_dim * pos_sum * cfg.num_hidden_layers + lin = lin_per_token * output_len + lmh = lmh_per_token * output_len + return {"linear": lin, "attention": attn, "lm_head": lmh, "total": lin + attn + lmh} + + +def implied_param_count(cfg: ModelConfig) -> int: + """Approximate parameter count from config (matches 2N rule by 2*N≈linear FLOPs/token).""" + return per_layer_linear_flops(cfg) * cfg.num_hidden_layers // 2 + lm_head_flops(cfg) // 2 + + +# ---- Run-level aggregation ------------------------------------------------- + + +def baseline_run_flops( + cfg: ModelConfig, + *, + unique_prompts: int, + samples_per_prompt: int, + mean_prompt_tokens: float, + mean_output_tokens: float, +) -> Dict[str, int]: + """Total baseline FLOPs assuming prefix caching: prefill once per question, decode per sample.""" + pre = prefill_flops(cfg, int(round(mean_prompt_tokens))) + dec = decode_flops_sum(cfg, int(round(mean_prompt_tokens)), int(round(mean_output_tokens))) + out = { + "prefill": {k: v * unique_prompts for k, v in pre.items()}, + "decode": {k: v * unique_prompts * samples_per_prompt for k, v in dec.items()}, + } + out["total"] = {k: out["prefill"][k] + out["decode"][k] for k in pre.keys()} + return out + + +def lvm_extra_flops( + cfg: ModelConfig, + *, + unique_prompts: int, + samples_per_prompt: int, + mean_prompt_tokens: float, + mean_output_tokens: float, + k_candidates: int, + candidate_cost_multiplier: float = 1.0, +) -> Dict[str, int]: + """LenVM-only extra FLOPs (on top of baseline) per output token: + + * one ``tree_value_extend`` forward (catch the value-model KV up with the + just-accepted token; a single-token decode at the current position). + * ``k_candidates`` value-model forwards (score the top-k candidate tokens; + each is a single-token decode at the position right after the extend). + + ``candidate_cost_multiplier`` is a knob for future LenVM implementations + that share work across candidates (e.g. batched single-forward scoring + that amortizes some MLP / attention cost across the k candidates). The + default of ``1.0`` matches the current sglang-LenVM in-proc path, where + each candidate is a separate single-token forward sharing only the + extended KV cache. Set to e.g. ``0.6`` if a future implementation can + batch the k candidates into one forward. + + Whether ``cfg.head_type`` is ``lm_head`` or ``value_head`` controls + whether each forward is charged the 2*d*V vocab projection (causal LM) + or the much smaller MLP2SiLUValueHead (LenVM checkpoints). + """ + pre = prefill_flops(cfg, int(round(mean_prompt_tokens))) + dec = decode_flops_sum(cfg, int(round(mean_prompt_tokens)), int(round(mean_output_tokens))) + candidate_scale = unique_prompts * samples_per_prompt * k_candidates * candidate_cost_multiplier + out = { + "lenvm_prefill": {k: v * unique_prompts for k, v in pre.items()}, + "lenvm_extend": {k: v * unique_prompts * samples_per_prompt for k, v in dec.items()}, + "lenvm_candidates": {k: int(v * candidate_scale) for k, v in dec.items()}, + } + out["total"] = {k: out["lenvm_prefill"][k] + out["lenvm_extend"][k] + out["lenvm_candidates"][k] for k in pre.keys()} + return out diff --git a/inference/timing/run_timing.py b/inference/timing/run_timing.py new file mode 100644 index 0000000..ead6e34 --- /dev/null +++ b/inference/timing/run_timing.py @@ -0,0 +1,247 @@ +"""Run a single sample_eval pass and capture wall-clock + token counts. + +Thin wrapper over inference.tradeoff.sample_eval that adds: +- end-to-end wall-clock timing (perf_counter) +- background nvidia-smi sampling at 1 Hz +- aggregate token counts from the run's responses jsonl +- emits /.meta.json with the above + +The server-side per-step timing (SGLANG_LVM_TIMING_LOG) is written by the +sglang process itself; this script does not touch that file. +""" + +from __future__ import annotations + +import argparse +import json +import shutil +import subprocess +import sys +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + + +def _build_sample_eval_argv(args: argparse.Namespace, tag: str) -> List[str]: + cmd: List[str] = [ + sys.executable, + "-m", + "inference.tradeoff.sample_eval", + "--dataset-name", args.dataset_name, + "--server-url", args.server_url, + "--output-dir", str(args.output_dir), + "--tag", tag, + "--stage", "run", + "--max-questions", str(args.max_questions), + "--max-concurrency", str(args.max_concurrency), + "--request-timeout", str(args.request_timeout), + "--max-tokens", str(args.max_tokens), + "--temperature", str(args.temperature), + "--top-p", str(args.top_p), + "--top-k", str(args.top_k), + "--min-p", str(args.min_p), + "--n", str(args.n), + "--http-backend", args.http_backend, + ] + if args.value_scale is not None: + cmd += ["--value-scale", str(args.value_scale)] + if args.value_mode is not None: + cmd += ["--value-mode", args.value_mode] + if args.value_gamma is not None: + cmd += ["--value-gamma", str(args.value_gamma)] + return cmd + + +@dataclass +class GpuSample: + t_offset_s: float + gpu_util_pct: List[int] + mem_used_mib: List[int] + + +def _gpu_sampler(stop_path: Path, output_path: Path, period_s: float = 1.0) -> None: + """nvidia-smi sampler subprocess body (invoked via -c).""" + # Not used; we run nvidia-smi from the parent. + raise SystemExit(0) + + +def _start_gpu_sampler(samples_out: Path) -> Optional[subprocess.Popen]: + if not shutil.which("nvidia-smi"): + return None + # nvidia-smi --query streams every s and we tee to file. + # Format: timestamp,index,utilization.gpu,memory.used + fh = open(samples_out, "w") + proc = subprocess.Popen( + [ + "nvidia-smi", + "--query-gpu=timestamp,index,utilization.gpu,memory.used", + "--format=csv,nounits", + "-lms", "1000", + ], + stdout=fh, + stderr=subprocess.DEVNULL, + ) + proc._log_fh = fh # type: ignore[attr-defined] + return proc + + +def _stop_gpu_sampler(proc: Optional[subprocess.Popen]) -> None: + if proc is None: + return + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + fh = getattr(proc, "_log_fh", None) + if fh is not None: + fh.close() + + +def _summarize_responses(responses_jsonl: Path) -> Dict[str, Any]: + """Aggregate token counts and per-question latency from sample_eval output. + + sample_eval writes one row per choice but duplicates the full-request + `usage` field across every choice's row, so we dedupe by question `idx` + (counting each request's usage once) before summing. `n_requests` and the + per-choice latency stats still sample every row. + """ + n_requests = 0 + total_output_tokens = 0 + total_prompt_tokens = 0 + output_token_counts: List[int] = [] + latencies: List[float] = [] + seen_idx: set = set() + with responses_jsonl.open() as f: + for line in f: + try: + row = json.loads(line) + except json.JSONDecodeError: + continue + n_requests += 1 + idx = row.get("idx") + if idx is not None and idx not in seen_idx: + seen_idx.add(idx) + usage = row.get("usage") or {} + out_tok = usage.get("completion_tokens") or usage.get("output_tokens") or 0 + in_tok = usage.get("prompt_tokens") or usage.get("input_tokens") or 0 + total_output_tokens += int(out_tok or 0) + total_prompt_tokens += int(in_tok or 0) + if out_tok: + output_token_counts.append(int(out_tok)) + lat = row.get("latency_s") or row.get("elapsed_s") + if isinstance(lat, (int, float)): + latencies.append(float(lat)) + summary: Dict[str, Any] = { + "n_requests": n_requests, + "total_output_tokens": total_output_tokens, + "total_prompt_tokens": total_prompt_tokens, + } + if output_token_counts: + output_token_counts.sort() + n = len(output_token_counts) + summary["output_tokens_mean"] = sum(output_token_counts) / n + summary["output_tokens_p50"] = output_token_counts[n // 2] + summary["output_tokens_p95"] = output_token_counts[int(n * 0.95)] + summary["output_tokens_max"] = output_token_counts[-1] + if latencies: + latencies.sort() + n = len(latencies) + summary["latency_s_mean"] = sum(latencies) / n + summary["latency_s_p50"] = latencies[n // 2] + summary["latency_s_p95"] = latencies[int(n * 0.95)] + return summary + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + p.add_argument("--tag", required=True, help="Output prefix (e.g. baseline, lenvm)") + p.add_argument("--server-url", required=True) + p.add_argument("--dataset-name", default="gsm8k") + p.add_argument("--output-dir", type=Path, required=True) + p.add_argument("--max-questions", type=int, default=50) + p.add_argument("--max-concurrency", type=int, default=50) + p.add_argument("--n", type=int, default=16) + p.add_argument("--max-tokens", type=int, default=6000) + p.add_argument("--temperature", type=float, default=1.0) + p.add_argument("--top-p", type=float, default=1.0) + p.add_argument("--top-k", type=int, default=-1) + p.add_argument("--min-p", type=float, default=0.01) + p.add_argument("--request-timeout", type=float, default=600000) + p.add_argument("--http-backend", default="aiohttp") + p.add_argument("--value-scale", type=float, default=None) + p.add_argument("--value-mode", default=None) + p.add_argument("--value-gamma", type=float, default=None) + return p.parse_args() + + +def main() -> int: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + # sample_eval writes responses jsonl using the tag in its own canonical + # naming; we reuse the same tag so the file is predictable. + sample_eval_tag = ( + f"{args.tag}_q{args.max_questions}_n{args.n}_p{args.top_p}_" + f"topk{args.top_k}_minp{args.min_p}" + ) + + cmd = _build_sample_eval_argv(args, tag=sample_eval_tag) + + gpu_samples_path = args.output_dir / f"{args.tag}.gpu_samples.csv" + gpu_proc = _start_gpu_sampler(gpu_samples_path) + + t_start = time.perf_counter() + t_wall_start = time.time() + try: + rc = subprocess.call(cmd) + finally: + _stop_gpu_sampler(gpu_proc) + t_end = time.perf_counter() + t_wall_end = time.time() + + if rc != 0: + print(f"sample_eval exited with rc={rc}", file=sys.stderr) + return rc + + # Match sample_eval's compute_paths: ..responses.jsonl + responses_path = args.output_dir / f"{args.dataset_name}.{sample_eval_tag}.responses.jsonl" + summary = _summarize_responses(responses_path) if responses_path.exists() else {} + + wall_clock_s = t_end - t_start + meta: Dict[str, Any] = { + "tag": args.tag, + "server_url": args.server_url, + "dataset": args.dataset_name, + "max_questions": args.max_questions, + "n_samples_per_q": args.n, + "max_concurrency": args.max_concurrency, + "max_tokens": args.max_tokens, + "top_k": args.top_k, + "value_scale": args.value_scale, + "value_mode": args.value_mode, + "value_gamma": args.value_gamma, + "wall_clock_s": wall_clock_s, + "wall_start_epoch_s": t_wall_start, + "wall_end_epoch_s": t_wall_end, + "responses_path": str(responses_path), + "summary": summary, + "cmd": cmd, + } + if summary.get("total_output_tokens"): + meta["throughput_output_tokens_per_s"] = ( + summary["total_output_tokens"] / wall_clock_s if wall_clock_s > 0 else 0 + ) + + meta_path = args.output_dir / f"{args.tag}.meta.json" + meta_path.write_text(json.dumps(meta, indent=2)) + print(f"meta -> {meta_path}") + print(f"wall_clock_s={wall_clock_s:.2f} " + f"out_tokens={summary.get('total_output_tokens', '?')} " + f"throughput={meta.get('throughput_output_tokens_per_s', 0):.1f} tok/s") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/inference/timing/sweep_analyze.py b/inference/timing/sweep_analyze.py new file mode 100644 index 0000000..781dfe2 --- /dev/null +++ b/inference/timing/sweep_analyze.py @@ -0,0 +1,195 @@ +"""Aggregate timing results from multiple ``lenvm_timing.sh`` runs into a top-k sweep. + +Each input ``--results-dirs`` entry is a directory produced by a single +``lenvm_timing.sh`` invocation (so it contains ``summary.json`` + per-stage +``*.meta.json`` + ``*.timing.jsonl``). This script extracts ``LENVM_TOP_K`` from +the LenVM ``meta.json`` and aggregates the per-run metrics into a single CSV + +plot, plus a stdout table. Use this to answer: + +- "How does the wall-clock slowdown scale with the LenVM candidate-set size?" +- "Does the theoretical FLOPs ratio track the measured wall-clock ratio as k + grows, or does the GPU-utilization gap widen?" +""" + +from __future__ import annotations + +import argparse +import csv +import json +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +_RATIO_KEYS = [ + "wall_clock_s", + "throughput_output_tokens_per_s", + "theoretical_pflops_total", + "achieved_tflops_per_s", + "t_sampler_total_ms_mean", + "t_lvm_apply_outer_ms_mean", + "t_lvm_forward_ms_mean", +] + + +def _load_summary(rd: Path) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + """Return (baseline_row, lenvm_row, ratio_row) from a summary.json.""" + rows: List[Dict[str, Any]] = json.loads((rd / "summary.json").read_text()) + base = next(r for r in rows if r.get("tag") == "baseline") + lvm = next(r for r in rows if r.get("tag") == "lenvm") + ratio = next((r for r in rows if "ratio" in str(r.get("tag"))), {}) + return base, lvm, ratio + + +def _infer_k(rd: Path, lvm_row: Dict[str, Any]) -> Optional[int]: + k = lvm_row.get("top_k") + if isinstance(k, int) and k > 0: + return k + # Fall back to parsing the directory name (e.g. sweep_q50_n16_k5). + m = re.search(r"_k(\d+)", rd.name) + if m: + return int(m.group(1)) + return None + + +def _percent(x: Optional[float]) -> str: + if x is None: + return "" + return f"{x * 100:.1f}%" + + +def _fmt(v: Any) -> str: + if v is None: + return "" + if isinstance(v, float): + return f"{v:.2f}" + return str(v) + + +def _print_table(rows: List[Dict[str, Any]], cols: List[Tuple[str, str]]) -> None: + widths = {k: max(len(label), max(len(_fmt(r.get(k))) for r in rows)) for k, label in cols} + header = " | ".join(f"{label:>{widths[k]}}" for k, label in cols) + print(header) + print("-+-".join("-" * widths[k] for k, _ in cols)) + for r in rows: + print(" | ".join(f"{_fmt(r.get(k)):>{widths[k]}}" for k, _ in cols)) + + +def _plot(rows: List[Dict[str, Any]], out_png: Path) -> Optional[Path]: + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except Exception as e: + print(f"matplotlib unavailable, skipping plot: {e}") + return None + + rows = [r for r in rows if r.get("k") is not None] + if not rows: + return None + rows = sorted(rows, key=lambda r: r["k"]) + ks = [r["k"] for r in rows] + flops_ratio = [r.get("flops_ratio") for r in rows] + wall_ratio = [r.get("wall_ratio") for r in rows] + util_ratio = [r.get("achieved_tflops_ratio") for r in rows] + + fig, ax = plt.subplots(figsize=(8, 5)) + ax.plot(ks, flops_ratio, marker="o", label="theoretical FLOPs ratio (LenVM / baseline)") + ax.plot(ks, wall_ratio, marker="s", label="measured wall-clock ratio") + if any(v is not None for v in util_ratio): + ax.plot(ks, util_ratio, marker="^", linestyle="--", + label="achieved TFLOPs/s ratio (≤1 means utilization loss)") + ax.axhline(1.0, color="gray", linewidth=0.5, linestyle=":") + ax.set_xlabel("LenVM candidate-set size k") + ax.set_ylabel("Ratio (LenVM / baseline)") + ax.set_title("LenVM overhead vs candidate-set size") + ax.set_xticks(ks) + ax.legend(loc="upper left", fontsize="small") + ax.grid(alpha=0.3) + fig.tight_layout() + fig.savefig(out_png, dpi=140) + plt.close(fig) + return out_png + + +def main() -> int: + p = argparse.ArgumentParser() + p.add_argument("--results-dirs", type=Path, nargs="+", required=True, + help="Directories produced by lenvm_timing.sh, one per top-k value.") + p.add_argument("--output-dir", type=Path, required=True) + args = p.parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + rows: List[Dict[str, Any]] = [] + for rd in args.results_dirs: + if not (rd / "summary.json").exists(): + print(f"skipping {rd} — no summary.json") + continue + base, lvm, ratio = _load_summary(rd) + k = _infer_k(rd, lvm) + row: Dict[str, Any] = { + "k": k, + "results_dir": str(rd), + "base_wall_s": base.get("wall_clock_s"), + "lvm_wall_s": lvm.get("wall_clock_s"), + "wall_ratio": ratio.get("wall_clock_s"), + "base_tok_per_s": base.get("throughput_output_tokens_per_s"), + "lvm_tok_per_s": lvm.get("throughput_output_tokens_per_s"), + "tok_per_s_ratio": ratio.get("throughput_output_tokens_per_s"), + "base_pflops": base.get("theoretical_pflops_total"), + "lvm_pflops": lvm.get("theoretical_pflops_total"), + "flops_ratio": ratio.get("theoretical_pflops_total"), + "base_achieved_tflops_per_s": base.get("achieved_tflops_per_s"), + "lvm_achieved_tflops_per_s": lvm.get("achieved_tflops_per_s"), + "achieved_tflops_ratio": ratio.get("achieved_tflops_per_s"), + "lvm_apply_ms_mean": lvm.get("t_lvm_apply_outer_ms_mean"), + "lvm_forward_ms_mean": lvm.get("t_lvm_forward_ms_mean"), + "lvm_build_pending_ms_mean": lvm.get("t_lvm_build_pending_ms_mean"), + "lvm_apply_guidance_ms_mean": lvm.get("t_lvm_apply_guidance_ms_mean"), + } + if row["wall_ratio"] and row["flops_ratio"]: + # Utilization gap = wall-clock ratio / theoretical FLOPs ratio. + # >1 means LenVM is slower wall-clock than its raw extra compute alone would predict. + row["utilization_gap"] = row["wall_ratio"] / row["flops_ratio"] + rows.append(row) + + rows = sorted(rows, key=lambda r: (r.get("k") if r.get("k") is not None else 999)) + + summary_json = args.output_dir / "sweep_summary.json" + summary_json.write_text(json.dumps(rows, indent=2)) + + csv_path = args.output_dir / "sweep_summary.csv" + if rows: + keys = sorted({k for r in rows for k in r.keys()}) + with csv_path.open("w", newline="") as f: + w = csv.DictWriter(f, fieldnames=keys) + w.writeheader() + for r in rows: + w.writerow({k: r.get(k) for k in keys}) + + plot_path = _plot(rows, args.output_dir / "topk_sweep.png") + + print(f"summary.json -> {summary_json}") + print(f"summary.csv -> {csv_path}") + if plot_path: + print(f"plot -> {plot_path}") + print() + _print_table( + rows, + cols=[ + ("k", "k"), + ("base_wall_s", "base_s"), + ("lvm_wall_s", "lvm_s"), + ("wall_ratio", "wall_ratio"), + ("flops_ratio", "flops_ratio"), + ("utilization_gap", "util_gap"), + ("lvm_apply_ms_mean", "apply_ms"), + ("lvm_forward_ms_mean", "lvm_fwd_ms"), + ("lvm_tok_per_s", "lvm_tok/s"), + ], + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/inference/lenvm_timing.sh b/scripts/inference/lenvm_timing.sh new file mode 100755 index 0000000..0bf660b --- /dev/null +++ b/scripts/inference/lenvm_timing.sh @@ -0,0 +1,168 @@ +#!/usr/bin/env bash +# Compare end-to-end and per-step decoding cost: baseline vs LenVM-guided. +# +# Two server lifecycles (so per-step timer only captures the configuration +# under test): +# 1. SGLang with --enable-lvm-guided-sampling OFF -> baseline timing. +# 2. SGLang with LenVM enabled + 7B base + 1.5B LenVM -> guided timing. +# Both servers point SGLANG_LVM_TIMING_LOG at distinct JSONL files. The client +# replays the same GSM8K prompt set against each, then analyze.py emits a +# CSV/JSON table and per-step decomposition plot. +# +# Run from repository root. + +set -euo pipefail + +HOST="${HOST:-0.0.0.0}" +PORT="${PORT:-10020}" +DP_SIZE="${DP_SIZE:-1}" +TP_SIZE="${TP_SIZE:-1}" +CONTEXT_LENGTH="${CONTEXT_LENGTH:-30000}" +MEM_FRACTION_STATIC="${MEM_FRACTION_STATIC:-0.4}" +LENVM_MEM_FRACTION_STATIC="${LENVM_MEM_FRACTION_STATIC:-0.4}" + +BASE_MODEL="${BASE_MODEL:-Qwen/Qwen2.5-7B-Instruct}" +LENVM_MODEL="${LENVM_MODEL:-./models/namezz/lvm-math-0402-a-qwen2.5-7b-instruct-b-qwen2.5-1.5b-instruct}" +DATASET="${DATASET:-gsm8k}" +MAX_QUESTIONS="${MAX_QUESTIONS:-50}" +MAX_CONCURRENCY="${MAX_CONCURRENCY:-50}" +N_SAMPLES="${N_SAMPLES:-16}" +MAX_TOKENS="${MAX_TOKENS:-6000}" +TEMPERATURE="${TEMPERATURE:-1.0}" +TOP_P="${TOP_P:-1.0}" +MIN_P="${MIN_P:-0.01}" +LENVM_TOP_K="${LENVM_TOP_K:-5}" +LENVM_VALUE_SCALE="${LENVM_VALUE_SCALE:-0}" +LENVM_VALUE_MODE="${LENVM_VALUE_MODE:-centered_exp}" +LENVM_VALUE_GAMMA="${LENVM_VALUE_GAMMA:-0.997}" + +RESULTS_DIR="${RESULTS_DIR:-./results/timing/$(basename "$BASE_MODEL")_vs_$(basename "$LENVM_MODEL")}" +mkdir -p "$RESULTS_DIR" + +# ---- helpers --------------------------------------------------------------- + +wait_for_server() { + local port="$1" + for _ in $(seq 1 1200); do + if curl -sf "http://127.0.0.1:${port}/v1/models" >/dev/null; then return 0; fi + sleep 2 + done + echo "Server on port ${port} failed to become ready" >&2 + return 1 +} + +kill_server() { + local pid="$1" + if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then + kill "$pid" 2>/dev/null || true + for _ in $(seq 1 30); do + kill -0 "$pid" 2>/dev/null || return 0 + sleep 1 + done + kill -9 "$pid" 2>/dev/null || true + fi +} + +# ---- stage 1: baseline (LenVM disabled) ------------------------------------ + +echo "==> Stage 1: baseline server (no LenVM)" +source .venv-infer/bin/activate + +BASELINE_TIMING_LOG="$RESULTS_DIR/baseline.timing.jsonl" +: > "$BASELINE_TIMING_LOG" + +SGLANG_LVM_TIMING_LOG="$BASELINE_TIMING_LOG" \ +python -m sglang.launch_server \ + --model-path "$BASE_MODEL" \ + --host "$HOST" \ + --port "$PORT" \ + --tp-size "$TP_SIZE" \ + --dp-size "$DP_SIZE" \ + --context-length "$CONTEXT_LENGTH" \ + --mem-fraction-static "$MEM_FRACTION_STATIC" & +SERVER_PID=$! +trap 'kill_server "$SERVER_PID"' EXIT + +wait_for_server "$PORT" +echo "Baseline server ready" + +source .venv-eval/bin/activate +python -m inference.timing.run_timing \ + --tag baseline \ + --server-url "http://127.0.0.1:$PORT" \ + --dataset-name "$DATASET" \ + --max-questions "$MAX_QUESTIONS" \ + --max-concurrency "$MAX_CONCURRENCY" \ + --n "$N_SAMPLES" \ + --max-tokens "$MAX_TOKENS" \ + --temperature "$TEMPERATURE" \ + --top-p "$TOP_P" \ + --top-k -1 \ + --min-p "$MIN_P" \ + --output-dir "$RESULTS_DIR" + +kill_server "$SERVER_PID" +trap - EXIT +SERVER_PID="" + +# ---- stage 2: LenVM (in-proc guidance) ------------------------------------- + +echo "==> Stage 2: LenVM server (7B base + LenVM in-proc)" +source .venv-infer/bin/activate + +LENVM_TIMING_LOG="$RESULTS_DIR/lenvm.timing.jsonl" +: > "$LENVM_TIMING_LOG" + +SGLANG_LVM_TIMING_LOG="$LENVM_TIMING_LOG" \ +python -m sglang.launch_server \ + --model-path "$BASE_MODEL" \ + --host "$HOST" \ + --port "$PORT" \ + --tp-size "$TP_SIZE" \ + --dp-size "$DP_SIZE" \ + --context-length "$CONTEXT_LENGTH" \ + --enable-lvm-guided-sampling \ + --lvm-guided-inproc \ + --lvm-guided-inproc-model-path "$LENVM_MODEL" \ + --lvm-guided-inproc-json-model-override-args '{"architectures":["Qwen2ForLengthValueModel"]}' \ + --disable-overlap-schedule \ + --mem-fraction-static "$MEM_FRACTION_STATIC" \ + --lvm-guided-inproc-mem-fraction-static "$LENVM_MEM_FRACTION_STATIC" \ + --lvm-guided-fn sglang.srt.lvm.lvm_guided_sampling:lvm_combined_guidance & +SERVER_PID=$! +trap 'kill_server "$SERVER_PID"' EXIT + +wait_for_server "$PORT" +echo "LenVM server ready" + +source .venv-eval/bin/activate +python -m inference.timing.run_timing \ + --tag lenvm \ + --server-url "http://127.0.0.1:$PORT" \ + --dataset-name "$DATASET" \ + --max-questions "$MAX_QUESTIONS" \ + --max-concurrency "$MAX_CONCURRENCY" \ + --n "$N_SAMPLES" \ + --max-tokens "$MAX_TOKENS" \ + --temperature "$TEMPERATURE" \ + --top-p "$TOP_P" \ + --top-k "$LENVM_TOP_K" \ + --min-p "$MIN_P" \ + --value-scale "$LENVM_VALUE_SCALE" \ + --value-mode "$LENVM_VALUE_MODE" \ + --value-gamma "$LENVM_VALUE_GAMMA" \ + --output-dir "$RESULTS_DIR" + +kill_server "$SERVER_PID" +trap - EXIT +SERVER_PID="" + +# ---- stage 3: analyze ------------------------------------------------------ + +echo "==> Stage 3: analyze" +python -m inference.timing.analyze \ + --results-dir "$RESULTS_DIR" \ + --base-model "$BASE_MODEL" \ + --lvm-model "$LENVM_MODEL" + +echo "Done. Results in $RESULTS_DIR" diff --git a/sglang-LenVM/python/sglang/srt/layers/sampler.py b/sglang-LenVM/python/sglang/srt/layers/sampler.py index ccd4bfd..7301561 100644 --- a/sglang-LenVM/python/sglang/srt/layers/sampler.py +++ b/sglang-LenVM/python/sglang/srt/layers/sampler.py @@ -17,6 +17,7 @@ from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda, is_npu from sglang.srt.lvm.lvm_guided_sampling import LvmGuidedSampler +from sglang.srt.lvm.timing import get_timer if is_cuda(): from sgl_kernel import ( @@ -45,6 +46,7 @@ def __init__(self, model_runner=None): self.lvm_guided_sampler = LvmGuidedSampler.from_server_args( get_global_server_args(), model_runner=model_runner ) + self._timer = get_timer() if is_dp_attention_enabled(): self.tp_sync_group = get_attention_tp_group().device_group @@ -134,6 +136,10 @@ def forward( positions: The positions of the tokens in the sequence. Used for deterministic sampling to get the unique seed for each position. """ + t_total = self._timer.section_start("t_sampler_total_ms") + guided_applied = False + batch_size_meta = int(logits_output.next_token_logits.shape[0]) + logits = logits_output.next_token_logits # Preprocess logits (custom processors and NaN handling) @@ -145,6 +151,7 @@ def forward( if return_logprob: logprobs = torch.nn.functional.log_softmax(logits, dim=-1) else: + t_pre = self._timer.section_start("t_pre_lvm_ms") can_sample_directly_from_probs = ( not sampling_info.need_top_p_sampling and not sampling_info.need_top_k_sampling @@ -181,7 +188,9 @@ def forward( probs = logits del logits - guided_applied = False + self._timer.section_end("t_pre_lvm_ms", t_pre) + + t_lvm = self._timer.section_start("t_lvm_apply_outer_ms") if self.lvm_guided_sampler is not None and sampling_info.reqs is not None: guided_probs = self.lvm_guided_sampler.apply( probs, @@ -194,10 +203,12 @@ def forward( if guided_probs is not None: probs = guided_probs guided_applied = True + self._timer.section_end("t_lvm_apply_outer_ms", t_lvm) if guided_applied: can_sample_directly_from_probs = True + t_sample = self._timer.section_start("t_sample_ms") if can_sample_directly_from_probs: # when we don't need top-k, top-p, or min-p sampling, we can directly sample from the probs batch_next_token_ids = sampling_from_probs_torch( @@ -244,6 +255,7 @@ def forward( raise ValueError( f"Invalid sampling backend: {get_global_server_args().sampling_backend}" ) + self._timer.section_end("t_sample_ms", t_sample) if return_logprob: if get_global_server_args().rl_on_policy_target is not None: @@ -291,6 +303,13 @@ def forward( group=self.tp_sync_group, ) + self._timer.section_end("t_sampler_total_ms", t_total) + self._timer.set_meta( + lvm_active=bool(guided_applied), + batch_size=batch_size_meta, + is_greedy=bool(sampling_info.is_all_greedy), + ) + self._timer.flush_step() return batch_next_token_ids def compute_logprobs_only( diff --git a/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py b/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py index ea51495..56881b2 100644 --- a/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py +++ b/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py @@ -15,6 +15,7 @@ from sglang.srt.utils.common import dynamic_import from sglang.srt.server_args import get_global_server_args from sglang.srt.lvm.lvm_value_utils import force_eos_value_zero, get_eos_token_ids +from sglang.srt.lvm.timing import get_timer from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.schedule_batch import Req @@ -1765,37 +1766,49 @@ def apply( Returns the modified probs tensor, or None when no guidance is needed (caller should use the original probs). """ + timer = get_timer() inproc = self._get_inproc_provider() if inproc not in (None, False): # Free KV cache for requests that have finished or aborted inproc.clean_stale_requests(set(r.rid for r in reqs)) + t_bp = timer.section_start("t_lvm_build_pending_ms") pending = self._build_pending(probs, reqs, temperatures, top_ps, top_ks, min_ps) + timer.section_end("t_lvm_build_pending_ms", t_bp) if pending is None: return None if pending.send_batch_indices: reqs_send = [pending.req_list[i] for i in pending.send_batch_indices] + timer.set_meta(lvm_n_reqs_with_guidance=len(reqs_send)) if pending.gpu_candidates is not None: # GPU fast path (synchronous). if inproc not in (None, False): try: rids_send = [req.rid for req in reqs_send] + t_fwd = timer.section_start("t_lvm_forward_ms") inproc.tree_value_extend(rids_send, pending.prefix_ids_send, reqs_send) gpu_emb = inproc.tree_value_launch_gpu( rids_send, pending.candidate_ids_send, gpu_candidates=pending.gpu_candidates ) gpu_embeddings = inproc.tree_value_collect_gpu(gpu_emb) + timer.section_end("t_lvm_forward_ms", t_fwd) + t_ag = timer.section_start("t_lvm_apply_guidance_ms") self._apply_guidance_gpu(pending, gpu_embeddings) + timer.section_end("t_lvm_apply_guidance_ms", t_ag) return pending.guided except Exception as exc: raise RuntimeError("LenVM GPU guidance path failed") from exc + t_fwd = timer.section_start("t_lvm_forward_ms") lvm_values = self._post_tree_value( [req.rid for req in reqs_send], pending.prefix_ids_send, pending.candidate_ids_send, reqs_send ) + timer.section_end("t_lvm_forward_ms", t_fwd) if lvm_values is None: return None + t_ag = timer.section_start("t_lvm_apply_guidance_ms") self._apply_guidance(pending, lvm_values) + timer.section_end("t_lvm_apply_guidance_ms", t_ag) return pending.guided diff --git a/sglang-LenVM/python/sglang/srt/lvm/timing.py b/sglang-LenVM/python/sglang/srt/lvm/timing.py new file mode 100644 index 0000000..b6e81d3 --- /dev/null +++ b/sglang-LenVM/python/sglang/srt/lvm/timing.py @@ -0,0 +1,75 @@ +"""Lightweight per-step Python wall-clock timer for LenVM-guided decoding. + +Activated by environment variable SGLANG_LVM_TIMING_LOG=/path/to/timing.jsonl. +When unset, all timer operations are no-ops with negligible overhead. + +Captures wall-clock (time.perf_counter) for sections within Sampler.forward and +LvmGuidedSampler.apply. Each scheduler step flushes one JSONL record with the +section durations and metadata. Records are flushed line-by-line so a tail-f +during a run is safe. + +The collector lives in the GPU worker process. For DP/TP>1 it would need a +per-rank suffix on the log path; current scope is single-rank smoke testing. +""" + +from __future__ import annotations + +import json +import os +import threading +import time +from typing import Optional + + +class _Timer: + def __init__(self) -> None: + log_path = os.environ.get("SGLANG_LVM_TIMING_LOG") + self.enabled = bool(log_path) + self._log_path: Optional[str] = log_path + self._fh = None + self._step_id = 0 + self._lock = threading.Lock() + self._current_step: dict = {} + if self.enabled: + os.makedirs(os.path.dirname(self._log_path) or ".", exist_ok=True) + self._fh = open(self._log_path, "a", buffering=1) + + def section_start(self, name: str) -> Optional[float]: + if not self.enabled: + return None + return time.perf_counter() + + def section_end(self, name: str, start: Optional[float]) -> None: + if not self.enabled or start is None: + return + elapsed_ms = (time.perf_counter() - start) * 1000.0 + # Accumulate in case a section is entered multiple times per step. + prev = self._current_step.get(name, 0.0) + self._current_step[name] = prev + elapsed_ms + + def set_meta(self, **kwargs) -> None: + if not self.enabled: + return + self._current_step.update(kwargs) + + def flush_step(self) -> None: + """Write one JSONL line for the current step and reset accumulator.""" + if not self.enabled or self._fh is None: + return + if not self._current_step: + return + with self._lock: + self._step_id += 1 + record = {"step": self._step_id, **self._current_step} + self._fh.write(json.dumps(record) + "\n") + self._current_step.clear() + + +_INSTANCE: Optional[_Timer] = None + + +def get_timer() -> _Timer: + global _INSTANCE + if _INSTANCE is None: + _INSTANCE = _Timer() + return _INSTANCE