Skip to content

Commit 07c722f

Browse files
committed
Added the make tensor view
1 parent bdb9a66 commit 07c722f

File tree

2 files changed

+57
-26
lines changed

2 files changed

+57
-26
lines changed

mlir/cuda-tile/Toy/mlir/LowerToCudaTile.cpp

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Support/TypeID.h"
1616
#include "toy/Dialect.h"
1717
#include "toy/Passes.h"
18+
#include "llvm/ADT/ArrayRef.h"
1819
#include "llvm/ADT/STLExtras.h"
1920
#include "llvm/ADT/SmallPtrSet.h"
2021
#include "llvm/ADT/SmallSet.h"
@@ -27,6 +28,7 @@
2728
#include "cuda_tile/Dialect/CudaTile/IR/Dialect.h"
2829
#include "cuda_tile/Dialect/CudaTile/IR/Ops.h"
2930

31+
#include <cstdint>
3032
#include <memory>
3133
#include <string>
3234

@@ -60,7 +62,7 @@ mlir::cuda_tile::ModuleOp createCudaModuleOp(mlir::OpBuilder &builder,
6062
auto cudaTileModuleOp = mlir::cuda_tile::ModuleOp::create(
6163
builder, moduleOp.getLoc(), "cuda_tile_module");
6264

63-
LDBG() << "Created CudaTile Module: \n" << cudaTileModuleOp << "\n";
65+
LDBG() << "Created CudaTile Module: \n" << cudaTileModuleOp;
6466
return cudaTileModuleOp;
6567
}
6668

@@ -70,7 +72,7 @@ void ToyToCudaTileLoweringPass::runOnOperation() {
7072
// Here we would implement the actual lowering logic from Toy GPUFuncOp
7173
// to CudaTile operations. For now, we just log that the pass is running.
7274
// LDBG() << "Running Toy to CudaTile lowering on GPUFuncOp: " << moduleOp
73-
// << "\n";
75+
// ;
7476

7577
mlir::OpBuilder builder(moduleOp.getContext());
7678
// 1. Create new cuda_tile.module Op in the last section.
@@ -86,28 +88,33 @@ void ToyToCudaTileLoweringPass::runOnOperation() {
8688
gfunOp->getAttrOfType<mlir::StringAttr>("sym_name").getValue();
8789
llvm::SmallVector<mlir::Type, 8> newArgTypes;
8890

89-
LDBG() << "Lowering GPU function: " << gfunc_name << "\n";
90-
LDBG() << "Converting input type into cuda tile type" << "\n";
91+
LDBG() << "Lowering GPU function: " << gfunc_name;
92+
LDBG() << "Converting input type into cuda tile type";
93+
94+
llvm::SmallVector<llvm::ArrayRef<int64_t>, 4> inputShapes;
95+
// llvm::SmallVector<llvm::ArrayRef<int64_t>, 4> resultShapes;
9196

9297
for (mlir::Type t : gfunOp.getFunctionType().getInputs()) {
93-
LDBG() << "Original arg type: " << t << "\n";
98+
LDBG() << "Original arg type: " << t;
9499
auto tt = llvm::dyn_cast<mlir::TensorType>(t);
95100
auto elemType = tt.getElementType();
96101
auto ptrElem = mlir::cuda_tile::PointerType::get(elemType);
97102
auto newType = mlir::cuda_tile::TileType::get({}, ptrElem);
98-
LDBG() << "The new arg type for cuda tile: " << newType << "\n";
103+
LDBG() << "The new arg type for cuda tile: " << newType;
99104
newArgTypes.push_back(newType);
105+
inputShapes.push_back(tt.getShape());
100106
}
101107

102-
LDBG() << "Converting result type into cuda tile type" << "\n";
108+
LDBG() << "Converting result type into cuda tile type";
103109
for (mlir::Type t : gfunOp.getFunctionType().getResults()) {
104-
LDBG() << "Original result type: " << t << "\n";
110+
LDBG() << "Original result type: " << t;
105111
auto tt = llvm::dyn_cast<mlir::TensorType>(t);
106112
auto elemType = tt.getElementType();
107113
auto ptrElem = mlir::cuda_tile::PointerType::get(elemType);
108114
auto newType = mlir::cuda_tile::TileType::get({}, ptrElem);
109-
LDBG() << "The new arg type for cuda tile: " << newType << "\n";
115+
LDBG() << "The new arg type for cuda tile: " << newType;
110116
newArgTypes.push_back(newType);
117+
inputShapes.push_back(tt.getShape());
111118
}
112119

113120
auto newFnType = builder.getFunctionType(newArgTypes, {});
@@ -118,9 +125,35 @@ void ToyToCudaTileLoweringPass::runOnOperation() {
118125
/*arg_attrs=*/{}, /*res_attrs=*/{}, {});
119126
auto bb = cudaEntryOp.addEntryBlock();
120127
builder.setInsertionPointToStart(bb);
128+
// 1. create a get_tile_block_id op
129+
auto tileBlockId = mlir::cuda_tile::GetTileBlockIdOp::create(
130+
builder, gfunOp->getLoc(),
131+
{mlir::cuda_tile::TileType::get({}, builder.getI32Type()),
132+
mlir::cuda_tile::TileType::get({}, builder.getI32Type()),
133+
mlir::cuda_tile::TileType::get({}, builder.getI32Type())});
134+
for (auto [idx, arg] : llvm::enumerate(bb->getArguments())) {
135+
// 2. create a make_tensor_view op
136+
auto resultType = builder.getI64ArrayAttr(inputShapes[idx]);
137+
LDBG() << "Argument " << idx << " : " << arg << ", shape: " << resultType;
138+
auto ptrElem = llvm::dyn_cast<mlir::cuda_tile::TileType>(arg.getType())
139+
.getElementType();
140+
auto eleType = llvm::dyn_cast<mlir::cuda_tile::PointerType>(ptrElem)
141+
.getPointeeType();
142+
mlir::cuda_tile::TensorViewType tensorViewType =
143+
mlir::cuda_tile::TensorViewType::get(
144+
builder.getContext(), eleType, inputShapes[idx],
145+
/*strides=*/{inputShapes[idx].back(), 1});
146+
// LDBG() << "Creating TensorViewType: " << tensorViewType;
147+
auto make_tensor_view = mlir::cuda_tile::MakeTensorViewOp::create(
148+
builder, gfunOp->getLoc(), tensorViewType, arg,
149+
/*dynamicShape=*/mlir::ValueRange{},
150+
/*dynamicStrides=*/mlir::ValueRange{});
151+
// LDBG() << "Created MakeTensorViewOp: \n" << make_tensor_view ;
152+
}
153+
121154
auto retOp = mlir::cuda_tile::ReturnOp::create(builder, gfunOp.getLoc());
122155

123-
LDBG() << "Created CudaTile Entry Op: \n" << cudaEntryOp << "\n";
156+
LDBG() << "Created CudaTile Entry Op: \n" << cudaEntryOp;
124157
});
125158
}
126159

mlir/cuda-tile/Toy/mlir/LowerToGpu.cpp

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,7 @@ struct GpuOutlinePass
5858

5959
llvm::StringRef getArgument() const override { return "toy-gpu-outline"; }
6060

61-
void initializeOptions(std::string grid) {
62-
this->grid = grid;
63-
}
61+
void initializeOptions(std::string grid) { this->grid = grid; }
6462

6563
void runOnOperation() override {
6664
auto func = getOperation();
@@ -124,11 +122,11 @@ struct GpuOutlinePass
124122
}
125123

126124
for (const auto &gpuSubgraph : gpuSubgraphs) {
127-
LDBG() << "----GPU subgraph----\n";
125+
LDBG() << "----GPU subgraph----";
128126
for (const auto &op : gpuSubgraph) {
129-
LDBG() << *op << "\n";
127+
LDBG() << *op;
130128
}
131-
LDBG() << "--------------------\n";
129+
LDBG() << "--------------------";
132130
}
133131

134132
llvm::SmallVector<std::string> outlinedFuncNames;
@@ -144,9 +142,9 @@ struct GpuOutlinePass
144142

145143
for (const auto &[index, gpuSubgraph] : llvm::enumerate(gpuSubgraphs)) {
146144
if (!gpuSubgraph.empty()) {
147-
LDBG() << "----GPU subgraph----\n";
145+
LDBG() << "----GPU subgraph----";
148146
for (const auto &op : gpuSubgraph) {
149-
LDBG() << *op << "\n";
147+
LDBG() << *op;
150148
}
151149

152150
// Identify its operands.
@@ -162,9 +160,9 @@ struct GpuOutlinePass
162160
}
163161
}
164162

165-
LDBG() << "Operands:\n";
163+
LDBG() << "Operands:";
166164
for (mlir::Value &operand : Operands) {
167-
LDBG() << " " << operand << "\n";
165+
LDBG() << " " << operand;
168166
}
169167

170168
llvm::SmallVector<mlir::Value, 2> Results;
@@ -181,16 +179,16 @@ struct GpuOutlinePass
181179
}
182180
}
183181

184-
LDBG() << "Results:\n";
182+
LDBG() << "Results:";
185183
for (mlir::Value &result : Results) {
186-
LDBG() << " " << result << "\n";
184+
LDBG() << " " << result;
187185
}
188186

189187
if (Results.size() != 1) {
190188
llvm::errs()
191189
<< "Currently only support single result GPU kernel "
192190
<< "Since the toy return op only supports single return value "
193-
<< "Found " << Results.size() << " results\n";
191+
<< "Found " << Results.size() << " results";
194192
return signalPassFailure();
195193
}
196194

@@ -244,7 +242,7 @@ struct GpuOutlinePass
244242
mlir::toy::ReturnOp::create(kernelBuilder, func.getLoc(),
245243
mappedResults);
246244

247-
LDBG() << "Created GPU kernel: " << gpuFunc << "\n";
245+
LDBG() << "Created GPU kernel: " << gpuFunc;
248246
}
249247

250248
outlinedFuncNames.push_back(outline_func_name);
@@ -269,9 +267,9 @@ struct GpuOutlinePass
269267

270268
for (mlir::Operation *op : llvm::reverse(gpuSubgraph))
271269
op->erase();
272-
LDBG() << "Inserted LaunchGpuOp: " << launch << "\n";
270+
LDBG() << "Inserted LaunchGpuOp: " << launch;
273271
}
274-
LDBG() << "--------------------\n";
272+
LDBG() << "--------------------";
275273
}
276274
}
277275
};

0 commit comments

Comments
 (0)