Skip to content

SwapAB + Async Load Cache#794

Closed
alibaba-miji wants to merge 5 commits intomainfrom
feature/support_qwen35_merge
Closed

SwapAB + Async Load Cache#794
alibaba-miji wants to merge 5 commits intomainfrom
feature/support_qwen35_merge

Conversation

@alibaba-miji
Copy link
Copy Markdown
Collaborator

@alibaba-miji alibaba-miji commented Mar 17, 2026

Summary

This PR includes three improvements targeting MoE inference performance and PD-separation KV cache write efficiency.

1. Support DeepGEMM Swap-AB Optimization

Add swap-AB mode for DeepGEMM FP8 GEMM kernels on SM90 (Hopper) GPUs. For small-M scenarios (common in decode phase), swapping the A/B operands yields better tiling and higher throughput.

Changes:

  • New C++ plugin ops: deep_gemm_fp8 (normal GEMM) and deep_gemm_grouped_fp8_masked (grouped masked GEMM) that handle quantization + padding + GEMM in a single call.
  • Config flag deep_gemm_use_swap_ab (default true) with runtime capability check (SM90 only).
  • Python wrappers: fp8_gemm_nt_swapab and m_grouped_fp8_gemm_nt_masked_swapab.

2. Split FakeBalanceExpert into a Standalone Op

Decouple the FakeBalanceExpert logic from SelectTopkOp into its own independent operator.

Changes:

  • New FakeBalanceExpertOp (C++ + pybind): takes expert_ids and expert_scales as input.
  • SelectTopk is now pure top-k selection; balance rewriting happens as a separate, optional post-step.
  • Cleaner separation of concerns, easier to enable/disable independently.

3. Async KV Cache Write for PD Separation

Replace synchronous per-layer writeCacheStore calls with an async thread-pool-based writer, eliminating CPU stalls on the main forward thread.

Changes:

  • New CacheStoreAsyncWriter class: a lock-free thread pool (30 threads, queue size 10000) with strict lifecycle (init → submit* → waitAllDone).
  • DeviceBase gains initCacheStoreWrite(), submitAsyncCacheStoreTask(), waitCacheStoreComplete() APIs.
  • WriteCacheStoreOp now captures torch::Tensors by value (cheap refcount bump) and submits work to the async writer instead of running inline.
  • PyWrappedModel::forward calls init before the layer loop and waitAllDone after — the main thread keeps launching CUDA kernels without blocking.

Architecture:
image

4. using torch::tensor instead of std::vectorstd::string to avoid value-copy

Test Plan

  • CacheStoreAsyncWriterTest — unit tests for lifecycle, concurrent submit, and exception propagation
  • SelectTopkOpTest — updated for the split op interface
  • DeepGemmMaskedExecutorTest — swap-AB grouped GEMM correctness
  • End-to-end PD separation inference with async cache write enabled (Use origin)

@alibaba-miji alibaba-miji requested a review from LLLLKKKK as a code owner March 17, 2026 12:00
}

torch::Tensor
deep_gemm_fp8(torch::Tensor lhs_bf16, torch::Tensor rhs_data, torch::Tensor rhs_scale, int user_deep_gemm_num_sm) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

外面已经判过了,里面不需要做这些封装了?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread rtp_llm/cpp/devices/DeviceBase.cc Outdated
@@ -148,7 +148,34 @@ void DeviceBase::updateCurrentTorchStream() {
}

void DeviceBase::setCacheStore(std::shared_ptr<rtp_llm::CacheStore> cache_store) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该不用从 device 上绕了吧

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那几个wrapper函数删掉了,但是unique_ptr还得在devicebase里,当全局变量用

{
std::lock_guard<std::mutex> ex_lock(exception_mutex_);
if (!stored_exception_) {
stored_exception_ = std::current_exception();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

线程安全

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@alibaba-miji alibaba-miji force-pushed the feature/support_qwen35_merge branch 3 times, most recently from 695c94d to 3d3bcfb Compare March 19, 2026 07:01
Copy link
Copy Markdown
Collaborator

@LLLLKKKK LLLLKKKK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

两个独立特性打包:(1) SM90 小 M FP8 GEMM 的 swap_ab 优化;(2) CacheStoreAsyncWriter 异步写入线程池。

问题 1: cache_store_async_writer_ 缺少 null 检查 ⚠️

WriteCacheStoreOp.cc 中直接调用 device->cache_store_async_writer_->submit(...) 没有 null 检查。如果 pd_separation 为 true 但 setCacheStore 未被调用(或传入 nullptr),会通过 null unique_ptr 崩溃。

问题 2: 测试 MultipleTaskOrdering 存在 data race

测试向 3 线程池推送 20 个任务并断言严格顺序执行。多线程下执行顺序不确定,且多个线程对共享 std::vector 做无同步 push_back,是未定义行为。这个测试天然 flaky。

问题 3: .bazelrc 硬编码 PATH 覆盖

会清除用户/系统自定义 PATH(如自定义 CUDA toolkit 路径、ccache),可能在非标准环境下静默破坏构建。

@alibaba-miji alibaba-miji force-pushed the feature/support_qwen35_merge branch 2 times, most recently from 000990d to 5f0a91c Compare March 19, 2026 15:15
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review

概述

四合一 PR:(1) DeepGEMM Swap-AB 优化(SM90 小 M 场景);(2) FakeBalanceExpert 从 SelectTopkOp 拆分为独立 Op;(3) PD 分离场景下 KV Cache 异步写入;(4) 格式化清理。+1915/-185, 38 files。

优点

  • Async CacheStoreWriter 设计清晰:状态机 IDLE↔RUNNING、异常传播、完整单元测试覆盖
  • FakeBalanceExpert 拆分提升了关注点分离,SelectTopk 变为纯 top-k 选择
  • Swap-AB 有 threshold 控制(normal GEMM M<32, grouped GEMM M<64),不会影响大 batch 路径
  • 新增 DeepGemmMaskedExecutorSwapAbTest 覆盖小 M 场景

建议改进

P1 - 重要

  1. HWKernelConfig Pickle 向后兼容性 (ConfigInit.cc)
    __setstate__if (t.size() != 18) throw 会导致旧版本序列化的 HWKernelConfig 反序列化失败。建议改为 if (t.size() < 17) throw,对 index 17 做条件判断赋默认值。

  2. SelectTopkOp pybind API 破坏性变更
    构造函数从 (ModelConfig, bool, int64, int64, int64) 变为 (ModelConfig)。外部直接调用旧 API 的代码会报错。建议在 PR description 中明确标注 breaking change。

  3. cache_store_async_writer_ 作为 public 成员暴露 (DeviceBase.h)
    PR description 提到了封装 API(initCacheStoreWrite() 等),但 forwardMicroBatched 直接访问 device_->cache_store_async_writer_->init(),而 forwarddevice_->initCacheStoreWrite()。两条路径风格不一致,建议统一使用封装方法,将成员改为 private。

  4. deepep_wrapper 边界条件变更 (deepep_wrapper.py:216)
    > 128 改为 >= 128,当 ll_num_max_token_per_rank == 128 时行为改变。请确认是否有意为之。

P2 - 建议

  1. 异步 writer 异常不中断后续任务 (CacheStoreAsyncWriter.cc)
    某个 task 抛异常后其他已提交 task 继续执行。如果某层 cache write 失败,后续层仍会写入可能导致数据不一致。建议 submit 时检查 stored_exception_ 做 fail-fast。

  2. 线程池硬编码 (CacheStoreAsyncWriter.cc)
    kThreadCount=30, kQueueSize=10000 硬编码。不同模型规模下最优值不同,建议通过配置可调。

  3. down_input_scale TMA 对齐行为变更 (deepgemm_masked_executor.py:292)
    get_mn_major_tma_aligned_tensor(down_input_scale) 在 if/else 之前调用,非 swap_ab 路径也使用了对齐后的 scale(原代码不做对齐)。请确认这不影响正确性。

  4. maga_server_manager.py 注释掉 TEST_UNDECLARED_OUTPUTS_DIR
    如果是临时调试不应提交;如果有意变更请说明原因。

总结

建议拆分为独立 PR 以降低 review 和回滚风险。最关键的是 HWKernelConfig pickle 兼容性和 SelectTopkOp API 变更需要处理。Async writer 和 Swap-AB 功能实现质量不错。

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review

概述

将 logprobs 功能从传输完整 vocab-size 概率向量优化为仅传输 top-k 结果,显著降低内存和带宽开销。改动覆盖了完整数据流路径(采样 -> stream -> RPC -> 渲染),并保留了 speculative decoding 所需的 all_probs 路径。

优点

  • 优化方向正确,将 logprobs 从 O(vocab_size) 降低到 O(top_k)
  • 数据流路径覆盖完整
  • 新增了 smoke 测试覆盖基本功能路径

建议改进

P1 - 重要

  1. selected token 的 logprob 可能为 -inf:当采样的 top_k(如 50)大于用户请求的 top_logprobs(如 3)时,被采样的 token 可能不在返回的 top-k 中,selected_logprob 会是 -inf,与 OpenAI API 规范不符。建议在 GPU 端额外检查 sampled token 是否在 topk 中,若不在则追加其 logprob。

  2. return_all_probs=true + top_logprobs=0 导致运行时崩溃NormalBatchStreamProcessor 中分配条件要求 max_top_logprobs > 0,但 NormalGenerateStream 中检查 return_all_probs 时会抛异常。通过非 OpenAI 路径设置 return_all_probs=true 但未设置 top_logprobs 时会触发。建议在 NormalBatchStreamProcessor 中当 return_all_probs=true && max_top_logprobs == 0 时自动设为 1。

  3. 非 flashinfer 采样路径缺少 topk logprobs 计算:当前非 flashinfer 路径最终都会调用 flashinferSampleGreedy 所以暂时安全,但这个隐含依赖值得注释说明。

P2 - 建议

  1. top_logprobs 缺少上界校验:用户可传入任意大值,可能导致 torch::topk 中 k 接近 vocab_size 甚至超出。OpenAI API 限制 0-20,建议添加上界校验。

  2. StreamUpdateInfo 使用位置初始化,字段顺序脆弱:初始化列表已有 11-13 个 nullptr,极易出错。建议改用命名初始化或 builder pattern。

  3. Speculative decoding 路径下 MtpBatchStreamProcessor 未适配 topk:仍只分配 all_probs,未分配 topk buffer。如果 spec decode 场景不需要 logprobs,建议添加注释说明。

总结

建议在合并前修复 selected token logprob 为 -inf 和 return_all_probs=true + top_logprobs=0 的边界条件问题。

@alibaba-miji alibaba-miji force-pushed the feature/support_qwen35_merge branch from 5f0a91c to 98b6019 Compare March 20, 2026 01:52
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review (v2 — 增量审查)

概述

PR 更新后(98b6019c),上次 P0 编译阻塞问题已修复。四合一 PR 结构不变:DeepGEMM Swap-AB、FakeBalanceExpert 拆分、Async KV Cache Write、小改动。

上次问题跟踪

问题 状态
P0: initCacheStoreWrite() 不存在 已修复 — 两条 forward 路径统一使用 cache_store_async_writer_
P1: HWKernelConfig pickle 兼容性 未修复 — __setstate__ 仍硬检查 t.size() != 18
P1: SelectTopkOp API 变更 已妥善处理 — FakeBalanceExpert 拆分完整
P1: cache_store_async_writer_ public 部分修复 — 风格统一但仍 public
P1: down_input_scale TMA 对齐 未修复 — 非 swap_ab 路径行为变更

新发现

P1 - 重要

  1. WriteCacheStoreOp 异步捕获数据生命周期: PyCacheStoreInputs 按值捕获到 lambda 中提交给线程池。如果其中字段已从 vector<string> 迁移为 torch::Tensor,需确认 tensor 底层数据在异步执行期间不被释放。

P2 - 建议

  • swap_ab 路径的 m_grouped_fp8_gemm_nt_masked_swapab 缺少 disable_ue8m0_cast 参数,需确认 C++ plugin 内部处理
  • deep_gemm_fp8 padding 后 slice().contiguous() 可能产生不必要的拷贝

总结

P0 编译问题已修复,FakeBalanceExpert 拆分质量好。仍需关注 HWKernelConfig pickle 兼容性(P1)和 TMA 对齐行为变更确认。

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review (修正版 — 之前的评论误贴了其他 PR 的内容,非常抱歉)

概述

四合一 PR:(1) DeepGEMM Swap-AB 优化(SM90 小 M 场景);(2) FakeBalanceExpert 从 SelectTopkOp 拆分为独立 Op;(3) PD 分离场景下 KV Cache 异步写入;(4) 小改动。规模较大(+1915/-185, 38 files),建议拆分。

优点

  • DeepGEMM Swap-AB 功能完整,有性能测试覆盖
  • FakeBalanceExpert 拆分干净
  • CacheStoreAsyncWriter 设计合理,有完整单元测试

建议改进

P1 - 重要

  1. HWKernelConfig Pickle 向后兼容性: pickle tuple size 17→18,__setstate__ 硬检查 t.size() != 18 导致旧版本反序列化失败。建议改为 if (t.size() < 17) throw 并对 index 17 做条件判断。
  2. SelectTopkOp pybind API 破坏性变更: 构造函数签名变更,外部调用方会报错。建议在 PR description 标注 breaking change。
  3. cache_store_async_writer_ 作为 public 成员暴露: 两条 forward 路径风格不一致。建议统一使用封装方法并改为 private。
  4. deepep_wrapper 边界条件: > 128 改为 >= 128,行为变化需确认是否有意为之。

P2 - 建议

  • 异步 writer 异常不中断后续任务,建议 fail-fast
  • 线程池硬编码参数(30 线程、队列 10000),建议可配置

总结

建议修复 pickle 兼容性问题后合入。四合一 PR 建议未来拆分为独立 PR。

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

PR #794: SwapAB + Async Load Cache — Code Review v3

Head SHA: 3dda2773 (上次 v2: 98b6019c) | 规模: 39 files, +1931/-186, 4 commits


上次 P1 问题跟踪

# 问题 状态 说明
1 HWKernelConfig Pickle 向后兼容性 未修复 [P1] __setstate__ 仍用 t.size() != 18 硬判断,旧版本 tuple(size=17) 反序列化直接抛异常。建议改为 t.size() < 17 + 可选字段 fallback
2 down_input_scale TMA 对齐改变非 swap_ab 路径行为 未修复 [P1] get_mn_major_tma_aligned_tensor(down_input_scale) 在 if/else 之前调用,非 swap_ab 路径也传了 aligned 版本,与原行为不同
3 WriteCacheStoreOp 异步捕获数据生命周期 已修复 Commit 3dda2773 在主线程创建 CUDA event,tensor 按值捕获(refcount bump),runWriteCacheStore 提取为独立函数,生命周期安全

Commit 3dda2773 新增审查

WriteCacheStoreOp 主线程创建 event 的改进是正确的。 具体变更:

  • WriteCacheStoreOp.cc: 主线程调用 device->createEvent() 创建 event,通过 std::move 传入 lambda
  • OpData.h: CacheStoreInputs 新增 pre_created_event 字段
  • DeviceBase.cc: writeCacheStore 优先使用 pre_created_event,fallback 到 createEvent()(兼容同步路径)

新发现 [P2]: createEvent() 的 stream 关联 — CUDA event 本身是 stream-agnostic 的(record 时才绑定),所以主线程创建是安全的。但建议添加注释说明 event 的 record 时机。


仍需关注

  1. [P1] Pickle 兼容性 — 建议合入前修复。修复方式简单:
if (t.size() < 17) throw std::runtime_error("Invalid state!");
// ... existing 17 fields ...
if (t.size() > 17) c.deep_gemm_use_swap_ab = t[17].cast<bool>();
  1. [P1] down_input_scale TMA 对齐 — 如果对非 swap_ab 路径也是正确改进,请在代码注释中说明。否则应将 get_mn_major_tma_aligned_tensor 调用移入 swap_ab 分支内部。另外 from deep_gemm.utils.layout import ... 在循环体内部,建议移到文件顶部。

  2. [P2] swap_ab 路径缺少 disable_ue8m0_cast — Gate/Up GroupGEMM-0 的 swap_ab 路径调用 m_grouped_fp8_gemm_nt_masked_swapab 时未传 disable_ue8m0_cast,需确认 C++ plugin 内部是否默认处理。


Review v3 by CI bot. 上次 review: v2 (SHA 98b6019)

@LLLLKKKK LLLLKKKK force-pushed the feature/support_qwen35_merge branch from 3dda277 to de70370 Compare March 22, 2026 18:00
@alibaba-miji alibaba-miji force-pushed the feature/support_qwen35_merge branch from 5ca27ab to e34d812 Compare March 23, 2026 06:56
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review — PR #794 — SwapAB + Async Load Cache

概述

本 PR 包含四个改动:(1) DeepGEMM Swap-AB 优化(SM90 小 M decode 加速);(2) FakeBalanceExpert 拆分为独立 Op;(3) Async KV Cache Write(PD 分离场景消除 CPU 阻塞);(4) 格式化及小修。代码质量整体不错,测试覆盖较好。

优点

  • CacheStoreAsyncWriter 设计清晰,init/submit/waitAllDone 生命周期管理合理,异常传播机制完善
  • torch::Tensor 按值捕获到 lambda 中(引用计数 bump)是正确的异步安全模式
  • FakeBalanceExpert 解耦使 SelectTopk 职责更单一,便于独立开关
  • DeepGemm Plugin 测试覆盖全面(正确性 + 性能 + 三方对比 + decode 模拟)
  • pre_created_event 设计避免了后台线程的 cudaEventRecord 竞争

建议改进

P0 - Bug

  1. Pickle 兼容性:__setstate__ 硬检查 t.size() != 18 会拒绝旧版本数据

    ConfigInit.ccHWKernelConfig__setstate__t.size() != 17 改为 t.size() != 18,旧版本序列化的 tuple(size=17)在新版本反序列化时会直接抛异常。应改为容错式:

    if (t.size() < 17)
        throw std::runtime_error("Invalid state!");
    // ... 解析前 17 个字段 ...
    if (t.size() > 17)
        c.deep_gemm_use_swap_ab = t[17].cast<bool>();

P1 - 重要

  1. cache_store_async_writer_ 作为 public 成员暴露在 DeviceBase 上

    std::unique_ptr<CacheStoreAsyncWriter> 被声明为 public,破坏封装性。建议提供 DeviceBase 上的封装方法(initCacheStoreWrite() / submitAsyncCacheStoreTask() / waitCacheStoreComplete()),将成员移回 protected

  2. CacheStoreAsyncWriter 硬编码 30 线程 + 10000 队列

    在 CPU 核数较少的环境下 30 个线程会造成过度竞争。建议通过配置可调,或至少注释说明选择依据。

P2 - 建议

  1. maga_server_manager.py 注释掉 TEST_UNDECLARED_OUTPUTS_DIR 看起来是调试残留

    # bazel_outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", os.getcwd())
    bazel_outputs_dir = os.getcwd()

    TEST_UNDECLARED_OUTPUTS_DIR 是 bazel 测试标准环境变量,移除可能影响 CI 测试输出收集。

  2. deepep_wrapper.py 边界条件 >128>=128 行为变更未在 PR description 中说明

    ll_num_max_token_per_rank == 128 时行为发生变化,请确认是否有意为之。

  3. Mega-PR 建议拆分:SwapAB、FakeBalanceExpert 拆分、Async Cache Write 是 3 个独立功能,建议拆为独立 PR 便于回滚和 bisect。

  4. submit()state_mutex_ 期间调用 pushTask:如果未来 pushTask 改为阻塞,会导致 mutex 长时间被持有。建议确认 LockFreeThreadPool::pushTask 永远不会阻塞,或在 pushTask 前释放锁。

总结

三个核心功能(SwapAB、FakeBalanceExpert 解耦、Async Cache Write)设计合理,测试覆盖好。P0 的 pickle 兼容性问题必须修复,否则旧版本序列化数据无法在新版本反序列化。其余为封装性和可配置性改进建议。修复 P0 后可以合入。

@alibaba-miji alibaba-miji force-pushed the feature/support_qwen35_merge branch from e34d812 to abac886 Compare March 24, 2026 07:50
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review — PR #794 SwapAB + Async Load Cache (v2 incremental, commit abac886)

概述

v2 修复了 PPU 环境下 WriteCacheStoreOp 因 async writer 未初始化而崩溃的 bug,改为 graceful fallback 到同步路径。v1 的 pickle 兼容性 P0 经确认为误报(内部 pybind struct),已撤回。

v2 新增问题

P2 - 建议

  1. WriteCacheStoreOp sync fallback 与 PyWrappedModel init/waitAllDone 不对称WriteCacheStoreOp 已做 nullptr 防御,但 PyWrappedModel.ccforwardMicroBatched/forward 仍有硬 CHECK。建议统一防御式编程风格。

v1 遗留问题(未修复)

P1 - 重要

  1. cache_store_async_writer_ public 成员暴露 — 建议封装方法 + 移回 protected
  2. 硬编码 30 线程 + 10000 队列 — 建议配置可调

P2 - 建议

  1. maga_server_manager.py 注释掉 TEST_UNDECLARED_OUTPUTS_DIR — 调试残留
  2. deepep_wrapper.py >128>=128 边界变更 — 未在 PR description 中说明

总结

v2 修复合理。v1 P0 已撤回。剩余 2 P1(public 成员 + 硬编码线程数)建议后续改进。

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review — PR #794 — SwapAB + Async Load Cache (v2 incremental)

概述

增量 review(v1: e34d812 → v2: abac886)。新增 1 个 commit:修复 PPU 环境下 WriteCacheStoreOp 因 async writer 未初始化而崩溃的问题,改为 graceful fallback 到同步路径。

v1 的 P0 pickle 兼容性问题经确认为误报(pybind struct 仅内部使用,不存在跨版本反序列化),已撤回。

优点

  • 新 commit 修复了一个实际 crash:PPU 环境未配置 cache store 时 WriteCacheStoreOp 会 CHECK 失败。fallback 到同步路径是正确的防御式编程。
  • CacheStoreAsyncWriter 有完整的单元测试覆盖(10 个 test case,包括异常传播、多轮 cycle、并发验证)。
  • SwapAB 和 FakeBalanceExpert 拆分都有对应的正确性测试。

建议改进

P1 - 重要

  1. cache_store_async_writer_ 作为 public 成员暴露在 DeviceBase 上
    std::unique_ptr<CacheStoreAsyncWriter> 被声明为 public,在 PyWrappedModel.ccWriteCacheStoreOp.cc 中直接访问。建议提供 DeviceBase 上的封装方法(initAsyncCacheWrite() / submitAsyncCacheWrite() / waitAsyncCacheWrite()),将成员移回 protected

  2. CacheStoreAsyncWriter 硬编码 30 线程 + 10000 队列
    kThreadCount = 30kQueueSize = 10000 在不同部署场景下可能不合适。建议通过配置可调,或至少在注释中说明选择依据。

P2 - 建议

  1. WriteCacheStoreOp sync fallback 与 PyWrappedModel init/waitAllDone 不对称
    WriteCacheStoreOp 现在优雅处理 cache_store_async_writer_ == nullptr,但 PyWrappedModel.cc 中仍有 RTP_LLM_CHECK_WITH_INFO(device_->cache_store_async_writer_ != nullptr, ...) 硬检查(被 pd_separation 守卫)。建议在 PyWrappedModel 中也加 nullptr 检查并跳过 init/waitAllDone,保持防御式编程风格一致。

  2. maga_server_manager.py 注释掉 TEST_UNDECLARED_OUTPUTS_DIR

    # bazel_outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", os.getcwd())
    bazel_outputs_dir = os.getcwd()

    看起来是调试残留,TEST_UNDECLARED_OUTPUTS_DIR 是 bazel 测试标准环境变量,移除可能影响 CI 测试输出收集。

  3. deepep_wrapper.py 边界条件 >128>=128
    ll_num_max_token_per_rank == 128 时行为变更(旧代码走 matched_tokens 匹配,新代码直接 round up 返回),PR description 未提及,请确认是否有意为之。

总结

v2 修复了一个实际 crash bug,修复方式合理。v1 的 P0 已撤回。剩余 P1 问题(public 成员暴露、硬编码线程数)是架构层面的改进建议,可在后续 PR 中处理。调试残留代码(P2 #4)建议在合并前清理。整体 LGTM with minor suggestions。

@alibaba-miji alibaba-miji force-pushed the feature/support_qwen35_merge branch from abac886 to 92bf2e5 Compare March 24, 2026 08:08
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review — PR #794 — SwapAB + Async Load Cache (v3 incremental)

概述

v3 在 v2 基础上新增了大量 swap-AB 相关代码:C++ DeepGemmPlugin pybind op、Python wrapper 集成、FakeBalanceExpert 独立 Op 拆分、以及完善的 async cache writer 重构。38 个文件变更,测试覆盖显著改善。

优点

  • 测试覆盖大幅改善:新增 CacheStoreAsyncWriterTest.cpp(149 行,覆盖生命周期、异常传播、多轮循环)、DeepGemmMaskedExecutorSwapAbTest、936 行 perf+correctness 测试
  • FakeBalanceExpert 解耦干净:从 SelectTopkOp 拆出为独立 Op,职责清晰,ROCm 侧有 placeholder
  • WriteCacheStoreOp 重构合理:提取 runWriteCacheStore() 辅助函数,async/sync 路径清晰,nullptr fallback 防御式编程
  • Swap-AB dispatch 有明确阈值_SWAP_AB_M_THRESHOLD = 32(linear)和 _GROUPED_SWAP_AB_THRESHOLD = 64(masked)

建议改进

P1 - 重要

  1. SelectTopkOp 构造函数签名 breaking change
    pybind 注册从 SelectTopkOp(model_config, fake_balance_expert, dp_rank, dp_size, ep_size) 变为 SelectTopkOp(model_config)。如果有外部用户直接调用旧签名,升级后会报错。需确认无外部消费者。

  2. cache_store_async_writer_ 作为 public 成员暴露在 DeviceBase 上(v1 遗留)
    建议提供封装方法(initAsyncWriter(), submitAsyncWrite(), waitAsyncWrites()),将成员移回 protected

  3. CacheStoreAsyncWriter 硬编码 30 线程 + 10000 队列(v1 遗留)
    建议通过配置可调,或至少在注释中说明选择依据。

P2 - 建议

  1. init_swapab_once 全局状态非线程安全且不可重置
    模块级 _swap_ab_checked/_swap_ab_enabled 无锁保护,且一旦设置无法重置(影响单元测试)。建议用 threading.Lock 保护或在 docstring 中标注单线程约束。

  2. _current_deep_gemm_num_sms 在 context manager 外被读取
    m_grouped_fp8_gemm_nt_masked_swapab 直接读取全局 _current_deep_gemm_num_sms,不在 context manager 内时为 -1。需确认 C++ 侧对 -1 的处理。

  3. deepep_wrapper.pyenable_swapab() 可能在 init_swapab_once() 之前被调用
    当前启动流程中 init_swapab_onceinit_deepep_wrapper 之前,正常路径没问题。但其他入口点可能得到错误结果。建议在 enable_swapab() 中加 assert。

  4. deep_gemm.utils.layout import 在热路径循环内
    deepgemm_masked_executor.pyfrom deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor 在 per-chunk 循环体内。建议移到文件顶部或方法开头。

  5. WriteCacheStoreOp.ccPyCacheStoreInputs 按值捕获
    cache_keysstd::vector<std::string>)会 deep copy。如果 cache_keys 很大,可能有性能影响。考虑 std::shared_ptrstd::move

总结

v3 将 swap-AB 优化从概念落地到完整的 C++ plugin + Python 集成 + 测试,代码质量整体不错。P1 #1 需确认 SelectTopkOp 签名变更的兼容性,P1 #2/#3 是架构债务可后续改进。Conditional LGTM

@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review — PR #794 SwapAB + Async Load Cache (v3 incremental, commit 92bf2e5)

概述

v3 新增 commit 修复了 maga_server_manager.py 中注释掉的 TEST_UNDECLARED_OUTPUTS_DIR(v2 P2 问题),并新增 SelectTopkOp 构造函数签名变更(新增 fake_balance_expert 参数)。

v2 问题状态

问题 状态
cache_store_async_writer_ public 成员 [P1] 仍存在
硬编码 30 线程 + 10000 队列 [P1] 仍存在
maga_server_manager.py 调试残留 [P2] 已修复
deepep_wrapper.py 边界变更 [P2] 仍存在

v3 新发现

  • SelectTopkOp 构造函数签名变更 [P1] — 新增 fake_balance_expert 参数,需确认无外部消费者依赖旧签名

总结

Conditional LGTM — P1 签名变更需确认兼容性,P1 public 成员和硬编码线程数可后续改进。

@alibaba-miji alibaba-miji force-pushed the feature/support_qwen35_merge branch from 92bf2e5 to cfc9bb0 Compare March 26, 2026 03:09
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review — PR #794

PR 概述

Title: SwapAB + Async Load Cache
Author: alibaba-miji
规模: 37(GitHub) files, ~+2900/-200

核心目标

三项性能优化:(1) DeepGEMM Swap-AB(SM90 小 M 场景 FP8 GEMM 吞吐提升);(2) FakeBalanceExpert 拆分为独立 Op;(3) PD 分离场景下 KV Cache 异步写入(消除主线程 CPU 阻塞)。


改动逻辑拆解

1. CacheStoreAsyncWriter(核心)

  • 基于 autil::LockFreeThreadPool(30 线程)的异步写入器,严格生命周期 init → submit* → waitAllDone
  • torch::Tensor 按值捕获(refcount bump),异常捕获并在 waitAllDone 时 rethrow
  • PyWrappedModel.cc:forward 前 init,forward 后 waitAllDone

2. DeepGEMM Swap-AB

  • 新增 deep_gemm_fp8 / deep_gemm_grouped_fp8_masked plugin ops
  • 配置 deep_gemm_use_swap_ab(默认 true,SM90 only)

3. FakeBalanceExpert 拆分

  • 独立 FakeBalanceExpertOpSelectTopk 简化为纯 top-k

Review 意见

问题

  1. cache_store_async_writer_ 作为 public 成员暴露 [P2]
    建议通过 getter 方法访问或添加注释说明。

  2. PR 包含 3 个独立功能 [P2]
    SwapAB / FakeBalanceExpert / Async Cache Write 建议后续拆分。

  3. 线程数和队列大小硬编码 [P3]
    kThreadCount=30, kQueueSize=10000 建议通过配置可调。

整体评价

三项优化都有明确性能收益。CacheStoreAsyncWriter 状态机设计严谨,torch::Tensor 按值捕获确保异步安全。PR description 含架构图,测试覆盖充分。

LGTM ready to ci

@alibaba-miji alibaba-miji force-pushed the feature/support_qwen35_merge branch 2 times, most recently from 93609d4 to 7a7163e Compare March 26, 2026 14:59
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review — PR #794

PR 概述

Title: SwapAB + Async Load Cache
Author: alibaba-miji
规模: 37(GitHub) + 2(内源) files, +1917/-186
Review 类型: 检测到 force push/rebase,本次为 PR 全量 review(第 3 次)

核心目标

本 PR 包含四项改动:(1) 为 DeepGEMM FP8 GEMM 添加 swap-AB 优化(SM90 Hopper),小 M 场景下提升 decode 吞吐;(2) 将 FakeBalanceExpert 从 SelectTopkOp 拆分为独立 Op,改善关注点分离;(3) 将 PD 分离场景的 KV Cache 写入从同步改为异步线程池,消除主线程 CPU 阻塞;(4) 内源 smoke 测试 golden 数据更新。


改动逻辑拆解

GitHub 开源仓库变更(主要代码)

1. DeepGEMM Swap-AB 优化

  • 新增 C++ plugin ops:DeepGemmPluginOp.cc/h,提供 deep_gemm_fp8(normal GEMM)和 deep_gemm_grouped_fp8_masked(grouped masked GEMM),在 C++ 侧完成量化 + padding + GEMM
  • 新增配置项 deep_gemm_use_swap_ab(默认 true),通过 HWKernelConfig 传播,运行时检查 SM90
  • Python 侧新增 fp8_gemm_nt_swapabm_grouped_fp8_gemm_nt_masked_swapab wrapper
  • CudaFp8DeepGEMMLinear.forwardM < 32enable_swapab() 时走 swap-AB 路径
  • DeepGemmMaskedExecutorexpected_m < 64enable_swapab() 时走 swap-AB 路径
  • deepep_wrapper.pycalc_low_latency_max_token_per_rank 新增 swap-AB 场景的 matched_tokens 列表,并修复边界条件 > 128>= 128

2. FakeBalanceExpert 拆分

  • SelectTopkOp 中移除 fake_balance_expert 逻辑,SelectTopkOp 构造函数简化为只接受 ModelConfig
  • 新增独立的 FakeBalanceExpertOp(C++ + pybind)和 Python wrapper FakeBalanceExpert
  • generic_moe.pySelectTopkFakeBalanceExpert 分别实例化,forward 中顺序调用
  • ROCm 侧添加 FakeBalanceExpertNotImplementedOp 占位
  • 测试更新为分离调用模式

3. Async KV Cache Write

  • 新增 CacheStoreAsyncWriter 类:基于 autil::LockFreeThreadPool(30 线程,队列 10000),严格生命周期 init → submit* → waitAllDone
  • DeviceBase 构造函数中始终创建 cache_store_async_writer_(public 成员)
  • WriteCacheStoreOp.cc 重构:所有 torch::Tensor 按值捕获(引用计数 bump),创建 event 后提交到异步 writer
  • PyWrappedModel::forwardforwardMicroBatched 在 PD 分离非 warmup 场景下调用 init() / waitAllDone()
  • CacheStoreInputs 新增 pre_created_event 字段,支持主线程预创建 event 避免后台线程竞争
  • 单元测试 CacheStoreAsyncWriterTest.cpp 覆盖生命周期、并发、异常传播等场景

4. 配置与格式化

  • HWKernelConfig 新增 deep_gemm_use_swap_ab 字段(C++ 默认 true),pickle tuple size 17→18
  • hw_kernel_group_args.py 新增 --deep_gemm_use_swap_ab 命令行参数
  • backend_manager.py 在 DeepEP 初始化前调用 init_swapab_once

内源仓库变更

1. Smoke 测试 Golden 数据

  • deepseek_v2/q_r_3090_mla.jsonq_r_3090_mla_r24.json 更新了预期输出文本和 token 计数

Checklist 检查结果

通用原则

软件工程原则

检查项 结果
SRP ✅ FakeBalanceExpert 拆分改善了 SRP
OCP ✅ swap-AB 通过配置开关和阈值判断,不修改核心 GEMM 逻辑
LSP
ISP ✅ SelectTopkOp 接口精简
DIP
DRY
KISS
YAGNI

架构审视

检查项 结果
抽象边界 cache_store_async_writer_ 作为 public 成员暴露在 DeviceBase 上,见 P2-1
依赖方向
状态完整性 WriteCacheStoreOp 无条件提交到 async writer,非 PD 分离场景下 writer 处于 IDLE 状态会 crash,见 P1-1
错误语义 ✅ 异常传播到 waitAllDone 后 rethrow
可观测性
可演进性
可运维性 ✅ swap-AB 可通过 --deep_gemm_use_swap_ab=false 关闭

测试

检查项 结果
新功能有对应测试 ✅ CacheStoreAsyncWriterTest、SelectTopkOpTest 更新、DeepGemmPlugin perf test
删除的测试有等价替代
边界 case 覆盖 ✅ perf test 覆盖 prime M 值
分布式改动有多卡测试 N/A

代码质量与文档

检查项 结果
无关改动分离 hw_kernel_group_args.py 混入大量引号格式化改动,见 P2-2
mega-PR 拆分 ❌ 4 个独立功能合并在一个 PR 中,见 P2-3
Commit 原子性 ✅ 4 个 commit 各自对应一个功能
Commit message 准确性
PR description ✅ 详细的架构图和说明
日志频率控制

领域检查

A. 兼容性与配置 — 全部 ✅

B. 正确性与逻辑

检查项 结果
逻辑错误 ❌ WriteCacheStoreOp 在非 PD 分离场景下无条件提交到 IDLE 状态的 async writer,见 P1-1
其余项

C. 线程安全与并发

检查项 结果
静态/全局状态 _swap_ab_checked/_swap_ab_enabled 全局变量无线程保护,见 P2-4
其余项

D. 性能 — 全部 ✅

E. 分布式 — 全部 ✅

F. 跨平台

检查项 结果
CUDA/ROCm binding 对称 ❌ DeepGemmPluginOp 未在 ROCm 侧添加 guard,见 P2-5
其余项

G-I — 全部 ✅


Review 意见

问题

  1. WriteCacheStoreOp 在非 PD 分离场景下无条件提交到 IDLE 状态的 async writer [P1]

    WriteCacheStoreOp.cc 重构后,只要 kv_cachecache_store_member 都有值,就会无条件调用 device->cache_store_async_writer_->submit(...)。但 CacheStoreAsyncWriter::submit() 要求状态为 RUNNING(由 init() 设置),否则会 RTP_LLM_CHECK_WITH_INFO 直接 abort。

    init() 仅在 PyWrappedModel::forward / forwardMicroBatched 中当 !inputs.warmup && inputs.pd_separation 时调用。这意味着如果未来有人在非 PD 分离场景下也设置了 cache_store_inputs,就会 crash。

    当前调用链中 cache_store_inputs 仅在 PD 分离时设置,所以正常流程安全。但这个隐式依赖很脆弱。

    建议:WriteCacheStoreOp.cc 中添加防御性检查,判断 async writer 是否处于 RUNNING 状态再决定走异步还是同步路径。或者在 submit 中增加一个 trySubmit 变体,IDLE 时 fallback 到同步执行。

  2. cache_store_async_writer_ 作为 public 成员暴露 [P2]

    DeviceBase.hcache_store_async_writer_ 声明为 public,破坏了封装。PR description 中提到了 initCacheStoreWrite() / submitAsyncCacheStoreTask() / waitCacheStoreComplete() 等 API,但实际代码中未实现封装方法。

    建议: 改为 private/protected,通过 DeviceBase 方法封装。

  3. hw_kernel_group_args.py 混入大量格式化改动 [P2]

    该文件约 80% 的 diff 是单引号→双引号和缩进调整,与 deep_gemm_use_swap_ab 功能无关,应分离。

  4. PR 包含 4 个独立功能 [P2]

    SwapAB、FakeBalanceExpert 拆分、Async Cache Write、smoke golden 更新可独立回滚。当前 commit 结构已按功能分离,可接受但建议后续拆分。

  5. _swap_ab_checked/_swap_ab_enabled 全局变量无线程保护 [P2]

    deepgemm_wrapper.pyinit_swapab_once 使用普通全局变量做 once-init。实际在主线程调用,风险较低。建议使用 threading.Lockfunctools.cache

  6. DeepGemmPluginOp 未在 ROCm RegisterBaseBindings 中添加 guard [P2]

    新 Op 仅在 CUDA 侧注册。ROCm 侧未添加对应 stub。当前不影响编译,但与 CUDA 侧不同步。FakeBalanceExpert 已有 Python 侧 NotImplementedOp 占位,这点做得好。

小问题

  • deep_gemm_fp8padded_output.slice(0, 0, m).contiguous() 产生额外 GPU 拷贝,小 M 影响不大。[P3]
  • CacheStoreAsyncWriter 硬编码 30 线程和 10000 队列大小,建议通过配置可调或注释说明选择依据。[P3]
  • deepgemm_masked_executor.py Down GroupGEMM-1 路径中 down_input_scale_aligned 在非 swap-AB 分支也被应用了(替代了原来的 down_input_scale),这是一个行为变更,需确认正确性。[P3]

整体评价

PR 整体质量良好,架构设计清晰。SwapAB 优化有完整的性能测试和正确性验证,FakeBalanceExpert 拆分改善了代码结构,Async Cache Write 的线程池设计合理且有充分的单元测试。主要风险点在于 WriteCacheStoreOp 无条件走异步路径的隐式依赖——虽然当前调用链保证了安全性,但缺乏防御性检查使得代码脆弱。建议在合入前添加防御性检查或文档化这个约束。

P0: 0, P1: 1, P2: 5, P3: 3

存在重要问题,建议处理 P1 后合入

@alibaba-miji alibaba-miji force-pushed the feature/support_qwen35_merge branch from 7a7163e to fa51037 Compare March 28, 2026 05:59
@LLLLKKKK
Copy link
Copy Markdown
Collaborator

🤖 AI Code Review — PR #794

PR 概述

Title: SwapAB + Async Load Cache
Author: alibaba-miji
规模: 37(GitHub) + 2(内源) files, +1917/-186
Review 类型: 检测到 force push/rebase,本次为 PR 全量 review(第 4 次)

核心目标

本 PR 包含四项改动:(1) 为 DeepGEMM FP8 GEMM 添加 swap-AB 优化(SM90 Hopper),小 M 场景下提升 decode 吞吐;(2) 将 FakeBalanceExpert 从 SelectTopkOp 拆分为独立 Op,改善关注点分离;(3) 将 PD 分离场景的 KV Cache 写入从同步改为异步线程池,消除主线程 CPU 阻塞;(4) 内源 smoke 测试 golden 数据更新。

与上次 Review 的差异

新增 1 个 commit(fa51037371c3),针对上次 P1 反馈做了以下改进:

  • CacheStoreAsyncWritersetCacheStore() 延迟创建改为在 DeviceBase 构造函数中始终创建,消除空指针风险
  • 移除了 PyWrappedModel.ccWriteCacheStoreOp.cc 中所有 cache_store_async_writer_ != nullptr 的运行时检查(因为始终非空)
  • WriteCacheStoreOpcreateEvent() 从后台线程移到主线程调用,通过 pre_created_event 传递给 writeCacheStore,避免 cudaEventRecord 在后台线程的竞争
  • CacheStoreInputs 新增 pre_created_event 字段(默认 nullptr,C++ 同步路径不受影响)

改动逻辑拆解

GitHub 开源仓库变更(主要代码)

1. DeepGEMM Swap-AB 优化

  • 新增 C++ plugin ops:DeepGemmPluginOp.cc/h,提供 deep_gemm_fp8(normal GEMM)和 deep_gemm_grouped_fp8_masked(grouped masked GEMM)
  • 新增配置项 deep_gemm_use_swap_ab(默认 true),通过 HWKernelConfig 传播,运行时检查 SM90
  • Python 侧新增 fp8_gemm_nt_swapabm_grouped_fp8_gemm_nt_masked_swapab wrapper
  • CudaFp8DeepGEMMLinear.forwardM < 32enable_swapab() 时走 swap-AB 路径
  • DeepGemmMaskedExecutorexpected_m < 64enable_swapab() 时走 swap-AB 路径
  • deepep_wrapper.pycalc_low_latency_max_token_per_rank 新增 swap-AB 场景的 matched_tokens 列表,并修复边界条件 > 128>= 128

2. FakeBalanceExpert 拆分

  • SelectTopkOp 中移除 fake_balance_expert 逻辑,构造函数简化为只接受 ModelConfig
  • 新增独立的 FakeBalanceExpertOp(C++ + pybind)和 Python wrapper FakeBalanceExpert
  • generic_moe.py 中分别实例化,forward 中顺序调用
  • ROCm 侧添加 NotImplementedOp 占位

3. Async KV Cache Write

  • 新增 CacheStoreAsyncWriter 类:基于 autil::LockFreeThreadPool(30 线程,队列 10000),严格生命周期 init → submit* → waitAllDone
  • DeviceBase 构造函数中始终创建 cache_store_async_writer_
  • WriteCacheStoreOp.cc 重构:torch::Tensor 按值捕获,主线程创建 event 后提交到异步 writer
  • PyWrappedModel::forwardforwardMicroBatched 在 PD 分离非 warmup 场景下调用 init() / waitAllDone()
  • 单元测试覆盖生命周期、并发、异常传播等场景

4. 配置与格式化

  • HWKernelConfig 新增 deep_gemm_use_swap_ab,pickle tuple size 17→18
  • hw_kernel_group_args.py 新增 --deep_gemm_use_swap_ab 命令行参数
  • backend_manager.py 在 DeepEP 初始化前调用 init_swapab_once

内源仓库变更

  • Smoke 测试 golden 数据更新

Checklist 检查结果

通用原则

软件工程原则

检查项 结果
SRP ✅ FakeBalanceExpert 拆分改善了 SRP
OCP ✅ swap-AB 通过配置开关和阈值判断
LSP
ISP ✅ SelectTopkOp 接口精简
DIP
DRY
KISS
YAGNI

架构审视

检查项 结果
抽象边界 cache_store_async_writer_ 作为 public 成员暴露,见 P2-1
依赖方向
状态完整性 ✅ 上次 P1 的空指针风险已通过构造函数始终创建解决
错误语义 ✅ 异常传播到 waitAllDone 后 rethrow
可观测性
可演进性
可运维性 ✅ swap-AB 可通过 --deep_gemm_use_swap_ab=false 关闭

测试

检查项 结果
新功能有对应测试
删除的测试有等价替代
边界 case 覆盖
分布式改动有多卡测试 N/A

代码质量与文档

检查项 结果
无关改动分离 ❌ hw_kernel_group_args.py 混入格式化改动,见 P2-3
mega-PR 拆分 ❌ 4 个独立功能合并,见 P2-4
Commit 原子性
Commit message 准确性
PR description
日志频率控制

领域检查

A. 兼容性与配置 — 全部 ✅

B. 正确性与逻辑 — 全部 ✅

C. 线程安全与并发

检查项 结果
静态/全局状态 atomic _swap_ab_checked/_swap_ab_enabled 无线程保护,见 P2-5
其余

D. 性能 — 全部 ✅

E. 分布式 — 全部 ✅

F. 跨平台 — 全部 ✅

G. 语言与框架特有 — 全部 ✅

H. 测试与 CI — 全部 ✅

I. 代码质量 — 全部 ✅


Review 意见

上次 P1 跟踪

上次 P1: WriteCacheStoreOp 无条件走异步路径降级为 P2-2

上次指出 WriteCacheStoreOp 无条件调用 submit(),若 async writer 处于 IDLE 状态会 crash。本次 commit 解决了空指针问题(构造函数始终创建 writer),但未添加 IDLE 状态的防御性检查。

经重新评估,降级为 P2,理由:

  • cache_store_inputs 仅在 pd_separation && !warmup 时通过 prepareWriteCacheParams 设置
  • init() 也仅在同一条件下调用,两者在同一个 if 块中配对
  • 调用链的隐式保证可靠,不存在绕过路径
  • 新增的 pre_created_event 机制正确地将 cudaEventRecord 从后台线程移到主线程

问题

  1. cache_store_async_writer_ 作为 public 成员暴露 [P2]

    DeviceBase.hcache_store_async_writer_ 声明为 public,破坏了封装。建议改为 private/protected,通过封装方法暴露。

  2. WriteCacheStoreOp 对 async writer RUNNING 状态的隐式依赖 [P2]

    当前调用链保证安全,但建议在 WriteCacheStoreOp.ccsubmit 调用前添加注释说明前置条件。

  3. hw_kernel_group_args.py 混入大量格式化改动 [P2]

    约 80% 的 diff 是单引号改双引号和缩进调整,与功能无关。建议后续分离。

  4. PR 包含 4 个独立功能 [P2]

    SwapAB、FakeBalanceExpert 拆分、Async Cache Write、smoke golden 更新可独立回滚。当前 commit 结构已按功能分离,可接受。

  5. _swap_ab_checked / _swap_ab_enabled 全局变量无线程保护 [P2]

    实际在主线程调用,风险低。建议使用 threading.Lockfunctools.cache 保护。

小问题

  • deep_gemm_fp8padded_output.slice(0, 0, m).contiguous() 会产生额外 GPU 拷贝。小 M 影响不大。[P3]
  • CacheStoreAsyncWriter 硬编码 30 线程和 10000 队列大小,建议注释说明选择依据。[P3]
  • deepgemm_masked_executor.py Down GroupGEMM-1 路径新增 get_mn_major_tma_aligned_tensor(down_input_scale) 处理,非 swap-AB 分支也使用了 aligned scale。需确认对非 swap-AB 路径的正确性。[P3]

整体评价

上次 P1 的核心风险——空指针——已通过构造函数始终创建 writer 解决。IDLE 状态的隐式依赖经重新评估,调用链保证可靠,降级为 P2。新增的 pre_created_event 机制正确地将 cudaEventRecord 从后台线程移到主线程。整体代码质量良好,架构设计清晰。

P0: 0, P1: 0, P2: 5, P3: 3

LGTM ready to ci

@alibaba-miji alibaba-miji force-pushed the feature/support_qwen35_merge branch 3 times, most recently from 47e4dd1 to 083563d Compare April 2, 2026 14:49
@alibaba-miji alibaba-miji force-pushed the feature/support_qwen35_merge branch from 083563d to d4eaf8b Compare April 3, 2026 07:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants