Skip to content

perf: optimize hipBLASLt grouped GEMM with algo tuning, enable grouped_gemm autotune hipblaslt support#284

Open
kyle-256 wants to merge 23 commits into
mainfrom
dev/kyle/improve_bf16_triton_gg
Open

perf: optimize hipBLASLt grouped GEMM with algo tuning, enable grouped_gemm autotune hipblaslt support#284
kyle-256 wants to merge 23 commits into
mainfrom
dev/kyle/improve_bf16_triton_gg

Conversation

@kyle-256
Copy link
Copy Markdown
Collaborator

hipBLASLt Grouped GEMM: 4-Way Performance Comparison

BF16, MI355X — Balanced & Unbalanced Routing

Backends: hipBLASLt, Triton, Autotune (best of hipBLASLt + Triton + CK per shape), grouped_gemm (GMM)

DeepSeek-V3 — Balanced

Layer B M N K Op hipBLASLt Triton GMM Autotune Autotune/GMM
GateUP 16 2048 4096 7168 fwd 992 1271 1063 1279 1.203x
GateUP 16 2048 4096 7168 dgrad 1091 1230 1139 1231 1.081x
GateUP 16 2048 4096 7168 wgrad 1256 1186 1296 1248 0.963x
GateUP 16 4096 4096 7168 fwd 1534 1292 1507 1524 1.011x
GateUP 16 4096 4096 7168 dgrad 1554 1257 1526 1543 1.011x
GateUP 16 4096 4096 7168 wgrad 1385 1307 1413 1368 0.968x
Down 16 2048 7168 2048 fwd 1203 1092 1219 1185 0.972x
Down 16 2048 7168 2048 dgrad 1208 1080 1261 1203 0.954x
Down 16 2048 7168 2048 wgrad 1182 1136 1202 1182 0.983x
Down 16 4096 7168 2048 fwd 1268 1139 1298 1252 0.965x
Down 16 4096 7168 2048 dgrad 1287 1143 1317 1267 0.962x
Down 16 4096 7168 2048 wgrad 1320 1254 1343 1302 0.969x
GateUP 32 2048 4096 7168 fwd 993 1269 1053 1272 1.208x
GateUP 32 2048 4096 7168 dgrad 1097 1238 1132 1237 1.093x
GateUP 32 2048 4096 7168 wgrad 1260 1186 1271 1258 0.990x
GateUP 32 4096 4096 7168 fwd 1516 1276 1497 1529 1.021x
GateUP 32 4096 4096 7168 dgrad 1553 1254 1514 1551 1.024x
GateUP 32 4096 4096 7168 wgrad 1387 1317 1390 1382 0.994x
Down 32 2048 7168 2048 fwd 1210 1105 1205 1194 0.991x
Down 32 2048 7168 2048 dgrad 1238 1080 1219 1214 0.996x
Down 32 2048 7168 2048 wgrad 1187 1153 1203 1184 0.984x
Down 32 4096 7168 2048 fwd 1273 1124 1277 1263 0.989x
Down 32 4096 7168 2048 dgrad 1293 1130 1282 1286 1.003x
Down 32 4096 7168 2048 wgrad 1325 1260 1343 1316 0.980x
avg 1276 1199 1290 1303 1.010x

DeepSeek-V3 — Unbalanced

Layer B M N K Op hipBLASLt Triton GMM Autotune Autotune/GMM
GateUP 16 2048 4096 7168 fwd 1048 1124 1127 1142 1.013x
GateUP 16 2048 4096 7168 dgrad 1093 1112 1166 1111 0.953x
GateUP 16 2048 4096 7168 wgrad 1098 1159 1145 1090 0.952x
GateUP 16 4096 4096 7168 fwd 1158 1214 1071 1155 1.078x
GateUP 16 4096 4096 7168 dgrad 1243 1197 1144 1238 1.082x
GateUP 16 4096 4096 7168 wgrad 1288 1291 1280 1284 1.003x
Down 16 2048 7168 2048 fwd 969 1062 960 957 0.997x
Down 16 2048 7168 2048 dgrad 1002 1031 989 991 1.002x
Down 16 2048 7168 2048 wgrad 1047 1107 1069 1030 0.964x
Down 16 4096 7168 2048 fwd 1175 1092 1148 1173 1.022x
Down 16 4096 7168 2048 dgrad 1206 1092 1178 1196 1.015x
Down 16 4096 7168 2048 wgrad 1230 1241 1239 1231 0.994x
GateUP 32 2048 4096 7168 fwd 1058 1117 1037 1113 1.073x
GateUP 32 2048 4096 7168 dgrad 1078 1115 1084 1114 1.028x
GateUP 32 2048 4096 7168 wgrad 1115 1182 1112 1113 1.001x
GateUP 32 4096 4096 7168 fwd 1255 1248 1178 1252 1.063x
GateUP 32 4096 4096 7168 dgrad 1297 1223 1238 1293 1.044x
GateUP 32 4096 4096 7168 wgrad 1285 1274 1283 1285 1.002x
Down 32 2048 7168 2048 fwd 986 1033 888 975 1.098x
Down 32 2048 7168 2048 dgrad 1031 1015 961 1025 1.067x
Down 32 2048 7168 2048 wgrad 1051 1151 1046 1052 1.006x
Down 32 4096 7168 2048 fwd 1142 1092 1133 1132 0.999x
Down 32 4096 7168 2048 dgrad 1177 1094 1172 1172 1.000x
Down 32 4096 7168 2048 wgrad 1238 1237 1250 1231 0.985x
avg 1136 1146 1121 1140 1.017x

gpt_oss_20B — Balanced

Layer B M N K Op hipBLASLt Triton GMM Autotune Autotune/GMM
GateUP 4 2048 5760 2880 fwd 1066 980 972 1011 1.040x
GateUP 4 2048 5760 2880 dgrad 1134 1023 1011 1019 1.008x
GateUP 4 2048 5760 2880 wgrad 796 899 720 894 1.242x
GateUP 4 4096 5760 2880 fwd 1165 1062 1114 1124 1.009x
GateUP 4 4096 5760 2880 dgrad 1142 1047 1091 1117 1.024x
GateUP 4 4096 5760 2880 wgrad 925 1048 840 1039 1.237x
Down 4 2048 2880 2880 fwd 499 810 653 795 1.217x
Down 4 2048 2880 2880 dgrad 881 733 751 782 1.041x
Down 4 2048 2880 2880 wgrad 633 699 628 686 1.092x
Down 4 4096 2880 2880 fwd 1114 1010 897 1014 1.130x
Down 4 4096 2880 2880 dgrad 1141 1039 1009 1016 1.007x
Down 4 4096 2880 2880 wgrad 767 860 777 848 1.091x
GateUP 32 2048 5760 2880 fwd 1058 1087 1042 1087 1.043x
GateUP 32 2048 5760 2880 dgrad 1139 1112 1107 1125 1.016x
GateUP 32 2048 5760 2880 wgrad 812 1082 828 1080 1.304x
GateUP 32 4096 5760 2880 fwd 1100 1152 1069 1163 1.088x
GateUP 32 4096 5760 2880 dgrad 1180 1148 1147 1175 1.024x
GateUP 32 4096 5760 2880 wgrad 931 1145 872 1146 1.314x
Down 32 2048 2880 2880 fwd 506 838 853 835 0.979x
Down 32 2048 2880 2880 dgrad 846 1027 973 1026 1.054x
Down 32 2048 2880 2880 wgrad 687 1025 782 1007 1.288x
Down 32 4096 2880 2880 fwd 1074 1067 1087 1068 0.983x
Down 32 4096 2880 2880 dgrad 1131 1071 1112 1138 1.023x
Down 32 4096 2880 2880 wgrad 809 1106 838 1104 1.317x
avg 939 1003 924 1012 1.096x

gpt_oss_20B — Unbalanced

Layer B M N K Op hipBLASLt Triton GMM Autotune Autotune/GMM
GateUP 4 2048 5760 2880 fwd 976 858 858 847 0.987x
GateUP 4 2048 5760 2880 dgrad 1021 863 879 864 0.983x
GateUP 4 2048 5760 2880 wgrad 720 869 694 861 1.241x
GateUP 4 4096 5760 2880 fwd 1075 1056 1004 1036 1.032x
GateUP 4 4096 5760 2880 dgrad 1094 1045 1043 1060 1.016x
GateUP 4 4096 5760 2880 wgrad 898 1034 823 1026 1.247x
Down 4 2048 2880 2880 fwd 738 672 572 660 1.154x
Down 4 2048 2880 2880 dgrad 816 735 679 729 1.074x
Down 4 2048 2880 2880 wgrad 677 656 576 644 1.118x
Down 4 4096 2880 2880 fwd 916 837 718 825 1.149x
Down 4 4096 2880 2880 dgrad 995 870 883 864 0.978x
Down 4 4096 2880 2880 wgrad 800 776 637 768 1.206x
GateUP 32 2048 5760 2880 fwd 919 1025 876 1018 1.162x
GateUP 32 2048 5760 2880 dgrad 988 1025 955 983 1.029x
GateUP 32 2048 5760 2880 wgrad 772 1061 791 1053 1.331x
GateUP 32 4096 5760 2880 fwd 1068 1109 1024 1112 1.086x
GateUP 32 4096 5760 2880 dgrad 1140 1094 1079 1131 1.048x
GateUP 32 4096 5760 2880 wgrad 904 1116 833 1113 1.336x
Down 32 2048 2880 2880 fwd 744 793 724 790 1.091x
Down 32 2048 2880 2880 dgrad 821 947 829 943 1.138x
Down 32 2048 2880 2880 wgrad 733 987 757 999 1.320x
Down 32 4096 2880 2880 fwd 928 1052 870 1054 1.211x
Down 32 4096 2880 2880 dgrad 1031 1032 952 1019 1.070x
Down 32 4096 2880 2880 wgrad 844 1079 840 1063 1.265x
avg 901 941 829 936 1.129x

Overall Summary

Scenario hipBLASLt (avg) Triton (avg) GMM (avg) Autotune (avg) Autotune/GMM Autotune wins vs GMM
DeepSeek-V3 bal 1276 1199 1290 1303 1.010x 9/24
DeepSeek-V3 unbal 1136 1146 1121 1140 1.017x 17/24
gpt_oss_20B bal 939 1003 924 1012 1.096x 22/24
gpt_oss_20B unbal 901 941 829 936 1.129x 21/24
TOTAL 1063 1072 1041 1098 1.055x 69/96

Per-op breakdown (all models, all routing):

Op hipBLASLt (avg) Triton (avg) GMM (avg) Autotune (avg) Autotune/GMM
fwd 1054 1061 1031 1089 1.056x
dgrad 1123 1066 1088 1114 1.024x
wgrad 1011 1090 1004 1090 1.086x

Copilot AI review requested due to automatic review settings April 14, 2026 06:42
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the hipBLASLt grouped GEMM implementation to improve BF16 performance on MI3xx-class GPUs, adding a serial-vs-parallel dispatch heuristic and introducing caching/autotuning in the underlying hipBLASLt GEMM wrapper. It also extends the benchmarking scripts to compare balanced vs unbalanced routing and adds a new baseline benchmark for the external grouped_gemm (GMM) library.

Changes:

  • Switch hipBLASLt grouped GEMM to copy group_lens to host and use a host-side group_lens buffer to avoid device/host coherence pitfalls.
  • Add serial/parallel dispatch selection (multi-stream) in hipBLASLt grouped GEMM and increase max streams/workspace scaling.
  • Add hipBLASLt descriptor + algo caching and an autotuning benchmark loop; expand benchmark scripts (balanced/unbalanced routing + new GMM baseline script).

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_impl.py Forces group_lens onto CPU for hipBLASLt BF16 grouped GEMM calls.
csrc/pytorch/grouped_gemm/hipblaslt_grouped_gemm.cpp Plumbs a group_lens_on_host flag into params.
csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu Implements host-side group_lens copy and adds serial/parallel multi-stream dispatch.
csrc/kernels/gemm/hipblaslt_gemm.cu Adds thread-local descriptor/algo caches and an autotuning loop over heuristic candidates.
csrc/include/primus_turbo/grouped_gemm.h Extends hipBLASLt grouped GEMM params with group_lens_on_host.
benchmark/ops/config.py Adds commented model configs for benchmarking reference.
benchmark/ops/bench_grouped_gemm_turbo.py Adds balanced/unbalanced routing option and increases warmup/measurement iterations.
benchmark/ops/bench_grouped_gemm_gmm.py New benchmark for grouped_gemm (GMM) baseline.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread benchmark/ops/bench_grouped_gemm_turbo.py
Comment thread benchmark/ops/bench_grouped_gemm_gmm.py
Comment thread csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu Outdated
Comment thread csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu Outdated
Comment thread csrc/kernels/gemm/hipblaslt_gemm.cu
Comment thread benchmark/ops/bench_grouped_gemm_turbo.py
@kyle-256 kyle-256 changed the title [WIP] improve bf16 triton gg [WIP] improve bf16 gg Apr 14, 2026
Copilot AI review requested due to automatic review settings April 14, 2026 07:49
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 9 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu
Comment thread csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu
Comment thread csrc/kernels/gemm/hipblaslt_gemm.cu
Comment thread benchmark/ops/bench_grouped_gemm_turbo.py
@kyle-256 kyle-256 force-pushed the dev/kyle/improve_bf16_triton_gg branch from e7aa21c to 63330c7 Compare April 14, 2026 09:09
Copilot AI review requested due to automatic review settings April 15, 2026 06:58
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 9 changed files in this pull request and generated 9 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu Outdated
Comment on lines +15 to +16
import grouped_gemm.ops as gmm_ops
import pandas as pd
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This benchmark unconditionally imports grouped_gemm.ops. If the optional grouped_gemm dependency isn't installed, the whole script will fail immediately with ImportError. Consider wrapping the import in try/except and emitting a clear install hint (or documenting this dependency alongside the benchmark).

Copilot uses AI. Check for mistakes.
Comment thread benchmark/ops/bench_grouped_gemm_turbo.py
Comment thread csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu
Comment thread csrc/kernels/gemm/hipblaslt_gemm.cu
Comment thread csrc/kernels/gemm/hipblaslt_gemm.cu
Comment thread csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu
Comment on lines 27 to 30
std::int64_t get_hipblaslt_grouped_gemm_workspace_size() {
// Multi-stream path needs one workspace per stream.
return kMaxNumStreams * get_hipblaslt_workspace_size_in_byte();
}
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_hipblaslt_grouped_gemm_workspace_size() always reserves workspace for kMaxNumStreams streams, even when dispatch_serial() is chosen (which only uses one handle/stream). If memory is a concern, consider allocating a smaller workspace for the serial path or splitting the API so callers can request serial vs parallel workspace sizes.

Copilot uses AI. Check for mistakes.
Comment thread csrc/kernels/gemm/hipblaslt_gemm.cu
@kyle-256
Copy link
Copy Markdown
Collaborator Author

kyle-256 commented Apr 15, 2026

Root cause of <50T / hang on main branch hipBLASLt grouped GEMM:
The serial grouped dispatch reused one shared hipBLASLt workspace across experts, which caused intermittent GPU stalls and the recurring low-TFLOPS outliers in balanced dgrad.

@kyle-256 kyle-256 changed the title [WIP] improve bf16 gg perf: optimize hipBLASLt grouped GEMM with algo tuning, enable grouped_gemm autotune hipblaslt support Apr 15, 2026
Copilot AI review requested due to automatic review settings April 15, 2026 13:03
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 9 changed files in this pull request and generated 6 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu Outdated
Comment thread csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu
Comment thread csrc/kernels/gemm/hipblaslt_gemm.cu
Comment thread csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu
Comment thread csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu
Comment thread benchmark/ops/bench_grouped_gemm_gmm.py
@kyle-256 kyle-256 force-pushed the dev/kyle/improve_bf16_triton_gg branch from f21bedd to cf3fe3e Compare April 16, 2026 06:47
Copilot AI review requested due to automatic review settings April 17, 2026 02:50
@kyle-256 kyle-256 force-pushed the dev/kyle/improve_bf16_triton_gg branch from cf3fe3e to ed6ced1 Compare April 17, 2026 02:50
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 9 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread csrc/kernels/gemm/hipblaslt_gemm.cu Outdated
Comment thread csrc/kernels/gemm/hipblaslt_gemm.cu
Comment thread csrc/kernels/gemm/hipblaslt_gemm.cu
Copilot AI review requested due to automatic review settings April 17, 2026 06:22
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 9 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu
Comment thread csrc/pytorch/grouped_gemm/hipblaslt_grouped_gemm.cpp Outdated
Comment on lines +154 to +157
char *workspace_base = static_cast<char *>(params.workspace);
for (size_t idx = 0; idx < num_gemms; ++idx) {
void *ws = workspace_base + idx * get_hipblaslt_workspace_size_in_byte();
// clang-format off
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dispatch_serial() uses a per-GEMM workspace slice (workspace_base + idx * workspace_bytes), which forces the caller to allocate one full hipBLASLt workspace per group. Given the current workspace size (32–64MiB), this becomes multi-GB for common expert counts. Since serial dispatch runs on a single stream, reusing a single workspace slice should be safe and would remove the need for per-expert workspace allocation.

Copilot uses AI. Check for mistakes.
Comment thread csrc/kernels/gemm/hipblaslt_gemm.cu
kyle-256 and others added 4 commits April 20, 2026 02:09
For fwd/dgrad (transA=False), b is [G, K, N] with one K×N weight
block per expert. The old code advanced b_ptr only for hot experts
and skipped cold ones (len==0), causing subsequent hot experts to
read from the wrong expert's weights.

Fix: compute b_ptr as absolute offset b_base + i * b_expert_stride
for fwd/dgrad, so cold experts do not shift the pointer. For wgrad
(transA=True), b is a flat [M_total, OUT_N] tensor and still uses
sequential advancement per hot group (unchanged).

Result: correctness PASS for all balanced and unbalanced cases.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
hipblaslt_gemm_impl previously called hipblasLtMatmulAlgoGetHeuristic on
every invocation. For grouped GEMM with N experts, this meant N sequential
searches per call -- each search costs 100-500ms on first encounter of a
new shape (e.g. unbalanced routing producing unseen token counts).

Add a thread-local algo cache keyed by (shape, dtype, trans, handle) to
skip the search on repeat calls. Descriptor creation is retained per-call
(cheap); only the expensive heuristic search is cached.

Result: sporadic 17T -> 818T for LFM2 G=4 M=16384 N=3584 K=2048 fwd.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ence bug

Direct CPU dereference of a device pointer (params.group_lens_ptr[i]) can
read stale HBM data even after hipStreamSynchronize on CDNA architectures:
the GPU L2 cache may not have been flushed, so the CPU observes zeros.
This caused valid_group_num=0 stochastically, skipping GEMM dispatch during
warmup and leaving the algo cache unpopulated.  The first *timed* iteration
then triggered a 500+ ms hipblasLtMatmulAlgoGetHeuristic call, collapsing
throughput to single- or double-digit TFLOPS on random runs.

Fix: replace the direct pointer loop with hipMemcpy(DeviceToHost), which
routes through the ROCm DMA engine with correct cache-coherence semantics.
Adds a host-side group_lens_host_ member vector that is reused across calls.

Also adds bench_hl_vs_gmm.py: head-to-head hipBLASLt vs GMM benchmark
covering all 3 model configs × balanced/unbalanced × fwd/dgrad/wgrad.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…dispatch

The old kMaxNumStreams=4 approach launched expert GEMMs across 4 concurrent
HIP streams. When few hot experts each had large token counts (unbalanced
MoE routing), 4 large GEMMs competed for GPU CUs, L2 cache, and register
file, collapsing efficiency from ~75% to ~50%.

Single-stream serial dispatch lets each expert GEMM use 100% of GPU
resources. Workspace reduced from 4×64 MiB to 1×64 MiB.

Result: 98/144 shapes beat GMM (was 26/144).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
kyle-256 and others added 15 commits April 20, 2026 02:09
Use 8 parallel streams for many small expert GEMMs (per-expert tokens < 512)
to overlap kernel launch overhead. Fall back to serial single-stream dispatch
for few large GEMMs to avoid GPU resource contention.

Threshold tuned via sweep: kMaxNumStreams=8, kSerialThreshold=512.

Result: 101-105/144 shapes beat GMM (was 98/144 with pure serial).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ates

On first encounter of each GEMM shape, request 8 candidate algorithms from
hipBLASLt heuristic, benchmark each (1 warm-up + 5 timed iterations), and
cache the fastest. Subsequent calls for the same shape use the cached winner.

This particularly helps non-power-of-2 shapes (7168, 3584, 2880, 1792) where
the heuristic's top-1 pick is not always optimal.

Result: LFM2 balanced improved from 12/24 to 15/24 BEAT.
Overall: 101-102/144 BEAT GMM.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add three balance modes for grouped GEMM benchmarks:
- balanced: uniform token distribution (unchanged)
- mild: random distribution where all experts get tokens (from main branch)
- extreme: only topk experts get tokens, rest get 0 (previous "unbalanced")

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Increase autotune warmup from 1→10 and bench iterations from 5→30
  for more stable algorithm selection across runs.
- Fix hipStreamSynchronize in grouped GEMM: was unconditionally called
  on every run() regardless of pre_sync flag. Now only syncs when
  pre_sync=true, removing an unnecessary CPU-GPU barrier.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When group_lens tensor is on CPU, use direct memcpy instead of
hipMemcpy(DeviceToHost), avoiding a blocking GPU-CPU sync on every
grouped GEMM call. This removes the most impactful latency bottleneck
identified by profiling.

The benchmark now passes CPU group_lens to hipBLASLt, matching how
GMM receives its batch_sizes (always on CPU).

Result: 153-160/216 BEAT GMM (was 145-149/216).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…rhead

Cache matrix layouts and matmul descriptors keyed by shape inside
hipblaslt_gemm_impl. Same-shape calls skip create/set/destroy cycle
(11 API calls saved per cache hit). dispatch_serial unchanged.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The algo cache key now excludes M-dependent dimensions (rows_a, lda,
rows_d, ldd). hipBLASLt kernel tile selection depends on N/K/dtype/trans,
not M — M only affects grid launch size.

Before: every new per-expert token count (which changes every training
iteration due to dynamic MoE routing) triggered a full autotune
(warmup=10 + bench=30 × 8 candidates = 320 matmuls). This made the
first call to each unique M prohibitively slow in training.

After: autotune runs once per (N, K, dtype, trans) combination. Different
M values reuse the cached algo immediately.

Descriptor cache remains M-dependent (layouts must match exact dims).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Revert group_offs parameter removal from hipblaslt_grouped_gemm to
  maintain API consistency (param is accepted but unused internally;
  CPU-side group_lens optimization is preserved)
- Delete temporary bench_hl_vs_gmm.py benchmark script
- Restore original model configs in config.py deleted during development;
  add LFM2-8B-A1B and gpt_oss_20B as commented-out entries

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Support balanced/unbalanced routing via command line instead of
relying on test case config. Default: balanced.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Re-enable hipBLASLt as an autotune candidate for BF16 grouped GEMM
(both fwd/dgrad and wgrad paths). This allows the autotune framework
to pick the best backend per shape — hipBLASLt wins on large M and
dgrad, Triton wins on wgrad with small B.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Re-enable hipBLASLt as autotune candidate for BF16 grouped GEMM
- Move group_lens D2H copy from Python .cpu() to C++ hipMemcpy,
  reducing per-call overhead from ~19us to ~9us
- C++ compute_args now handles both host and device group_lens
  via group_lens_on_host flag

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When algo cache misses, hipblaslt_gemm_impl benchmarks 8 algo candidates
using the current call's M. In unbalanced MoE routing, the first expert
dispatched may have a very small or large M, biasing the algo selection.

Fix: before dispatching individual expert GEMMs, call hipblaslt_gemm_impl
once with balanced M (total_tokens / num_experts) to seed the algo cache
with a representative workload. Subsequent calls hit the cache (M-invariant
key) and skip tuning. The warm call only happens once per (N,K,trans) shape.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Retry BF16 hipBLASLt matmuls with an exact-shape autotuned algo when the shared M-invariant cache returns an invalid configuration. This keeps grouped GEMM on the fast shared-cache path while preserving correctness on gfx942 edge cases.

Made-with: Cursor
Copilot AI review requested due to automatic review settings April 20, 2026 02:09
@kyle-256 kyle-256 force-pushed the dev/kyle/improve_bf16_triton_gg branch from cdbc4ee to b17b585 Compare April 20, 2026 02:10
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 9 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +89 to +93
// Warm the hipBLASLt algo cache with balanced M (total_M / group_num)
// so the tuned algo is representative of the average expert size,
// not biased by whichever expert happens to be dispatched first.
warm_algo_cache(params, num_gemms);

Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warm_algo_cache() is always called before deciding serial vs parallel dispatch. However, the warm call uses handle_, while the parallel path uses par_handles_[s], and hipblaslt_gemm_impl’s algo cache key includes the handle. This means the warmup autotune work may not benefit the subsequent parallel dispatch and can add unnecessary synchronization/overhead. Consider moving the warm call after dispatch-strategy selection and warming the handle(s) that will actually be used.

Copilot uses AI. Check for mistakes.
Comment on lines +117 to +122
// Compute balanced M = total_tokens / group_num.
int64_t total_tokens = 0;
for (size_t i = 0; i < num_gemms; ++i) {
total_tokens += params.transA ? cols_b_[i] : cols_c_[i];
}
const int64_t balanced_m = total_tokens / params.group_num;
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

balanced_m is computed as total_tokens / params.group_num, but num_gemms only counts non-zero-length groups (you skip len==0). If there are many zero-length experts, dividing by group_num can produce a much smaller (or even zero) balanced_m, which undermines the intent of warming with a representative active-expert M. Consider dividing by the number of active groups (e.g., num_gemms) instead.

Suggested change
// Compute balanced M = total_tokens / group_num.
int64_t total_tokens = 0;
for (size_t i = 0; i < num_gemms; ++i) {
total_tokens += params.transA ? cols_b_[i] : cols_c_[i];
}
const int64_t balanced_m = total_tokens / params.group_num;
// Compute balanced M over active groups only.
if (num_gemms == 0)
return;
int64_t total_tokens = 0;
for (size_t i = 0; i < num_gemms; ++i) {
total_tokens += params.transA ? cols_b_[i] : cols_c_[i];
}
const int64_t balanced_m = total_tokens / static_cast<int64_t>(num_gemms);

Copilot uses AI. Check for mistakes.
Comment on lines +234 to +257
float best_ms = 1e30f;
int best_idx = 0;
for (int c = 0; c < returnedAlgoCount; ++c) {
// warm-up
for (int w = 0; w < kWarmupIters; ++w) {
(void) hipblasLtMatmul(handle, operation_desc, &a1, A, A_desc, B, B_desc, &b0,
D, D_desc, D, D_desc, &candidates[c].algo, workspace,
workspace_size, stream);
}

PRIMUS_TURBO_CHECK_HIP(hipEventRecord(ev_start, stream));
for (int i = 0; i < kBenchIters; ++i) {
(void) hipblasLtMatmul(handle, operation_desc, &a1, A, A_desc, B, B_desc, &b0,
D, D_desc, D, D_desc, &candidates[c].algo, workspace,
workspace_size, stream);
}
PRIMUS_TURBO_CHECK_HIP(hipEventRecord(ev_stop, stream));
PRIMUS_TURBO_CHECK_HIP(hipEventSynchronize(ev_stop));

float ms = 0;
PRIMUS_TURBO_CHECK_HIP(hipEventElapsedTime(&ms, ev_start, ev_stop));
if (ms < best_ms) {
best_ms = ms;
best_idx = c;
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the autotune loop, the return status of hipblasLtMatmul is ignored during warmup/benchmark iterations. If an algo candidate intermittently fails, it can still be selected as “fastest” and later cause runtime failures. Capture the hipblasStatus_t for each call and skip/penalize candidates that don’t return HIPBLAS_STATUS_SUCCESS (and consider checking errors during warmup too).

Suggested change
float best_ms = 1e30f;
int best_idx = 0;
for (int c = 0; c < returnedAlgoCount; ++c) {
// warm-up
for (int w = 0; w < kWarmupIters; ++w) {
(void) hipblasLtMatmul(handle, operation_desc, &a1, A, A_desc, B, B_desc, &b0,
D, D_desc, D, D_desc, &candidates[c].algo, workspace,
workspace_size, stream);
}
PRIMUS_TURBO_CHECK_HIP(hipEventRecord(ev_start, stream));
for (int i = 0; i < kBenchIters; ++i) {
(void) hipblasLtMatmul(handle, operation_desc, &a1, A, A_desc, B, B_desc, &b0,
D, D_desc, D, D_desc, &candidates[c].algo, workspace,
workspace_size, stream);
}
PRIMUS_TURBO_CHECK_HIP(hipEventRecord(ev_stop, stream));
PRIMUS_TURBO_CHECK_HIP(hipEventSynchronize(ev_stop));
float ms = 0;
PRIMUS_TURBO_CHECK_HIP(hipEventElapsedTime(&ms, ev_start, ev_stop));
if (ms < best_ms) {
best_ms = ms;
best_idx = c;
float best_ms = 1e30f;
int best_idx = 0;
bool found_valid_algo = false;
for (int c = 0; c < returnedAlgoCount; ++c) {
bool candidate_ok = true;
hipblasStatus_t matmul_status = HIPBLAS_STATUS_SUCCESS;
// warm-up
for (int w = 0; w < kWarmupIters; ++w) {
matmul_status = hipblasLtMatmul(handle, operation_desc, &a1, A, A_desc, B, B_desc,
&b0, D, D_desc, D, D_desc, &candidates[c].algo,
workspace, workspace_size, stream);
if (matmul_status != HIPBLAS_STATUS_SUCCESS) {
candidate_ok = false;
break;
}
}
if (!candidate_ok) {
continue;
}
PRIMUS_TURBO_CHECK_HIP(hipEventRecord(ev_start, stream));
for (int i = 0; i < kBenchIters; ++i) {
matmul_status = hipblasLtMatmul(handle, operation_desc, &a1, A, A_desc, B, B_desc,
&b0, D, D_desc, D, D_desc, &candidates[c].algo,
workspace, workspace_size, stream);
if (matmul_status != HIPBLAS_STATUS_SUCCESS) {
candidate_ok = false;
break;
}
}
if (!candidate_ok) {
continue;
}
PRIMUS_TURBO_CHECK_HIP(hipEventRecord(ev_stop, stream));
PRIMUS_TURBO_CHECK_HIP(hipEventSynchronize(ev_stop));
float ms = 0;
PRIMUS_TURBO_CHECK_HIP(hipEventElapsedTime(&ms, ev_start, ev_stop));
if (!found_valid_algo || ms < best_ms) {
best_ms = ms;
best_idx = c;
found_valid_algo = true;

Copilot uses AI. Check for mistakes.
Release cached hipBLASLt descriptors with the thread-local cache and skip autotune candidates that fail during warmup or benchmarking. This avoids leaking descriptor handles and prevents the cache from selecting unusable algos.

Made-with: Cursor
The serial grouped dispatch reused one shared hipBLASLt workspace across experts, which caused intermittent GPU stalls and the recurring low-TFLOPS outliers in balanced dgrad. Allocate a dedicated workspace slice per expert and size the PyTorch wrapper allocation to cover the serial path.

Made-with: Cursor
Clear cached hipBLASLt GEMM/grouped GEMM runtime state during backend resets and avoid timing-based float32 autotune picks under shared CI load. This reduces flaky small-shape failures when xdist workers reuse stale state.

Made-with: Cursor
Copilot AI review requested due to automatic review settings April 20, 2026 09:39
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 12 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +155 to +169
void dispatch_serial(const HipblasltGroupedGemmParams &params, size_t num_gemms) {
char *workspace_base = static_cast<char *>(params.workspace);
for (size_t idx = 0; idx < num_gemms; ++idx) {
void *ws = workspace_base + idx * get_hipblaslt_workspace_size_in_byte();
// clang-format off
hipblaslt_gemm_impl(
gemm_ptrs_[idx].b_ptr, params.b_type, rows_b_[idx], cols_b_[idx], ld_b_[idx],
gemm_ptrs_[idx].b_scale_ptr,
params.transB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
gemm_ptrs_[idx].a_ptr, params.a_type, rows_a_[idx], cols_a_[idx], ld_a_[idx],
gemm_ptrs_[idx].a_scale_ptr,
params.transA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
gemm_ptrs_[idx].c_ptr, params.c_type, rows_c_[idx], cols_c_[idx], ld_c_[idx],
ws, get_hipblaslt_workspace_size_in_byte(),
params.use_low_precision,
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dispatch_serial() offsets the workspace by idx * get_hipblaslt_workspace_size_in_byte(), which forces callers to allocate one full hipBLASLt workspace per expert. For serial execution on a single stream, the workspace can be reused for every GEMM (or at most one per stream), which drastically reduces memory pressure and avoids OOM.

Copilot uses AI. Check for mistakes.
params.stream
);
// clang-format on
PRIMUS_TURBO_CHECK_HIP(hipStreamSynchronize(params.stream));
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warm_algo_cache() does a host-side hipStreamSynchronize(params.stream) after enqueueing the warmup GEMM. This introduces a global stall and can break CUDA/HIP graph capture. The warmup call is already ordered with subsequent work on the same stream, so the explicit host synchronize should be removed (or gated off during capture).

Suggested change
PRIMUS_TURBO_CHECK_HIP(hipStreamSynchronize(params.stream));
// No explicit host-side synchronize here: the warmup GEMM is already
// ordered before subsequent work submitted to params.stream, and
// synchronizing would introduce an unnecessary stall and can break
// HIP graph capture.

Copilot uses AI. Check for mistakes.
Comment on lines +244 to +247
runtime = importlib.import_module("primus_turbo.pytorch._C.runtime")

runtime.clear_hipblaslt_gemm_runtime_caches()
runtime.clear_hipblaslt_grouped_gemm_runtime_state()
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GlobalBackendManager.reset() now unconditionally imports the compiled extension submodule primus_turbo.pytorch._C.runtime. In environments where the extension isn’t built/available (e.g., CPU-only tooling, docs, type-checking), this will raise and break reset. Consider wrapping the import and cache-clears in a try/except ImportError (or checking for module availability) so reset still works without the extension.

Suggested change
runtime = importlib.import_module("primus_turbo.pytorch._C.runtime")
runtime.clear_hipblaslt_gemm_runtime_caches()
runtime.clear_hipblaslt_grouped_gemm_runtime_state()
try:
runtime = importlib.import_module("primus_turbo.pytorch._C.runtime")
except ImportError:
runtime = None
if runtime is not None:
runtime.clear_hipblaslt_gemm_runtime_caches()
runtime.clear_hipblaslt_grouped_gemm_runtime_state()

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants