Skip to content

Commit a6a56d5

Browse files
committed
Added base scheduler example
1 parent 995cc42 commit a6a56d5

9 files changed

Lines changed: 219 additions & 1 deletion

File tree

mlir/optimization/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@ include(AddLLVM)
3535
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fno-rtti")
3636
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
3737

38-
add_subdirectory(explore)
38+
add_subdirectory(scheduler)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# For a better template to copy, see examples/standalone
2+
include_directories(include)
3+
add_subdirectory(include)
4+
5+
set(LLVM_LINK_COMPONENTS Core Support nativecodegen OrcJIT)
6+
7+
# set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td)
8+
# mlir_tablegen(ToyCombine.inc -gen-rewriters)
9+
# add_public_tablegen_target(ToyCh6CombineIncGen)
10+
11+
add_executable(
12+
lab-scheduler
13+
lab-opt.cpp
14+
lib/OpStatsPass.cpp
15+
)
16+
17+
# add_dependencies(lab-scheduler ToyCh6ShapeInferenceInterfaceIncGen
18+
# ToyCh6OpsIncGen ToyCh6CombineIncGen)
19+
20+
include_directories(${CMAKE_CURRENT_BINARY_DIR})
21+
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
22+
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
23+
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
24+
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
25+
target_link_libraries(
26+
lab-scheduler
27+
PRIVATE ${dialect_libs}
28+
${conversion_libs}
29+
${extension_libs}
30+
MLIRAnalysis
31+
MLIRBuiltinToLLVMIRTranslation
32+
MLIRCallInterfaces
33+
MLIRCastInterfaces
34+
MLIRExecutionEngine
35+
MLIRIR
36+
MLIRLLVMCommonConversion
37+
MLIRLLVMDialect
38+
MLIRLLVMToLLVMIRTranslation
39+
MLIRMemRefDialect
40+
MLIRParser
41+
MLIRPass
42+
MLIRSideEffectInterfaces
43+
MLIRSupport
44+
MLIRTargetLLVMIRExport
45+
MLIRTransforms)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(lab)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include "mlir/Pass/Pass.h"
4+
#include <memory>
5+
6+
namespace mlir {
7+
8+
class Pass;
9+
10+
std::unique_ptr<Pass> createLabOpStatsPass();
11+
std::unique_ptr<Pass> createLabBufferStatsPass();
12+
std::unique_ptr<Pass> createLabFusionFeasibilityPass();
13+
std::unique_ptr<Pass> createLabMatmulTilePass();
14+
std::unique_ptr<Pass> createLabPipelinePlanPass();
15+
16+
} // namespace mlir
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include "lab/LabPasses.h"
2+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
3+
#include "mlir/Dialect/Arith/IR/Arith.h"
4+
#include "mlir/Dialect/Func/IR/FuncOps.h"
5+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
6+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
7+
#include "mlir/Dialect/SCF/IR/SCF.h"
8+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
9+
#include "mlir/InitAllDialects.h"
10+
#include "mlir/InitAllPasses.h"
11+
#include "mlir/Pass/PassManager.h"
12+
#include "mlir/Pass/PassRegistry.h"
13+
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
14+
15+
int main(int argc, char **argv) {
16+
mlir::DialectRegistry registry;
17+
registry.insert<mlir::func::FuncDialect, mlir::linalg::LinalgDialect,
18+
mlir::arith::ArithDialect, mlir::tensor::TensorDialect,
19+
mlir::memref::MemRefDialect, mlir::scf::SCFDialect,
20+
mlir::affine::AffineDialect>();
21+
22+
mlir::registerAllPasses();
23+
mlir::PassPipelineRegistration<>("lab-op-stats", "Lab Op Stats Pass",
24+
[](mlir::OpPassManager &pm) {
25+
pm.addPass(mlir::createLabOpStatsPass());
26+
});
27+
28+
return mlir::asMainReturnCode(
29+
mlir::MlirOptMain(argc, argv, "Lab optimizer\n", registry));
30+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#include "mlir/Dialect/Func/IR/FuncOps.h"
2+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
3+
#include "mlir/IR/BuiltinTypes.h"
4+
#include "mlir/Pass/Pass.h"
5+
6+
using namespace mlir;
7+
8+
namespace {
9+
struct LabOpStatsPass
10+
: public PassWrapper<LabOpStatsPass, OperationPass<func::FuncOp>> {
11+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LabOpStatsPass)
12+
13+
void runOnOperation() override {
14+
func::FuncOp func = getOperation();
15+
16+
func.walk([&](Operation *op) {
17+
if (auto matmul = dyn_cast<linalg::MatmulOp>(op)) {
18+
analyzeMatmul(matmul);
19+
} else if (auto generic = dyn_cast<linalg::GenericOp>(op)) {
20+
analyzeGeneric(generic);
21+
}
22+
});
23+
}
24+
25+
static int64_t getElementBytes(Type t) {
26+
if (auto ft = dyn_cast<FloatType>(t))
27+
return ft.getWidth() / 8;
28+
if (auto it = dyn_cast<IntegerType>(t))
29+
return it.getWidth() / 8;
30+
return 0;
31+
}
32+
33+
void analyzeMatmul(linalg::MatmulOp op) {
34+
auto aType =
35+
dyn_cast<ShapedType>(op.getDpsInputOperand(0)->get().getType());
36+
auto bType =
37+
dyn_cast<ShapedType>(op.getDpsInputOperand(1)->get().getType());
38+
auto cType = dyn_cast<ShapedType>(op.getDpsInitOperand(0)->get().getType());
39+
40+
if (!aType || !bType || !cType || !aType.hasStaticShape() ||
41+
!bType.hasStaticShape() || !cType.hasStaticShape()) {
42+
op.emitRemark() << "[lab-op-stats] dynamic shape matmul, skip";
43+
return;
44+
}
45+
46+
int64_t M = aType.getShape()[0];
47+
int64_t K = aType.getShape()[1];
48+
int64_t N = bType.getShape()[1];
49+
50+
int64_t elemBytes = getElementBytes(aType.getElementType());
51+
if (elemBytes == 0) {
52+
op.emitRemark() << "[lab-op-stats] unsupported element type";
53+
return;
54+
}
55+
56+
int64_t flops = 2 * M * N * K;
57+
int64_t aBytes = aType.getNumElements() * elemBytes;
58+
int64_t bBytes = bType.getNumElements() * elemBytes;
59+
int64_t cBytes = cType.getNumElements() * elemBytes;
60+
int64_t totalBytes = aBytes + bBytes + cBytes;
61+
62+
double intensity = totalBytes > 0 ? static_cast<double>(flops) /
63+
static_cast<double>(totalBytes)
64+
: 0.0;
65+
66+
op.emitRemark() << "[lab-op-stats] matmul "
67+
<< "M=" << M << " N=" << N << " K=" << K
68+
<< " flops=" << flops << " bytes=" << totalBytes
69+
<< " intensity=" << intensity;
70+
}
71+
72+
void analyzeGeneric(linalg::GenericOp op) {
73+
unsigned numLoops = op.getNumLoops();
74+
unsigned numParallel = op.getNumParallelLoops();
75+
unsigned numReduction = numLoops - numParallel;
76+
77+
op.emitRemark() << "[lab-op-stats] generic "
78+
<< "loops=" << numLoops << " parallel=" << numParallel
79+
<< " reduction=" << numReduction;
80+
}
81+
};
82+
} // namespace
83+
84+
namespace mlir {
85+
std::unique_ptr<Pass> createLabOpStatsPass() {
86+
return std::make_unique<LabOpStatsPass>();
87+
}
88+
} // namespace mlir
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
module {
2+
func.func @conv_relu(%input: tensor<1x16x32x32xf32>,
3+
%filter: tensor<32x16x3x3xf32>,
4+
%init: tensor<1x32x30x30xf32>) -> tensor<1x32x30x30xf32> {
5+
%0 = linalg.conv_2d_nchw_fchw
6+
ins(%input, %filter : tensor<1x16x32x32xf32>, tensor<32x16x3x3xf32>)
7+
outs(%init : tensor<1x32x30x30xf32>) -> tensor<1x32x30x30xf32>
8+
9+
%cst = arith.constant 0.0 : f32
10+
%1 = linalg.generic
11+
{indexing_maps = [
12+
affine_map<(n, c, h, w) -> (n, c, h, w)>,
13+
affine_map<(n, c, h, w) -> ()>,
14+
affine_map<(n, c, h, w) -> (n, c, h, w)>
15+
],
16+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
17+
ins(%0, %cst : tensor<1x32x30x30xf32>, f32)
18+
outs(%init : tensor<1x32x30x30xf32>) {
19+
^bb0(%x: f32, %zero: f32, %out: f32):
20+
%cmp = arith.cmpf oge, %x, %zero : f32
21+
%sel = arith.select %cmp, %x, %zero : f32
22+
linalg.yield %sel : f32
23+
} -> tensor<1x32x30x30xf32>
24+
25+
return %1 : tensor<1x32x30x30xf32>
26+
}
27+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module {
2+
func.func @matmul(%A: tensor<128x256xf32>,
3+
%B: tensor<256x128xf32>,
4+
%C: tensor<128x128xf32>) -> tensor<128x128xf32> {
5+
%0 = linalg.matmul
6+
ins(%A, %B : tensor<128x256xf32>, tensor<256x128xf32>)
7+
outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
8+
return %0 : tensor<128x128xf32>
9+
}
10+
}

0 commit comments

Comments
 (0)