Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions docs/designs/ptoas-auto-sync-design.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@
- `eventIds`(`SmallVector<int>`)+ `eventIdNum`:分配后填入的 event id 列表,
长度 = `eventIdNum`(多缓冲场景大于 1)。
- `depRootBuffers`:构造该同步对的依赖链所涉及的 root buffer 集合,
用于 `RemoveRedundantSync` 中的 `hasSameSyncDepRoots` 判定,避免误删
pipe 对相同但语义无关」的另一组同步
用于分配/widening 阶段的启发式和调试信息;`RemoveRedundantSync` 删除
set/wait 冗余时按 pipe pair 语义判断,不把 root buffer 相等作为必要条件
- `uselessSync`:`RemoveRedundantSync` 命中后置为 true,最终从 `pipeBefore/After` 中移除。
- `isCompensation`:是否为 loop 回边补偿同步(头/尾 comp_set / comp_wait)。
- `isCompensation`:预留给分析阶段提前生成的 synthetic compensation sync。
当前 loop 回边的 head/tail 配对同步由 `SyncEventIdAllocation` 在冗余删除之后生成。
- `lowestCommonAncestorBuffer`、`reuseCntForWiden`、`reallocatedLoopHeadTailSync`:
分配阶段的辅助状态(widen/reallocate 用)。

Expand Down Expand Up @@ -266,13 +267,14 @@ MLIR op,最终生成 `pto::SetFlagOp` / `pto::WaitFlagOp` / `pto::BarrierOp`
- 主要逻辑:
1. 收集所有 `syncOperations_[k].size()==2` 的 set/wait 对,按 `forEndIndex` /
`kSyncIndex` 排序(内层优先)。
2. 对每对 `(setFlag, waitFlag)`:跳过多缓冲(`eventIdNum != 1`)、跳过补偿同步、
跳过 `depRootBuffers` 不一致的对,再走 `CheckAllSync`。
2. 对每对 `(setFlag, waitFlag)`:跳过 `isCompensation` 标记的预生成补偿同步,
再走 `CheckAllSync`。
3. `CheckAllSync` → `CheckRepeatSync` 在 `[setIRIndex, waitIRIndex]` 区间扫描
`pipeBefore/pipeAfter`,遇到分支调 `CheckBranchBetween`、遇到循环调
`CheckLoopBetween`。`CanMatchedSync` 用一个 `SmallVector<bool>` 充当
`syncFinder` 状态机:先看见相同 `kSyncIndex` 之外的 set,置位;再看见对应
wait 时确认配对,认为外层同步对被覆盖。
wait 时确认配对;只要内部完整同步对的 pipe pair 相同且 `eventIdNum`
不大于外部同步对,就认为外层同步对被覆盖。
4. 命中后置 `uselessSync=true`,并通过 `InstanceElement::RemoveSync` 从挂载队列摘除。
- 输出:被标记为冗余的 `SyncOperation` 仍保留在 `syncOperations_` 中
(便于打印/调试),但不再出现在任何 `SyncIR` 节点的 `pipeBefore/After`。
Expand Down
4 changes: 2 additions & 2 deletions include/PTO/Transforms/InsertSync/SyncCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ class SyncOperation {
public:
SmallVector<int> eventIds;
// Root buffers participating in the dependency that created this sync pair.
// Used by redundant-sync pruning to avoid removing syncs from unrelated
// producer/consumer chains that happen to share the same pipe pair.
// These are kept for allocation/widening heuristics and debug printing; set/
// wait redundancy pruning is based on the pipe pair semantics.
SmallVector<Value> depRootBuffers;
bool uselessSync{false};
int eventIdNum{1};
Expand Down
23 changes: 4 additions & 19 deletions lib/PTO/Transforms/InsertSync/RemoveRedundantSync.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,9 @@ void RemoveRedundantSync::Run() {

// 3. 逐个检查并移除冗余
for (auto [setFlag, waitFlag] : syncOps) {
// Conservative mode:
// 1) keep multi-buffer and compensation syncs
// 2) only prune syncs that carry concrete dependency signatures
if (setFlag->eventIdNum != 1 || waitFlag->eventIdNum != 1) {
continue;
}
if (setFlag->isCompensation || waitFlag->isCompensation) {
continue;
}
if (!hasSameSyncDepRoots(setFlag, waitFlag)) {
continue;
}

bool useless = CheckAllSync(setFlag, waitFlag);
if (useless) {
Expand Down Expand Up @@ -225,12 +216,9 @@ bool RemoveRedundantSync::CheckLoopBetween(LoopInstanceElement *loopElement,
bool RemoveRedundantSync::CanMatchedSync(SmallVector<bool> &syncFinder,
SyncOperation *relatedSync,
SyncOperation *setFlag) {
// 1. 过滤不相关的同步
// - 类型必须匹配 (Wait/Set)
// - 不能是自己 (Index 相同)
// - Pipe 必须完全一致 (Src->Dst)
// - EventIdNum: 内部的同步能力必须强于外部 (related.eventIdNum >= set.eventIdNum ???)
// 这里暂时假设 Single Buffer (eventIdNum=1) 场景即可覆盖主流程
// Set/wait flags serialize a pipe pair, not a particular root buffer. A
// complete inner pair on the same pipe pair can cover an outer pair even when
// the memory dependency roots differ.

bool isWait = (relatedSync->GetType() == SyncOperation::TYPE::WAIT_EVENT);
bool isSet = (relatedSync->GetType() == SyncOperation::TYPE::SET_EVENT);
Expand All @@ -243,11 +231,8 @@ bool RemoveRedundantSync::CanMatchedSync(SmallVector<bool> &syncFinder,

if (!isWait && !isSet) return false;
if (relatedSync->GetSyncIndex() == setFlag->GetSyncIndex()) return false;
if (relatedSync->eventIdNum != setFlag->eventIdNum) return false;
if (relatedSync->GetForEndIndex() != setFlag->GetForEndIndex()) return false;
if (relatedSync->eventIdNum > setFlag->eventIdNum) return false;
if (relatedSync->isCompensation || setFlag->isCompensation) return false;
if (!hasSameSyncDepRoots(relatedSync, setFlag))
return false;

// Pipe 检查:内部同步必须也是解决同样的 Src -> Dst 依赖
if (relatedSync->GetSrcPipe() != setFlag->GetSrcPipe()) return false;
Expand Down
56 changes: 56 additions & 0 deletions test/lit/pto/issue226_remove_redundant_pipe_pair.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// RUN: ptoas --pto-arch=a3 --enable-insert-sync %s | FileCheck %s
//
// Regression guard for issue #226: RemoveRedundantSync should prune an outer
// set/wait pair when the interval contains a complete inner set/wait pair on
// the same pipe pair, even if those syncs came from different dep roots.
//
// CHECK-LABEL: __global__ AICORE void remove_redundant_pipe_pair
// CHECK: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[M2MTE1:[0-9]+]]);
// CHECK: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[M2MTE1]]);
// CHECK-NOT: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[M2MTE1]]);
// CHECK: set_flag(PIPE_MTE1, PIPE_M, EVENT_ID[[MTE12M:[0-9]+]]);
// CHECK: wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID[[MTE12M]]);
// CHECK-NOT: wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID[[MTE12M]]);
// CHECK: ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll);

module {
func.func @remove_redundant_pipe_pair(
%arg0: memref<64x1xf16, strided<[1, 1]>, #pto.address_space<vec>>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c64 = arith.constant 64 : index
%vbuf0 = pto.bind_tile %arg0, %c64, %c1
{config = #pto.tile_buf_config<blayout=1 : i32, slayout=2 : i32, s_fractal_size=512, pad=0 : i32>}
: memref<64x1xf16, strided<[1, 1]>, #pto.address_space<vec>>
-> memref<64x1xf16, strided<[1, 1], offset: ?>, #pto.address_space<vec>>

pto.section.cube {
%mat_a = pto.alloc_tile : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>
%mat_b = pto.alloc_tile : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>
%left = pto.alloc_tile : !pto.tile_buf<loc=left, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=row_major, slayout=row_major, fractal=512, pad=0>
%right = pto.alloc_tile : !pto.tile_buf<loc=right, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=row_major, slayout=col_major, fractal=512, pad=0>
%acc = pto.alloc_tile : !pto.tile_buf<loc=acc, dtype=f32, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=1024, pad=0>

scf.for %i = %c0 to %c2 step %c1 {
pto.tload ins(%vbuf0 : memref<64x1xf16, strided<[1, 1], offset: ?>, #pto.address_space<vec>>)
outs(%mat_a : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>)
pto.tmov ins(%mat_a : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>)
outs(%left : !pto.tile_buf<loc=left, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=row_major, slayout=row_major, fractal=512, pad=0>)
pto.tmov ins(%mat_b : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>)
outs(%right : !pto.tile_buf<loc=right, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=row_major, slayout=col_major, fractal=512, pad=0>)
pto.tmatmul ins(%left, %right : !pto.tile_buf<loc=left, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=row_major, slayout=row_major, fractal=512, pad=0>, !pto.tile_buf<loc=right, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=row_major, slayout=col_major, fractal=512, pad=0>)
outs(%acc : !pto.tile_buf<loc=acc, dtype=f32, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=1024, pad=0>)
%is_first = arith.cmpi eq, %i, %c0 : index
scf.if %is_first {
pto.tmov ins(%acc : !pto.tile_buf<loc=acc, dtype=f32, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=1024, pad=0>)
outs(%mat_a : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>)
} else {
pto.tmov ins(%acc : !pto.tile_buf<loc=acc, dtype=f32, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=1024, pad=0>)
outs(%mat_b : !pto.tile_buf<loc=mat, dtype=f16, rows=32, cols=32, v_row=32, v_col=32, blayout=col_major, slayout=row_major, fractal=512, pad=0>)
}
}
}
return
}
}
7 changes: 3 additions & 4 deletions test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,21 @@
// CHECK-LABEL: AICORE void scope3_incore_0_aic(
// CHECK: TMATMUL_ACC(
// CHECK-NEXT: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[CARRY:[0-9]+]]);
// CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[PRE0:[0-9]+]]);
// CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[PRE1:[0-9]+]]);
// CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[PRE:[0-9]+]]);
// CHECK-NEXT: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[CARRY]]);
// CHECK-NEXT: set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD0:[0-9]+]]);
// CHECK-NEXT: set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD1:[0-9]+]]);
// CHECK-NEXT: set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD2:[0-9]+]]);
// CHECK-NEXT: set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD3:[0-9]+]]);
// CHECK-NEXT: for (size_t
// CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD1]]);
// CHECK-NEXT: TLOAD(
// CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD2]]);
// CHECK-NEXT: TLOAD(
// CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD3]]);
// CHECK-NEXT: TLOAD(
// CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD0]]);
// CHECK-NEXT: TLOAD(
// CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD1]]);
// CHECK-NEXT: TLOAD(
// CHECK: set_flag(PIPE_M, PIPE_FIX, EVENT_ID[[PUSH:[0-9]+]]);
// CHECK-NEXT: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[POST:[0-9]+]]);
// CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD0]]);
Expand Down
Loading