Skip to content

Commit 0472b56

Browse files
committed
添加缓冲区驻留分析实现及相关测试
1 parent cd56fa2 commit 0472b56

File tree

5 files changed

+337
-1
lines changed

5 files changed

+337
-1
lines changed

mlir/optimization/scheduler/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_executable(
1919
lib/LivenessAdapter.cpp
2020
lib/LocalListScheduling.cpp
2121
lib/SimpleLoopInterchange.cpp
22+
lib/BufferResidencyAnalysis.cpp
2223
)
2324

2425
# add_dependencies(lab-scheduler ToyCh6ShapeInferenceInterfaceIncGen

mlir/optimization/scheduler/include/lab/LabPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ std::unique_ptr<Pass> createLabMemrefLifetimePass();
1717
std::unique_ptr<Pass> createLabFusionFeasibilityPass();
1818
std::unique_ptr<Pass> createAsyncLocalSchedulePass();
1919
std::unique_ptr<Pass> createSimpleLoopInterchangePass();
20+
std::unique_ptr<Pass> createResidencyAnalysisPass();
2021

2122
} // namespace mlir

mlir/optimization/scheduler/lab-opt.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ int main(int argc, char **argv) {
4646
[](mlir::OpPassManager &pm) {
4747
pm.addPass(mlir::createLabFusionFeasibilityPass());
4848
});
49-
5049
mlir::PassPipelineRegistration<>(
5150
"lab-async-local-schedule", "Lab Async Local Schedule Pass",
5251
[](mlir::OpPassManager &pm) {
@@ -57,6 +56,12 @@ int main(int argc, char **argv) {
5756
[](mlir::OpPassManager &pm) {
5857
pm.addPass(mlir::createSimpleLoopInterchangePass());
5958
});
59+
mlir::PassPipelineRegistration<>(
60+
"residency-analysis", "Buffer Residency Analysis Pass",
61+
[](mlir::OpPassManager &pm) {
62+
pm.addPass(mlir::createResidencyAnalysisPass());
63+
});
64+
6065
return mlir::asMainReturnCode(
6166
mlir::MlirOptMain(argc, argv, "Lab optimizer\n", registry));
6267
}
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
2+
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
3+
#include "mlir/Analysis/DataFlowFramework.h"
4+
#include "mlir/Dialect/Func/IR/FuncOps.h"
5+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
6+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
7+
#include "mlir/IR/Block.h"
8+
#include "mlir/IR/BuiltinAttributes.h"
9+
#include "mlir/IR/BuiltinTypes.h"
10+
#include "mlir/IR/MLIRContext.h"
11+
#include "mlir/IR/Operation.h"
12+
#include "mlir/IR/Value.h"
13+
#include "mlir/Pass/Pass.h"
14+
#include "mlir/Support/LLVM.h"
15+
16+
using namespace mlir;
17+
using namespace mlir::dataflow;
18+
19+
enum class ResidencyKind {
20+
Uninitialized, // 由框架隐式表示,通常不单独存
21+
Unknown, // 什么都不知道
22+
DDR, // 主存/外存
23+
FastMem, // L1/L2/shared/local SRAM
24+
Conflict // 来自不同来源且无法统一
25+
};
26+
27+
struct ResidencyValue {
28+
enum Kind { Unknown, DDR, FastMem, Conflict } kind = Unknown;
29+
30+
static ResidencyValue getPessimisticValueState(MLIRContext *) {
31+
return ResidencyValue{Unknown};
32+
}
33+
34+
static ResidencyValue getPessimisticValueState(Value value) {
35+
if (auto mt = dyn_cast<MemRefType>(value.getType())) {
36+
if (!mt.getMemorySpace())
37+
return ResidencyValue{DDR};
38+
39+
if (auto i = dyn_cast<IntegerAttr>(mt.getMemorySpace())) {
40+
if (i.getInt() == 0)
41+
return ResidencyValue{DDR};
42+
if (i.getInt() == 1)
43+
return ResidencyValue{FastMem};
44+
}
45+
}
46+
return ResidencyValue{Unknown};
47+
}
48+
49+
static ResidencyValue join(const ResidencyValue &lhs,
50+
const ResidencyValue &rhs) {
51+
if (lhs.kind == rhs.kind)
52+
return lhs;
53+
if (lhs.kind == Unknown)
54+
return rhs;
55+
if (rhs.kind == Unknown)
56+
return lhs;
57+
return ResidencyValue{Conflict};
58+
}
59+
60+
bool operator==(const ResidencyValue &rhs) const { return kind == rhs.kind; }
61+
62+
void print(raw_ostream &os) const {
63+
switch (kind) {
64+
case Unknown:
65+
os << "unknown";
66+
break;
67+
case DDR:
68+
os << "ddr";
69+
break;
70+
case FastMem:
71+
os << "fastmem";
72+
break;
73+
case Conflict:
74+
os << "conflict";
75+
break;
76+
}
77+
}
78+
};
79+
80+
using ResidencyLattice = Lattice<ResidencyValue>;
81+
82+
class ResidencyAnalysis
83+
: public SparseForwardDataFlowAnalysis<ResidencyLattice> {
84+
public:
85+
using SparseForwardDataFlowAnalysis<
86+
ResidencyLattice>::SparseForwardDataFlowAnalysis;
87+
88+
LogicalResult visitOperation(Operation *op,
89+
ArrayRef<const ResidencyLattice *> operands,
90+
ArrayRef<ResidencyLattice *> results) override {
91+
92+
// Rule 1: memref.alloc
93+
if (auto alloc = dyn_cast<memref::AllocOp>(op)) {
94+
ResidencyValue v =
95+
ResidencyValue::getPessimisticValueState(alloc.getResult());
96+
auto *lattice = getLatticeElement(alloc.getResult());
97+
propagateIfChanged(lattice, lattice->join(v));
98+
return success();
99+
}
100+
101+
// Rule 2: memref.subview / cast / reinterpret_cast
102+
if (isa<memref::SubViewOp, memref::CastOp, memref::ReinterpretCastOp>(op)) {
103+
if (op->getNumOperands() >= 1 && op->getNumResults() >= 1) {
104+
auto *srcLat = operands.front();
105+
if (srcLat)
106+
join(results.front(), *srcLat);
107+
}
108+
return success();
109+
}
110+
111+
// Rule 3: memref.copy
112+
if (auto copy = dyn_cast<memref::CopyOp>(op)) {
113+
Value target = copy.getTarget();
114+
auto *lattice = getLatticeElement(target);
115+
if (const ResidencyLattice *srcLat = operands.front())
116+
propagateIfChanged(lattice, lattice->join(srcLat->getValue()));
117+
118+
ResidencyValue dst = ResidencyValue::getPessimisticValueState(target);
119+
propagateIfChanged(lattice, lattice->join(dst));
120+
return success();
121+
}
122+
123+
// Rule 4: linalg generic / matmul
124+
if (isa<linalg::LinalgOp>(op)) {
125+
auto linalgOp = cast<linalg::LinalgOp>(op);
126+
for (Value v : linalgOp.getDpsInits()) {
127+
ResidencyValue rv = ResidencyValue::getPessimisticValueState(v);
128+
auto *lattice = getLatticeElement(v);
129+
propagateIfChanged(lattice, lattice->join(rv));
130+
}
131+
return success();
132+
}
133+
134+
if (auto msc = dyn_cast<memref::MemorySpaceCastOp>(op)) {
135+
if (!operands.empty() && operands.front() && !results.empty()) {
136+
join(results.front(), *operands.front());
137+
}
138+
return success();
139+
}
140+
141+
// 默认策略:如果 op 只是透传一个值,可传播第一个操作数状态
142+
if (op->getNumOperands() == 1 && op->getNumResults() == 1) {
143+
if (!operands.empty() && operands.front())
144+
join(results.front(), *operands.front());
145+
return success();
146+
}
147+
148+
setAllToEntryStates(results);
149+
return success();
150+
}
151+
152+
protected:
153+
void setToEntryState(ResidencyLattice *lattice) override {
154+
propagateIfChanged(lattice,
155+
lattice->join(ResidencyValue::getPessimisticValueState(
156+
lattice->getAnchor())));
157+
}
158+
};
159+
160+
struct ResidencyAnalysisPass
161+
: public PassWrapper<ResidencyAnalysisPass, OperationPass<func::FuncOp>> {
162+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResidencyAnalysisPass)
163+
164+
static void printMemrefMemoryInfo(Value value, raw_ostream &os,
165+
DataFlowSolver &solver) {
166+
os << " value=";
167+
value.print(os);
168+
os << " type=";
169+
value.getType().print(os);
170+
171+
auto memrefType = dyn_cast<MemRefType>(value.getType());
172+
if (!memrefType) {
173+
os << " memory_space=<not-memref>\n";
174+
return;
175+
}
176+
177+
os << " memory_space=";
178+
Attribute memorySpace = memrefType.getMemorySpace();
179+
if (!memorySpace) {
180+
os << "default";
181+
} else if (auto intAttr = dyn_cast<IntegerAttr>(memorySpace)) {
182+
os << intAttr.getInt();
183+
} else {
184+
memorySpace.print(os);
185+
}
186+
os << " residency=";
187+
solver.lookupState<ResidencyLattice>(value)->getValue().print(os);
188+
189+
if (Operation *defOp = value.getDefiningOp())
190+
os << " defined_by=" << defOp->getName();
191+
os << "\n";
192+
}
193+
194+
static void printLinalgOperandMemoryInfo(linalg::LinalgOp linalgOp,
195+
DataFlowSolver &solver) {
196+
llvm::errs() << "linalg op: " << linalgOp->getName() << "\n";
197+
llvm::errs() << " ins:\n";
198+
for (Value value : linalgOp.getDpsInputs())
199+
printMemrefMemoryInfo(value, llvm::errs(), solver);
200+
201+
llvm::errs() << " outs:\n";
202+
for (Value value : linalgOp.getDpsInits())
203+
printMemrefMemoryInfo(value, llvm::errs(), solver);
204+
}
205+
206+
static StringRef getResidencyTag(ResidencyValue::Kind kind) {
207+
switch (kind) {
208+
case ResidencyValue::Unknown:
209+
return "unknown";
210+
case ResidencyValue::DDR:
211+
return "ddr";
212+
case ResidencyValue::FastMem:
213+
return "fastmem";
214+
case ResidencyValue::Conflict:
215+
return "conflict";
216+
}
217+
218+
llvm_unreachable("unexpected residency kind");
219+
}
220+
221+
void runOnOperation() override {
222+
func::FuncOp func = getOperation();
223+
DataFlowSolver solver;
224+
solver.load<DeadCodeAnalysis>();
225+
solver.load<ResidencyAnalysis>();
226+
if (failed(solver.initializeAndRun(func))) {
227+
signalPassFailure();
228+
return;
229+
}
230+
231+
func.walk([&](Operation *op) {
232+
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op))
233+
printLinalgOperandMemoryInfo(linalgOp, solver);
234+
235+
for (Value result : op->getResults()) {
236+
if (!isa<MemRefType>(result.getType()))
237+
continue;
238+
239+
auto *lat = solver.lookupState<ResidencyLattice>(result);
240+
if (!lat)
241+
continue;
242+
243+
op->emitRemark() << "result residency = "
244+
<< getResidencyTag(lat->getValue().kind);
245+
}
246+
});
247+
248+
func.emitRemark() << "=== function args ===";
249+
for (BlockArgument arg : func.getArguments()) {
250+
if (!isa<MemRefType>(arg.getType()))
251+
continue;
252+
253+
if (auto *lat = solver.lookupState<ResidencyLattice>(arg)) {
254+
llvm::errs() << "arg: ";
255+
lat->getValue().print(llvm::errs());
256+
llvm::errs() << "\n";
257+
}
258+
}
259+
260+
func.walk([&](Operation *op) {
261+
for (Value result : op->getResults()) {
262+
if (!isa<MemRefType>(result.getType()))
263+
continue;
264+
265+
if (auto *lat = solver.lookupState<ResidencyLattice>(result)) {
266+
llvm::errs() << "op result @" << op->getName() << " : ";
267+
lat->getValue().print(llvm::errs());
268+
llvm::errs() << "\n";
269+
}
270+
}
271+
});
272+
}
273+
};
274+
275+
namespace mlir {
276+
std::unique_ptr<Pass> createResidencyAnalysisPass() {
277+
return std::make_unique<ResidencyAnalysisPass>();
278+
}
279+
} // namespace mlir
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
func.func @ex1(%A: memref<128x128xf32>) {
2+
%buf_fast = memref.alloc() : memref<128x128xf32, 1>
3+
4+
%generic = memref.memory_space_cast %buf_fast
5+
: memref<128x128xf32, 1> to memref<128x128xf32>
6+
7+
%tile = memref.subview %generic[0, 0][64, 64][1, 1]
8+
: memref<128x128xf32>
9+
to memref<64x64xf32, strided<[128, 1], offset: 0>>
10+
11+
return
12+
}
13+
14+
// arg: ddr
15+
// op result @memref.alloc : fastmem
16+
// op result @memref.memory_space_cast : fastmem
17+
// op result @memref.subview : fastmem
18+
19+
func.func @ex4(%A: memref<128x128xf32>,
20+
%B: memref<128x128xf32>,
21+
%C: memref<128x128xf32>) {
22+
%bufA = memref.alloc() : memref<128x128xf32, 1>
23+
%bufB = memref.alloc() : memref<128x128xf32, 1>
24+
25+
%a = memref.memory_space_cast %bufA
26+
: memref<128x128xf32, 1> to memref<128x128xf32>
27+
%b = memref.subview %bufB[0, 0][128, 128][1, 1]
28+
: memref<128x128xf32, 1>
29+
to memref<128x128xf32, strided<[128, 1], offset: 0>, 1>
30+
31+
linalg.matmul
32+
ins(%a, %b : memref<128x128xf32>,
33+
memref<128x128xf32, strided<[128, 1], offset: 0>, 1>)
34+
outs(%C : memref<128x128xf32>)
35+
return
36+
}
37+
38+
// arg: ddr
39+
// arg: ddr
40+
// arg: ddr
41+
// op result @memref.alloc : fastmem
42+
// op result @memref.alloc : fastmem
43+
// op result @memref.memory_space_cast : fastmem
44+
// op result @memref.subview : fastmem
45+
// linalg op: linalg.matmul
46+
// ins:
47+
// value=%memspacecast = memref.memory_space_cast %alloc : memref<128x128xf32, 1> to memref<128x128xf32> type=memref<128x128xf32> memory_space=default residency=fastmem defined_by=memref.memory_space_cast
48+
// value=%subview = memref.subview %alloc_0[0, 0] [128, 128] [1, 1] : memref<128x128xf32, 1> to memref<128x128xf32, strided<[128, 1]>, 1> type=memref<128x128xf32, strided<[128, 1]>, 1> memory_space=1 residency=fastmem defined_by=memref.subview
49+
// outs:
50+
// value=<block argument> of type 'memref<128x128xf32>' at index: 2 type=memref<128x128xf32> memory_space=default residency=ddr

0 commit comments

Comments
 (0)