From 374565b2f7d661d2e7f65b0c0f67af14016b06b6 Mon Sep 17 00:00:00 2001 From: namezhenzhang Date: Sat, 23 May 2026 21:30:56 +0000 Subject: [PATCH] Optimize LenVM guided sampling latency --- docs/lenvm-guided-sampling-optimization.md | 124 ++++ .../python/sglang/srt/layers/sampler.py | 128 +++- .../sglang/srt/lvm/lvm_guided_sampling.py | 572 +++++++++++++++--- .../sglang/srt/lvm/lvm_inproc_runner.py | 221 ++++++- .../python/sglang/srt/lvm/lvm_value_utils.py | 9 +- .../python/sglang/srt/lvm/tree_value_spec.py | 44 +- .../sglang/srt/models/qwen2_5_vl_lvm.py | 17 +- .../python/sglang/srt/models/qwen2_lvm.py | 18 +- .../python/sglang/srt/models/qwen3_lvm.py | 17 +- .../srt/test_lvm_guided_sampling_fast_path.py | 280 +++++++++ sglang-LenVM/test/srt/test_tree_value_spec.py | 126 ++++ 11 files changed, 1457 insertions(+), 99 deletions(-) create mode 100644 docs/lenvm-guided-sampling-optimization.md create mode 100644 sglang-LenVM/test/srt/test_lvm_guided_sampling_fast_path.py create mode 100644 sglang-LenVM/test/srt/test_tree_value_spec.py diff --git a/docs/lenvm-guided-sampling-optimization.md b/docs/lenvm-guided-sampling-optimization.md new file mode 100644 index 0000000..3b03f57 --- /dev/null +++ b/docs/lenvm-guided-sampling-optimization.md @@ -0,0 +1,124 @@ +# LenVM guided sampling optimization summary + +Latency-oriented changes for the in-process LenVM guided decoding path in the +SGLang fork. This follows the timing question from PR #2: how much overhead +does LenVM add, and where does that overhead remain after optimizing the hot +path? + +## TL;DR results + +Reference benchmark: 1x H100 SXM, `Qwen/Qwen2.5-7B-Instruct` base model plus +`namezz/lvm-math-0402-a-qwen2.5-7b-instruct-b-qwen2.5-1.5b-instruct`, +GSM8K 50 questions x 16 samples, `max_tokens=6000`, temperature 1.0, +top-p 1.0, min-p 0.01. Baseline uses `top_k=-1`; LenVM uses `top_k=5`, +`value_mode=centered_exp`, `value_scale=0.001`, `gamma=0.997`. + +`value_scale=0.001` is used as a near-neutral active-LenVM configuration. In +the optimized path, `centered_exp` with scale 0 is a true no-op and skips LenVM +entirely, so scale 0 is not useful for measuring active LenVM overhead. + +| run | wall clock | avg completion tokens / choice | acc_first | acc_any | +| --- | ---: | ---: | ---: | ---: | +| baseline | 19.84 s | 295.96 | 0.88 | 1.00 | +| optimized LenVM | 80.02 s | 298.36 | 0.90 | 0.96 | +| ratio | 4.03x slower | +0.81% | - | - | + +Compared with the original PR #2 reference result, 19.22 s -> 87.44 s +(4.55x slower), this reduces the guided run wall clock by 8.5%, the slowdown +ratio by 11.4%, and the incremental LenVM overhead by 11.8%. + +The model memory reported by SGLang on the benchmark node was 14.30 GB for the +bf16 7B base model and 3.03 GB for the bf16 1.5B LenVM value model. + +## Time breakdown + +Small profiling run: same 7B/1.5B model pair, GSM8K 10 questions x 4 samples, +`top_k=5`, scale 0.001, with `SGLANG_LVM_TIMING=1`. + +Baseline `Sampler.forward` is about 0.70-0.77 ms per call; most of that is the +FlashInfer sample backend. With LenVM active, `Sampler.forward` rises to about +13-15 ms after warmup, and `LvmGuidedSampler.apply` accounts for roughly 97% of +that time. + +The final timing log at 800 guided calls reports: + +| component | average time | +| --- | ---: | +| `Sampler.forward` | 13.22 ms | +| `LvmGuidedSampler.apply` | 12.85 ms | +| `build_pending` | 1.31 ms | +| fused extend+candidate path | 1.10 ms | +| fallback prefix extend | 9.23 ms | +| fallback candidate launch | 8.84 ms | +| GPU guidance application | 1.08 ms | + +The branch-specific fallback timings are not additive with `apply_total`, but +they show the remaining bottleneck clearly: most residual latency comes from +requests that still fall back to two LVM forwards, one to extend the LenVM KV +cache and one to score candidate tokens. + +## What changed + +- Add a request-level fast precheck before initializing the in-process LenVM + provider. Batches with no active value-guidance request return immediately, + which avoids polluting baseline runs. +- Treat neutral guidance settings as no-ops: `centered_exp` / `value_bias` + scale 0, `mul` scale <= 0 or scale 1, and other expectation modes at scale 1. +- Keep compacted candidate ids, probabilities, and masks on GPU for the common + expectation-guidance path. Only row metadata is copied to CPU. +- Apply guidance in place and return only rows that changed, preserving + top-k/top-p/min-p filtering for unmodified rows without cloning the full + `[batch, vocab]` probability tensor. +- Add a fused in-process LenVM path that extends a tiny prefix delta and scores + candidates in one forward when the request layout supports it. +- Add FlashInfer prefill argument generation for tree-value attention masks and + vectorize the candidate self-attention diagonal. +- Avoid `.tolist()` GPU synchronization in Qwen2/Qwen3/Qwen2.5-VL LenVM value + slicing by carrying prefix and candidate lengths in the tree-value spec. +- Cache request EOS ids for repeated candidate filtering. +- Add opt-in timing logs controlled by `SGLANG_LVM_TIMING`, + `SGLANG_LVM_TIMING_INTERVAL`, and `SGLANG_LVM_TIMING_SKIP_CALLS`. + +## Validation + +Focused unit coverage was added for: + +- GPU candidate compaction and CPU fallback candidate lists. +- No-op scale skipping. +- Baseline batches not initializing the in-process LenVM runner. +- Mixed value-guidance modes on the GPU path. +- Fused tree-value prefix/candidate mask construction. +- FlashInfer prefill argument generation. + +Commands run: + +```bash +python -m compileall -q \ + sglang-LenVM/python/sglang/srt/layers/sampler.py \ + sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py \ + sglang-LenVM/python/sglang/srt/lvm/lvm_inproc_runner.py \ + sglang-LenVM/python/sglang/srt/lvm/lvm_value_utils.py \ + sglang-LenVM/python/sglang/srt/lvm/tree_value_spec.py \ + sglang-LenVM/python/sglang/srt/models/qwen2_lvm.py \ + sglang-LenVM/python/sglang/srt/models/qwen3_lvm.py \ + sglang-LenVM/python/sglang/srt/models/qwen2_5_vl_lvm.py + +git diff --check +``` + +The focused tests were also executed directly through `.venv-infer/bin/python` +because the workspace does not currently have `pytest` installed. + +Slurm benchmark artifacts: + +- `results/timing/pr7b_q50n16_s001_84065_20260523_211347/` +- `results/timing/pr7b_prof_s001_84067_20260523_212103/` + +## Remaining optimization target + +The fused path is still all-or-nothing at the batch level in many situations. +The profiling run reached only 63 fused calls out of 800 guided calls. The next +high-value optimization is to split each batch into fusible and fallback rows: +run the fused path for eligible rows and use the two-phase path only for the +rest. That should directly attack the remaining prefix-extend and +candidate-launch overhead. diff --git a/sglang-LenVM/python/sglang/srt/layers/sampler.py b/sglang-LenVM/python/sglang/srt/layers/sampler.py index ccd4bfd..afd2d3c 100644 --- a/sglang-LenVM/python/sglang/srt/layers/sampler.py +++ b/sglang-LenVM/python/sglang/srt/layers/sampler.py @@ -1,4 +1,6 @@ import logging +import os +import time from typing import Callable, Dict, List, Optional, Tuple import torch @@ -31,6 +33,103 @@ logger = logging.getLogger(__name__) +_SAMPLER_TIMING_TOTALS: Dict[str, float] = {} +_SAMPLER_TIMING_COUNTS: Dict[str, int] = {} +_SAMPLER_TIMING_EVENTS: Dict[str, int] = {} + + +def _sampler_timing_enabled() -> bool: + return os.environ.get("SGLANG_LVM_TIMING", "").strip().lower() in { + "1", + "true", + "yes", + "on", + } + + +def _sampler_timing_interval() -> int: + raw = os.environ.get("SGLANG_LVM_TIMING_INTERVAL", "200") + try: + return max(int(raw), 1) + except ValueError: + return 200 + + +def _sampler_timing_skip_calls() -> int: + raw = os.environ.get("SGLANG_LVM_TIMING_SKIP_CALLS", "0") + try: + return max(int(raw), 0) + except ValueError: + return 0 + + +def _sampler_timing_tic() -> float: + if torch.cuda.is_available(): + torch.cuda.synchronize() + return time.perf_counter() + + +def _sampler_timing_toc(sample: Optional[Dict[str, float]], name: str, tic: float) -> None: + if sample is None: + return + if torch.cuda.is_available(): + torch.cuda.synchronize() + sample[name] = sample.get(name, 0.0) + (time.perf_counter() - tic) * 1000.0 + + +def _record_sampler_timing( + sample: Optional[Dict[str, float]], *, guided_applied: bool = False +) -> None: + if sample is None: + return + + seen = _SAMPLER_TIMING_EVENTS.get("_seen", 0) + 1 + _SAMPLER_TIMING_EVENTS["_seen"] = seen + if seen <= _sampler_timing_skip_calls(): + return + + for key, value in sample.items(): + _SAMPLER_TIMING_TOTALS[key] = _SAMPLER_TIMING_TOTALS.get(key, 0.0) + float(value) + _SAMPLER_TIMING_COUNTS[key] = _SAMPLER_TIMING_COUNTS.get(key, 0) + 1 + + _SAMPLER_TIMING_EVENTS["calls"] = _SAMPLER_TIMING_EVENTS.get("calls", 0) + 1 + if guided_applied: + _SAMPLER_TIMING_EVENTS["guided"] = _SAMPLER_TIMING_EVENTS.get("guided", 0) + 1 + + calls = _SAMPLER_TIMING_EVENTS["calls"] + if calls % _sampler_timing_interval() != 0: + return + + forward_avg = _SAMPLER_TIMING_TOTALS.get("sampler_forward", 0.0) / max( + _SAMPLER_TIMING_COUNTS.get("sampler_forward", 0), 1 + ) + ordered = ( + "sampler_forward", + "preprocess_logits", + "temperature", + "token_temp_scale", + "softmax", + "lvm_apply", + "sample_direct", + "sample_backend", + "logprob", + "tp_sync", + ) + pieces = [ + f"calls={calls}", + f"guided={_SAMPLER_TIMING_EVENTS.get('guided', 0)}", + ] + for name in ordered: + count = _SAMPLER_TIMING_COUNTS.get(name, 0) + if count <= 0: + continue + avg = _SAMPLER_TIMING_TOTALS[name] / count + if name == "sampler_forward" or forward_avg <= 0: + pieces.append(f"{name}={avg:.3f}ms") + else: + pieces.append(f"{name}={avg:.3f}ms/{avg / forward_avg * 100:.1f}%") + logger.info("[sampler_timing] %s", " ".join(pieces)) + SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP") SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB") _CUSTOM_SAMPLER_FACTORIES: Dict[str, Callable[[], "Sampler"]] = {} @@ -134,16 +233,26 @@ def forward( positions: The positions of the tokens in the sequence. Used for deterministic sampling to get the unique seed for each position. """ + timing_sample = {} if _sampler_timing_enabled() else None + t_forward = _sampler_timing_tic() if timing_sample is not None else 0.0 + guided_applied = False + logits = logits_output.next_token_logits # Preprocess logits (custom processors and NaN handling) + t = _sampler_timing_tic() if timing_sample is not None else 0.0 logits = self._preprocess_logits(logits, sampling_info) + _sampler_timing_toc(timing_sample, "preprocess_logits", t) if sampling_info.is_all_greedy: # Use torch.argmax if all requests use greedy sampling + t = _sampler_timing_tic() if timing_sample is not None else 0.0 batch_next_token_ids = torch.argmax(logits, -1) + _sampler_timing_toc(timing_sample, "sample_direct", t) if return_logprob: + t = _sampler_timing_tic() if timing_sample is not None else 0.0 logprobs = torch.nn.functional.log_softmax(logits, dim=-1) + _sampler_timing_toc(timing_sample, "logprob", t) else: can_sample_directly_from_probs = ( not sampling_info.need_top_p_sampling @@ -164,25 +273,31 @@ def forward( ) # Post process logits + t = _sampler_timing_tic() if timing_sample is not None else 0.0 logits.div_(sampling_info.temperatures) + _sampler_timing_toc(timing_sample, "temperature", t) # Per-token temperature scaling: for boosted tokens, divide their # logits by an additional divisor (equivalent to temperature T/d). # Done in-place on logits before softmax — zero overhead on the # subsequent sampling path. if self.lvm_guided_sampler is not None and sampling_info.reqs is not None: + t = _sampler_timing_tic() if timing_sample is not None else 0.0 self._apply_token_temp_scale(logits, sampling_info.reqs) + _sampler_timing_toc(timing_sample, "token_temp_scale", t) # For ascend backend, softmax is not needed before sampling if not get_global_server_args().sampling_backend == "ascend" or ( return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB ): + t = _sampler_timing_tic() if timing_sample is not None else 0.0 logits[:] = torch.softmax(logits, dim=-1) + _sampler_timing_toc(timing_sample, "softmax", t) probs = logits del logits - guided_applied = False if self.lvm_guided_sampler is not None and sampling_info.reqs is not None: + t = _sampler_timing_tic() if timing_sample is not None else 0.0 guided_probs = self.lvm_guided_sampler.apply( probs, sampling_info.reqs, @@ -191,6 +306,7 @@ def forward( sampling_info.top_ks, sampling_info.min_ps, ) + _sampler_timing_toc(timing_sample, "lvm_apply", t) if guided_probs is not None: probs = guided_probs guided_applied = True @@ -200,12 +316,15 @@ def forward( if can_sample_directly_from_probs: # when we don't need top-k, top-p, or min-p sampling, we can directly sample from the probs + t = _sampler_timing_tic() if timing_sample is not None else 0.0 batch_next_token_ids = sampling_from_probs_torch( probs, sampling_seed=sampling_info.sampling_seed, positions=positions, ) + _sampler_timing_toc(timing_sample, "sample_direct", t) else: + t = _sampler_timing_tic() if timing_sample is not None else 0.0 if get_global_server_args().sampling_backend == "flashinfer": if sampling_info.need_min_p_sampling: probs = top_k_renorm_prob(probs, sampling_info.top_ks) @@ -244,8 +363,10 @@ def forward( raise ValueError( f"Invalid sampling backend: {get_global_server_args().sampling_backend}" ) + _sampler_timing_toc(timing_sample, "sample_backend", t) if return_logprob: + t = _sampler_timing_tic() if timing_sample is not None else 0.0 if get_global_server_args().rl_on_policy_target is not None: logprobs = logprobs_via_logsoftmax_kernel del logprobs_via_logsoftmax_kernel @@ -257,6 +378,7 @@ def forward( del probs_without_temp_scaling else: logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min) + _sampler_timing_toc(timing_sample, "logprob", t) # Attach logprobs to logits_output (in-place modification) if return_logprob: @@ -285,12 +407,16 @@ def forward( # In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized. # When using xgrammar, this becomes more likely so we also do the sync when grammar is used. + t = _sampler_timing_tic() if timing_sample is not None else 0.0 torch.distributed.all_reduce( batch_next_token_ids, op=dist.ReduceOp.MIN, group=self.tp_sync_group, ) + _sampler_timing_toc(timing_sample, "tp_sync", t) + _sampler_timing_toc(timing_sample, "sampler_forward", t_forward) + _record_sampler_timing(timing_sample, guided_applied=guided_applied) return batch_next_token_ids def compute_logprobs_only( diff --git a/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py b/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py index ea51495..7257079 100644 --- a/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py +++ b/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py @@ -5,6 +5,7 @@ import importlib.util import logging import os +import time from dataclasses import dataclass from typing import Any, Callable, Iterable, List, Optional @@ -26,6 +27,127 @@ logger = logging.getLogger(__name__) +_LVM_TIMING_TOTALS: dict[str, float] = {} +_LVM_TIMING_COUNTS: dict[str, int] = {} +_LVM_TIMING_EVENTS: dict[str, int] = {} + + +def _lvm_timing_enabled() -> bool: + return os.environ.get("SGLANG_LVM_TIMING", "").strip().lower() in { + "1", + "true", + "yes", + "on", + } + + +def _lvm_timing_interval() -> int: + raw = os.environ.get("SGLANG_LVM_TIMING_INTERVAL", "200") + try: + return max(int(raw), 1) + except ValueError: + return 200 + + +def _lvm_timing_skip_calls() -> int: + raw = os.environ.get("SGLANG_LVM_TIMING_SKIP_CALLS", "0") + try: + return max(int(raw), 0) + except ValueError: + return 0 + + +def _lvm_timing_tic() -> float: + if torch.cuda.is_available(): + torch.cuda.synchronize() + return time.perf_counter() + + +def _lvm_timing_toc(sample: Optional[dict[str, float]], name: str, tic: float) -> None: + if sample is None: + return + if torch.cuda.is_available(): + torch.cuda.synchronize() + sample[name] = sample.get(name, 0.0) + (time.perf_counter() - tic) * 1000.0 + + +def _record_lvm_timing( + sample: Optional[dict[str, float]], + *, + skipped: bool = False, + pending: Optional["PendingLvmResult"] = None, + gpu_path: bool = False, + fused_path: bool = False, +) -> None: + if sample is None: + return + + seen = _LVM_TIMING_EVENTS.get("_seen", 0) + 1 + _LVM_TIMING_EVENTS["_seen"] = seen + if seen <= _lvm_timing_skip_calls(): + return + + for key, value in sample.items(): + _LVM_TIMING_TOTALS[key] = _LVM_TIMING_TOTALS.get(key, 0.0) + float(value) + _LVM_TIMING_COUNTS[key] = _LVM_TIMING_COUNTS.get(key, 0) + 1 + + _LVM_TIMING_EVENTS["apply_calls"] = _LVM_TIMING_EVENTS.get("apply_calls", 0) + 1 + if skipped: + _LVM_TIMING_EVENTS["skipped"] = _LVM_TIMING_EVENTS.get("skipped", 0) + 1 + if gpu_path: + _LVM_TIMING_EVENTS["gpu_path"] = _LVM_TIMING_EVENTS.get("gpu_path", 0) + 1 + if fused_path: + _LVM_TIMING_EVENTS["fused_path"] = _LVM_TIMING_EVENTS.get("fused_path", 0) + 1 + + if pending is not None: + n_send = len(pending.send_batch_indices) + if pending.candidate_lens_send is not None: + n_cands = sum(int(x) for x in pending.candidate_lens_send) + else: + n_cands = sum(len(x) for x in pending.candidate_ids_send) + _LVM_TIMING_EVENTS["send_rows"] = _LVM_TIMING_EVENTS.get("send_rows", 0) + n_send + _LVM_TIMING_EVENTS["candidates"] = _LVM_TIMING_EVENTS.get("candidates", 0) + n_cands + + calls = _LVM_TIMING_EVENTS["apply_calls"] + if calls % _lvm_timing_interval() != 0: + return + + apply_avg = _LVM_TIMING_TOTALS.get("apply_total", 0.0) / max( + _LVM_TIMING_COUNTS.get("apply_total", 0), 1 + ) + ordered = ( + "apply_total", + "precheck", + "get_inproc_provider", + "clean_stale", + "build_pending", + "extend_launch_fused", + "extend_prefix", + "launch_candidates", + "collect_gpu", + "apply_guidance_gpu", + "post_tree_value", + "apply_guidance_cpu", + ) + pieces = [ + f"calls={calls}", + f"skipped={_LVM_TIMING_EVENTS.get('skipped', 0)}", + f"gpu_path={_LVM_TIMING_EVENTS.get('gpu_path', 0)}", + f"fused={_LVM_TIMING_EVENTS.get('fused_path', 0)}", + f"avg_send_rows={_LVM_TIMING_EVENTS.get('send_rows', 0) / max(calls, 1):.2f}", + f"avg_candidates={_LVM_TIMING_EVENTS.get('candidates', 0) / max(calls, 1):.2f}", + ] + for name in ordered: + count = _LVM_TIMING_COUNTS.get(name, 0) + if count <= 0: + continue + avg = _LVM_TIMING_TOTALS[name] / count + if name == "apply_total" or apply_avg <= 0: + pieces.append(f"{name}={avg:.3f}ms") + else: + pieces.append(f"{name}={avg:.3f}ms/{avg / apply_avg * 100:.1f}%") + logger.info("[lvm_timing] %s", " ".join(pieces)) + def _get_req_custom_params(req: Any) -> dict[str, Any]: if req is None: @@ -114,6 +236,33 @@ def _extract_value_mode(kwargs: dict, default: str = "mul") -> str: return mode +def _get_req_value_mode_and_scale(req: Any) -> tuple[str, float]: + """Return cached expectation-guidance mode/scale for a request.""" + custom_params = _get_req_custom_params(req) + cache_key = ( + id(custom_params), + custom_params.get("mode"), + custom_params.get("value_mode"), + custom_params.get("scale"), + custom_params.get("value_scale"), + ) + cached = getattr(req, "_lvm_value_mode_scale_cache", None) + if ( + isinstance(cached, tuple) + and len(cached) == 3 + and cached[0] == cache_key + ): + return cached[1], cached[2] + + mode = _extract_value_mode({"req": req}, default="mul") + scale = _extract_value_scale({"req": req}, default=1.0) + try: + setattr(req, "_lvm_value_mode_scale_cache", (cache_key, mode, scale)) + except Exception: + pass + return mode, scale + + def _extract_length_gamma(kwargs: dict, default: float = 0.997) -> float: """Extract gamma used for value->length mapping. @@ -748,19 +897,22 @@ class LvmGuidedConfig: class PendingLvmResult: """Intermediate state produced by _build_pending() and consumed by apply(). - Carries filtered candidate lists, the cloned probs tensor (with deterministic + Carries filtered candidate lists, the mutable probs tensor (with deterministic rows already filled), and optional GPU tensors for the fast guidance path. """ req_list: List[Any] device: torch.device - # probs.clone() with deterministic (single-candidate) rows already zeroed/filled. + # Probs tensor with deterministic (single-candidate) rows already zeroed/filled. # None means there is nothing to do (all rows were deterministic or skipped). guided: Optional[torch.Tensor] send_batch_indices: List[int] prefix_ids_send: List[List[int]] candidate_ids_send: List[List[int]] candidate_probs_send: List[List[float]] + # Candidate lengths for the GPU fast path. This lets the in-process runner + # avoid copying every candidate id back to CPU just to recover per-row sizes. + candidate_lens_send: Optional[List[int]] = None # GPU tensors for the fast path (only set when the guidance function can use the # expectation-guidance GPU path and all send indices come from the top-k path, # not top-k-all). @@ -1027,7 +1179,11 @@ def tree_value_collect(self, cpu_embeddings) -> List[List[float]]: return out def tree_value_launch_gpu( - self, rids: List[str], candidate_ids: List[List[int]], gpu_candidates: Optional[tuple] = None + self, + rids: List[str], + candidate_ids: List[List[int]], + gpu_candidates: Optional[tuple] = None, + candidate_lens: Optional[List[int]] = None, ): """Like tree_value_launch() but keeps embeddings on GPU (no PCIe copy).""" self.lvm_stream.wait_stream(torch.cuda.current_stream()) @@ -1037,6 +1193,7 @@ def tree_value_launch_gpu( rids, candidate_ids, gpu_candidates=gpu_candidates, + candidate_lens_per_req=candidate_lens, mrope_deltas=self._mrope_deltas if self.is_vlm else None, ) # embeddings stay on GPU — no PCIe copy. @@ -1044,6 +1201,63 @@ def tree_value_launch_gpu( return embeddings # GPU tensor(s), not yet safe from default stream + def tree_value_extend_and_launch_gpu( + self, + rids: List[str], + prefix_ids: List[List[int]], + _reqs: List[Req], + candidate_ids: List[List[int]], + gpu_candidates: Optional[tuple] = None, + candidate_lens: Optional[List[int]] = None, + ): + """Fuse tiny prefix extension and candidate scoring into one forward. + + Returns None when fusion is not appropriate, so callers can fall + back to the two-phase path. + """ + if self.is_vlm: + return None + if ( + getattr( + self.incremental_runner.runner.token_to_kv_pool_allocator, + "page_size", + 1, + ) + != 1 + ): + return None + + new_tokens_list: List[List[int]] = [] + for rid, p_ids in zip(rids, prefix_ids): + cached_len = self.incremental_runner.kv_mgr.kv_len(rid) + target_len = len(p_ids) + if target_len < cached_len: + self.incremental_runner.kv_mgr.retract(rid, target_len) + cached_len = target_len + + if target_len > cached_len: + new_tokens = p_ids[cached_len:] + else: + new_tokens = [] + if len(new_tokens) > 1: + return None + new_tokens_list.append(new_tokens) + + self.lvm_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.lvm_stream): + embeddings = ( + self.incremental_runner.extend_and_eval_candidates_batch_gpu( + rids, + new_tokens_list, + candidate_ids, + gpu_candidates=gpu_candidates, + candidate_lens_per_req=candidate_lens, + ) + ) + self.embed_ready.record(self.lvm_stream) + + return embeddings + def tree_value_collect_gpu(self, gpu_embeddings): """Insert a stream dependency so the default stream waits for lvm_stream.""" torch.cuda.current_stream().wait_event(self.embed_ready) @@ -1169,7 +1383,32 @@ def _req_wants_value_guidance(req: Any) -> bool: "cmp", "op", ) - return any(k in custom_params for k in keys) + if not any(k in custom_params for k in keys): + return False + + hard_constraint_keys = ( + "target_value", + "target_length", + "value_constraint", + "constraint", + "cmp", + "op", + ) + if any(k in custom_params for k in hard_constraint_keys): + return True + + try: + mode, scale = _get_req_value_mode_and_scale(req) + except ValueError: + # Preserve existing behavior: invalid explicit params should surface + # when guidance is built rather than silently disabling LenVM. + return True + + if mode in ("centered_exp", "value_bias"): + return not math.isclose(scale, 0.0, rel_tol=0.0, abs_tol=1e-12) + if mode == "mul" and scale <= 0.0: + return False + return not math.isclose(scale, 1.0, rel_tol=0.0, abs_tol=1e-12) @staticmethod def _extract_entropy_threshold(req: Any) -> Optional[float]: @@ -1284,6 +1523,7 @@ def _build_pending( top_ps: torch.Tensor, top_ks: torch.Tensor, min_ps: torch.Tensor, + enable_gpu_candidate_compact: bool = False, ) -> Optional["PendingLvmResult"]: """Filter candidates and build a PendingLvmResult without contacting the LVM. @@ -1297,9 +1537,15 @@ def _build_pending( # Identify which rows actually want value guidance. guided_rows: List[int] = [] + entropy_threshold_by_row: dict[int, float] = {} + has_entropy_threshold = False for i, req in enumerate(req_list): if self._req_wants_value_guidance(req): guided_rows.append(i) + thr = self._extract_entropy_threshold(req) + if thr is not None: + entropy_threshold_by_row[i] = float(thr) + has_entropy_threshold = True # If nobody requested value guidance, do nothing and let normal sampling proceed. if not guided_rows: @@ -1316,6 +1562,7 @@ def _build_pending( candidate_ids_send: List[List[int]] = [] send_batch_indices: List[int] = [] candidate_probs_send: List[List[float]] = [] + candidate_lens_send: List[int] = [] deterministic_rows: List[tuple[int, int]] = [] # Split rare slow-path (top_k == ALL) from the common (top_k is small). @@ -1326,16 +1573,21 @@ def _build_pending( # without any hard target_value/target_length constraints, and all sequences # use top-k (not top-k-all), so we can keep tensors on GPU. has_hard_target = any( - _extract_target_value({"req": req_list[ridx]}) is not None - for ridx in range(len(req_list)) + any( + _get_req_custom_params(req_list[ridx]).get(k) is not None + for k in ("target_value", "target_length") + ) + for ridx in guided_rows + ) + _use_gpu_path = ( + self._fn in (lvm_expectation_guidance, lvm_combined_guidance) + and enable_gpu_candidate_compact + and not bool(mask_all.any().item()) + and not has_hard_target ) - _use_gpu_path = self._fn in ( - lvm_expectation_guidance, - lvm_combined_guidance, - ) and not bool(mask_all.any().item()) and not has_hard_target # Will hold (vals_send_gpu, idx_send_gpu) for the top-k send rows, on GPU. - _gpu_vals_chunks: List[torch.Tensor] = [] - _gpu_idx_chunks: List[torch.Tensor] = [] + _gpu_vals_send: Optional[torch.Tensor] = None + _gpu_idx_send: Optional[torch.Tensor] = None # --------------------------- # Fast path: batched top-k -> top-p/min-p/entropy on the top-k subset. @@ -1393,21 +1645,24 @@ def _build_pending( ).view(-1) # Optional entropy-based skip (per request, Python-sourced thresholds). - rows_topk_list: List[int] = rows_topk_t.detach().cpu().tolist() - thr_list: List[float] = [] - has_thr = torch.zeros(len(rows_topk_list), device=device, dtype=torch.bool) - for j, ridx in enumerate(rows_topk_list): - thr = self._extract_entropy_threshold(req_list[ridx]) - if thr is None: - thr_list.append(float("nan")) - else: - thr_list.append(float(thr)) - has_thr[j] = True + skip_entropy = torch.zeros_like(det_mask, dtype=torch.bool) + rows_topk_list: Optional[List[int]] = None + if has_entropy_threshold: + rows_topk_list = rows_topk_t.detach().cpu().tolist() + thr_list: List[float] = [] + has_thr = torch.zeros( + len(rows_topk_list), device=device, dtype=torch.bool + ) + for j, ridx in enumerate(rows_topk_list): + thr = entropy_threshold_by_row.get(ridx) + if thr is None: + thr_list.append(float("nan")) + else: + thr_list.append(float(thr)) + has_thr[j] = True - # Use float64 for value-guidance gating to avoid precision loss in entropy comparisons. - thr_t = torch.tensor(thr_list, device=device, dtype=torch.float64) - skip_entropy = torch.zeros_like(has_thr, dtype=torch.bool) - if bool(has_thr.any().item()): + # Use float64 for stable entropy comparisons. + thr_t = torch.tensor(thr_list, device=device, dtype=torch.float64) p = vals.to(torch.float64) s = p.sum(dim=-1) # Avoid division by 0; counts==0 already fixed. @@ -1420,6 +1675,8 @@ def _build_pending( # Materialize deterministic rows in Python list. if bool(det_mask.any().item()): + if rows_topk_list is None: + rows_topk_list = rows_topk_t.detach().cpu().tolist() det_token_ids_cpu = det_token_ids.detach().cpu().tolist() det_mask_cpu = det_mask.detach().cpu().tolist() for j, is_det in enumerate(det_mask_cpu): @@ -1432,38 +1689,49 @@ def _build_pending( # Keep GPU slices before moving to CPU (used by GPU guidance fast path). vals_send_gpu = vals[send_mask] # [B_topk_send, K_max], GPU idx_send_gpu = topk_idx[send_mask] # [B_topk_send, K_max], GPU - idx_send = idx_send_gpu.detach().cpu() - # GPU fast path: only need bool mask (4x smaller than float32 transfer). - # CPU path: need full float values for candidate_probs_send. + rows_send_list = rows_send_t.detach().cpu().tolist() + if _use_gpu_path: - valid_mask_send = (vals_send_gpu > 0).detach().cpu() + # For the GPU path, keep candidate ids/probs/masks on device. + # Only copy one integer per guided row to CPU so ForwardBatch + # can be built without materializing per-candidate Python lists. + counts_send_list = counts[send_mask].detach().cpu().tolist() + for j, ridx in enumerate(rows_send_list): + n_cands = int(counts_send_list[j]) + if n_cands <= 1: + raise RuntimeError( + "Internal LenVM candidate filtering error: " + f"send row has {n_cands} candidates" + ) + + prefix = self._get_prefix_ids_incremental(req_list[ridx]) + prefix_ids_send.append(prefix) + candidate_lens_send.append(n_cands) + send_batch_indices.append(ridx) + + _gpu_vals_send = vals_send_gpu + _gpu_idx_send = idx_send_gpu else: + idx_send = idx_send_gpu.detach().cpu() + # CPU path needs full candidate ids and probs as Python lists. vals_send = vals_send_gpu.detach().cpu() - rows_send_list = rows_send_t.detach().cpu().tolist() - for j, ridx in enumerate(rows_send_list): - # In practice (sorted desc + thresholding), non-zeros are a prefix. Still, use mask for safety. - if _use_gpu_path: - m = valid_mask_send[j] - else: + for j, ridx in enumerate(rows_send_list): + # In practice (sorted desc + thresholding), non-zeros + # are a prefix. Still, use mask for safety. m = vals_send[j] > 0 - cand_ids = idx_send[j][m].tolist() - if len(cand_ids) <= 1: - # Should have been caught by det_mask, but keep a safe fallback. - if len(cand_ids) == 1: - deterministic_rows.append((ridx, int(cand_ids[0]))) - continue - - prefix = self._get_prefix_ids_incremental(req_list[ridx]) - prefix_ids_send.append(prefix) - candidate_ids_send.append(cand_ids) - if not _use_gpu_path: + cand_ids = idx_send[j][m].tolist() + if len(cand_ids) <= 1: + # Should have been caught by det_mask, but keep a safe fallback. + if len(cand_ids) == 1: + deterministic_rows.append((ridx, int(cand_ids[0]))) + continue + + prefix = self._get_prefix_ids_incremental(req_list[ridx]) + prefix_ids_send.append(prefix) + candidate_ids_send.append(cand_ids) candidate_probs_send.append(vals_send[j][m].tolist()) - send_batch_indices.append(ridx) - if _use_gpu_path: - # Capture GPU row j for later gpu_candidates assembly. - _gpu_vals_chunks.append(vals_send_gpu[j].unsqueeze(0)) - _gpu_idx_chunks.append(idx_send_gpu[j].unsqueeze(0)) + send_batch_indices.append(ridx) # --------------------------- # Slow path: top_k == ALL (full vocab filtering). Rare; keep correctness-oriented CPU behavior. @@ -1512,19 +1780,36 @@ def _build_pending( if not send_batch_indices and not deterministic_rows: return None - # Build guided tensor and fill deterministic rows immediately. - guided = probs.clone() + # Reuse the caller-owned probability tensor. The sampler consumes the guided + # distribution after apply() returns, so a full [batch, vocab] clone is avoidable. + guided = probs # Fill deterministic rows (single candidate) without contacting LVM. for i, tok in deterministic_rows: guided[i].zero_() guided[i, tok] = 1.0 + modified_rows = set(send_batch_indices) + modified_rows.update(i for i, _tok in deterministic_rows) + if len(modified_rows) < len(req_list): + top_ks_cpu = top_ks.detach().cpu().tolist() + top_ps_cpu = top_ps.detach().cpu().tolist() + min_ps_cpu = min_ps.detach().cpu().tolist() + for i in range(len(req_list)): + if i in modified_rows: + continue + top_k_i = int(top_ks_cpu[i]) + top_p_i = float(top_ps_cpu[i]) + min_p_i = float(min_ps_cpu[i]) + if top_k_i == TOP_K_ALL and top_p_i >= 1.0 and min_p_i <= 0.0: + continue + guided[i].copy_(self._filter_probs(guided[i], top_k_i, top_p_i, min_p_i)) + # Assemble GPU candidate tensors for the fast guidance path. gpu_candidates = None - if _use_gpu_path and _gpu_vals_chunks: - gp = torch.cat(_gpu_vals_chunks, dim=0).float() # [B_send, K_max] - gi = torch.cat(_gpu_idx_chunks, dim=0) # [B_send, K_max] + if _use_gpu_path and _gpu_vals_send is not None and candidate_lens_send: + gp = _gpu_vals_send.float() # [B_send, K_max] + gi = _gpu_idx_send # [B_send, K_max] gm = gp > 0 # [B_send, K_max] bool gpu_candidates = (gp, gi, gm) @@ -1536,6 +1821,7 @@ def _build_pending( prefix_ids_send=prefix_ids_send, candidate_ids_send=candidate_ids_send, candidate_probs_send=candidate_probs_send, + candidate_lens_send=candidate_lens_send or None, gpu_candidates=gpu_candidates, ) @@ -1581,11 +1867,12 @@ def _apply_guidance_gpu(self, pending: "PendingLvmResult", gpu_embeddings) -> No modes_list: List[str] = [] for ridx in pending.send_batch_indices: req = pending.req_list[ridx] - scales_list.append(_extract_value_scale({"req": req})) - modes_list.append(_extract_value_mode({"req": req}, default="mul")) + mode, scale = _get_req_value_mode_and_scale(req) + scales_list.append(scale) + modes_list.append(mode) - # Use the most common mode; fall back to "mul" if mixed (rare). - mode = modes_list[0] if len(set(modes_list)) == 1 else "mul" + same_mode = len(set(modes_list)) == 1 + mode = modes_list[0] if same_mode else None scale_t = torch.tensor(scales_list, device=device, dtype=torch.float32) # [B] # -- Sigmoid of raw embeddings → values in [0, 1] --------------------------- @@ -1612,7 +1899,88 @@ def _apply_guidance_gpu(self, pending: "PendingLvmResult", gpu_embeddings) -> No max_v = values.masked_fill(~valid_mask, -1e9).max(dim=-1).values # [B] v_range = (max_v - min_v).clamp(min=1e-8) # [B] - if mode == "centered_exp": + if not same_mode: + final_probs = p_norm.clone() + for bi, mode_i in enumerate(modes_list): + valid_i = valid_mask[bi] + scale_i = scale_t[bi] + if mode_i == "centered_exp": + logits = values[bi] * scale_i + logits = logits.masked_fill(~valid_i, -1e9) + logits = logits - logits.max().view(()) + w = (p_norm[bi] * torch.exp(logits)).masked_fill(~valid_i, 0.0) + final_probs[bi] = w / w.sum().clamp(min=1e-20) + continue + + if mode_i == "value_bias": + logits = emb[bi] * scale_i + logits = logits.masked_fill(~valid_i, -1e9) + logits = logits - logits.max().view(()) + w = (p_norm[bi] * torch.exp(logits)).masked_fill(~valid_i, 0.0) + final_probs[bi] = w / w.sum().clamp(min=1e-20) + continue + + if mode_i not in ("exp", "linear", "length_mul", "mul"): + raise ValueError(f"[LVM GPU path] Unknown value_mode: {mode_i!r}") + + if mode_i == "mul" and float(scale_i.item()) <= 0.0: + continue + if mode_i == "exp": + cur_norm_i = ((cur_exp[bi] - min_v[bi]) / v_range[bi]).clamp( + 0.0, 1.0 + ) + target_i = min_v[bi] + ( + 1.0 - (1.0 - cur_norm_i) ** scale_i + ) * v_range[bi] + elif mode_i == "linear": + if bool((scale_i >= 1.0).item()): + target_i = cur_exp[bi] + (scale_i - 1.0) * ( + max_v[bi] - cur_exp[bi] + ) + else: + target_i = cur_exp[bi] - (1.0 - scale_i) * ( + cur_exp[bi] - min_v[bi] + ) + elif mode_i == "length_mul": + gamma = 0.997 + log_gamma = math.log(gamma) + mu_v = cur_exp[bi].clamp(min=1e-15, max=1.0 - 1e-15) + l_cur = torch.log1p(-mu_v) / log_gamma + target_i = 1.0 - torch.exp(scale_i * l_cur * log_gamma) + target_i = target_i.clamp(min_v[bi] + 1e-12, max_v[bi] - 1e-12) + else: # "mul" + target_i = cur_exp[bi] * scale_i + + eps_i = (v_range[bi] * 1e-6).clamp(min=1e-12) + target_i = torch.max( + torch.min(target_i, max_v[bi] - eps_i), min_v[bi] + eps_i + ) + if not bool( + ((v_range[bi] > 1e-8) & (torch.abs(target_i - cur_exp[bi]) > 1e-8)).item() + ): + continue + + log_p_i = torch.log(p_norm[bi].clamp(min=1e-20)) + log_p_i = log_p_i.masked_fill(~valid_i, -1e9) + dv_i = (values[bi] - min_v[bi]).masked_fill(~valid_i, 0.0) + lam_i = torch.zeros((), device=device, dtype=torch.float32) + for _ in range(20): + logits = log_p_i + dv_i * lam_i + m = logits.max() + w = torch.exp(logits - m).masked_fill(~valid_i, 0.0) + w = w / w.sum().clamp(min=1e-20) + mean = (w * values[bi]).sum() + mean2 = (w * values[bi] * values[bi]).sum() + var = (mean2 - mean * mean).clamp(min=0.0) + lam_i = (lam_i - (mean - target_i) / var.clamp(min=1e-16)).clamp( + -100.0, 100.0 + ) + + final_logits = (log_p_i + dv_i * lam_i).masked_fill(~valid_i, -1e9) + m = final_logits.max() + w = torch.exp(final_logits - m).masked_fill(~valid_i, 0.0) + final_probs[bi] = w / w.sum().clamp(min=1e-20) + elif mode == "centered_exp": # print("centered_exp") # p'(i) ∝ p(i) * exp(sigmoid(emb(i)) * scale); values = sigmoid(emb) ∈ [0, 1]. # Subtract per-row max before exp to prevent overflow (cancels in normalization). @@ -1765,13 +2133,41 @@ def apply( Returns the modified probs tensor, or None when no guidance is needed (caller should use the original probs). """ + timing_sample = {} if _lvm_timing_enabled() else None + t_apply = _lvm_timing_tic() if timing_sample is not None else 0.0 + + t = _lvm_timing_tic() if timing_sample is not None else 0.0 + req_list = reqs if isinstance(reqs, list) else list(reqs) + if not any(self._req_wants_value_guidance(req) for req in req_list): + _lvm_timing_toc(timing_sample, "precheck", t) + _lvm_timing_toc(timing_sample, "apply_total", t_apply) + _record_lvm_timing(timing_sample, skipped=True) + return None + _lvm_timing_toc(timing_sample, "precheck", t) + + t = _lvm_timing_tic() if timing_sample is not None else 0.0 inproc = self._get_inproc_provider() + _lvm_timing_toc(timing_sample, "get_inproc_provider", t) if inproc not in (None, False): # Free KV cache for requests that have finished or aborted - inproc.clean_stale_requests(set(r.rid for r in reqs)) - - pending = self._build_pending(probs, reqs, temperatures, top_ps, top_ks, min_ps) + t = _lvm_timing_tic() if timing_sample is not None else 0.0 + inproc.clean_stale_requests(set(r.rid for r in req_list)) + _lvm_timing_toc(timing_sample, "clean_stale", t) + + t = _lvm_timing_tic() if timing_sample is not None else 0.0 + pending = self._build_pending( + probs, + req_list, + temperatures, + top_ps, + top_ks, + min_ps, + enable_gpu_candidate_compact=inproc not in (None, False), + ) + _lvm_timing_toc(timing_sample, "build_pending", t) if pending is None: + _lvm_timing_toc(timing_sample, "apply_total", t_apply) + _record_lvm_timing(timing_sample, skipped=True) return None if pending.send_batch_indices: @@ -1781,21 +2177,61 @@ def apply( if inproc not in (None, False): try: rids_send = [req.rid for req in reqs_send] - inproc.tree_value_extend(rids_send, pending.prefix_ids_send, reqs_send) - gpu_emb = inproc.tree_value_launch_gpu( - rids_send, pending.candidate_ids_send, gpu_candidates=pending.gpu_candidates + t = _lvm_timing_tic() if timing_sample is not None else 0.0 + gpu_emb = inproc.tree_value_extend_and_launch_gpu( + rids_send, + pending.prefix_ids_send, + reqs_send, + pending.candidate_ids_send, + gpu_candidates=pending.gpu_candidates, + candidate_lens=pending.candidate_lens_send, ) + _lvm_timing_toc(timing_sample, "extend_launch_fused", t) + fused_path = gpu_emb is not None + if gpu_emb is None: + t = _lvm_timing_tic() if timing_sample is not None else 0.0 + inproc.tree_value_extend( + rids_send, pending.prefix_ids_send, reqs_send + ) + _lvm_timing_toc(timing_sample, "extend_prefix", t) + t = _lvm_timing_tic() if timing_sample is not None else 0.0 + gpu_emb = inproc.tree_value_launch_gpu( + rids_send, + pending.candidate_ids_send, + gpu_candidates=pending.gpu_candidates, + candidate_lens=pending.candidate_lens_send, + ) + _lvm_timing_toc(timing_sample, "launch_candidates", t) + t = _lvm_timing_tic() if timing_sample is not None else 0.0 gpu_embeddings = inproc.tree_value_collect_gpu(gpu_emb) + _lvm_timing_toc(timing_sample, "collect_gpu", t) + t = _lvm_timing_tic() if timing_sample is not None else 0.0 self._apply_guidance_gpu(pending, gpu_embeddings) + _lvm_timing_toc(timing_sample, "apply_guidance_gpu", t) + _lvm_timing_toc(timing_sample, "apply_total", t_apply) + _record_lvm_timing( + timing_sample, + pending=pending, + gpu_path=True, + fused_path=fused_path, + ) return pending.guided except Exception as exc: raise RuntimeError("LenVM GPU guidance path failed") from exc + t = _lvm_timing_tic() if timing_sample is not None else 0.0 lvm_values = self._post_tree_value( [req.rid for req in reqs_send], pending.prefix_ids_send, pending.candidate_ids_send, reqs_send ) + _lvm_timing_toc(timing_sample, "post_tree_value", t) if lvm_values is None: + _lvm_timing_toc(timing_sample, "apply_total", t_apply) + _record_lvm_timing(timing_sample, pending=pending) return None + t = _lvm_timing_tic() if timing_sample is not None else 0.0 self._apply_guidance(pending, lvm_values) + _lvm_timing_toc(timing_sample, "apply_guidance_cpu", t) + _lvm_timing_toc(timing_sample, "apply_total", t_apply) + _record_lvm_timing(timing_sample, pending=pending) return pending.guided diff --git a/sglang-LenVM/python/sglang/srt/lvm/lvm_inproc_runner.py b/sglang-LenVM/python/sglang/srt/lvm/lvm_inproc_runner.py index fe5f080..46d40e1 100644 --- a/sglang-LenVM/python/sglang/srt/lvm/lvm_inproc_runner.py +++ b/sglang-LenVM/python/sglang/srt/lvm/lvm_inproc_runner.py @@ -317,6 +317,7 @@ def eval_candidates_batch_gpu( rids: List[str], candidate_ids_per_req: List[List[int]], gpu_candidates: Optional[tuple] = None, + candidate_lens_per_req: Optional[List[int]] = None, mrope_deltas: Optional[dict] = None, ) -> List[torch.Tensor]: """Evaluate candidates with tree attention and return raw GPU embeddings. @@ -334,7 +335,17 @@ def eval_candidates_batch_gpu( device = self.device prefix_lens = [self.kv_mgr.kv_len(rid) for rid in rids] - cand_lens = [len(cands) for cands in candidate_ids_per_req] + if candidate_lens_per_req is None: + cand_lens = [len(cands) for cands in candidate_ids_per_req] + else: + cand_lens = [int(n) for n in candidate_lens_per_req] + if len(cand_lens) != len(rids): + raise RuntimeError( + "candidate_lens_per_req length must match rids length: " + f"{len(cand_lens)} vs {len(rids)}" + ) + if any(n <= 0 for n in cand_lens): + raise RuntimeError(f"candidate lengths must be positive, got {cand_lens!r}") seq_lens_list = [prefix_lens[i] + cand_lens[i] for i in range(len(rids))] total_cands = sum(cand_lens) pool_indices = [self.kv_mgr.pool_indices[rid] for rid in rids] @@ -375,6 +386,11 @@ def eval_candidates_batch_gpu( if gpu_candidates is not None: _, gi, gm = gpu_candidates input_ids_t = gi[gm].to(torch.int64) + if int(input_ids_t.numel()) != total_cands: + raise RuntimeError( + "GPU candidate mask does not match candidate lengths: " + f"mask_tokens={int(input_ids_t.numel())} total_cands={total_cands}" + ) else: input_ids_t = torch.tensor( [t for cands in candidate_ids_per_req for t in cands], @@ -456,6 +472,209 @@ def eval_candidates_batch_gpu( return logits_output.embeddings + def extend_and_eval_candidates_batch_gpu( + self, + rids: List[str], + new_tokens_per_req: List[List[int]], + candidate_ids_per_req: List[List[int]], + gpu_candidates: Optional[tuple] = None, + candidate_lens_per_req: Optional[List[int]] = None, + ) -> List[torch.Tensor]: + """Extend a tiny prefix delta and evaluate candidates in one LVM forward. + + This is intended for the steady-state decode path where each request adds + at most one accepted token per step. Large initial prompt fills should use + extend_prefix_batch() first; fusing them would require a dense L^2 custom + mask and is slower. + """ + if not rids: + return [] + if len(new_tokens_per_req) != len(rids): + raise RuntimeError( + "new_tokens_per_req length must match rids length: " + f"{len(new_tokens_per_req)} vs {len(rids)}" + ) + + runner = self.runner + device = self.device + + cached_prefix_lens = [self.kv_mgr.kv_len(rid) for rid in rids] + extend_lens = [len(toks) for toks in new_tokens_per_req] + if any(n > 1 for n in extend_lens): + raise RuntimeError( + "extend_and_eval_candidates_batch_gpu only supports tiny prefix " + f"deltas, got {extend_lens!r}" + ) + + if candidate_lens_per_req is None: + cand_lens = [len(cands) for cands in candidate_ids_per_req] + else: + cand_lens = [int(n) for n in candidate_lens_per_req] + if len(cand_lens) != len(rids): + raise RuntimeError( + "candidate_lens_per_req length must match rids length: " + f"{len(cand_lens)} vs {len(rids)}" + ) + if any(n <= 0 for n in cand_lens): + raise RuntimeError(f"candidate lengths must be positive, got {cand_lens!r}") + + prefix_lens = [ + cached_prefix_lens[i] + extend_lens[i] for i in range(len(rids)) + ] + extend_seq_lens = [ + extend_lens[i] + cand_lens[i] for i in range(len(rids)) + ] + seq_lens_list = [ + prefix_lens[i] + cand_lens[i] for i in range(len(rids)) + ] + total_extend = sum(extend_seq_lens) + total_cands = sum(cand_lens) + pool_indices = [self.kv_mgr.get_or_alloc(rid) for rid in rids] + + allocator = runner.token_to_kv_pool_allocator + if getattr(allocator, "page_size", 1) != 1: + raise RuntimeError( + "fused LenVM extend+candidate path requires token-granular KV " + "allocation (page_size == 1)" + ) + out_cache_loc = allocator.alloc(total_extend) + if out_cache_loc is None: + raise RuntimeError( + f"LVM KV pool OOM: cannot allocate {total_extend} fused slots. " + "Consider --lvm-guided-inproc-mem-fraction-static." + ) + out_cache_loc = out_cache_loc.to(torch.int64) + + candidate_cache_chunks = [] + pt = 0 + for pool_idx, cached_len, e_len, n in zip( + pool_indices, cached_prefix_lens, extend_lens, cand_lens + ): + seq_tokens = e_len + n + runner.req_to_token_pool.write( + (pool_idx, slice(cached_len, cached_len + seq_tokens)), + out_cache_loc[pt : pt + seq_tokens], + ) + if n > 0: + candidate_cache_chunks.append( + out_cache_loc[pt + e_len : pt + seq_tokens] + ) + pt += seq_tokens + + if gpu_candidates is not None: + _, gi, gm = gpu_candidates + has_new_list = [len(toks) > 0 for toks in new_tokens_per_req] + if any(has_new_list): + new_ids_t = torch.tensor( + [int(toks[0]) if toks else 0 for toks in new_tokens_per_req], + dtype=torch.int64, + device=device, + ) + has_new_t = torch.tensor( + has_new_list, dtype=torch.bool, device=device + ) + B, K = gi.shape + padded_input_ids = torch.empty( + (B, K + 1), dtype=torch.int64, device=device + ) + padded_input_ids[:, 0] = new_ids_t + padded_input_ids[:, 1:] = gi + padded_mask = torch.empty((B, K + 1), dtype=torch.bool, device=device) + padded_mask[:, 0] = has_new_t + padded_mask[:, 1:] = gm + input_ids_t = padded_input_ids[padded_mask] + else: + input_ids_t = gi[gm].to(torch.int64) + else: + input_ids_list: List[int] = [] + for toks, cands in zip(new_tokens_per_req, candidate_ids_per_req): + input_ids_list.extend(int(t) for t in toks) + input_ids_list.extend(int(t) for t in cands) + input_ids_t = torch.tensor(input_ids_list, dtype=torch.int64, device=device) + if int(input_ids_t.numel()) != total_extend: + raise RuntimeError( + "fused LenVM input length does not match metadata: " + f"input_tokens={int(input_ids_t.numel())} total_extend={total_extend}" + ) + + req_pool_indices_t = torch.tensor(pool_indices, dtype=torch.int64, device=device) + seq_lens_t = torch.tensor(seq_lens_list, dtype=torch.int32, device=device) + seq_lens_cpu_t = torch.tensor(seq_lens_list, dtype=torch.int32) + extend_prefix_lens_t = torch.tensor( + cached_prefix_lens, dtype=torch.int32, device=device + ) + extend_seq_lens_t = torch.tensor( + extend_seq_lens, dtype=torch.int32, device=device + ) + + custom_mask, positions = build_tree_value_custom_mask_and_positions( + prefix_lens=prefix_lens, + candidate_lens=cand_lens, + cached_prefix_lens=cached_prefix_lens, + device=device, + ) + spec_info = TreeValueSpecInput( + custom_mask=custom_mask, + positions=positions, + tree_value_prefix_lens=list(prefix_lens), + tree_value_candidate_lens=list(cand_lens), + tree_value_cached_prefix_lens=list(cached_prefix_lens), + ) + + extend_start_loc = torch.zeros(len(rids), dtype=torch.int32, device=device) + if len(rids) > 1: + extend_start_loc[1:] = torch.cumsum(extend_seq_lens_t[:-1], dim=0) + + forward_batch = ForwardBatch( + forward_mode=ForwardMode.EXTEND, + batch_size=len(rids), + input_ids=input_ids_t, + req_pool_indices=req_pool_indices_t, + seq_lens=seq_lens_t, + seq_lens_cpu=seq_lens_cpu_t, + out_cache_loc=out_cache_loc, + seq_lens_sum=sum(seq_lens_list), + positions=positions, + extend_num_tokens=total_extend, + extend_seq_lens=extend_seq_lens_t, + extend_prefix_lens=extend_prefix_lens_t, + extend_start_loc=extend_start_loc, + extend_prefix_lens_cpu=list(cached_prefix_lens), + extend_seq_lens_cpu=list(extend_seq_lens), + req_to_token_pool=runner.req_to_token_pool, + token_to_kv_pool=runner.token_to_kv_pool, + attn_backend=runner.attn_backend, + return_logprob=False, + is_extend_in_batch=True, + is_prefill_only=True, + spec_algorithm=SpeculativeAlgorithm.NONE, + spec_info=spec_info, + global_forward_mode=ForwardMode.EXTEND, + ) + forward_batch.num_token_non_padded_cpu = total_extend + + candidate_cache_loc = ( + torch.cat(candidate_cache_chunks, dim=0) + if candidate_cache_chunks + else out_cache_loc[:0] + ) + + success = False + try: + with self._lvm_embedding_cache_ctx(): + logits_output = runner.forward_extend(forward_batch) + success = True + finally: + if success: + runner.token_to_kv_pool_allocator.free(candidate_cache_loc) + else: + runner.token_to_kv_pool_allocator.free(out_cache_loc) + + for rid, e_len in zip(rids, extend_lens): + self.kv_mgr.kv_lens[rid] += e_len + + return logits_output.embeddings + def eval_candidates_batch( self, rids: List[str], diff --git a/sglang-LenVM/python/sglang/srt/lvm/lvm_value_utils.py b/sglang-LenVM/python/sglang/srt/lvm/lvm_value_utils.py index 0a5992f..87ddc9b 100644 --- a/sglang-LenVM/python/sglang/srt/lvm/lvm_value_utils.py +++ b/sglang-LenVM/python/sglang/srt/lvm/lvm_value_utils.py @@ -10,6 +10,10 @@ def get_eos_token_ids(req: Any) -> Set[int]: - `req.eos_token_ids`: a Set[int] from `ModelConfig.hf_eos_token_id` (preferred) - `req.tokenizer.eos_token_id`: tokenizer-defined EOS id """ + cached = getattr(req, "_lvm_eos_token_ids", None) + if cached is not None: + return set(cached) + eos: Set[int] = set() eos_ids = getattr(req, "eos_token_ids", None) @@ -32,6 +36,10 @@ def get_eos_token_ids(req: Any) -> Set[int]: except Exception as exc: raise ValueError(f"Invalid LenVM tokenizer eos_token_id: {tok_eos!r}") from exc + try: + setattr(req, "_lvm_eos_token_ids", tuple(sorted(eos))) + except Exception: + pass return eos @@ -52,4 +60,3 @@ def force_eos_value_zero(token_ids: List[int], token_values: List[float], req: A token_values[j] = 0.0 except Exception as exc: raise ValueError(f"Invalid LenVM token id while forcing EOS value to zero: {tid!r}") from exc - diff --git a/sglang-LenVM/python/sglang/srt/lvm/tree_value_spec.py b/sglang-LenVM/python/sglang/srt/lvm/tree_value_spec.py index c84101a..91dc212 100644 --- a/sglang-LenVM/python/sglang/srt/lvm/tree_value_spec.py +++ b/sglang-LenVM/python/sglang/srt/lvm/tree_value_spec.py @@ -41,6 +41,45 @@ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: # No token multiplier for DP buffers. return 1, 1 + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + paged_kernel_lens_sum: int, + req_to_token: torch.Tensor, + ): + """Build attention arguments for flashinfer-style prefill backends.""" + from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton + + device = req_pool_indices.device + q_lens, k_lens, _mask_offsets, _pos_offsets = self._per_req_qk_and_offsets() + bs = len(q_lens) + + q_lens_t = torch.tensor(q_lens, dtype=torch.int32, device=device) + k_lens_t = torch.tensor(k_lens, dtype=torch.int32, device=device) + + qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum(q_lens_t, dim=0) + kv_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=device) + kv_indptr[1:] = torch.cumsum(k_lens_t, dim=0) + + if paged_kernel_lens_sum is None: + paged_kernel_lens_sum = sum(k_lens) + + kv_indices = torch.empty( + paged_kernel_lens_sum, dtype=torch.int32, device=device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + k_lens_t, + kv_indptr, + None, + kv_indices, + req_to_token.size(1), + ) + return kv_indices, kv_indptr, qo_indptr, self.custom_mask + def _per_req_qk_and_offsets(self) -> Tuple[List[int], List[int], List[int], List[int]]: """ Compute per-request: @@ -250,8 +289,8 @@ def build_tree_value_custom_mask_and_positions( # Candidate rows: attend all prefix tokens (0..L-1) and itself. cand_row_start = L - P m[cand_row_start : cand_row_start + N, :L] = True - for j in range(N): - m[cand_row_start + j, L + j] = True + cand_arange = np.arange(N) + m[cand_row_start + cand_arange, L + cand_arange] = True mask_off += q_len * k_len @@ -265,4 +304,3 @@ def build_tree_value_custom_mask_and_positions( custom_mask = torch.from_numpy(mask_buf).to(device=device, non_blocking=True) positions = torch.from_numpy(pos_buf).to(device=device, non_blocking=True) return custom_mask, positions - diff --git a/sglang-LenVM/python/sglang/srt/models/qwen2_5_vl_lvm.py b/sglang-LenVM/python/sglang/srt/models/qwen2_5_vl_lvm.py index 536042e..fe8d866 100644 --- a/sglang-LenVM/python/sglang/srt/models/qwen2_5_vl_lvm.py +++ b/sglang-LenVM/python/sglang/srt/models/qwen2_5_vl_lvm.py @@ -77,22 +77,23 @@ def forward( ) if prefix_lens is not None and cand_lens is not None: - extend_lens = forward_batch.extend_seq_lens.tolist() cached_prefix_lens = ( - forward_batch.extend_prefix_lens.tolist() - if forward_batch.extend_prefix_lens is not None - else [0] * len(extend_lens) + getattr(spec, "tree_value_cached_prefix_lens", None) + or [0] * len(prefix_lens) ) out: list[torch.Tensor] = [] offset = 0 - for i, ext_len in enumerate(extend_lens): + for i, (prefix_len, cand_len, cached_prefix_len) in enumerate( + zip(prefix_lens, cand_lens, cached_prefix_lens) + ): + prefix_len = int(prefix_len) + cand_len = int(cand_len) + cached_prefix_len = int(cached_prefix_len) + ext_len = max(prefix_len - cached_prefix_len, 0) + cand_len vals_i = token_values[offset : offset + ext_len] offset += ext_len - prefix_len = int(prefix_lens[i]) - cand_len = int(cand_lens[i]) - cached_prefix_len = int(cached_prefix_lens[i]) cand_offset = max(prefix_len - cached_prefix_len, 0) out.append(vals_i[cand_offset : cand_offset + cand_len]) return EmbeddingPoolerOutput(embeddings=out) diff --git a/sglang-LenVM/python/sglang/srt/models/qwen2_lvm.py b/sglang-LenVM/python/sglang/srt/models/qwen2_lvm.py index 482f0a6..527011c 100644 --- a/sglang-LenVM/python/sglang/srt/models/qwen2_lvm.py +++ b/sglang-LenVM/python/sglang/srt/models/qwen2_lvm.py @@ -86,22 +86,23 @@ def forward( prefix_lens = getattr(spec, "tree_value_prefix_lens", None) if spec is not None else None cand_lens = getattr(spec, "tree_value_candidate_lens", None) if spec is not None else None if prefix_lens is not None and cand_lens is not None: - extend_lens = forward_batch.extend_seq_lens.tolist() cached_prefix_lens = ( - forward_batch.extend_prefix_lens.tolist() - if forward_batch.extend_prefix_lens is not None - else [0] * len(extend_lens) + getattr(spec, "tree_value_cached_prefix_lens", None) + or [0] * len(prefix_lens) ) out: list[torch.Tensor] = [] offset = 0 - for i, ext_len in enumerate(extend_lens): + for i, (prefix_len, cand_len, cached_prefix_len) in enumerate( + zip(prefix_lens, cand_lens, cached_prefix_lens) + ): + L = int(prefix_len) + N = int(cand_len) + P = int(cached_prefix_len) + ext_len = max(L - P, 0) + N vals_i = token_values[offset : offset + ext_len] offset += ext_len - L = int(prefix_lens[i]) - N = int(cand_lens[i]) - P = int(cached_prefix_lens[i]) cand_offset = max(L - P, 0) out.append(vals_i[cand_offset : cand_offset + N]) return EmbeddingPoolerOutput(embeddings=out) @@ -200,4 +201,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): EntryClass = [ Qwen2ForLengthValueModel, ] - diff --git a/sglang-LenVM/python/sglang/srt/models/qwen3_lvm.py b/sglang-LenVM/python/sglang/srt/models/qwen3_lvm.py index f9b41ba..7d8e27c 100644 --- a/sglang-LenVM/python/sglang/srt/models/qwen3_lvm.py +++ b/sglang-LenVM/python/sglang/srt/models/qwen3_lvm.py @@ -83,22 +83,23 @@ def forward( prefix_lens = getattr(spec, "tree_value_prefix_lens", None) if spec is not None else None cand_lens = getattr(spec, "tree_value_candidate_lens", None) if spec is not None else None if prefix_lens is not None and cand_lens is not None: - extend_lens = forward_batch.extend_seq_lens.tolist() cached_prefix_lens = ( - forward_batch.extend_prefix_lens.tolist() - if forward_batch.extend_prefix_lens is not None - else [0] * len(extend_lens) + getattr(spec, "tree_value_cached_prefix_lens", None) + or [0] * len(prefix_lens) ) out: list[torch.Tensor] = [] offset = 0 - for i, ext_len in enumerate(extend_lens): + for i, (prefix_len, cand_len, cached_prefix_len) in enumerate( + zip(prefix_lens, cand_lens, cached_prefix_lens) + ): + L = int(prefix_len) + N = int(cand_len) + P = int(cached_prefix_len) + ext_len = max(L - P, 0) + N vals_i = token_values[offset : offset + ext_len] offset += ext_len - L = int(prefix_lens[i]) - N = int(cand_lens[i]) - P = int(cached_prefix_lens[i]) cand_offset = max(L - P, 0) out.append(vals_i[cand_offset : cand_offset + N]) return EmbeddingPoolerOutput(embeddings=out) diff --git a/sglang-LenVM/test/srt/test_lvm_guided_sampling_fast_path.py b/sglang-LenVM/test/srt/test_lvm_guided_sampling_fast_path.py new file mode 100644 index 0000000..1f51e68 --- /dev/null +++ b/sglang-LenVM/test/srt/test_lvm_guided_sampling_fast_path.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import importlib.util +import sys +import types +from pathlib import Path +from types import SimpleNamespace + +import torch + + +def _module(name: str, **attrs): + mod = types.ModuleType(name) + for key, value in attrs.items(): + setattr(mod, key, value) + return mod + + +def _load_lvm_guided_sampling(monkeypatch): + """Load lvm_guided_sampling.py with light stubs for heavyweight SGLang deps.""" + + stubs = { + "sglang": _module("sglang"), + "sglang.srt": _module("sglang.srt"), + "sglang.srt.sampling": _module("sglang.srt.sampling"), + "sglang.srt.sampling.sampling_params": _module( + "sglang.srt.sampling.sampling_params", TOP_K_ALL=-1 + ), + "sglang.srt.utils": _module("sglang.srt.utils"), + "sglang.srt.utils.common": _module( + "sglang.srt.utils.common", dynamic_import=lambda spec: None + ), + "sglang.srt.server_args": _module( + "sglang.srt.server_args", get_global_server_args=lambda: SimpleNamespace() + ), + "sglang.srt.lvm": _module("sglang.srt.lvm"), + "sglang.srt.lvm.lvm_value_utils": _module( + "sglang.srt.lvm.lvm_value_utils", + force_eos_value_zero=lambda token_ids, token_values, req: None, + get_eos_token_ids=lambda req: set(), + ), + "sglang.srt.configs": _module("sglang.srt.configs"), + "sglang.srt.configs.model_config": _module( + "sglang.srt.configs.model_config", ModelConfig=object + ), + "sglang.srt.managers": _module("sglang.srt.managers"), + "sglang.srt.managers.schedule_batch": _module( + "sglang.srt.managers.schedule_batch", Req=object + ), + } + for name, mod in stubs.items(): + monkeypatch.setitem(sys.modules, name, mod) + + source = ( + Path(__file__).resolve().parents[2] + / "python" + / "sglang" + / "srt" + / "lvm" + / "lvm_guided_sampling.py" + ) + spec = importlib.util.spec_from_file_location( + "_lvm_guided_sampling_under_test", source + ) + assert spec is not None and spec.loader is not None + mod = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, spec.name, mod) + spec.loader.exec_module(mod) + return mod + + +def _req(rid: str, custom_params: dict | None = None): + if custom_params is None: + custom_params = {"value_scale": -100.0, "value_mode": "centered_exp"} + return SimpleNamespace( + rid=rid, + origin_input_ids=[11, 12], + output_ids=[21], + sampling_params=SimpleNamespace(custom_params=custom_params), + ) + + +def test_gpu_fast_path_keeps_candidate_ids_on_device_with_top_p(monkeypatch): + mod = _load_lvm_guided_sampling(monkeypatch) + sampler = mod.LvmGuidedSampler( + mod.LvmGuidedConfig(url=None, timeout=1.0, bypass_cache=False, fn_spec=None) + ) + sampler._fn = mod.lvm_combined_guidance + + probs = torch.tensor( + [ + [0.40, 0.25, 0.20, 0.10, 0.05, 0.0, 0.0, 0.0], + [0.35, 0.25, 0.20, 0.11, 0.09, 0.0, 0.0, 0.0], + ], + dtype=torch.float32, + ) + pending = sampler._build_pending( + probs=probs, + reqs=[_req("r0"), _req("r1")], + temperatures=torch.ones(2), + top_ps=torch.tensor([0.75, 0.80], dtype=torch.float32), + top_ks=torch.tensor([5, 5], dtype=torch.int64), + min_ps=torch.zeros(2), + enable_gpu_candidate_compact=True, + ) + + assert pending is not None + assert pending.gpu_candidates is not None + assert pending.candidate_ids_send == [] + assert pending.candidate_probs_send == [] + assert pending.candidate_lens_send == [3, 4] + assert pending.prefix_ids_send == [[11, 12, 21], [11, 12, 21]] + + padded_probs, padded_ids, valid_mask = pending.gpu_candidates + assert valid_mask.sum(dim=1).tolist() == pending.candidate_lens_send + assert padded_ids[valid_mask].tolist() == [0, 1, 2, 0, 1, 2, 3] + assert torch.all(padded_probs[valid_mask] > 0) + + +def test_default_path_keeps_cpu_candidate_lists_for_fallback(monkeypatch): + mod = _load_lvm_guided_sampling(monkeypatch) + sampler = mod.LvmGuidedSampler( + mod.LvmGuidedConfig(url=None, timeout=1.0, bypass_cache=False, fn_spec=None) + ) + sampler._fn = mod.lvm_combined_guidance + + probs = torch.tensor( + [ + [0.40, 0.25, 0.20, 0.10, 0.05, 0.0, 0.0, 0.0], + [0.35, 0.25, 0.20, 0.11, 0.09, 0.0, 0.0, 0.0], + ], + dtype=torch.float32, + ) + pending = sampler._build_pending( + probs=probs, + reqs=[_req("r0"), _req("r1")], + temperatures=torch.ones(2), + top_ps=torch.tensor([0.75, 0.80], dtype=torch.float32), + top_ks=torch.tensor([5, 5], dtype=torch.int64), + min_ps=torch.zeros(2), + ) + + assert pending is not None + assert pending.gpu_candidates is None + assert pending.candidate_lens_send is None + assert pending.candidate_ids_send == [[0, 1, 2], [0, 1, 2, 3]] + assert [len(p) for p in pending.candidate_probs_send] == [3, 4] + + +def test_unmodified_rows_keep_sampling_filters_when_guidance_applies(monkeypatch): + mod = _load_lvm_guided_sampling(monkeypatch) + sampler = mod.LvmGuidedSampler( + mod.LvmGuidedConfig(url=None, timeout=1.0, bypass_cache=False, fn_spec=None) + ) + sampler._fn = mod.lvm_combined_guidance + + probs = torch.tensor( + [ + [0.40, 0.30, 0.20, 0.10], + [0.10, 0.40, 0.30, 0.20], + ], + dtype=torch.float32, + ) + pending = sampler._build_pending( + probs=probs, + reqs=[_req("r0"), _req("r1", {})], + temperatures=torch.ones(2), + top_ps=torch.ones(2), + top_ks=torch.tensor([2, 2], dtype=torch.int64), + min_ps=torch.zeros(2), + ) + + assert pending is not None + assert pending.candidate_ids_send == [[0, 1]] + assert torch.allclose( + pending.guided[1], + torch.tensor([0.0, 0.40, 0.30, 0.0], dtype=torch.float32), + ) + + +def test_neutral_centered_exp_scale_skips_lvm(monkeypatch): + mod = _load_lvm_guided_sampling(monkeypatch) + sampler = mod.LvmGuidedSampler( + mod.LvmGuidedConfig(url=None, timeout=1.0, bypass_cache=False, fn_spec=None) + ) + sampler._fn = mod.lvm_combined_guidance + + probs = torch.tensor([[0.40, 0.25, 0.20, 0.10, 0.05]], dtype=torch.float32) + pending = sampler._build_pending( + probs=probs, + reqs=[_req("r0", {"value_scale": 0.0, "value_mode": "centered_exp"})], + temperatures=torch.ones(1), + top_ps=torch.tensor([1.0], dtype=torch.float32), + top_ks=torch.tensor([5], dtype=torch.int64), + min_ps=torch.zeros(1), + enable_gpu_candidate_compact=True, + ) + + assert pending is None + + +def test_neutral_mul_scale_zero_skips_lvm(monkeypatch): + mod = _load_lvm_guided_sampling(monkeypatch) + sampler = mod.LvmGuidedSampler( + mod.LvmGuidedConfig(url=None, timeout=1.0, bypass_cache=False, fn_spec=None) + ) + sampler._fn = mod.lvm_combined_guidance + + probs = torch.tensor([[0.40, 0.25, 0.20, 0.10, 0.05]], dtype=torch.float32) + pending = sampler._build_pending( + probs=probs, + reqs=[_req("r0", {"value_scale": 0.0, "value_mode": "mul"})], + temperatures=torch.ones(1), + top_ps=torch.tensor([1.0], dtype=torch.float32), + top_ks=torch.tensor([5], dtype=torch.int64), + min_ps=torch.zeros(1), + enable_gpu_candidate_compact=True, + ) + + assert pending is None + + +def test_apply_without_guidance_does_not_initialize_inproc(monkeypatch): + mod = _load_lvm_guided_sampling(monkeypatch) + sampler = mod.LvmGuidedSampler( + mod.LvmGuidedConfig(url=None, timeout=1.0, bypass_cache=False, fn_spec=None) + ) + + def fail_if_called(): + raise AssertionError("in-proc provider should not initialize for baseline rows") + + sampler._get_inproc_provider = fail_if_called + + probs = torch.tensor([[0.40, 0.25, 0.20, 0.10, 0.05]], dtype=torch.float32) + out = sampler.apply( + probs=probs, + reqs=[_req("r0", {})], + temperatures=torch.ones(1), + top_ps=torch.tensor([1.0], dtype=torch.float32), + top_ks=torch.tensor([5], dtype=torch.int64), + min_ps=torch.zeros(1), + ) + + assert out is None + + +def test_gpu_guidance_handles_mixed_modes(monkeypatch): + mod = _load_lvm_guided_sampling(monkeypatch) + sampler = mod.LvmGuidedSampler( + mod.LvmGuidedConfig(url=None, timeout=1.0, bypass_cache=False, fn_spec=None) + ) + + padded_probs = torch.tensor( + [[0.60, 0.40], [0.55, 0.45]], dtype=torch.float32 + ) + padded_ids = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64) + valid_mask = torch.ones_like(padded_probs, dtype=torch.bool) + values = torch.tensor([[0.20, 0.80], [0.10, 0.90]], dtype=torch.float32) + gpu_embeddings = torch.logit(values) + + pending = mod.PendingLvmResult( + req_list=[ + _req("r0", {"value_scale": 1.0, "value_mode": "centered_exp"}), + _req("r1", {"value_scale": 0.0, "value_mode": "mul"}), + ], + device=torch.device("cpu"), + guided=torch.zeros((2, 4), dtype=torch.float32), + send_batch_indices=[0, 1], + prefix_ids_send=[], + candidate_ids_send=[], + candidate_probs_send=[], + gpu_candidates=(padded_probs, padded_ids, valid_mask), + ) + + sampler._apply_guidance_gpu(pending, gpu_embeddings) + + expected_row0 = padded_probs[0] * torch.exp(values[0]) + expected_row0 = expected_row0 / expected_row0.sum() + assert torch.allclose(pending.guided[0, :2], expected_row0, atol=1e-6) + assert torch.allclose(pending.guided[1, :2], padded_probs[1], atol=1e-6) diff --git a/sglang-LenVM/test/srt/test_tree_value_spec.py b/sglang-LenVM/test/srt/test_tree_value_spec.py new file mode 100644 index 0000000..e1ccd07 --- /dev/null +++ b/sglang-LenVM/test/srt/test_tree_value_spec.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import importlib.util +import sys +import types +from pathlib import Path + +import torch + + +def _load_tree_value_spec(monkeypatch): + spec_info = types.ModuleType("sglang.srt.speculative.spec_info") + + class SpecInput: + def __init__(self, spec_input_type): + self.spec_input_type = spec_input_type + + class SpecInputType: + EAGLE_VERIFY = 1 + + spec_info.SpecInput = SpecInput + spec_info.SpecInputType = SpecInputType + + attention_utils = types.ModuleType("sglang.srt.layers.attention.utils") + + class FakeCreateKvIndicesKernel: + def __getitem__(self, _grid): + def launch( + req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + _kv_start_idx, + kv_indices, + _stride, + ): + for i in range(req_pool_indices.numel()): + start = int(kv_indptr[i].item()) + end = int(kv_indptr[i + 1].item()) + pool_idx = int(req_pool_indices[i].item()) + kv_indices[start:end] = req_to_token[ + pool_idx, : int(paged_kernel_lens[i].item()) + ].to(kv_indices.dtype) + + return launch + + attention_utils.create_flashinfer_kv_indices_triton = FakeCreateKvIndicesKernel() + + for name in ( + "sglang", + "sglang.srt", + "sglang.srt.layers", + "sglang.srt.layers.attention", + "sglang.srt.speculative", + ): + monkeypatch.setitem(sys.modules, name, types.ModuleType(name)) + monkeypatch.setitem(sys.modules, "sglang.srt.speculative.spec_info", spec_info) + monkeypatch.setitem( + sys.modules, "sglang.srt.layers.attention.utils", attention_utils + ) + + source = ( + Path(__file__).resolve().parents[2] + / "python" + / "sglang" + / "srt" + / "lvm" + / "tree_value_spec.py" + ) + spec = importlib.util.spec_from_file_location("_tree_value_spec_under_test", source) + assert spec is not None and spec.loader is not None + mod = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, spec.name, mod) + spec.loader.exec_module(mod) + return mod + + +def test_fused_prefix_and_candidate_mask(monkeypatch): + mod = _load_tree_value_spec(monkeypatch) + + custom_mask, positions = mod.build_tree_value_custom_mask_and_positions( + prefix_lens=[4], + candidate_lens=[3], + cached_prefix_lens=[3], + device=torch.device("cpu"), + ) + + assert positions.tolist() == [3, 4, 4, 4] + assert custom_mask.view(4, 7).int().tolist() == [ + [1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 0, 1, 0], + [1, 1, 1, 1, 0, 0, 1], + ] + + +def test_tree_value_spec_generates_flashinfer_args(monkeypatch): + mod = _load_tree_value_spec(monkeypatch) + custom_mask, positions = mod.build_tree_value_custom_mask_and_positions( + prefix_lens=[4], + candidate_lens=[3], + cached_prefix_lens=[3], + device=torch.device("cpu"), + ) + spec_info = mod.TreeValueSpecInput( + custom_mask=custom_mask, + positions=positions, + tree_value_prefix_lens=[4], + tree_value_candidate_lens=[3], + tree_value_cached_prefix_lens=[3], + ) + + req_to_token = torch.arange(20, dtype=torch.int64).view(2, 10) + kv_indices, kv_indptr, qo_indptr, returned_mask = ( + spec_info.generate_attn_arg_prefill( + req_pool_indices=torch.tensor([1], dtype=torch.int64), + paged_kernel_lens=torch.tensor([7], dtype=torch.int32), + paged_kernel_lens_sum=7, + req_to_token=req_to_token, + ) + ) + + assert qo_indptr.tolist() == [0, 4] + assert kv_indptr.tolist() == [0, 7] + assert kv_indices.tolist() == [10, 11, 12, 13, 14, 15, 16] + assert returned_mask is custom_mask