Skip to content

fix: Skip cudagraph capture at prefillWarmup stage#805

Open
bppps wants to merge 1 commit intoalibaba:mainfrom
bppps:fix/cudagraph_with_linear_attn
Open

fix: Skip cudagraph capture at prefillWarmup stage#805
bppps wants to merge 1 commit intoalibaba:mainfrom
bppps:fix/cudagraph_with_linear_attn

Conversation

@bppps
Copy link
Copy Markdown
Collaborator

@bppps bppps commented Mar 20, 2026

Problems:

NormalEngine initialization goes through warmup -> initCacheManager -> initExecutor procedure, presented below.

warm_up_result = warmUp(params);
RTP_LLM_LOG_INFO(
"warm up done, max runtime used memory: %ld bytes (%ld MiB), device reserved memory: %ld bytes (%ld MiB)",
warm_up_result->max_used_memory,
warm_up_result->max_used_memory / 1024 / 1024,
warm_up_result->device_reserved_bytes,
warm_up_result->device_reserved_bytes / 1024 / 1024);
} else {
RTP_LLM_LOG_INFO("skip warm up.");
}
initCacheManager(warm_up_result);
RTP_LLM_LOG_INFO("create cache manager done");
initExecutor(params, propose_params_);

At warmup stage, KVCacheManager is set nullptr during NormalExecutor initialization, which results in PyModelWrapper created with no KVCache enabled. If we launched rtp_llm with cudagraph enabled, then it will will try to capture decode graph in prefill warmup.

executor_.reset(new NormalExecutor(params, nullptr, device_, nullptr, true));

cache_manager ?
std::make_optional(is_propose_ ? cache_manager->getMTPModuleCacheLayerLayout(propose_model_index_) :
cache_manager->getMainModelCacheLayerLayout()) :
std::nullopt,

Besides if we are using hybrid attention models like Qwen35, kv cache is now required strictly by default.

assert kv_cache is not None, "kv_cache is required for decode"

launch command:

python3 -m rtp_llm.start_server --checkpoint_path=/home/admin/qwen35_2b --model_type=qwen35_dense --max_seq_len=10240  --enable_cuda_graph=1

Results in stack trace below :
what(): AssertionError: kv_cache is required for decode

At:
/home/admin/rtp-llm/rtp_llm/models_py/model_desc/qwen3_next.py(399): forward
/opt/conda310/lib/python3.10/site-packages/torch/nn/modules/module.py(1750): _call_impl
/opt/conda310/lib/python3.10/site-packages/torch/nn/modules/module.py(1739): _wrapped_call_impl
/home/admin/rtp-llm/rtp_llm/models_py/model_desc/qwen3_next.py(572): forward
/opt/conda310/lib/python3.10/site-packages/torch/nn/modules/module.py(1750): _call_impl
/opt/conda310/lib/python3.10/site-packages/torch/nn/modules/module.py(1739): _wrapped_call_impl
/home/admin/rtp-llm/rtp_llm/models_py/model_desc/qwen3_next.py(654): forward
/opt/conda310/lib/python3.10/site-packages/torch/nn/modules/module.py(1750): _call_impl
/opt/conda310/lib/python3.10/site-packages/torch/nn/modules/module.py(1739): _wrapped_call_impl
/home/admin/rtp-llm/rtp_llm/models_py/model_desc/qwen3_next.py(749): forward

Modifications:

This skips cudagraph capture by examining kv_cache_layer nullptr which is default in warmup. And it may need consideration that potential inaccuracies exists in VRAM measurement, which are within acceptable tolerance and problems can be reported in engine setup if KVCacheManager requires a larger space.

@bppps bppps requested a review from LLLLKKKK as a code owner March 20, 2026 02:05
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review

概述

修复 Qwen35 等模型在 --enable_cuda_graph=1 时 warmup 阶段 crash 的问题。warmup 时 cache_manager=nullptr 导致 kv_cache_layer_layout 为 nullopt,CUDA graph capture 尝试 decode forward 时因缺少 kv_cache 而失败。

优点

  • 修复精准,仅改动 2 行核心逻辑 + 1 行防御性检查
  • 所有调用路径(NormalExecutor/EmbeddingExecutor/MtpExecutor)的行为验证正确

建议改进

P1 - 重要

  1. warmup VRAM 测量不准确可能导致 OOM: 跳过 CUDA graph capture 意味着 warmup 测量的峰值内存偏低,initCacheManager 可能过度分配 KV cache blocks,后续真正 capture 时 OOM。建议至少添加 log warning。
  2. enable_cuda_graph_ 为 true 但 graph_runner_ 为 nullptr 的不变量破坏: 未来代码如果检查 enable_cuda_graph_ 并假设 graph_runner_ 有效会 segfault。建议在跳过 capture 时将 enable_cuda_graph_ 设为 false,或引入 graph_ready_ 标志。

P2 - 建议

  • 注释 "required by linear attention" 不准确,实际是 warmup 阶段通用问题
  • forward() 中 null check 正确但建议加 warning log 辅助调试

总结

crash fix 正确且必要。建议修复不变量问题(设 enable_cuda_graph_=false)并添加 VRAM 测量的 warning log 后合入。

@bppps bppps force-pushed the fix/cudagraph_with_linear_attn branch 3 times, most recently from 4a70e5d to c3343ee Compare March 20, 2026 03:30
…is passed nullptr but strongly required by linear attention in qwen35 based models

Signed-off-by: bppps <bpppsaka@gmail.com>
@bppps bppps force-pushed the fix/cudagraph_with_linear_attn branch from c3343ee to 071aae0 Compare March 20, 2026 03:34
@bppps
Copy link
Copy Markdown
Collaborator Author

bppps commented Mar 20, 2026

🤖 AI Code Review

概述

修复 Qwen35 等模型在 --enable_cuda_graph=1 时 warmup 阶段 crash 的问题。warmup 时 cache_manager=nullptr 导致 kv_cache_layer_layout 为 nullopt,CUDA graph capture 尝试 decode forward 时因缺少 kv_cache 而失败。

优点

  • 修复精准,仅改动 2 行核心逻辑 + 1 行防御性检查
  • 所有调用路径(NormalExecutor/EmbeddingExecutor/MtpExecutor)的行为验证正确

建议改进

P1 - 重要

  1. warmup VRAM 测量不准确可能导致 OOM: 跳过 CUDA graph capture 意味着 warmup 测量的峰值内存偏低,initCacheManager 可能过度分配 KV cache blocks,后续真正 capture 时 OOM。建议至少添加 log warning。
  2. enable_cuda_graph_ 为 true 但 graph_runner_ 为 nullptr 的不变量破坏: 未来代码如果检查 enable_cuda_graph_ 并假设 graph_runner_ 有效会 segfault。建议在跳过 capture 时将 enable_cuda_graph_ 设为 false,或引入 graph_ready_ 标志。

P2 - 建议

  • 注释 "required by linear attention" 不准确,实际是 warmup 阶段通用问题
  • forward() 中 null check 正确但建议加 warning log 辅助调试

总结

crash fix 正确且必要。建议修复不变量问题(设 enable_cuda_graph_=false)并添加 VRAM 测量的 warning log 后合入。

已参考评审意见修复

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review (v2 — PR #805 fix: Skip cudagraph capture at prefillWarmup stage 增量审查)

Verdict: Approve — v2 changes address both P1 findings from the initial review.


Delta Summary (3d5afb7071aae0)

Changed files: PyWrappedModel.h, PyWrappedModel.cc (+11/−3)

  • 引入 cuda_graph_ready_ 标志,将"用户请求启用 CUDA graph"(enable_cuda_graph_) 与"CUDA graph 实际可用"(cuda_graph_ready_) 清晰分离
  • forward() 改用 cuda_graph_ready_ 作为判断条件,消除 graph_runner_ 空指针解引用风险
  • 新增 RTP_LLM_LOG_WARNING,明确提示 warmup 阶段跳过 CUDA graph capture 可能导致 VRAM 测量偏低

v1 问题修复状态

# 问题 状态
P1 enable_cuda_graph_=truegraph_runner_=nullptr 导致空指针 ✅ 已修复 — cuda_graph_ready_ 保证 truegraph_runner_ 非空
P1 Warmup VRAM 测量不含 CUDA graph 内存,可能导致 OOM ⚠️ 已缓解 — 添加了日志警告;根本性修复(预留内存)超出本 PR 范围
P2 注释误导("required by linear attention") ✅ 已修复
P2 forward() 静默回退无诊断信息 ✅ 已修复
P3 构造函数中使用参数而非成员变量(风格问题) ℹ️ 未处理,不阻塞

v2 新增问题

无。


正确性验证

场景 cuda_graph_ready_ graph_runner_ 正确?
Warmup (cache_manager=nullptr) false nullptr ✅ (本次修复目标)
正常推理 (cache_manager 有效) true valid
enable_cuda_graph=false false nullptr

建议跟进

  • 补充 warmup + cuda_graph 场景的单元测试

Generated by AI Code Review Bot | Review Detail

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

LLLLKKKK commented Apr 9, 2026

🤖 AI Code Review — PR #805
Head SHA: 071aae029fde34a92a6187a1b5c06ecd443ac33f | Verdict: LGTM

Summary

Fixes a crash/incorrect behavior during the prefill warmup phase where CUDA graph capture was attempted without a valid kv_cache_layer_layout (because cache_manager is nullptr during warmup). The fix introduces a cuda_graph_ready_ flag that is only set to true after successful graph capture, and guards the forward path with this flag instead of the config-level enable_cuda_graph_.

Findings

No significant issues found. The change is well-scoped:

  1. The new guard condition enable_cuda_graph_ && (is_prefill_cuda_graph_mode || params.kv_cache_layer_layout.has_value()) correctly skips capture when the layout is unavailable.
  2. The warning log about VRAM measurement being potentially lower (leading to over-allocation of KV cache blocks) is a helpful operational note.
  3. The cuda_graph_ready_ flag cleanly separates "user wants cuda graph" from "cuda graph is actually captured and usable."

Clean fix, no concerns.

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review — PR #805

PR 概述

Title: fix: Skip cudagraph capture at prefillWarmup stage
Author: bppps
规模: 2 files, +11/-3

核心目标

修复 enable_cuda_graph=1 时 Qwen3.5 等 hybrid attention 模型启动崩溃的问题。warmup 阶段 KVCacheManager 为 nullptr,导致 PyWrappedModel 构造函数尝试 capture CUDA graph 时因缺少 KV cache 而 crash。本 PR 通过检查 kv_cache_layer_layout 是否可用来跳过 warmup 阶段的 CUDA graph capture,并引入 cuda_graph_ready_ 标志确保 forward 时不会误用未 capture 的 graph。


Review 意见

问题

  1. warmup 阶段 VRAM 测量不含 CUDA graph 内存,可能导致 KV cache block 过度分配 [P2]

    作者在 WARNING 日志中已明确提到此风险。warmup 跳过 CUDA graph capture 后,测量的 peak VRAM 不包含 CUDA graph 占用的内存,initCacheManager 基于此测量值计算可用 KV cache block 数量可能偏多。建议后续考虑在 initCacheManager 中预留 CUDA graph 的估算内存量。

  2. 缺少测试覆盖 [P2]

    本 PR 修复了启动崩溃但未添加对应测试。建议至少在 smoke test 中增加一个 enable_cuda_graph=1 的 Qwen3.5 case。

整体评价

干净、聚焦的 bug fix PR。问题根因分析清晰,修复方案合理,WARNING 日志对潜在风险做了充分说明。PR description 质量高。

LGTM ready to ci — 当前 review 未发现阻塞级或重要级问题,可进入 CI 验证和合入流程;P2 建议后续改进但不阻塞。

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

AI Code Review — PR #805

Summary: P0/0 · P1/0 · P2/0 · P3/0

Review status: LGTM

lgtm ready to ci

Strengths

  • Correctly prevents null dereference on graph_runner_ when CUDA graph capture is skipped during warmup phase
  • Warning log clearly documents the VRAM measurement limitation so operators know warmup numbers undercount peak usage
  • Minimal, focused change — new boolean flag cleanly separates 'user wants graphs' from 'graphs are actually ready'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants