1515#include " mlir/Support/TypeID.h"
1616#include " toy/Dialect.h"
1717#include " toy/Passes.h"
18+ #include " llvm/ADT/ArrayRef.h"
1819#include " llvm/ADT/STLExtras.h"
1920#include " llvm/ADT/SmallPtrSet.h"
2021#include " llvm/ADT/SmallSet.h"
2728#include " cuda_tile/Dialect/CudaTile/IR/Dialect.h"
2829#include " cuda_tile/Dialect/CudaTile/IR/Ops.h"
2930
31+ #include < cstdint>
3032#include < memory>
3133#include < string>
3234
@@ -60,7 +62,7 @@ mlir::cuda_tile::ModuleOp createCudaModuleOp(mlir::OpBuilder &builder,
6062 auto cudaTileModuleOp = mlir::cuda_tile::ModuleOp::create (
6163 builder, moduleOp.getLoc (), " cuda_tile_module" );
6264
63- LDBG () << " Created CudaTile Module: \n " << cudaTileModuleOp << " \n " ;
65+ LDBG () << " Created CudaTile Module: \n " << cudaTileModuleOp;
6466 return cudaTileModuleOp;
6567}
6668
@@ -70,7 +72,7 @@ void ToyToCudaTileLoweringPass::runOnOperation() {
7072 // Here we would implement the actual lowering logic from Toy GPUFuncOp
7173 // to CudaTile operations. For now, we just log that the pass is running.
7274 // LDBG() << "Running Toy to CudaTile lowering on GPUFuncOp: " << moduleOp
73- // << "\n" ;
75+ // ;
7476
7577 mlir::OpBuilder builder (moduleOp.getContext ());
7678 // 1. Create new cuda_tile.module Op in the last section.
@@ -86,28 +88,33 @@ void ToyToCudaTileLoweringPass::runOnOperation() {
8688 gfunOp->getAttrOfType <mlir::StringAttr>(" sym_name" ).getValue ();
8789 llvm::SmallVector<mlir::Type, 8 > newArgTypes;
8890
89- LDBG () << " Lowering GPU function: " << gfunc_name << " \n " ;
90- LDBG () << " Converting input type into cuda tile type" << " \n " ;
91+ LDBG () << " Lowering GPU function: " << gfunc_name;
92+ LDBG () << " Converting input type into cuda tile type" ;
93+
94+ llvm::SmallVector<llvm::ArrayRef<int64_t >, 4 > inputShapes;
95+ // llvm::SmallVector<llvm::ArrayRef<int64_t>, 4> resultShapes;
9196
9297 for (mlir::Type t : gfunOp.getFunctionType ().getInputs ()) {
93- LDBG () << " Original arg type: " << t << " \n " ;
98+ LDBG () << " Original arg type: " << t;
9499 auto tt = llvm::dyn_cast<mlir::TensorType>(t);
95100 auto elemType = tt.getElementType ();
96101 auto ptrElem = mlir::cuda_tile::PointerType::get (elemType);
97102 auto newType = mlir::cuda_tile::TileType::get ({}, ptrElem);
98- LDBG () << " The new arg type for cuda tile: " << newType << " \n " ;
103+ LDBG () << " The new arg type for cuda tile: " << newType;
99104 newArgTypes.push_back (newType);
105+ inputShapes.push_back (tt.getShape ());
100106 }
101107
102- LDBG () << " Converting result type into cuda tile type" << " \n " ;
108+ LDBG () << " Converting result type into cuda tile type" ;
103109 for (mlir::Type t : gfunOp.getFunctionType ().getResults ()) {
104- LDBG () << " Original result type: " << t << " \n " ;
110+ LDBG () << " Original result type: " << t;
105111 auto tt = llvm::dyn_cast<mlir::TensorType>(t);
106112 auto elemType = tt.getElementType ();
107113 auto ptrElem = mlir::cuda_tile::PointerType::get (elemType);
108114 auto newType = mlir::cuda_tile::TileType::get ({}, ptrElem);
109- LDBG () << " The new arg type for cuda tile: " << newType << " \n " ;
115+ LDBG () << " The new arg type for cuda tile: " << newType;
110116 newArgTypes.push_back (newType);
117+ inputShapes.push_back (tt.getShape ());
111118 }
112119
113120 auto newFnType = builder.getFunctionType (newArgTypes, {});
@@ -118,9 +125,35 @@ void ToyToCudaTileLoweringPass::runOnOperation() {
118125 /* arg_attrs=*/ {}, /* res_attrs=*/ {}, {});
119126 auto bb = cudaEntryOp.addEntryBlock ();
120127 builder.setInsertionPointToStart (bb);
128+ // 1. create a get_tile_block_id op
129+ auto tileBlockId = mlir::cuda_tile::GetTileBlockIdOp::create (
130+ builder, gfunOp->getLoc (),
131+ {mlir::cuda_tile::TileType::get ({}, builder.getI32Type ()),
132+ mlir::cuda_tile::TileType::get ({}, builder.getI32Type ()),
133+ mlir::cuda_tile::TileType::get ({}, builder.getI32Type ())});
134+ for (auto [idx, arg] : llvm::enumerate (bb->getArguments ())) {
135+ // 2. create a make_tensor_view op
136+ auto resultType = builder.getI64ArrayAttr (inputShapes[idx]);
137+ LDBG () << " Argument " << idx << " : " << arg << " , shape: " << resultType;
138+ auto ptrElem = llvm::dyn_cast<mlir::cuda_tile::TileType>(arg.getType ())
139+ .getElementType ();
140+ auto eleType = llvm::dyn_cast<mlir::cuda_tile::PointerType>(ptrElem)
141+ .getPointeeType ();
142+ mlir::cuda_tile::TensorViewType tensorViewType =
143+ mlir::cuda_tile::TensorViewType::get (
144+ builder.getContext (), eleType, inputShapes[idx],
145+ /* strides=*/ {inputShapes[idx].back (), 1 });
146+ // LDBG() << "Creating TensorViewType: " << tensorViewType;
147+ auto make_tensor_view = mlir::cuda_tile::MakeTensorViewOp::create (
148+ builder, gfunOp->getLoc (), tensorViewType, arg,
149+ /* dynamicShape=*/ mlir::ValueRange{},
150+ /* dynamicStrides=*/ mlir::ValueRange{});
151+ // LDBG() << "Created MakeTensorViewOp: \n" << make_tensor_view ;
152+ }
153+
121154 auto retOp = mlir::cuda_tile::ReturnOp::create (builder, gfunOp.getLoc ());
122155
123- LDBG () << " Created CudaTile Entry Op: \n " << cudaEntryOp << " \n " ;
156+ LDBG () << " Created CudaTile Entry Op: \n " << cudaEntryOp;
124157 });
125158}
126159
0 commit comments