fix(v4): triton-3.6+ MoE SMEM OOM + empty-batch indexer prefill#710
Merged
fix(v4): triton-3.6+ MoE SMEM OOM + empty-batch indexer prefill#710
Conversation
Two V4-Pro CI failures surfaced after PR #703 unblocked the loader path: 1. Triton MoE OOM on AMD CDNA4 with triton 3.6+/3.7+ - `triton_kernels.matmul_ogs_details.opt_flags_amd` has a CDNA4 special case `if cdna4 and block_m == 128: block_n = 512`, giving BLOCK_M*BLOCK_N = 64K FP32 acc entries. triton 3.6+ spills the accumulator to LDS more aggressively than 3.5, exceeding the MI355X 160 KiB LDS budget (observed 269 KiB). - Fix: wrap matmul_ogs calls with a CDNA4-only context manager that pins block_m=64 / block_n=256 (BLOCK_M*BLOCK_N = 16K, fits regs). Tunable via `ATOM_TRITON_MOE_BLOCK_{M,N}` env vars. - Other GPU families and triton ≤3.5 paths are unaffected. 2. `cp_gather_indexer_k_quant_cache` HIP "invalid configuration argument" when `cu_committed_cpu[-1] == 0` (fresh prefill with prompt shorter than the CSA `ratio`). The kernel grid is computed as `(num_tokens + BLOCK_Y_SIZE - 1) / BLOCK_Y_SIZE` so `num_tokens=0` makes grid.x=0 and the HIP launcher rejects it. - Fix: clamp `cu_committed_cpu[-1]` to ≥1 in the indexer-meta builder. The dummy +1 row is gathered from the last seq's first cache block but never read downstream, because `fp8_mqa_logits` and `top_k_per_row_prefill` honor per-token `cu_starts`/`cu_ends` derived from `cu_committed_gpu[:-1]` and `n_committed_per_seq`, both of which remain 0. Output stays all -1 sentinels, matching the all-empty semantics. Pure host-side scalar arithmetic on a value already host-synced; no CG/torch.compile graph branch added. Verified locally with triton 3.7.0+amd.rocm7.1.0: - DeepSeek-V4-Pro server starts (no OOM) - 1-token "Hi" curl returns successfully (was crashing pre-fix) - GSM8K-50 fewshot=5 = 0.94 (matches pre-PR-703 baseline)
Contributor
There was a problem hiding this comment.
Pull request overview
This PR addresses two ROCm/CDNA4 regressions surfaced in CI for DeepSeek-V4: (1) Triton 3.6+ MoE matmul tile choices causing LDS/SMEM OOM on gfx950, and (2) an empty-committed prefill case producing a zero-sized HIP grid for the indexer FP8 gather path.
Changes:
- Add a CDNA4 (gfx950)-scoped context manager to constrain
triton_kernels.matmul_ogstiling during MoE matmuls to avoid LDS overflow on Triton 3.6+. - Clamp the V4 indexer prefill metadata’s
cu_committed_cpu[-1]to at least 1 to preventcp_gather_indexer_k_quant_cachefrom launching withgrid.x = 0.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
atom/model_ops/fused_moe_triton.py |
Adds a gfx950-only context manager wrapping matmul_ogs calls to constrain MoE tiling via env-tunable parameters. |
atom/model_ops/attentions/deepseek_v4_attn.py |
Adds an empty-committed guard by bumping the final cu_committed cumsum to avoid a zero-grid kernel launch. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+57
to
+59
| Pin block_n ≤ ATOM_TRITON_MOE_MAX_BLOCK_N (default 256) so BLOCK_M*BLOCK_N | ||
| stays at 32K. Default block_n in compute_block_nk is already capped at | ||
| 256 except for that single cdna4 branch, so this only sidesteps the bad |
Comment on lines
+74
to
+78
| # acc), comfortably fitting MI355X's register file. Override via env if | ||
| # a future compiler/kernel update relaxes the budget. | ||
| block_m = int(os.getenv("ATOM_TRITON_MOE_BLOCK_M", "64")) | ||
| block_n = int(os.getenv("ATOM_TRITON_MOE_BLOCK_N", "256")) | ||
| update_opt_flags_constraints({"block_m": block_m, "block_n": block_n}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Two CI failures surfaced after #703 unblocked the V4-Pro accuracy step.
1. Triton MoE SMEM OOM on AMD CDNA4 (triton 3.6+)
triton_kernels.matmul_ogs_details.opt_flags_amdhas a CDNA4 special case:```python
if get_cdna_version() == 4 and block_m == 128:
block_n = 512
```
Producing BLOCK_M*BLOCK_N = 64K FP32 acc entries. triton 3.6+/3.7+ spills the accumulator to LDS more aggressively than 3.5, exceeding MI355X's 160 KiB LDS budget (observed 269 KiB required vs 163 KiB hardware limit on V4-Pro FP8 MoE).
Fix: wrap matmul_ogs calls with a CDNA4-only context manager that pins `block_m=64` / `block_n=256` (BLOCK_M*BLOCK_N = 16K, fits comfortably). Tunable via `ATOM_TRITON_MOE_BLOCK_{M,N}` env vars. Other GPU families and triton ≤3.5 paths are unaffected.
2. `cp_gather_indexer_k_quant_cache` HIP "invalid configuration argument"
When the indexer prefill builder produces `cu_committed_cpu[-1] == 0` (fresh prefill with prompt shorter than the CSA `ratio`), the kernel grid:
```cpp
dim3((num_tokens + BLOCK_Y_SIZE - 1) / BLOCK_Y_SIZE, ...)
```
becomes `grid.x = 0` and the HIP launcher rejects it.
Fix: clamp `cu_committed_cpu[-1]` to ≥1 in `_build_v4_indexer_meta`. The dummy +1 row gets gathered from the last seq's first cache block but is never read downstream — `fp8_mqa_logits` and `top_k_per_row_prefill` honor per-token `cu_starts`/`cu_ends` derived from `cu_committed_gpu[:-1]` and `n_committed_per_seq`, which both remain 0. Output stays all -1 sentinels, matching the all-empty semantics.
This is pure host-side scalar arithmetic on a value already host-synced (`int(cu_committed_cpu[-1])`); no new CG/torch.compile graph branch is introduced.
Local Verification
triton 3.7.0+amd.rocm7.1.0 (closest to CI's 3.6.0+rocm7.2.3, both reproduce the OOM):
Test plan