Skip to content

Fix singleton-batch matmul accumulator layout#1237

Open
zhaozhaozz wants to merge 2 commits intohw-native-sys:mainfrom
zhaozhaozz:fix/singleton-batch-matmul-acc
Open

Fix singleton-batch matmul accumulator layout#1237
zhaozhaozz wants to merge 2 commits intohw-native-sys:mainfrom
zhaozhaozz:fix/singleton-batch-matmul-acc

Conversation

@zhaozhaozz
Copy link
Copy Markdown
Contributor

@zhaozhaozz zhaozhaozz commented Apr 30, 2026

Summary

  • keep singleton-batch flattened matmul_acc dummy initializers in Acc with an Acc-compatible TileView
  • avoid no-op singleton full-page tile.slice that lowers to unsupported Mat->Mat tmov
  • normalize implicit TileView semantics in PTO tile buffer/store codegen

Validation

  • python -m pytest tests/ut/ir/transforms/test_infer_tile_memory_space.py tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py tests/ut/codegen/test_pto_codegen.py -q
  • pre-commit run --show-diff-on-failure --color=always --all-files
  • python tests/lint/clang_tidy.py --diff-base HEAD
  • pypto-lib/temp/matmul.py --platform a2a3sim

Fixes #1235

Copilot AI review requested due to automatic review settings April 30, 2026 06:31
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Apr 30, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Detects singleton-batch loop-carried accumulators used by tile.batch_matmul_acc, records their init vars, and ensures their materialized tiles are placed in Acc memory; also refines flattening to avoid unnecessary tile.slice when operand is already the expected 2D page.

Changes

Cohort / File(s) Summary
FlattenTileNdTo2D Pass
src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp
Recursively scans loop bodies to find iter-args used as tile.batch_matmul_acc accumulators when batch product==1, records those init vars, sets target_memory=MemorySpace::Acc when materializing tile.create/tile.full, and skips redundant tile.slice for already-flattened operands.
Regression Test
tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py
Adds a regression that constructs a singleton-batch accumulator loop, runs FlattenTileNdTo2D (verification disabled), and asserts the flattened IR keeps accumulator in Acc (no tile.move/batch-matmul ops and presence of target_memory=pl.Mem.Acc).
PTO / TileView normalization
src/backend/common/pto_ops_common.cpp, src/codegen/pto/pto_type_utils.cpp
Normalize/derive an implicit TileView from tile_view_, shape_, and memory_space_ instead of requiring an explicit tile_view_; use normalized valid_shape/layout semantics for 2D validation and TileTypeComponents extraction.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • lyfne123
  • Hzfengsy

Poem

🐰 In loops where single batches sat,
I sniffed the init — no Vec-to-Acc spat.
I hop, I mark, I tuck it snug in Acc,
No needless moves, just a cleaner pass. ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Fix singleton-batch matmul accumulator layout' is clear, concise, and directly relates to the main change: ensuring singleton-batch accumulators remain in Acc memory rather than Vec, avoiding unnecessary moves.
Description check ✅ Passed The pull request description is well-structured and related to the changeset, covering the summary of changes (singleton-batch accumulator layout, TileView normalization), validation steps, and issue references.
Linked Issues check ✅ Passed The pull request addresses all primary objectives from issue #1235: detecting singleton-batch tile.batch_matmul_acc loop initializers, flattening them directly to Mem.Acc to avoid Vec↔Acc round-trips, adding regression tests, and normalizing TileView semantics in PTO codegen.
Out of Scope Changes check ✅ Passed All changes are within scope: FlattenTileNdTo2D pass enhancements for singleton-batch accumulator handling, regression test coverage, and PTO codegen TileView normalization directly support the linked issue objectives.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
Review rate limit: 0/1 reviews remaining, refill in 60 minutes.

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the flattening of ND tiles to 2D by ensuring that singleton-batch accumulator initializations are kept in accumulator memory rather than being moved to vector memory. This avoids unnecessary memory round-trips when lowering tile.batch_matmul_acc to tile.matmul_acc. The changes include new utility functions for identifying singleton batch dimensions and collecting matmul operands, as well as updated logic in the transformation pass to handle loop-carried iterators. Feedback is provided regarding the manual IR traversal logic, which should use a visitor for better maintainability, and the duplication of logic between ForStmt and WhileStmt processing.

Comment on lines +243 to +280
void CollectBatchMatmulAccOperands(const std::vector<StmtPtr>& stmts,
std::unordered_set<const Var*>& acc_operands) {
for (const auto& stmt : stmts) {
if (auto assign = As<AssignStmt>(stmt)) {
if (auto call = As<Call>(assign->value_)) {
if (call->op_ && call->op_->name_ == "tile.batch_matmul_acc" && !call->args_.empty()) {
if (auto acc = AsVarLike(call->args_[0])) {
acc_operands.insert(acc.get());
}
}
}
continue;
}
if (auto seq = As<SeqStmts>(stmt)) {
CollectBatchMatmulAccOperands(seq->stmts_, acc_operands);
continue;
}
if (auto scope = As<ScopeStmt>(stmt)) {
CollectBatchMatmulAccOperands(FlattenToStmts(scope->body_), acc_operands);
continue;
}
if (auto if_stmt = As<IfStmt>(stmt)) {
CollectBatchMatmulAccOperands(FlattenToStmts(if_stmt->then_body_), acc_operands);
if (if_stmt->else_body_.has_value()) {
CollectBatchMatmulAccOperands(FlattenToStmts(*if_stmt->else_body_), acc_operands);
}
continue;
}
if (auto for_stmt = As<ForStmt>(stmt)) {
CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), acc_operands);
continue;
}
if (auto while_stmt = As<WhileStmt>(stmt)) {
CollectBatchMatmulAccOperands(FlattenToStmts(while_stmt->body_), acc_operands);
continue;
}
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function manually traverses different statement types to find tile.batch_matmul_acc calls. This logic should be simplified and made more maintainable by using an IRVisitor. Using a visitor ensures that the traversal is recursive, which is necessary to correctly handle nested control flow structures like SeqStmts and ScopeStmt.

References
  1. Helper functions that traverse IR statements to find specific nodes must be recursive to handle nested control flow structures like SeqStmts and ScopeStmt.

Comment on lines +1287 to +1313
std::unordered_set<const Var*> batch_matmul_acc_operands;
CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), batch_matmul_acc_operands);
for (const auto& ia : for_stmt->iter_args_) {
CountVarRefs(ia->initValue_);
auto init_var = As<Var>(ia->initValue_);
auto init_tile = init_var ? As<TileType>(init_var->GetType()) : nullptr;
if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") &&
batch_matmul_acc_operands.count(ia.get())) {
singleton_batch_acc_init_vars.insert(init_var.get());
}
}
continue;
}
// WhileStmt: count condition and iter_arg init Var refs.
if (auto while_stmt = As<WhileStmt>(s)) {
CountVarRefs(while_stmt->condition_);
for (const auto& ia : while_stmt->iter_args_) CountVarRefs(ia->initValue_);
std::unordered_set<const Var*> batch_matmul_acc_operands;
CollectBatchMatmulAccOperands(FlattenToStmts(while_stmt->body_), batch_matmul_acc_operands);
for (const auto& ia : while_stmt->iter_args_) {
CountVarRefs(ia->initValue_);
auto init_var = As<Var>(ia->initValue_);
auto init_tile = init_var ? As<TileType>(init_var->GetType()) : nullptr;
if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") &&
batch_matmul_acc_operands.count(ia.get())) {
singleton_batch_acc_init_vars.insert(init_var.get());
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for processing iter_args in ForStmt and WhileStmt is identical. This duplicated code could be extracted into a helper function or a lambda to improve readability and maintainability.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates FlattenTileNdTo2D to keep singleton-batch tile.batch_matmul_acc loop accumulator initializers in Mem.Acc, avoiding unnecessary Vec↔Acc moves on the batch=1 fast path and adding a regression test for the loop-carried accumulator pattern.

Changes:

  • Add analysis in FlattenTileNdTo2D to detect singleton-batch loop accumulator init vars used by tile.batch_matmul_acc.
  • Force flattened tile.create/tile.full for those accumulator inits to use target_memory=Acc.
  • Add a unit test ensuring singleton-batch loop accumulator init flattens to Acc with no tile.move emitted.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp Detects singleton-batch loop accumulator init vars and preserves Acc memory to avoid Vec/Acc round-trips.
tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py Adds regression coverage for singleton-batch loop-carried accumulator initialization staying in Acc.

Comment on lines +1287 to +1297
std::unordered_set<const Var*> batch_matmul_acc_operands;
CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), batch_matmul_acc_operands);
for (const auto& ia : for_stmt->iter_args_) {
CountVarRefs(ia->initValue_);
auto init_var = As<Var>(ia->initValue_);
auto init_tile = init_var ? As<TileType>(init_var->GetType()) : nullptr;
if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") &&
batch_matmul_acc_operands.count(ia.get())) {
singleton_batch_acc_init_vars.insert(init_var.get());
}
}
Copy link

Copilot AI Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The singleton-batch accumulator init detection only runs on ForStmt/WhileStmt nodes that are directly present in the current stmts vector. If the tile.create([1,...]) init is defined in an outer block while the loop is nested under a ScopeStmt/IfStmt/SeqStmts, this set won’t be populated in the outer TransformBody invocation, so the initializer won’t be forced to MemorySpace::Acc and the original Vec/Acc round-trip can remain. Consider performing this analysis via a recursive walk over the full statement tree (similar to CollectBatchMatmulAccOperands) and sharing the result across recursive TransformBody calls.

Copilot uses AI. Check for mistakes.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@zhaozhaozz zhaozhaozz changed the title fix(ir): Keep singleton batch matmul accumulators in Acc Fix singleton-batch matmul accumulator layout Apr 30, 2026
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py (1)

1783-1790: ⚡ Quick win

Prefer structural IR assertions over pretty-string matching.

The str(function) substring checks are fragile to printer formatting changes. Consider asserting on ir.Call op names/kwargs directly for memory-space expectations.

♻️ Suggested direction
-        ir_str = str(after.get_function("main_incore_0"))
-
-        assert "tile.batch_matmul" not in ir_str
-        assert "tile.batch_matmul_acc" not in ir_str
-        assert "tile.slice" not in ir_str
-        assert "tile.move" not in ir_str
-        assert "target_memory=pl.Mem.Acc" in ir_str
-        assert "pl.Mem.Acc, pl.TileView(blayout=pl.TileLayout.row_major" not in ir_str
+        fn = after.get_function("main_incore_0")
+        assert fn is not None
+        body = cast(ir.SeqStmts, fn.body)
+        calls = [
+            s.value
+            for s in body.stmts
+            if isinstance(s, ir.AssignStmt) and isinstance(s.value, ir.Call)
+        ]
+        names = [c.op.name for c in calls]
+        assert "tile.batch_matmul" not in names
+        assert "tile.batch_matmul_acc" not in names
+        assert "tile.slice" not in names
+        assert "tile.move" not in names
+        create_calls = [c for c in calls if c.op.name == "tile.create"]
+        assert any(c.kwargs.get("target_memory") == pl.MemorySpace.Acc for c in create_calls)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py` around lines 1783 -
1790, Replace fragile pretty-string assertions in the test that call
str(after.get_function("main_incore_0")) with structural checks over the IR Call
nodes: iterate the calls from after.get_function("main_incore_0") (e.g., inspect
its body or use the helper that yields ir.Call nodes), assert no call has op
name "tile.batch_matmul", "tile.batch_matmul_acc", "tile.slice", or "tile.move",
and assert that at least one Call has kwargs["target_memory"] == pl.Mem.Acc
while none have a tile view kwarg like kwargs.get("view") whose layout equals
pl.TileLayout.row_major (or similar field used for tile views); use ir.Call
op.name and kwargs access rather than substring matching to validate
memory-space and tile-view properties.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp`:
- Around line 692-706: The pass currently CHECKs that AsVarLike(operand) is
non-null for the 2D singleton-page fast path (in FlattenTileNdTo2D) which aborts
on valid inline expressions; change the logic to try AsVarLike(operand) and if
it returns null fall back to creating the same tile.slice + Var + AssignStmt
sequence used in the else branch (use the same offset/shape and page.stmts push)
instead of calling CHECK, so inline expressions are handled by emitting a
slice-backed Var-like temporary rather than crashing.

---

Nitpick comments:
In `@tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py`:
- Around line 1783-1790: Replace fragile pretty-string assertions in the test
that call str(after.get_function("main_incore_0")) with structural checks over
the IR Call nodes: iterate the calls from after.get_function("main_incore_0")
(e.g., inspect its body or use the helper that yields ir.Call nodes), assert no
call has op name "tile.batch_matmul", "tile.batch_matmul_acc", "tile.slice", or
"tile.move", and assert that at least one Call has kwargs["target_memory"] ==
pl.Mem.Acc while none have a tile view kwarg like kwargs.get("view") whose
layout equals pl.TileLayout.row_major (or similar field used for tile views);
use ir.Call op.name and kwargs access rather than substring matching to validate
memory-space and tile-view properties.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: cdf798d0-71e1-42f5-8d3f-63e64429fff6

📥 Commits

Reviewing files that changed from the base of the PR and between 5d3e061 and 8387756.

📒 Files selected for processing (4)
  • src/backend/common/pto_ops_common.cpp
  • src/codegen/pto/pto_type_utils.cpp
  • src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp
  • tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py

Comment on lines +692 to +706
auto flat_rows = As<ConstInt>(operand_type->shape_[0]);
auto flat_cols = As<ConstInt>(operand_type->shape_[1]);
if (batch_index == 0 && flat_rows && flat_cols && flat_rows->value_ == source_rows &&
flat_cols->value_ == source_cols) {
// Singleton-batch operands are already the exact 2D page. Avoid a full-tile
// slice, which would lower to an unsupported Mat->Mat tmov on a2a3.
current = AsVarLike(operand);
CHECK(current) << "FlattenTileNdTo2D: expected 2D batch_matmul operand to be Var-like";
} else {
auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span);
auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span);
auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span);
current = std::make_shared<Var>(base_name + "_slice_" + suffix, slice->GetType(), span);
page.stmts.push_back(std::make_shared<AssignStmt>(current, slice, span));
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Avoid hard-failing when the 2D singleton-page operand is not Var-like

Line 698–699 assumes the operand is always Var/IterArg and CHECKs otherwise. This can abort the pass for valid IR where the operand is an inline expression. Please fall back to the existing tile.slice branch when AsVarLike(operand) is null, instead of crashing.

Suggested fix
-    if (batch_index == 0 && flat_rows && flat_cols && flat_rows->value_ == source_rows &&
-        flat_cols->value_ == source_cols) {
+    if (batch_index == 0 && flat_rows && flat_cols && flat_rows->value_ == source_rows &&
+        flat_cols->value_ == source_cols) {
       // Singleton-batch operands are already the exact 2D page. Avoid a full-tile
       // slice, which would lower to an unsupported Mat->Mat tmov on a2a3.
-      current = AsVarLike(operand);
-      CHECK(current) << "FlattenTileNdTo2D: expected 2D batch_matmul operand to be Var-like";
+      if (auto operand_var = AsVarLike(operand)) {
+        current = operand_var;
+      } else {
+        auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span);
+        auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span);
+        auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span);
+        current = std::make_shared<Var>(base_name + "_slice_" + suffix, slice->GetType(), span);
+        page.stmts.push_back(std::make_shared<AssignStmt>(current, slice, span));
+      }
     } else {
       auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span);
       auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span);
       auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span);
       current = std::make_shared<Var>(base_name + "_slice_" + suffix, slice->GetType(), span);
       page.stmts.push_back(std::make_shared<AssignStmt>(current, slice, span));
     }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp` around lines 692 - 706, The
pass currently CHECKs that AsVarLike(operand) is non-null for the 2D
singleton-page fast path (in FlattenTileNdTo2D) which aborts on valid inline
expressions; change the logic to try AsVarLike(operand) and if it returns null
fall back to creating the same tile.slice + Var + AssignStmt sequence used in
the else branch (use the same offset/shape and page.stmts push) instead of
calling CHECK, so inline expressions are handled by emitting a slice-backed
Var-like temporary rather than crashing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

[Pass Bug] FlattenTileNdTo2D inserts Vec/Acc moves for singleton batch_matmul_acc init

2 participants