test: add a diagnostic script for prefix caching naning#1987
test: add a diagnostic script for prefix caching naning#1987
Conversation
Signed-off-by: Terry Kong <terryk@nvidia.com>
📝 WalkthroughWalkthroughDocumentation updated to include a new diagnostic section for prefix caching NaN logprobs validation. A new Python script added to tools/model_diagnostics/ that reproduces and validates prefix caching behavior in vLLM, including multi-iteration generation with prefix cache reuse and NaN logprob detection. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
tools/model_diagnostics/5.prefix_caching_nan.py (2)
75-84: Unconditionalbreaksilently under-counts NaNs iflogprobs > 1.The
breakat line 84 exits after inspecting only the first token-id entry per step, regardless of whether a NaN was found. Withlogprobs=1andtemperature=0.0(greedy), each step's dict has exactly one entry so this is functionally correct today. However, vLLM returns up tologprobs+1elements per step, meaning iflogprobsis ever bumped above 1, the counter would silently undercount NaN occurrences (at most 1 per step). Removing thebreakmakes the intent clear and future-proof:♻️ Proposed fix
for _tid, lp_obj in step.items(): lp = lp_obj.logprob if hasattr(lp_obj, "logprob") else lp_obj if isinstance(lp, float) and math.isnan(lp): nan_count += 1 - break🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/model_diagnostics/5.prefix_caching_nan.py` around lines 75 - 84, The loop over out2.logprobs currently contains an unconditional break after inspecting the first token-id entry, which causes under-counting NaNs when a step contains multiple entries; remove the break so the inner loop over step.items() examines every lp_obj (keep existing hasattr(lp_obj, "logprob") check and NaN detection for lp) so nan_count increments for every NaN in all token-id entries rather than just the first one per step.
35-41: Consider consolidating underif __name__ == "__main__":and moving imports to the top.All module-level logic (argparse, LLM instantiation, generation) runs unconditionally on import. A
__main__guard is the standard protection against accidental execution when scripts are discovered by tooling. Additionally, thevllmimport (lines 40–41) appears mid-file afterparse_args()— while this speeds up--help, it diverges from PEP 8 and can surprise readers.♻️ Suggested structure
import argparse import math +from vllm import LLM, SamplingParams +import vllm MODEL = "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16" ... -parser = argparse.ArgumentParser() -parser.add_argument("--model", type=str, default=MODEL) -parser.add_argument("--tp", type=int, default=TP) -args = parser.parse_args() - -import vllm -from vllm import LLM, SamplingParams - -print(f"vLLM version: {vllm.__version__}") -... +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default=MODEL) + parser.add_argument("--tp", type=int, default=TP) + args = parser.parse_args() + + print(f"vLLM version: {vllm.__version__}") + ...🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/model_diagnostics/5.prefix_caching_nan.py` around lines 35 - 41, Move all runtime logic (argparse setup and calls using parser/args, LLM instantiation, and generation) under a guarded block: wrap code that calls parser.parse_args(), creates the vllm LLM and SamplingParams, and runs generation inside if __name__ == "__main__":. Also relocate imports (import vllm and from vllm import LLM, SamplingParams) to the top of the file with other imports to follow PEP8; keep only lightweight module-level constants like MODEL and TP outside the guard and ensure no heavy side-effect code runs at import time.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/adding-new-models.md`:
- Around line 330-344: The two fenced code blocks under "Expected pass output
(vLLM 0.13.0)" and "Expected fail output (vLLM 0.14.0)" are missing language
specifiers causing markdownlint MD040; update both fences to include a language
(e.g., use ```text) so the pass-output and fail-output blocks explicitly start
with ```text and end with ``` to satisfy MD040 and preserve formatting.
In `@tools/model_diagnostics/5.prefix_caching_nan.py`:
- Line 87: The print statement currently uses an unnecessary f-string: locate
the call print(f"\n Sample logprobs from iteration 2:") and remove the leading
"f" so it becomes a plain string literal; this eliminates the Ruff F541 spurious
f-string warning without changing behavior.
- Around line 29-75: The module defines several module-level mutable bindings
(parser, args, numbers, prompt, llm, sampling_params, out1, out2, nan_count)
which violate the global naming guideline; wrap all runtime code that creates or
mutates these symbols inside an if __name__ == "__main__": block so they become
local to main (keep MODEL, TP, MAX_TOKENS, MAX_MODEL_LEN, COUNT_UP_TO and
imports at module scope), e.g., move creation of
argparse.ArgumentParser()/parser, args = parser.parse_args(), numbers, prompt
construction, LLM() instantiation, SamplingParams(), the two generate calls that
produce out1/out2, and nan_count into that guard; alternatively if you must keep
any of them global, rename using upper snake_case with the G_ prefix (e.g.,
G_PARSER, G_LLM) to satisfy the guideline.
---
Nitpick comments:
In `@tools/model_diagnostics/5.prefix_caching_nan.py`:
- Around line 75-84: The loop over out2.logprobs currently contains an
unconditional break after inspecting the first token-id entry, which causes
under-counting NaNs when a step contains multiple entries; remove the break so
the inner loop over step.items() examines every lp_obj (keep existing
hasattr(lp_obj, "logprob") check and NaN detection for lp) so nan_count
increments for every NaN in all token-id entries rather than just the first one
per step.
- Around line 35-41: Move all runtime logic (argparse setup and calls using
parser/args, LLM instantiation, and generation) under a guarded block: wrap code
that calls parser.parse_args(), creates the vllm LLM and SamplingParams, and
runs generation inside if __name__ == "__main__":. Also relocate imports (import
vllm and from vllm import LLM, SamplingParams) to the top of the file with other
imports to follow PEP8; keep only lightweight module-level constants like MODEL
and TP outside the guard and ensure no heavy side-effect code runs at import
time.
| Expected pass output (vLLM 0.13.0): | ||
| ``` | ||
| Iteration 1 — prompt length: 13990 chars | ||
| tokens: 2048, finish_reason: length | ||
| text (first 100): '3001 3002 3003 3004 3005 3006 3007 3008 3009 3010 3011 3012 3013 3014 3015 3016 3017 3018 3019 3020 ' | ||
|
|
||
| Iteration 2 — prompt length: 16038 chars | ||
| tokens: 2048, finish_reason: length | ||
| text (first 100): '1 3412 3413 3414 3415 3416 3417 3418 3419 3420 3421 3422 3423 3424 3425 3426 3427 3428 3429 3430 343' | ||
|
|
||
| [nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16] ALL GOOD! | ||
| ``` | ||
|
|
||
| Expected fail output (vLLM 0.14.0): | ||
| ``` |
There was a problem hiding this comment.
Add language specifiers to the two output code blocks (markdownlint MD040).
Both the pass-output and fail-output fenced code blocks are missing a language identifier, triggering MD040.
🐛 Proposed fix
Expected pass output (vLLM 0.13.0):
-```
+```text
Iteration 1 — prompt length: 13990 chars
... Expected fail output (vLLM 0.14.0):
-```
+```text
Iteration 1 — prompt length: 13990 chars
...🧰 Tools
🪛 markdownlint-cli2 (0.21.0)
[warning] 331-331: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
[warning] 344-344: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/adding-new-models.md` around lines 330 - 344, The two fenced code blocks
under "Expected pass output (vLLM 0.13.0)" and "Expected fail output (vLLM
0.14.0)" are missing language specifiers causing markdownlint MD040; update both
fences to include a language (e.g., use ```text) so the pass-output and
fail-output blocks explicitly start with ```text and end with ``` to satisfy
MD040 and preserve formatting.
| MODEL = "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16" | ||
| TP = 2 | ||
| MAX_TOKENS = 2048 | ||
| MAX_MODEL_LEN = 32768 | ||
| COUNT_UP_TO = 3000 | ||
|
|
||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--model", type=str, default=MODEL) | ||
| parser.add_argument("--tp", type=int, default=TP) | ||
| args = parser.parse_args() | ||
|
|
||
| import vllm | ||
| from vllm import LLM, SamplingParams | ||
|
|
||
| print(f"vLLM version: {vllm.__version__}") | ||
|
|
||
| numbers = " ".join(str(i) for i in range(1, COUNT_UP_TO + 1)) | ||
| prompt = ( | ||
| "You are a counting assistant. Output ONLY numbers separated by spaces.\n\n" | ||
| f"User: Continue counting: {numbers} " | ||
| ) | ||
|
|
||
| llm = LLM( | ||
| model=args.model, | ||
| tensor_parallel_size=args.tp, | ||
| enable_prefix_caching=True, | ||
| max_model_len=MAX_MODEL_LEN, | ||
| gpu_memory_utilization=0.90, | ||
| trust_remote_code=True, | ||
| ) | ||
| sampling_params = SamplingParams(temperature=0.0, max_tokens=MAX_TOKENS, logprobs=1) | ||
|
|
||
| # Iteration 1: initial generation (builds the prefix cache) | ||
| print(f"\nIteration 1 — prompt length: {len(prompt)} chars") | ||
| out1 = llm.generate([prompt], sampling_params)[0].outputs[0] | ||
| print(f" tokens: {len(out1.token_ids)}, finish_reason: {out1.finish_reason}") | ||
| print(f" text (first 100): {out1.text[:100]!r}") | ||
|
|
||
| # Iteration 2: extend prompt with prior output (triggers prefix cache reuse) | ||
| prompt += out1.text | ||
| print(f"\nIteration 2 — prompt length: {len(prompt)} chars") | ||
| out2 = llm.generate([prompt], sampling_params)[0].outputs[0] | ||
| print(f" tokens: {len(out2.token_ids)}, finish_reason: {out2.finish_reason}") | ||
| print(f" text (first 100): {out2.text[:100]!r}") | ||
|
|
||
| # Check for NaN logprobs | ||
| nan_count = 0 |
There was a problem hiding this comment.
Module-level mutable variables are missing the G_ prefix required by coding guidelines.
parser, args, numbers, prompt, llm, sampling_params, out1, out2, and nan_count are all module-scope mutable bindings. As per coding guidelines, global variables should use upper snake_case with a G_ prefix (e.g., G_PARSER, G_LLM). The constants MODEL, TP, etc. are already correctly named.
The cleanest resolution is the if __name__ == "__main__": refactor suggested above — that would demote all of these to local variables, eliminating the naming concern entirely.
As per coding guidelines: "Use upper snake_case with G prefix for global variables, e.g., G_MY_GLOBAL"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tools/model_diagnostics/5.prefix_caching_nan.py` around lines 29 - 75, The
module defines several module-level mutable bindings (parser, args, numbers,
prompt, llm, sampling_params, out1, out2, nan_count) which violate the global
naming guideline; wrap all runtime code that creates or mutates these symbols
inside an if __name__ == "__main__": block so they become local to main (keep
MODEL, TP, MAX_TOKENS, MAX_MODEL_LEN, COUNT_UP_TO and imports at module scope),
e.g., move creation of argparse.ArgumentParser()/parser, args =
parser.parse_args(), numbers, prompt construction, LLM() instantiation,
SamplingParams(), the two generate calls that produce out1/out2, and nan_count
into that guard; alternatively if you must keep any of them global, rename using
upper snake_case with the G_ prefix (e.g., G_PARSER, G_LLM) to satisfy the
guideline.
| break | ||
|
|
||
| if nan_count > 0: | ||
| print(f"\n Sample logprobs from iteration 2:") |
There was a problem hiding this comment.
Remove spurious f prefix — f-string with no placeholders (Ruff F541).
🐛 Proposed fix
- print(f"\n Sample logprobs from iteration 2:")
+ print("\n Sample logprobs from iteration 2:")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| print(f"\n Sample logprobs from iteration 2:") | |
| print("\n Sample logprobs from iteration 2:") |
🧰 Tools
🪛 Ruff (0.15.1)
[error] 87-87: f-string without any placeholders
Remove extraneous f prefix
(F541)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tools/model_diagnostics/5.prefix_caching_nan.py` at line 87, The print
statement currently uses an unnecessary f-string: locate the call print(f"\n
Sample logprobs from iteration 2:") and remove the leading "f" so it becomes a
plain string literal; this eliminates the Ruff F541 spurious f-string warning
without changing behavior.
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Documentation
New Features