Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions docs/lenvm-guided-sampling-optimization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# LenVM guided sampling optimization summary

Latency-oriented changes for the in-process LenVM guided decoding path in the
SGLang fork. This follows the timing question from PR #2: how much overhead
does LenVM add, and where does that overhead remain after optimizing the hot
path?

## TL;DR results

Reference benchmark: 1x H100 SXM, `Qwen/Qwen2.5-7B-Instruct` base model plus
`namezz/lvm-math-0402-a-qwen2.5-7b-instruct-b-qwen2.5-1.5b-instruct`,
GSM8K 50 questions x 16 samples, `max_tokens=6000`, temperature 1.0,
top-p 1.0, min-p 0.01. Baseline uses `top_k=-1`; LenVM uses `top_k=5`,
`value_mode=centered_exp`, `value_scale=0.001`, `gamma=0.997`.

`value_scale=0.001` is used as a near-neutral active-LenVM configuration. In
the optimized path, `centered_exp` with scale 0 is a true no-op and skips LenVM
entirely, so scale 0 is not useful for measuring active LenVM overhead.

| run | wall clock | avg completion tokens / choice | acc_first | acc_any |
| --- | ---: | ---: | ---: | ---: |
| baseline | 19.84 s | 295.96 | 0.88 | 1.00 |
| optimized LenVM | 80.02 s | 298.36 | 0.90 | 0.96 |
| ratio | 4.03x slower | +0.81% | - | - |

Compared with the original PR #2 reference result, 19.22 s -> 87.44 s
(4.55x slower), this reduces the guided run wall clock by 8.5%, the slowdown
ratio by 11.4%, and the incremental LenVM overhead by 11.8%.

The model memory reported by SGLang on the benchmark node was 14.30 GB for the
bf16 7B base model and 3.03 GB for the bf16 1.5B LenVM value model.

## Time breakdown

Small profiling run: same 7B/1.5B model pair, GSM8K 10 questions x 4 samples,
`top_k=5`, scale 0.001, with `SGLANG_LVM_TIMING=1`.

Baseline `Sampler.forward` is about 0.70-0.77 ms per call; most of that is the
FlashInfer sample backend. With LenVM active, `Sampler.forward` rises to about
13-15 ms after warmup, and `LvmGuidedSampler.apply` accounts for roughly 97% of
that time.

The final timing log at 800 guided calls reports:

| component | average time |
| --- | ---: |
| `Sampler.forward` | 13.22 ms |
| `LvmGuidedSampler.apply` | 12.85 ms |
| `build_pending` | 1.31 ms |
| fused extend+candidate path | 1.10 ms |
| fallback prefix extend | 9.23 ms |
| fallback candidate launch | 8.84 ms |
| GPU guidance application | 1.08 ms |

The branch-specific fallback timings are not additive with `apply_total`, but
they show the remaining bottleneck clearly: most residual latency comes from
requests that still fall back to two LVM forwards, one to extend the LenVM KV
cache and one to score candidate tokens.

## What changed

- Add a request-level fast precheck before initializing the in-process LenVM
provider. Batches with no active value-guidance request return immediately,
which avoids polluting baseline runs.
- Treat neutral guidance settings as no-ops: `centered_exp` / `value_bias`
scale 0, `mul` scale <= 0 or scale 1, and other expectation modes at scale 1.
- Keep compacted candidate ids, probabilities, and masks on GPU for the common
expectation-guidance path. Only row metadata is copied to CPU.
- Apply guidance in place and return only rows that changed, preserving
top-k/top-p/min-p filtering for unmodified rows without cloning the full
`[batch, vocab]` probability tensor.
- Add a fused in-process LenVM path that extends a tiny prefix delta and scores
candidates in one forward when the request layout supports it.
- Add FlashInfer prefill argument generation for tree-value attention masks and
vectorize the candidate self-attention diagonal.
- Avoid `.tolist()` GPU synchronization in Qwen2/Qwen3/Qwen2.5-VL LenVM value
slicing by carrying prefix and candidate lengths in the tree-value spec.
- Cache request EOS ids for repeated candidate filtering.
- Add opt-in timing logs controlled by `SGLANG_LVM_TIMING`,
`SGLANG_LVM_TIMING_INTERVAL`, and `SGLANG_LVM_TIMING_SKIP_CALLS`.

## Validation

Focused unit coverage was added for:

- GPU candidate compaction and CPU fallback candidate lists.
- No-op scale skipping.
- Baseline batches not initializing the in-process LenVM runner.
- Mixed value-guidance modes on the GPU path.
- Fused tree-value prefix/candidate mask construction.
- FlashInfer prefill argument generation.

Commands run:

```bash
python -m compileall -q \
sglang-LenVM/python/sglang/srt/layers/sampler.py \
sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py \
sglang-LenVM/python/sglang/srt/lvm/lvm_inproc_runner.py \
sglang-LenVM/python/sglang/srt/lvm/lvm_value_utils.py \
sglang-LenVM/python/sglang/srt/lvm/tree_value_spec.py \
sglang-LenVM/python/sglang/srt/models/qwen2_lvm.py \
sglang-LenVM/python/sglang/srt/models/qwen3_lvm.py \
sglang-LenVM/python/sglang/srt/models/qwen2_5_vl_lvm.py

git diff --check
```

The focused tests were also executed directly through `.venv-infer/bin/python`
because the workspace does not currently have `pytest` installed.

Slurm benchmark artifacts:

- `results/timing/pr7b_q50n16_s001_84065_20260523_211347/`
- `results/timing/pr7b_prof_s001_84067_20260523_212103/`

## Remaining optimization target

The fused path is still all-or-nothing at the batch level in many situations.
The profiling run reached only 63 fused calls out of 800 guided calls. The next
high-value optimization is to split each batch into fusible and fallback rows:
run the fused path for eligible rows and use the two-phase path only for the
rest. That should directly attack the remaining prefix-extend and
candidate-launch overhead.
128 changes: 127 additions & 1 deletion sglang-LenVM/python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import os
import time
from typing import Callable, Dict, List, Optional, Tuple

import torch
Expand Down Expand Up @@ -31,6 +33,103 @@

logger = logging.getLogger(__name__)

_SAMPLER_TIMING_TOTALS: Dict[str, float] = {}
_SAMPLER_TIMING_COUNTS: Dict[str, int] = {}
_SAMPLER_TIMING_EVENTS: Dict[str, int] = {}


def _sampler_timing_enabled() -> bool:
return os.environ.get("SGLANG_LVM_TIMING", "").strip().lower() in {
"1",
"true",
"yes",
"on",
}


def _sampler_timing_interval() -> int:
raw = os.environ.get("SGLANG_LVM_TIMING_INTERVAL", "200")
try:
return max(int(raw), 1)
except ValueError:
return 200


def _sampler_timing_skip_calls() -> int:
raw = os.environ.get("SGLANG_LVM_TIMING_SKIP_CALLS", "0")
try:
return max(int(raw), 0)
except ValueError:
return 0


def _sampler_timing_tic() -> float:
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.perf_counter()


def _sampler_timing_toc(sample: Optional[Dict[str, float]], name: str, tic: float) -> None:
if sample is None:
return
if torch.cuda.is_available():
torch.cuda.synchronize()
sample[name] = sample.get(name, 0.0) + (time.perf_counter() - tic) * 1000.0


def _record_sampler_timing(
sample: Optional[Dict[str, float]], *, guided_applied: bool = False
) -> None:
if sample is None:
return

seen = _SAMPLER_TIMING_EVENTS.get("_seen", 0) + 1
_SAMPLER_TIMING_EVENTS["_seen"] = seen
if seen <= _sampler_timing_skip_calls():
return

for key, value in sample.items():
_SAMPLER_TIMING_TOTALS[key] = _SAMPLER_TIMING_TOTALS.get(key, 0.0) + float(value)
_SAMPLER_TIMING_COUNTS[key] = _SAMPLER_TIMING_COUNTS.get(key, 0) + 1

_SAMPLER_TIMING_EVENTS["calls"] = _SAMPLER_TIMING_EVENTS.get("calls", 0) + 1
if guided_applied:
_SAMPLER_TIMING_EVENTS["guided"] = _SAMPLER_TIMING_EVENTS.get("guided", 0) + 1

calls = _SAMPLER_TIMING_EVENTS["calls"]
if calls % _sampler_timing_interval() != 0:
return

forward_avg = _SAMPLER_TIMING_TOTALS.get("sampler_forward", 0.0) / max(
_SAMPLER_TIMING_COUNTS.get("sampler_forward", 0), 1
)
ordered = (
"sampler_forward",
"preprocess_logits",
"temperature",
"token_temp_scale",
"softmax",
"lvm_apply",
"sample_direct",
"sample_backend",
"logprob",
"tp_sync",
)
pieces = [
f"calls={calls}",
f"guided={_SAMPLER_TIMING_EVENTS.get('guided', 0)}",
]
for name in ordered:
count = _SAMPLER_TIMING_COUNTS.get(name, 0)
if count <= 0:
continue
avg = _SAMPLER_TIMING_TOTALS[name] / count
if name == "sampler_forward" or forward_avg <= 0:
pieces.append(f"{name}={avg:.3f}ms")
else:
pieces.append(f"{name}={avg:.3f}ms/{avg / forward_avg * 100:.1f}%")
logger.info("[sampler_timing] %s", " ".join(pieces))

SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
_CUSTOM_SAMPLER_FACTORIES: Dict[str, Callable[[], "Sampler"]] = {}
Expand Down Expand Up @@ -134,16 +233,26 @@ def forward(
positions: The positions of the tokens in the sequence. Used for deterministic sampling
to get the unique seed for each position.
"""
timing_sample = {} if _sampler_timing_enabled() else None
t_forward = _sampler_timing_tic() if timing_sample is not None else 0.0
guided_applied = False

logits = logits_output.next_token_logits

# Preprocess logits (custom processors and NaN handling)
t = _sampler_timing_tic() if timing_sample is not None else 0.0
logits = self._preprocess_logits(logits, sampling_info)
_sampler_timing_toc(timing_sample, "preprocess_logits", t)

if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling
t = _sampler_timing_tic() if timing_sample is not None else 0.0
batch_next_token_ids = torch.argmax(logits, -1)
_sampler_timing_toc(timing_sample, "sample_direct", t)
if return_logprob:
t = _sampler_timing_tic() if timing_sample is not None else 0.0
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
_sampler_timing_toc(timing_sample, "logprob", t)
else:
can_sample_directly_from_probs = (
not sampling_info.need_top_p_sampling
Expand All @@ -164,25 +273,31 @@ def forward(
)

# Post process logits
t = _sampler_timing_tic() if timing_sample is not None else 0.0
logits.div_(sampling_info.temperatures)
_sampler_timing_toc(timing_sample, "temperature", t)

# Per-token temperature scaling: for boosted tokens, divide their
# logits by an additional divisor (equivalent to temperature T/d).
# Done in-place on logits before softmax — zero overhead on the
# subsequent sampling path.
if self.lvm_guided_sampler is not None and sampling_info.reqs is not None:
t = _sampler_timing_tic() if timing_sample is not None else 0.0
self._apply_token_temp_scale(logits, sampling_info.reqs)
_sampler_timing_toc(timing_sample, "token_temp_scale", t)

# For ascend backend, softmax is not needed before sampling
if not get_global_server_args().sampling_backend == "ascend" or (
return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB
):
t = _sampler_timing_tic() if timing_sample is not None else 0.0
logits[:] = torch.softmax(logits, dim=-1)
_sampler_timing_toc(timing_sample, "softmax", t)
probs = logits
del logits

guided_applied = False
if self.lvm_guided_sampler is not None and sampling_info.reqs is not None:
t = _sampler_timing_tic() if timing_sample is not None else 0.0
guided_probs = self.lvm_guided_sampler.apply(
probs,
sampling_info.reqs,
Expand All @@ -191,6 +306,7 @@ def forward(
sampling_info.top_ks,
sampling_info.min_ps,
)
_sampler_timing_toc(timing_sample, "lvm_apply", t)
if guided_probs is not None:
probs = guided_probs
guided_applied = True
Expand All @@ -200,12 +316,15 @@ def forward(

if can_sample_directly_from_probs:
# when we don't need top-k, top-p, or min-p sampling, we can directly sample from the probs
t = _sampler_timing_tic() if timing_sample is not None else 0.0
batch_next_token_ids = sampling_from_probs_torch(
probs,
sampling_seed=sampling_info.sampling_seed,
positions=positions,
)
_sampler_timing_toc(timing_sample, "sample_direct", t)
else:
t = _sampler_timing_tic() if timing_sample is not None else 0.0
if get_global_server_args().sampling_backend == "flashinfer":
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
Expand Down Expand Up @@ -244,8 +363,10 @@ def forward(
raise ValueError(
f"Invalid sampling backend: {get_global_server_args().sampling_backend}"
)
_sampler_timing_toc(timing_sample, "sample_backend", t)

if return_logprob:
t = _sampler_timing_tic() if timing_sample is not None else 0.0
if get_global_server_args().rl_on_policy_target is not None:
logprobs = logprobs_via_logsoftmax_kernel
del logprobs_via_logsoftmax_kernel
Expand All @@ -257,6 +378,7 @@ def forward(
del probs_without_temp_scaling
else:
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
_sampler_timing_toc(timing_sample, "logprob", t)

# Attach logprobs to logits_output (in-place modification)
if return_logprob:
Expand Down Expand Up @@ -285,12 +407,16 @@ def forward(
# In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized.
# When using xgrammar, this becomes more likely so we also do the sync when grammar is used.

t = _sampler_timing_tic() if timing_sample is not None else 0.0
torch.distributed.all_reduce(
batch_next_token_ids,
op=dist.ReduceOp.MIN,
group=self.tp_sync_group,
)
_sampler_timing_toc(timing_sample, "tp_sync", t)

_sampler_timing_toc(timing_sample, "sampler_forward", t_forward)
_record_sampler_timing(timing_sample, guided_applied=guided_applied)
return batch_next_token_ids

def compute_logprobs_only(
Expand Down
Loading