Fix singleton-batch matmul accumulator layout#1237
Fix singleton-batch matmul accumulator layout#1237zhaozhaozz wants to merge 2 commits intohw-native-sys:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughDetects singleton-batch loop-carried accumulators used by Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Review rate limit: 0/1 reviews remaining, refill in 60 minutes.Comment |
There was a problem hiding this comment.
Code Review
This pull request optimizes the flattening of ND tiles to 2D by ensuring that singleton-batch accumulator initializations are kept in accumulator memory rather than being moved to vector memory. This avoids unnecessary memory round-trips when lowering tile.batch_matmul_acc to tile.matmul_acc. The changes include new utility functions for identifying singleton batch dimensions and collecting matmul operands, as well as updated logic in the transformation pass to handle loop-carried iterators. Feedback is provided regarding the manual IR traversal logic, which should use a visitor for better maintainability, and the duplication of logic between ForStmt and WhileStmt processing.
| void CollectBatchMatmulAccOperands(const std::vector<StmtPtr>& stmts, | ||
| std::unordered_set<const Var*>& acc_operands) { | ||
| for (const auto& stmt : stmts) { | ||
| if (auto assign = As<AssignStmt>(stmt)) { | ||
| if (auto call = As<Call>(assign->value_)) { | ||
| if (call->op_ && call->op_->name_ == "tile.batch_matmul_acc" && !call->args_.empty()) { | ||
| if (auto acc = AsVarLike(call->args_[0])) { | ||
| acc_operands.insert(acc.get()); | ||
| } | ||
| } | ||
| } | ||
| continue; | ||
| } | ||
| if (auto seq = As<SeqStmts>(stmt)) { | ||
| CollectBatchMatmulAccOperands(seq->stmts_, acc_operands); | ||
| continue; | ||
| } | ||
| if (auto scope = As<ScopeStmt>(stmt)) { | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(scope->body_), acc_operands); | ||
| continue; | ||
| } | ||
| if (auto if_stmt = As<IfStmt>(stmt)) { | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(if_stmt->then_body_), acc_operands); | ||
| if (if_stmt->else_body_.has_value()) { | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(*if_stmt->else_body_), acc_operands); | ||
| } | ||
| continue; | ||
| } | ||
| if (auto for_stmt = As<ForStmt>(stmt)) { | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), acc_operands); | ||
| continue; | ||
| } | ||
| if (auto while_stmt = As<WhileStmt>(stmt)) { | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(while_stmt->body_), acc_operands); | ||
| continue; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
This function manually traverses different statement types to find tile.batch_matmul_acc calls. This logic should be simplified and made more maintainable by using an IRVisitor. Using a visitor ensures that the traversal is recursive, which is necessary to correctly handle nested control flow structures like SeqStmts and ScopeStmt.
References
- Helper functions that traverse IR statements to find specific nodes must be recursive to handle nested control flow structures like SeqStmts and ScopeStmt.
| std::unordered_set<const Var*> batch_matmul_acc_operands; | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), batch_matmul_acc_operands); | ||
| for (const auto& ia : for_stmt->iter_args_) { | ||
| CountVarRefs(ia->initValue_); | ||
| auto init_var = As<Var>(ia->initValue_); | ||
| auto init_tile = init_var ? As<TileType>(init_var->GetType()) : nullptr; | ||
| if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") && | ||
| batch_matmul_acc_operands.count(ia.get())) { | ||
| singleton_batch_acc_init_vars.insert(init_var.get()); | ||
| } | ||
| } | ||
| continue; | ||
| } | ||
| // WhileStmt: count condition and iter_arg init Var refs. | ||
| if (auto while_stmt = As<WhileStmt>(s)) { | ||
| CountVarRefs(while_stmt->condition_); | ||
| for (const auto& ia : while_stmt->iter_args_) CountVarRefs(ia->initValue_); | ||
| std::unordered_set<const Var*> batch_matmul_acc_operands; | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(while_stmt->body_), batch_matmul_acc_operands); | ||
| for (const auto& ia : while_stmt->iter_args_) { | ||
| CountVarRefs(ia->initValue_); | ||
| auto init_var = As<Var>(ia->initValue_); | ||
| auto init_tile = init_var ? As<TileType>(init_var->GetType()) : nullptr; | ||
| if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") && | ||
| batch_matmul_acc_operands.count(ia.get())) { | ||
| singleton_batch_acc_init_vars.insert(init_var.get()); | ||
| } | ||
| } |
There was a problem hiding this comment.
Pull request overview
Updates FlattenTileNdTo2D to keep singleton-batch tile.batch_matmul_acc loop accumulator initializers in Mem.Acc, avoiding unnecessary Vec↔Acc moves on the batch=1 fast path and adding a regression test for the loop-carried accumulator pattern.
Changes:
- Add analysis in
FlattenTileNdTo2Dto detect singleton-batch loop accumulator init vars used bytile.batch_matmul_acc. - Force flattened
tile.create/tile.fullfor those accumulator inits to usetarget_memory=Acc. - Add a unit test ensuring singleton-batch loop accumulator init flattens to
Accwith notile.moveemitted.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp |
Detects singleton-batch loop accumulator init vars and preserves Acc memory to avoid Vec/Acc round-trips. |
tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py |
Adds regression coverage for singleton-batch loop-carried accumulator initialization staying in Acc. |
| std::unordered_set<const Var*> batch_matmul_acc_operands; | ||
| CollectBatchMatmulAccOperands(FlattenToStmts(for_stmt->body_), batch_matmul_acc_operands); | ||
| for (const auto& ia : for_stmt->iter_args_) { | ||
| CountVarRefs(ia->initValue_); | ||
| auto init_var = As<Var>(ia->initValue_); | ||
| auto init_tile = init_var ? As<TileType>(init_var->GetType()) : nullptr; | ||
| if (init_var && HasSingletonBatchDims(init_tile, "tile.batch_matmul_acc init") && | ||
| batch_matmul_acc_operands.count(ia.get())) { | ||
| singleton_batch_acc_init_vars.insert(init_var.get()); | ||
| } | ||
| } |
There was a problem hiding this comment.
The singleton-batch accumulator init detection only runs on ForStmt/WhileStmt nodes that are directly present in the current stmts vector. If the tile.create([1,...]) init is defined in an outer block while the loop is nested under a ScopeStmt/IfStmt/SeqStmts, this set won’t be populated in the outer TransformBody invocation, so the initializer won’t be forced to MemorySpace::Acc and the original Vec/Acc round-trip can remain. Consider performing this analysis via a recursive walk over the full statement tree (similar to CollectBatchMatmulAccOperands) and sharing the result across recursive TransformBody calls.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py (1)
1783-1790: ⚡ Quick winPrefer structural IR assertions over pretty-string matching.
The
str(function)substring checks are fragile to printer formatting changes. Consider asserting onir.Callop names/kwargs directly for memory-space expectations.♻️ Suggested direction
- ir_str = str(after.get_function("main_incore_0")) - - assert "tile.batch_matmul" not in ir_str - assert "tile.batch_matmul_acc" not in ir_str - assert "tile.slice" not in ir_str - assert "tile.move" not in ir_str - assert "target_memory=pl.Mem.Acc" in ir_str - assert "pl.Mem.Acc, pl.TileView(blayout=pl.TileLayout.row_major" not in ir_str + fn = after.get_function("main_incore_0") + assert fn is not None + body = cast(ir.SeqStmts, fn.body) + calls = [ + s.value + for s in body.stmts + if isinstance(s, ir.AssignStmt) and isinstance(s.value, ir.Call) + ] + names = [c.op.name for c in calls] + assert "tile.batch_matmul" not in names + assert "tile.batch_matmul_acc" not in names + assert "tile.slice" not in names + assert "tile.move" not in names + create_calls = [c for c in calls if c.op.name == "tile.create"] + assert any(c.kwargs.get("target_memory") == pl.MemorySpace.Acc for c in create_calls)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py` around lines 1783 - 1790, Replace fragile pretty-string assertions in the test that call str(after.get_function("main_incore_0")) with structural checks over the IR Call nodes: iterate the calls from after.get_function("main_incore_0") (e.g., inspect its body or use the helper that yields ir.Call nodes), assert no call has op name "tile.batch_matmul", "tile.batch_matmul_acc", "tile.slice", or "tile.move", and assert that at least one Call has kwargs["target_memory"] == pl.Mem.Acc while none have a tile view kwarg like kwargs.get("view") whose layout equals pl.TileLayout.row_major (or similar field used for tile views); use ir.Call op.name and kwargs access rather than substring matching to validate memory-space and tile-view properties.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp`:
- Around line 692-706: The pass currently CHECKs that AsVarLike(operand) is
non-null for the 2D singleton-page fast path (in FlattenTileNdTo2D) which aborts
on valid inline expressions; change the logic to try AsVarLike(operand) and if
it returns null fall back to creating the same tile.slice + Var + AssignStmt
sequence used in the else branch (use the same offset/shape and page.stmts push)
instead of calling CHECK, so inline expressions are handled by emitting a
slice-backed Var-like temporary rather than crashing.
---
Nitpick comments:
In `@tests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py`:
- Around line 1783-1790: Replace fragile pretty-string assertions in the test
that call str(after.get_function("main_incore_0")) with structural checks over
the IR Call nodes: iterate the calls from after.get_function("main_incore_0")
(e.g., inspect its body or use the helper that yields ir.Call nodes), assert no
call has op name "tile.batch_matmul", "tile.batch_matmul_acc", "tile.slice", or
"tile.move", and assert that at least one Call has kwargs["target_memory"] ==
pl.Mem.Acc while none have a tile view kwarg like kwargs.get("view") whose
layout equals pl.TileLayout.row_major (or similar field used for tile views);
use ir.Call op.name and kwargs access rather than substring matching to validate
memory-space and tile-view properties.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: cdf798d0-71e1-42f5-8d3f-63e64429fff6
📒 Files selected for processing (4)
src/backend/common/pto_ops_common.cppsrc/codegen/pto/pto_type_utils.cppsrc/ir/transforms/flatten_tile_nd_to_2d_pass.cpptests/ut/ir/transforms/test_flatten_tile_nd_to_2d.py
| auto flat_rows = As<ConstInt>(operand_type->shape_[0]); | ||
| auto flat_cols = As<ConstInt>(operand_type->shape_[1]); | ||
| if (batch_index == 0 && flat_rows && flat_cols && flat_rows->value_ == source_rows && | ||
| flat_cols->value_ == source_cols) { | ||
| // Singleton-batch operands are already the exact 2D page. Avoid a full-tile | ||
| // slice, which would lower to an unsupported Mat->Mat tmov on a2a3. | ||
| current = AsVarLike(operand); | ||
| CHECK(current) << "FlattenTileNdTo2D: expected 2D batch_matmul operand to be Var-like"; | ||
| } else { | ||
| auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span); | ||
| auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span); | ||
| auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span); | ||
| current = std::make_shared<Var>(base_name + "_slice_" + suffix, slice->GetType(), span); | ||
| page.stmts.push_back(std::make_shared<AssignStmt>(current, slice, span)); | ||
| } |
There was a problem hiding this comment.
Avoid hard-failing when the 2D singleton-page operand is not Var-like
Line 698–699 assumes the operand is always Var/IterArg and CHECKs otherwise. This can abort the pass for valid IR where the operand is an inline expression. Please fall back to the existing tile.slice branch when AsVarLike(operand) is null, instead of crashing.
Suggested fix
- if (batch_index == 0 && flat_rows && flat_cols && flat_rows->value_ == source_rows &&
- flat_cols->value_ == source_cols) {
+ if (batch_index == 0 && flat_rows && flat_cols && flat_rows->value_ == source_rows &&
+ flat_cols->value_ == source_cols) {
// Singleton-batch operands are already the exact 2D page. Avoid a full-tile
// slice, which would lower to an unsupported Mat->Mat tmov on a2a3.
- current = AsVarLike(operand);
- CHECK(current) << "FlattenTileNdTo2D: expected 2D batch_matmul operand to be Var-like";
+ if (auto operand_var = AsVarLike(operand)) {
+ current = operand_var;
+ } else {
+ auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span);
+ auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span);
+ auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span);
+ current = std::make_shared<Var>(base_name + "_slice_" + suffix, slice->GetType(), span);
+ page.stmts.push_back(std::make_shared<AssignStmt>(current, slice, span));
+ }
} else {
auto offset = MakeShapeTupleFromInts({batch_index * source_rows, 0}, span);
auto shape = MakeShapeTupleFromInts({source_rows, source_cols}, span);
auto slice = op_registry.Create("tile.slice", {operand, shape, offset}, span);
current = std::make_shared<Var>(base_name + "_slice_" + suffix, slice->GetType(), span);
page.stmts.push_back(std::make_shared<AssignStmt>(current, slice, span));
}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp` around lines 692 - 706, The
pass currently CHECKs that AsVarLike(operand) is non-null for the 2D
singleton-page fast path (in FlattenTileNdTo2D) which aborts on valid inline
expressions; change the logic to try AsVarLike(operand) and if it returns null
fall back to creating the same tile.slice + Var + AssignStmt sequence used in
the else branch (use the same offset/shape and page.stmts push) instead of
calling CHECK, so inline expressions are handled by emitting a slice-backed
Var-like temporary rather than crashing.
Summary
Validation
Fixes #1235