22#include " mlir/Dialect/Linalg/IR/Linalg.h"
33#include " mlir/IR/BuiltinTypes.h"
44#include " mlir/Pass/Pass.h"
5+ #include < cstdint>
56
67using namespace mlir ;
78
89namespace {
9- struct LabOpStatsPass
10- : public PassWrapper<LabOpStatsPass, OperationPass<func::FuncOp>> {
11- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (LabOpStatsPass)
1210
13- void runOnOperation () override {
14- func::FuncOp func = getOperation ();
11+ class LabOpStateAnalysis {
12+ public:
13+ struct States {
14+ int64_t M;
15+ int64_t N;
16+ int64_t K;
17+ int64_t flops;
18+ int64_t bytes;
19+ int64_t macs;
20+ double intensity;
21+ unsigned numLoops;
22+ unsigned numParallel;
23+ unsigned numReduction;
24+ };
1525
26+ explicit LabOpStateAnalysis (Operation *op) {
27+ auto func = cast<func::FuncOp>(op);
1628 func.walk ([&](Operation *op) {
1729 if (auto matmul = dyn_cast<linalg::MatmulOp>(op)) {
1830 analyzeMatmul (matmul);
@@ -54,6 +66,7 @@ struct LabOpStatsPass
5466 }
5567
5668 int64_t flops = 2 * M * N * K;
69+ int64_t totalMacs = M * N * K;
5770 int64_t aBytes = aType.getNumElements () * elemBytes;
5871 int64_t bBytes = bType.getNumElements () * elemBytes;
5972 int64_t cBytes = cType.getNumElements () * elemBytes;
@@ -63,20 +76,49 @@ struct LabOpStatsPass
6376 static_cast <double >(totalBytes)
6477 : 0.0 ;
6578
66- op.emitRemark () << " [lab-op-stats] matmul "
67- << " M=" << M << " N=" << N << " K=" << K
68- << " flops=" << flops << " bytes=" << totalBytes
69- << " intensity=" << intensity;
79+ states.M = M;
80+ states.N = N;
81+ states.K = K;
82+ states.flops = flops;
83+ states.bytes = totalBytes;
84+ states.macs = totalMacs;
85+ states.intensity = intensity;
7086 }
7187
7288 void analyzeGeneric (linalg::GenericOp op) {
7389 unsigned numLoops = op.getNumLoops ();
7490 unsigned numParallel = op.getNumParallelLoops ();
7591 unsigned numReduction = numLoops - numParallel;
7692
77- op.emitRemark () << " [lab-op-stats] generic "
78- << " loops=" << numLoops << " parallel=" << numParallel
79- << " reduction=" << numReduction;
93+ states.numLoops = numLoops;
94+ states.numParallel = numParallel;
95+ states.numReduction = numReduction;
96+ }
97+
98+ const States &getStates () const { return states; }
99+
100+ private:
101+ States states;
102+ };
103+
104+ struct LabOpStatsPass
105+ : public PassWrapper<LabOpStatsPass, OperationPass<func::FuncOp>> {
106+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (LabOpStatsPass)
107+
108+ StringRef getArgument () const override { return " lab-op-stats" ; }
109+
110+ void runOnOperation () override {
111+ auto func = getOperation ();
112+ auto &analysis = getAnalysis<LabOpStateAnalysis>();
113+ const auto &states = analysis.getStates ();
114+
115+ func.emitRemark () << " [lab-op-stats] M=" << states.M << " N=" << states.N
116+ << " K=" << states.K << " FLOPs=" << states.flops
117+ << " Bytes=" << states.bytes << " MACs=" << states.macs
118+ << " Intensity=" << states.intensity
119+ << " Loops=" << states.numLoops
120+ << " Parallel=" << states.numParallel
121+ << " Reduction=" << states.numReduction ;
80122 }
81123};
82124} // namespace
0 commit comments