Skip to content

Commit 37800da

Browse files
committed
Create the gpu outline pass
1 parent 7e1f45c commit 37800da

File tree

6 files changed

+386
-10
lines changed

6 files changed

+386
-10
lines changed

mlir/cuda-tile/Toy/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ add_executable(
2828
mlir/LowerToAffineLoops.cpp
2929
mlir/LowerToLLVM.cpp
3030
mlir/ShapeInferencePass.cpp
31-
mlir/ToyCombine.cpp)
31+
mlir/ToyCombine.cpp
32+
mlir/LowerToGpu.cpp
33+
)
3234

3335
add_dependencies(toy-cuda
3436
ToyCudaShapeInferenceInterfaceIncGen

mlir/cuda-tile/Toy/include/toy/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define TOY_PASSES_H
1515

1616
#include <memory>
17+
#include <string>
1718

1819
namespace mlir {
1920
class Pass;
@@ -29,6 +30,8 @@ std::unique_ptr<mlir::Pass> createLowerToAffinePass();
2930
/// well as `Affine` and `Std`, to the LLVM dialect for codegen.
3031
std::unique_ptr<mlir::Pass> createLowerToLLVMPass();
3132

33+
std::unique_ptr<mlir::Pass> createGpuOutlinePass(std::string grid="1,1,1");
34+
3235
} // namespace toy
3336
} // namespace mlir
3437

mlir/cuda-tile/Toy/mlir/Dialect.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,6 @@ llvm::LogicalResult ReturnOp::verify() {
379379
if (!function)
380380
return emitOpError() << "must be enclosed in a function-like op";
381381

382-
383382
/// ReturnOps can only have a single optional operand.
384383
if (getNumOperands() > 1)
385384
return emitOpError() << "expects at most 1 return operand";
@@ -498,7 +497,7 @@ llvm::LogicalResult MatMulOp::verify() {
498497
//===----------------------------------------------------------------------===//
499498

500499
void LaunchGpuOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
501-
StringRef callee, ArrayRef<mlir::Value> arguments) {
500+
StringRef callee, ArrayRef<mlir::Value> arguments) {
502501
// Generic call always returns an unranked Tensor initially.
503502
state.addTypes(UnrankedTensorType::get(builder.getF32Type()));
504503
state.addOperands(arguments);
@@ -529,21 +528,20 @@ MutableOperandRange LaunchGpuOp::getArgOperandsMutable() {
529528
return getInputsMutable();
530529
}
531530

532-
533531
//===----------------------------------------------------------------------===//
534532
// GPUFuncOp
535533
//===----------------------------------------------------------------------===//
536534

537535
void GPUFuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
538-
llvm::StringRef name, mlir::FunctionType type,
539-
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
536+
llvm::StringRef name, mlir::FunctionType type,
537+
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
540538
// FunctionOpInterface provides a convenient `build` method that will populate
541539
// the state of our GPUFuncOp, and create an entry block.
542540
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
543541
}
544542

545543
mlir::ParseResult GPUFuncOp::parse(mlir::OpAsmParser &parser,
546-
mlir::OperationState &result) {
544+
mlir::OperationState &result) {
547545
// Dispatch to the FunctionOpInterface provided utility method that parses the
548546
// function operation.
549547
auto buildFuncType =
Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
#include "mlir/IR/Attributes.h"
2+
#include "mlir/IR/Block.h"
3+
#include "mlir/IR/Builders.h"
4+
#include "mlir/IR/BuiltinOps.h"
5+
#include "mlir/IR/BuiltinTypes.h"
6+
#include "mlir/IR/IRMapping.h"
7+
#include "mlir/IR/Operation.h"
8+
#include "mlir/IR/SymbolTable.h"
9+
#include "mlir/IR/Types.h"
10+
#include "mlir/IR/Value.h"
11+
#include "mlir/Pass/Pass.h"
12+
#include "mlir/Support/LLVM.h"
13+
#include "mlir/Support/TypeID.h"
14+
#include "toy/Dialect.h"
15+
#include "toy/Passes.h"
16+
#include "llvm/ADT/STLExtras.h"
17+
#include "llvm/ADT/SmallPtrSet.h"
18+
#include "llvm/ADT/SmallSet.h"
19+
#include "llvm/ADT/SmallVector.h"
20+
#include "llvm/ADT/StringExtras.h"
21+
#include "llvm/ADT/StringRef.h"
22+
#include "llvm/Support/Casting.h"
23+
#include "llvm/Support/DebugLog.h"
24+
25+
#include <memory>
26+
#include <string>
27+
28+
#define DEBUG_TYPE "toy-gpu-outline"
29+
30+
namespace {
31+
32+
static bool isGpuOperation(mlir::Operation *op,
33+
const llvm::SmallSet<llvm::StringRef, 4> &gpuOps) {
34+
llvm::StringRef opName = op->getName().getStringRef().split('.').second;
35+
return gpuOps.contains(opName);
36+
}
37+
38+
static llvm::SmallVector<int64_t, 3> parseGrid(llvm::StringRef gridStr) {
39+
llvm::SmallVector<int64_t, 3> dims;
40+
llvm::SmallVector<llvm::StringRef, 4> pieces;
41+
gridStr.split(pieces, ',');
42+
for (llvm::StringRef piece : pieces) {
43+
int64_t value = 0;
44+
if (!piece.empty() && llvm::to_integer(piece.trim(), value))
45+
dims.push_back(value);
46+
}
47+
if (dims.size() != 3)
48+
dims = {1, 1, 1};
49+
return dims;
50+
}
51+
52+
struct GpuOutlinePass
53+
: public mlir::PassWrapper<GpuOutlinePass,
54+
mlir::OperationPass<mlir::toy::FuncOp>> {
55+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GpuOutlinePass)
56+
57+
std::string grid{"1,1,1"};
58+
59+
llvm::StringRef getArgument() const override { return "toy-gpu-outline"; }
60+
61+
void initializeOptions(std::string grid) {
62+
this->grid = grid;
63+
}
64+
65+
void runOnOperation() override {
66+
auto func = getOperation();
67+
if (func.getName() != "main")
68+
return;
69+
70+
llvm::SmallSet<llvm::StringRef, 4> gpuOperations = {"matmul", "add", "mul",
71+
"transpose"};
72+
73+
// // Collect GPU-eligible ops in block order for deterministic cloning.
74+
// llvm::SmallDenseSet<mlir::Operation *, 8> gpuOpSet;
75+
// llvm::SmallVector<mlir::Operation *> gpuOps;
76+
77+
// for (mlir::Operation &op : func.front()) {
78+
// if (isGpuOperation(&op, gpuOperations)) {
79+
// gpuOpSet.insert(&op);
80+
// gpuOps.push_back(&op);
81+
// }
82+
// }
83+
84+
// if (gpuOps.empty())
85+
// return;
86+
87+
llvm::SmallVector<int64_t, 3> gridDims = parseGrid(grid);
88+
89+
llvm::SmallVector<llvm::SmallVector<mlir::Operation *>> gpuSubgraphs;
90+
91+
// Find a gpu subgraph like
92+
// [[gpuOps, ...], [gpuOps, ...], ...]
93+
// original sequence:
94+
// [..., non-gpu-op, [gpu-op, gpu-op], non-gpu-op, [gpu-op, ...]]
95+
func.walk([&](mlir::Operation *op) {
96+
if (isGpuOperation(op, gpuOperations)) {
97+
if (gpuSubgraphs.empty()) {
98+
gpuSubgraphs.push_back({op});
99+
} else {
100+
gpuSubgraphs.back().push_back(op);
101+
}
102+
} else {
103+
if (gpuSubgraphs.empty()) {
104+
gpuSubgraphs.push_back({});
105+
} else if (!gpuSubgraphs.back().empty()) {
106+
gpuSubgraphs.push_back({});
107+
}
108+
}
109+
});
110+
111+
if (gpuSubgraphs.empty())
112+
return;
113+
114+
bool allEmpty = llvm::all_of(
115+
gpuSubgraphs, [](const llvm::SmallVector<mlir::Operation *> &sg) {
116+
return sg.empty();
117+
});
118+
119+
if (allEmpty)
120+
return;
121+
122+
if (gpuSubgraphs.back().empty()) {
123+
gpuSubgraphs.pop_back();
124+
}
125+
126+
for (const auto &gpuSubgraph : gpuSubgraphs) {
127+
LDBG() << "----GPU subgraph----\n";
128+
for (const auto &op : gpuSubgraph) {
129+
LDBG() << *op << "\n";
130+
}
131+
LDBG() << "--------------------\n";
132+
}
133+
134+
llvm::SmallVector<std::string> outlinedFuncNames;
135+
llvm::SmallVector<mlir::Operation *> insertPoints;
136+
137+
// the logic to outline each gpu subgraph
138+
// 1. find operands or input for the subgraph (exclude the input inside
139+
// subgraph).
140+
// 2. find results or output for the subgraph (exclude the output inside
141+
// subgraph).
142+
// 3. create a new function with operands as input and results as output.
143+
// 4. insert a LaunchGpuOp to call the outlined function at the insert point
144+
145+
for (const auto &[index, gpuSubgraph] : llvm::enumerate(gpuSubgraphs)) {
146+
if (!gpuSubgraph.empty()) {
147+
LDBG() << "----GPU subgraph----\n";
148+
for (const auto &op : gpuSubgraph) {
149+
LDBG() << *op << "\n";
150+
}
151+
152+
// Identify its operands.
153+
llvm::SmallVector<mlir::Value, 8> Operands;
154+
llvm::SmallPtrSet<mlir::Value, 8> OperandSet;
155+
for (mlir::Operation *op : gpuSubgraph) {
156+
for (mlir::Value operand : op->getOperands()) {
157+
auto *def = operand.getDefiningOp();
158+
if (!def || !isGpuOperation(def, gpuOperations)) {
159+
if (OperandSet.insert(operand).second)
160+
Operands.push_back(operand);
161+
}
162+
}
163+
}
164+
165+
LDBG() << "Operands:\n";
166+
for (mlir::Value &operand : Operands) {
167+
LDBG() << " " << operand << "\n";
168+
}
169+
170+
llvm::SmallVector<mlir::Value, 2> Results;
171+
llvm::SmallPtrSet<mlir::Value, 2> ResultSet;
172+
173+
for (mlir::Operation *op : gpuSubgraph) {
174+
for (mlir::Value result : op->getResults()) {
175+
bool escapes =
176+
llvm::any_of(result.getUsers(), [&](mlir::Operation *user) {
177+
return !isGpuOperation(user, gpuOperations);
178+
});
179+
if (escapes && ResultSet.insert(result).second)
180+
Results.push_back(result);
181+
}
182+
}
183+
184+
LDBG() << "Results:\n";
185+
for (mlir::Value &result : Results) {
186+
LDBG() << " " << result << "\n";
187+
}
188+
189+
if (Results.size() != 1) {
190+
llvm::errs()
191+
<< "Currently only support single result GPU kernel "
192+
<< "Since the toy return op only supports single return value "
193+
<< "Found " << Results.size() << " results\n";
194+
return signalPassFailure();
195+
}
196+
197+
// buid the kernel for each subgraph
198+
llvm::SmallVector<mlir::Type, 8> argTypes;
199+
argTypes.reserve(Operands.size());
200+
for (mlir::Value v : Operands)
201+
argTypes.push_back(v.getType());
202+
203+
llvm::SmallVector<mlir::Type> resultTypes;
204+
resultTypes.reserve(Results.size());
205+
for (mlir::Value v : Results)
206+
resultTypes.push_back(v.getType());
207+
208+
mlir::ModuleOp module = func->getParentOfType<mlir::ModuleOp>();
209+
mlir::SymbolTable symbolTable(module);
210+
std::string outline_func_name =
211+
"outlined_gpu_kernel_" + std::to_string(index);
212+
213+
unsigned suffix = 0;
214+
while (symbolTable.lookup(outline_func_name))
215+
outline_func_name =
216+
outline_func_name + "_" + std::to_string(++suffix);
217+
218+
insertPoints.push_back(gpuSubgraph.front());
219+
220+
{
221+
mlir::OpBuilder moduleBuilder(module.getContext());
222+
mlir::OpBuilder::InsertionGuard guard(moduleBuilder);
223+
moduleBuilder.setInsertionPointToEnd(module.getBody());
224+
auto funcType = moduleBuilder.getFunctionType(argTypes, resultTypes);
225+
auto gpuFunc = mlir::toy::GPUFuncOp::create(
226+
moduleBuilder, func.getLoc(), outline_func_name, funcType);
227+
228+
mlir::Block &kernelEntry = gpuFunc.getBody().front();
229+
mlir::OpBuilder kernelBuilder =
230+
mlir::OpBuilder::atBlockEnd(&kernelEntry);
231+
232+
mlir::IRMapping mapping;
233+
for (auto [blockArg, captured] :
234+
llvm::zip(kernelEntry.getArguments(), Operands))
235+
mapping.map(captured, blockArg);
236+
237+
for (mlir::Operation *op : gpuSubgraph) {
238+
kernelBuilder.clone(*op, mapping);
239+
}
240+
llvm::SmallVector<mlir::Value> mappedResults;
241+
mappedResults.reserve(Results.size());
242+
for (mlir::Value res : Results)
243+
mappedResults.push_back(mapping.lookup(res));
244+
mlir::toy::ReturnOp::create(kernelBuilder, func.getLoc(),
245+
mappedResults);
246+
247+
LDBG() << "Created GPU kernel: " << gpuFunc << "\n";
248+
}
249+
250+
outlinedFuncNames.push_back(outline_func_name);
251+
252+
{
253+
mlir::OpBuilder hostBuilder(func.getContext());
254+
mlir::OpBuilder::InsertionGuard guard(hostBuilder);
255+
// Insert the host launch in place of the first outlined op.
256+
hostBuilder.setInsertionPoint(gpuSubgraph.back()->getNextNode());
257+
258+
auto calleeAttr = mlir::SymbolRefAttr::get(
259+
func.getContext(), llvm::StringRef(outline_func_name));
260+
261+
auto gridAttr = hostBuilder.getDenseI64ArrayAttr(gridDims);
262+
263+
auto launch = mlir::toy::LaunchGpuOp::create(
264+
hostBuilder, func.getLoc(), resultTypes, Operands,
265+
{{"callee", calleeAttr}, {"grid", gridAttr}});
266+
267+
for (auto [idx, res] : llvm::enumerate(Results))
268+
res.replaceAllUsesWith(launch.getResult(idx));
269+
270+
for (mlir::Operation *op : llvm::reverse(gpuSubgraph))
271+
op->erase();
272+
LDBG() << "Inserted LaunchGpuOp: " << launch << "\n";
273+
}
274+
LDBG() << "--------------------\n";
275+
}
276+
}
277+
};
278+
};
279+
}; // namespace
280+
281+
namespace mlir::toy {
282+
283+
std::unique_ptr<mlir::Pass> createGpuOutlinePass(std::string grid) {
284+
auto pass = std::make_unique<GpuOutlinePass>();
285+
pass->initializeOptions(grid); // You can change the grid dimensions here
286+
return pass;
287+
};
288+
289+
}; // namespace mlir::toy

0 commit comments

Comments
 (0)