Summary
The generated GEMM kernel from ptoas --enable-insert-sync is consistently much slower than both:
torch_npu torch.matmul on the same device
- the hand-written A2/A3 GEMM kernel in
pto-isa/kernels/manual/a2a3/gemm_performance
This reproduces across several GEMM shapes, not just one shape. The slowdown is roughly 2.4x - 2.6x vs torch_npu, which suggests the issue is systematic rather than shape-specific.
The current suspicion is that the auto-inserted synchronization is too conservative for this double-buffered pipeline (TLOAD -> TEXTRACT -> TMATMUL -> TSTORE) and collapses overlap between stages.
Command line
From the repository root, use the existing PTO-DSL GEMM example:
export PTOAS_ROOT=<path-to-ptoas>
export PTO_ISA_ROOT=$(pwd)
python3 kernels/python/gemm_performance/run_gemm.py \
--case a2a3_perf_3072 \
--case a2a3_allgather_gemm \
--case a2a3_perf_6144 \
--case a2a3_gemm_ar_aligned \
--benchmark --torch-npu
The relevant ptoas step inside the runner is:
ptoas --enable-insert-sync gemm_performance.pto -o gemm_performance.cpp
Reproduction input
Full reproduction input is the generated PTO IR from the existing example:
kernels/python/gemm_performance/case_builds/a2a3_perf_6144/gemm_performance.pto
The kernel is produced from:
kernels/python/gemm_performance/gemm_performance.py
Excerpt:
module {
func.func @GemmPerformance(%arg0: !pto.ptr<f32>, %arg1: !pto.ptr<f16>, %arg2: !pto.ptr<f16>) {
pto.section.cube {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c6144 = arith.constant 6144 : index
%c1536 = arith.constant 1536 : index
%c1024 = arith.constant 1024 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%c4 = arith.constant 4 : index
%0 = arith.muli %c64, %c4 : index
%1 = pto.get_block_idx
%2 = arith.index_cast %1 : i64 to index
...
scf.for %arg5 = %c0 to %8 step %c1 {
...
scf.if %37 {
...
pto.tload ...
pto.tload ...
}
...
pto.textract ...
pto.textract ...
...
pto.tmatmul ...
...
pto.tstore ...
}
}
return
}
}
Expected performance
Expected the generated kernel to be much closer to the hand-written A2/A3 kernel and torch_npu for the same tiling strategy.
For the 6144x6144x6144 case specifically:
torch_npu on the same device is about 1.41 ms / 329.5 TFLOPS
- the manual A2/A3 GEMM reference in
kernels/manual/a2a3/gemm_performance/README.md reports about 1.506 ms
So a generated kernel in the same general range (for example within ~20-30% of the manual version) would be expected. Instead it is more than 2.5x slower.
Actual performance
Measured with the command above on an A2/A3 environment:
[3072x3072x3072] avg=0.447 ms, 129.71 TFLOPS
[3072x3072x3072] torch_npu avg=0.182 ms, 319.27 TFLOPS
speedup vs torch_npu: 0.41x
[2048x2048x1024] avg=0.104 ms, 82.45 TFLOPS
[2048x2048x1024] torch_npu avg=0.042 ms, 204.23 TFLOPS
speedup vs torch_npu: 0.40x
[6144x6144x6144] avg=3.836 ms, 120.94 TFLOPS
[6144x6144x6144] torch_npu avg=1.408 ms, 329.52 TFLOPS
speedup vs torch_npu: 0.37x
[5632x6144x1536] avg=0.811 ms, 131.00 TFLOPS
[5632x6144x1536] torch_npu avg=0.315 ms, 337.91 TFLOPS
speedup vs torch_npu: 0.39x
Observations:
- The generated kernel lands around
120-131 TFLOPS for multiple shapes.
- The gap vs
torch_npu is stable across shapes (~0.37x - 0.41x), which suggests a systematic pipeline/synchronization issue.
- The hand-written A2/A3 GEMM kernel in
pto-isa is much closer to torch_npu, so the tiling strategy itself is probably not the main problem.
Profiling data (optional)
No full flamegraph attached yet, but inspection of the generated C++ suggests heavy over-synchronization after auto-insert-sync:
- many
wait_flag(...) calls in the inner K loop
- multiple
pipe_barrier(...) calls around TLOAD, TEXTRACT, and TMATMUL
PIPE_ALL barriers appear in the generated code
This is in contrast to the hand-written manual kernel, which uses a much smaller synchronization set and only synchronizes at true buffer-reuse boundaries.
Files used for comparison:
- Generated path:
kernels/python/gemm_performance/case_builds/a2a3_perf_6144/gemm_performance.cpp
- Manual path:
kernels/manual/a2a3/gemm_performance/gemm_performance_kernel.cpp
The issue may be that --enable-insert-sync is correct but too conservative for this class of L1/L0 double-buffered GEMM pipeline, reducing overlap between:
- GM -> L1 (
TLOAD)
- L1 -> L0 (
TEXTRACT)
- Cube compute (
TMATMUL)
Git commit
PTOAS:
ce40b146e828474f6ee8a9d97b3d4ac3499a5e7b
Related reproduction tree:
pto-isa: 60c3e04d44abc1417a47501b7eb440e123eb623e
Summary
The generated GEMM kernel from
ptoas --enable-insert-syncis consistently much slower than both:torch_nputorch.matmulon the same devicepto-isa/kernels/manual/a2a3/gemm_performanceThis reproduces across several GEMM shapes, not just one shape. The slowdown is roughly
2.4x - 2.6xvstorch_npu, which suggests the issue is systematic rather than shape-specific.The current suspicion is that the auto-inserted synchronization is too conservative for this double-buffered pipeline (
TLOAD -> TEXTRACT -> TMATMUL -> TSTORE) and collapses overlap between stages.Command line
From the repository root, use the existing PTO-DSL GEMM example:
The relevant
ptoasstep inside the runner is:Reproduction input
Full reproduction input is the generated PTO IR from the existing example:
kernels/python/gemm_performance/case_builds/a2a3_perf_6144/gemm_performance.ptoThe kernel is produced from:
kernels/python/gemm_performance/gemm_performance.pyExcerpt:
Expected performance
Expected the generated kernel to be much closer to the hand-written A2/A3 kernel and
torch_npufor the same tiling strategy.For the
6144x6144x6144case specifically:torch_npuon the same device is about1.41 ms/329.5 TFLOPSkernels/manual/a2a3/gemm_performance/README.mdreports about1.506 msSo a generated kernel in the same general range (for example within ~20-30% of the manual version) would be expected. Instead it is more than
2.5xslower.Actual performance
Measured with the command above on an A2/A3 environment:
Observations:
120-131 TFLOPSfor multiple shapes.torch_npuis stable across shapes (~0.37x - 0.41x), which suggests a systematic pipeline/synchronization issue.pto-isais much closer totorch_npu, so the tiling strategy itself is probably not the main problem.Profiling data (optional)
No full flamegraph attached yet, but inspection of the generated C++ suggests heavy over-synchronization after auto-insert-sync:
wait_flag(...)calls in the inner K looppipe_barrier(...)calls aroundTLOAD,TEXTRACT, andTMATMULPIPE_ALLbarriers appear in the generated codeThis is in contrast to the hand-written manual kernel, which uses a much smaller synchronization set and only synchronizes at true buffer-reuse boundaries.
Files used for comparison:
kernels/python/gemm_performance/case_builds/a2a3_perf_6144/gemm_performance.cppkernels/manual/a2a3/gemm_performance/gemm_performance_kernel.cppThe issue may be that
--enable-insert-syncis correct but too conservative for this class of L1/L0 double-buffered GEMM pipeline, reducing overlap between:TLOAD)TEXTRACT)TMATMUL)Git commit
PTOAS:
Related reproduction tree: