Skip to content

Commit e43098e

Browse files
committed
Change to analysis implementation
1 parent a6a56d5 commit e43098e

File tree

1 file changed

+54
-12
lines changed

1 file changed

+54
-12
lines changed

mlir/optimization/scheduler/lib/OpStatsPass.cpp

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,29 @@
22
#include "mlir/Dialect/Linalg/IR/Linalg.h"
33
#include "mlir/IR/BuiltinTypes.h"
44
#include "mlir/Pass/Pass.h"
5+
#include <cstdint>
56

67
using namespace mlir;
78

89
namespace {
9-
struct LabOpStatsPass
10-
: public PassWrapper<LabOpStatsPass, OperationPass<func::FuncOp>> {
11-
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LabOpStatsPass)
1210

13-
void runOnOperation() override {
14-
func::FuncOp func = getOperation();
11+
class LabOpStateAnalysis {
12+
public:
13+
struct States {
14+
int64_t M;
15+
int64_t N;
16+
int64_t K;
17+
int64_t flops;
18+
int64_t bytes;
19+
int64_t macs;
20+
double intensity;
21+
unsigned numLoops;
22+
unsigned numParallel;
23+
unsigned numReduction;
24+
};
1525

26+
explicit LabOpStateAnalysis(Operation *op) {
27+
auto func = cast<func::FuncOp>(op);
1628
func.walk([&](Operation *op) {
1729
if (auto matmul = dyn_cast<linalg::MatmulOp>(op)) {
1830
analyzeMatmul(matmul);
@@ -54,6 +66,7 @@ struct LabOpStatsPass
5466
}
5567

5668
int64_t flops = 2 * M * N * K;
69+
int64_t totalMacs = M * N * K;
5770
int64_t aBytes = aType.getNumElements() * elemBytes;
5871
int64_t bBytes = bType.getNumElements() * elemBytes;
5972
int64_t cBytes = cType.getNumElements() * elemBytes;
@@ -63,20 +76,49 @@ struct LabOpStatsPass
6376
static_cast<double>(totalBytes)
6477
: 0.0;
6578

66-
op.emitRemark() << "[lab-op-stats] matmul "
67-
<< "M=" << M << " N=" << N << " K=" << K
68-
<< " flops=" << flops << " bytes=" << totalBytes
69-
<< " intensity=" << intensity;
79+
states.M = M;
80+
states.N = N;
81+
states.K = K;
82+
states.flops = flops;
83+
states.bytes = totalBytes;
84+
states.macs = totalMacs;
85+
states.intensity = intensity;
7086
}
7187

7288
void analyzeGeneric(linalg::GenericOp op) {
7389
unsigned numLoops = op.getNumLoops();
7490
unsigned numParallel = op.getNumParallelLoops();
7591
unsigned numReduction = numLoops - numParallel;
7692

77-
op.emitRemark() << "[lab-op-stats] generic "
78-
<< "loops=" << numLoops << " parallel=" << numParallel
79-
<< " reduction=" << numReduction;
93+
states.numLoops = numLoops;
94+
states.numParallel = numParallel;
95+
states.numReduction = numReduction;
96+
}
97+
98+
const States &getStates() const { return states; }
99+
100+
private:
101+
States states;
102+
};
103+
104+
struct LabOpStatsPass
105+
: public PassWrapper<LabOpStatsPass, OperationPass<func::FuncOp>> {
106+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LabOpStatsPass)
107+
108+
StringRef getArgument() const override { return "lab-op-stats"; }
109+
110+
void runOnOperation() override {
111+
auto func = getOperation();
112+
auto &analysis = getAnalysis<LabOpStateAnalysis>();
113+
const auto &states = analysis.getStates();
114+
115+
func.emitRemark() << "[lab-op-stats] M=" << states.M << " N=" << states.N
116+
<< " K=" << states.K << " FLOPs=" << states.flops
117+
<< " Bytes=" << states.bytes << " MACs=" << states.macs
118+
<< " Intensity=" << states.intensity
119+
<< " Loops=" << states.numLoops
120+
<< " Parallel=" << states.numParallel
121+
<< " Reduction=" << states.numReduction;
80122
}
81123
};
82124
} // namespace

0 commit comments

Comments
 (0)