1+ #include " mlir/Dialect/Func/IR/FuncOps.h"
2+ #include " mlir/Dialect/Linalg/IR/Linalg.h"
3+ #include " mlir/IR/BuiltinTypes.h"
4+ #include " mlir/Pass/Pass.h"
5+
6+ using namespace mlir ;
7+
8+ namespace {
9+ struct LabOpStatsPass
10+ : public PassWrapper<LabOpStatsPass, OperationPass<func::FuncOp>> {
11+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (LabOpStatsPass)
12+
13+ void runOnOperation () override {
14+ func::FuncOp func = getOperation ();
15+
16+ func.walk ([&](Operation *op) {
17+ if (auto matmul = dyn_cast<linalg::MatmulOp>(op)) {
18+ analyzeMatmul (matmul);
19+ } else if (auto generic = dyn_cast<linalg::GenericOp>(op)) {
20+ analyzeGeneric (generic);
21+ }
22+ });
23+ }
24+
25+ static int64_t getElementBytes (Type t) {
26+ if (auto ft = dyn_cast<FloatType>(t))
27+ return ft.getWidth () / 8 ;
28+ if (auto it = dyn_cast<IntegerType>(t))
29+ return it.getWidth () / 8 ;
30+ return 0 ;
31+ }
32+
33+ void analyzeMatmul (linalg::MatmulOp op) {
34+ auto aType =
35+ dyn_cast<ShapedType>(op.getDpsInputOperand (0 )->get ().getType ());
36+ auto bType =
37+ dyn_cast<ShapedType>(op.getDpsInputOperand (1 )->get ().getType ());
38+ auto cType = dyn_cast<ShapedType>(op.getDpsInitOperand (0 )->get ().getType ());
39+
40+ if (!aType || !bType || !cType || !aType.hasStaticShape () ||
41+ !bType.hasStaticShape () || !cType.hasStaticShape ()) {
42+ op.emitRemark () << " [lab-op-stats] dynamic shape matmul, skip" ;
43+ return ;
44+ }
45+
46+ int64_t M = aType.getShape ()[0 ];
47+ int64_t K = aType.getShape ()[1 ];
48+ int64_t N = bType.getShape ()[1 ];
49+
50+ int64_t elemBytes = getElementBytes (aType.getElementType ());
51+ if (elemBytes == 0 ) {
52+ op.emitRemark () << " [lab-op-stats] unsupported element type" ;
53+ return ;
54+ }
55+
56+ int64_t flops = 2 * M * N * K;
57+ int64_t aBytes = aType.getNumElements () * elemBytes;
58+ int64_t bBytes = bType.getNumElements () * elemBytes;
59+ int64_t cBytes = cType.getNumElements () * elemBytes;
60+ int64_t totalBytes = aBytes + bBytes + cBytes;
61+
62+ double intensity = totalBytes > 0 ? static_cast <double >(flops) /
63+ static_cast <double >(totalBytes)
64+ : 0.0 ;
65+
66+ op.emitRemark () << " [lab-op-stats] matmul "
67+ << " M=" << M << " N=" << N << " K=" << K
68+ << " flops=" << flops << " bytes=" << totalBytes
69+ << " intensity=" << intensity;
70+ }
71+
72+ void analyzeGeneric (linalg::GenericOp op) {
73+ unsigned numLoops = op.getNumLoops ();
74+ unsigned numParallel = op.getNumParallelLoops ();
75+ unsigned numReduction = numLoops - numParallel;
76+
77+ op.emitRemark () << " [lab-op-stats] generic "
78+ << " loops=" << numLoops << " parallel=" << numParallel
79+ << " reduction=" << numReduction;
80+ }
81+ };
82+ } // namespace
83+
84+ namespace mlir {
85+ std::unique_ptr<Pass> createLabOpStatsPass () {
86+ return std::make_unique<LabOpStatsPass>();
87+ }
88+ } // namespace mlir
0 commit comments