diff --git a/docs/designs/ptoas-auto-sync-design.md b/docs/designs/ptoas-auto-sync-design.md index ffab92f70..b8f5c0a2d 100644 --- a/docs/designs/ptoas-auto-sync-design.md +++ b/docs/designs/ptoas-auto-sync-design.md @@ -113,10 +113,11 @@ - `eventIds`(`SmallVector`)+ `eventIdNum`:分配后填入的 event id 列表, 长度 = `eventIdNum`(多缓冲场景大于 1)。 - `depRootBuffers`:构造该同步对的依赖链所涉及的 root buffer 集合, - 用于 `RemoveRedundantSync` 中的 `hasSameSyncDepRoots` 判定,避免误删 - 「pipe 对相同但语义无关」的另一组同步。 + 用于分配/widening 阶段的启发式和调试信息;`RemoveRedundantSync` 删除 + set/wait 冗余时按 pipe pair 语义判断,不把 root buffer 相等作为必要条件。 - `uselessSync`:`RemoveRedundantSync` 命中后置为 true,最终从 `pipeBefore/After` 中移除。 -- `isCompensation`:是否为 loop 回边补偿同步(头/尾 comp_set / comp_wait)。 +- `isCompensation`:预留给分析阶段提前生成的 synthetic compensation sync。 + 当前 loop 回边的 head/tail 配对同步由 `SyncEventIdAllocation` 在冗余删除之后生成。 - `lowestCommonAncestorBuffer`、`reuseCntForWiden`、`reallocatedLoopHeadTailSync`: 分配阶段的辅助状态(widen/reallocate 用)。 @@ -266,13 +267,14 @@ MLIR op,最终生成 `pto::SetFlagOp` / `pto::WaitFlagOp` / `pto::BarrierOp` - 主要逻辑: 1. 收集所有 `syncOperations_[k].size()==2` 的 set/wait 对,按 `forEndIndex` / `kSyncIndex` 排序(内层优先)。 - 2. 对每对 `(setFlag, waitFlag)`:跳过多缓冲(`eventIdNum != 1`)、跳过补偿同步、 - 跳过 `depRootBuffers` 不一致的对,再走 `CheckAllSync`。 + 2. 对每对 `(setFlag, waitFlag)`:跳过 `isCompensation` 标记的预生成补偿同步, + 再走 `CheckAllSync`。 3. `CheckAllSync` → `CheckRepeatSync` 在 `[setIRIndex, waitIRIndex]` 区间扫描 `pipeBefore/pipeAfter`,遇到分支调 `CheckBranchBetween`、遇到循环调 `CheckLoopBetween`。`CanMatchedSync` 用一个 `SmallVector` 充当 `syncFinder` 状态机:先看见相同 `kSyncIndex` 之外的 set,置位;再看见对应 - wait 时确认配对,认为外层同步对被覆盖。 + wait 时确认配对;只要内部完整同步对的 pipe pair 相同且 `eventIdNum` + 不大于外部同步对,就认为外层同步对被覆盖。 4. 命中后置 `uselessSync=true`,并通过 `InstanceElement::RemoveSync` 从挂载队列摘除。 - 输出:被标记为冗余的 `SyncOperation` 仍保留在 `syncOperations_` 中 (便于打印/调试),但不再出现在任何 `SyncIR` 节点的 `pipeBefore/After`。 diff --git a/include/PTO/Transforms/InsertSync/SyncCommon.h b/include/PTO/Transforms/InsertSync/SyncCommon.h index dc1acd225..09aa4dc9d 100644 --- a/include/PTO/Transforms/InsertSync/SyncCommon.h +++ b/include/PTO/Transforms/InsertSync/SyncCommon.h @@ -154,8 +154,8 @@ class SyncOperation { public: SmallVector eventIds; // Root buffers participating in the dependency that created this sync pair. - // Used by redundant-sync pruning to avoid removing syncs from unrelated - // producer/consumer chains that happen to share the same pipe pair. + // These are kept for allocation/widening heuristics and debug printing; set/ + // wait redundancy pruning is based on the pipe pair semantics. SmallVector depRootBuffers; bool uselessSync{false}; int eventIdNum{1}; diff --git a/lib/PTO/Transforms/InsertSync/RemoveRedundantSync.cpp b/lib/PTO/Transforms/InsertSync/RemoveRedundantSync.cpp index 7ba44ad59..fe93c6bf2 100644 --- a/lib/PTO/Transforms/InsertSync/RemoveRedundantSync.cpp +++ b/lib/PTO/Transforms/InsertSync/RemoveRedundantSync.cpp @@ -60,18 +60,9 @@ void RemoveRedundantSync::Run() { // 3. 逐个检查并移除冗余 for (auto [setFlag, waitFlag] : syncOps) { - // Conservative mode: - // 1) keep multi-buffer and compensation syncs - // 2) only prune syncs that carry concrete dependency signatures - if (setFlag->eventIdNum != 1 || waitFlag->eventIdNum != 1) { - continue; - } if (setFlag->isCompensation || waitFlag->isCompensation) { continue; } - if (!hasSameSyncDepRoots(setFlag, waitFlag)) { - continue; - } bool useless = CheckAllSync(setFlag, waitFlag); if (useless) { @@ -225,12 +216,9 @@ bool RemoveRedundantSync::CheckLoopBetween(LoopInstanceElement *loopElement, bool RemoveRedundantSync::CanMatchedSync(SmallVector &syncFinder, SyncOperation *relatedSync, SyncOperation *setFlag) { - // 1. 过滤不相关的同步 - // - 类型必须匹配 (Wait/Set) - // - 不能是自己 (Index 相同) - // - Pipe 必须完全一致 (Src->Dst) - // - EventIdNum: 内部的同步能力必须强于外部 (related.eventIdNum >= set.eventIdNum ???) - // 这里暂时假设 Single Buffer (eventIdNum=1) 场景即可覆盖主流程 + // Set/wait flags serialize a pipe pair, not a particular root buffer. A + // complete inner pair on the same pipe pair can cover an outer pair even when + // the memory dependency roots differ. bool isWait = (relatedSync->GetType() == SyncOperation::TYPE::WAIT_EVENT); bool isSet = (relatedSync->GetType() == SyncOperation::TYPE::SET_EVENT); @@ -243,11 +231,8 @@ bool RemoveRedundantSync::CanMatchedSync(SmallVector &syncFinder, if (!isWait && !isSet) return false; if (relatedSync->GetSyncIndex() == setFlag->GetSyncIndex()) return false; - if (relatedSync->eventIdNum != setFlag->eventIdNum) return false; - if (relatedSync->GetForEndIndex() != setFlag->GetForEndIndex()) return false; + if (relatedSync->eventIdNum > setFlag->eventIdNum) return false; if (relatedSync->isCompensation || setFlag->isCompensation) return false; - if (!hasSameSyncDepRoots(relatedSync, setFlag)) - return false; // Pipe 检查:内部同步必须也是解决同样的 Src -> Dst 依赖 if (relatedSync->GetSrcPipe() != setFlag->GetSrcPipe()) return false; diff --git a/test/lit/pto/issue226_remove_redundant_pipe_pair.pto b/test/lit/pto/issue226_remove_redundant_pipe_pair.pto new file mode 100644 index 000000000..7b8f292f1 --- /dev/null +++ b/test/lit/pto/issue226_remove_redundant_pipe_pair.pto @@ -0,0 +1,56 @@ +// RUN: ptoas --pto-arch=a3 --enable-insert-sync %s | FileCheck %s +// +// Regression guard for issue #226: RemoveRedundantSync should prune an outer +// set/wait pair when the interval contains a complete inner set/wait pair on +// the same pipe pair, even if those syncs came from different dep roots. +// +// CHECK-LABEL: __global__ AICORE void remove_redundant_pipe_pair +// CHECK: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[M2MTE1:[0-9]+]]); +// CHECK: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[M2MTE1]]); +// CHECK-NOT: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[M2MTE1]]); +// CHECK: set_flag(PIPE_MTE1, PIPE_M, EVENT_ID[[MTE12M:[0-9]+]]); +// CHECK: wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID[[MTE12M]]); +// CHECK-NOT: wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID[[MTE12M]]); +// CHECK: ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + +module { + func.func @remove_redundant_pipe_pair( + %arg0: memref<64x1xf16, strided<[1, 1]>, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %vbuf0 = pto.bind_tile %arg0, %c64, %c1 + {config = #pto.tile_buf_config} + : memref<64x1xf16, strided<[1, 1]>, #pto.address_space> + -> memref<64x1xf16, strided<[1, 1], offset: ?>, #pto.address_space> + + pto.section.cube { + %mat_a = pto.alloc_tile : !pto.tile_buf + %mat_b = pto.alloc_tile : !pto.tile_buf + %left = pto.alloc_tile : !pto.tile_buf + %right = pto.alloc_tile : !pto.tile_buf + %acc = pto.alloc_tile : !pto.tile_buf + + scf.for %i = %c0 to %c2 step %c1 { + pto.tload ins(%vbuf0 : memref<64x1xf16, strided<[1, 1], offset: ?>, #pto.address_space>) + outs(%mat_a : !pto.tile_buf) + pto.tmov ins(%mat_a : !pto.tile_buf) + outs(%left : !pto.tile_buf) + pto.tmov ins(%mat_b : !pto.tile_buf) + outs(%right : !pto.tile_buf) + pto.tmatmul ins(%left, %right : !pto.tile_buf, !pto.tile_buf) + outs(%acc : !pto.tile_buf) + %is_first = arith.cmpi eq, %i, %c0 : index + scf.if %is_first { + pto.tmov ins(%acc : !pto.tile_buf) + outs(%mat_a : !pto.tile_buf) + } else { + pto.tmov ins(%acc : !pto.tile_buf) + outs(%mat_b : !pto.tile_buf) + } + } + } + return + } +} diff --git a/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto b/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto index f7367d179..ca02e3840 100644 --- a/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto +++ b/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto @@ -8,22 +8,21 @@ // CHECK-LABEL: AICORE void scope3_incore_0_aic( // CHECK: TMATMUL_ACC( // CHECK-NEXT: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[CARRY:[0-9]+]]); -// CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[PRE0:[0-9]+]]); -// CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[PRE1:[0-9]+]]); +// CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[PRE:[0-9]+]]); // CHECK-NEXT: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[CARRY]]); // CHECK-NEXT: set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD0:[0-9]+]]); // CHECK-NEXT: set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD1:[0-9]+]]); // CHECK-NEXT: set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD2:[0-9]+]]); // CHECK-NEXT: set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD3:[0-9]+]]); // CHECK-NEXT: for (size_t +// CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD1]]); +// CHECK-NEXT: TLOAD( // CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD2]]); // CHECK-NEXT: TLOAD( // CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD3]]); // CHECK-NEXT: TLOAD( // CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD0]]); // CHECK-NEXT: TLOAD( -// CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD1]]); -// CHECK-NEXT: TLOAD( // CHECK: set_flag(PIPE_M, PIPE_FIX, EVENT_ID[[PUSH:[0-9]+]]); // CHECK-NEXT: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[POST:[0-9]+]]); // CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD0]]);