perf: optimize hipBLASLt grouped GEMM with algo tuning, enable grouped_gemm autotune hipblaslt support#284
perf: optimize hipBLASLt grouped GEMM with algo tuning, enable grouped_gemm autotune hipblaslt support#284kyle-256 wants to merge 23 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates the hipBLASLt grouped GEMM implementation to improve BF16 performance on MI3xx-class GPUs, adding a serial-vs-parallel dispatch heuristic and introducing caching/autotuning in the underlying hipBLASLt GEMM wrapper. It also extends the benchmarking scripts to compare balanced vs unbalanced routing and adds a new baseline benchmark for the external grouped_gemm (GMM) library.
Changes:
- Switch hipBLASLt grouped GEMM to copy
group_lensto host and use a host-sidegroup_lensbuffer to avoid device/host coherence pitfalls. - Add serial/parallel dispatch selection (multi-stream) in hipBLASLt grouped GEMM and increase max streams/workspace scaling.
- Add hipBLASLt descriptor + algo caching and an autotuning benchmark loop; expand benchmark scripts (balanced/unbalanced routing + new GMM baseline script).
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_impl.py | Forces group_lens onto CPU for hipBLASLt BF16 grouped GEMM calls. |
| csrc/pytorch/grouped_gemm/hipblaslt_grouped_gemm.cpp | Plumbs a group_lens_on_host flag into params. |
| csrc/kernels/grouped_gemm/hipblaslt_grouped_gemm.cu | Implements host-side group_lens copy and adds serial/parallel multi-stream dispatch. |
| csrc/kernels/gemm/hipblaslt_gemm.cu | Adds thread-local descriptor/algo caches and an autotuning loop over heuristic candidates. |
| csrc/include/primus_turbo/grouped_gemm.h | Extends hipBLASLt grouped GEMM params with group_lens_on_host. |
| benchmark/ops/config.py | Adds commented model configs for benchmarking reference. |
| benchmark/ops/bench_grouped_gemm_turbo.py | Adds balanced/unbalanced routing option and increases warmup/measurement iterations. |
| benchmark/ops/bench_grouped_gemm_gmm.py | New benchmark for grouped_gemm (GMM) baseline. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 9 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
e7aa21c to
63330c7
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 9 changed files in this pull request and generated 9 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| import grouped_gemm.ops as gmm_ops | ||
| import pandas as pd |
There was a problem hiding this comment.
This benchmark unconditionally imports grouped_gemm.ops. If the optional grouped_gemm dependency isn't installed, the whole script will fail immediately with ImportError. Consider wrapping the import in try/except and emitting a clear install hint (or documenting this dependency alongside the benchmark).
| std::int64_t get_hipblaslt_grouped_gemm_workspace_size() { | ||
| // Multi-stream path needs one workspace per stream. | ||
| return kMaxNumStreams * get_hipblaslt_workspace_size_in_byte(); | ||
| } |
There was a problem hiding this comment.
get_hipblaslt_grouped_gemm_workspace_size() always reserves workspace for kMaxNumStreams streams, even when dispatch_serial() is chosen (which only uses one handle/stream). If memory is a concern, consider allocating a smaller workspace for the serial path or splitting the API so callers can request serial vs parallel workspace sizes.
|
Root cause of <50T / hang on main branch hipBLASLt grouped GEMM: |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 9 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
f21bedd to
cf3fe3e
Compare
cf3fe3e to
ed6ced1
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 9 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 9 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| char *workspace_base = static_cast<char *>(params.workspace); | ||
| for (size_t idx = 0; idx < num_gemms; ++idx) { | ||
| void *ws = workspace_base + idx * get_hipblaslt_workspace_size_in_byte(); | ||
| // clang-format off |
There was a problem hiding this comment.
dispatch_serial() uses a per-GEMM workspace slice (workspace_base + idx * workspace_bytes), which forces the caller to allocate one full hipBLASLt workspace per group. Given the current workspace size (32–64MiB), this becomes multi-GB for common expert counts. Since serial dispatch runs on a single stream, reusing a single workspace slice should be safe and would remove the need for per-expert workspace allocation.
For fwd/dgrad (transA=False), b is [G, K, N] with one K×N weight block per expert. The old code advanced b_ptr only for hot experts and skipped cold ones (len==0), causing subsequent hot experts to read from the wrong expert's weights. Fix: compute b_ptr as absolute offset b_base + i * b_expert_stride for fwd/dgrad, so cold experts do not shift the pointer. For wgrad (transA=True), b is a flat [M_total, OUT_N] tensor and still uses sequential advancement per hot group (unchanged). Result: correctness PASS for all balanced and unbalanced cases. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
hipblaslt_gemm_impl previously called hipblasLtMatmulAlgoGetHeuristic on every invocation. For grouped GEMM with N experts, this meant N sequential searches per call -- each search costs 100-500ms on first encounter of a new shape (e.g. unbalanced routing producing unseen token counts). Add a thread-local algo cache keyed by (shape, dtype, trans, handle) to skip the search on repeat calls. Descriptor creation is retained per-call (cheap); only the expensive heuristic search is cached. Result: sporadic 17T -> 818T for LFM2 G=4 M=16384 N=3584 K=2048 fwd. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ence bug Direct CPU dereference of a device pointer (params.group_lens_ptr[i]) can read stale HBM data even after hipStreamSynchronize on CDNA architectures: the GPU L2 cache may not have been flushed, so the CPU observes zeros. This caused valid_group_num=0 stochastically, skipping GEMM dispatch during warmup and leaving the algo cache unpopulated. The first *timed* iteration then triggered a 500+ ms hipblasLtMatmulAlgoGetHeuristic call, collapsing throughput to single- or double-digit TFLOPS on random runs. Fix: replace the direct pointer loop with hipMemcpy(DeviceToHost), which routes through the ROCm DMA engine with correct cache-coherence semantics. Adds a host-side group_lens_host_ member vector that is reused across calls. Also adds bench_hl_vs_gmm.py: head-to-head hipBLASLt vs GMM benchmark covering all 3 model configs × balanced/unbalanced × fwd/dgrad/wgrad. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…dispatch The old kMaxNumStreams=4 approach launched expert GEMMs across 4 concurrent HIP streams. When few hot experts each had large token counts (unbalanced MoE routing), 4 large GEMMs competed for GPU CUs, L2 cache, and register file, collapsing efficiency from ~75% to ~50%. Single-stream serial dispatch lets each expert GEMM use 100% of GPU resources. Workspace reduced from 4×64 MiB to 1×64 MiB. Result: 98/144 shapes beat GMM (was 26/144). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Use 8 parallel streams for many small expert GEMMs (per-expert tokens < 512) to overlap kernel launch overhead. Fall back to serial single-stream dispatch for few large GEMMs to avoid GPU resource contention. Threshold tuned via sweep: kMaxNumStreams=8, kSerialThreshold=512. Result: 101-105/144 shapes beat GMM (was 98/144 with pure serial). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ates On first encounter of each GEMM shape, request 8 candidate algorithms from hipBLASLt heuristic, benchmark each (1 warm-up + 5 timed iterations), and cache the fastest. Subsequent calls for the same shape use the cached winner. This particularly helps non-power-of-2 shapes (7168, 3584, 2880, 1792) where the heuristic's top-1 pick is not always optimal. Result: LFM2 balanced improved from 12/24 to 15/24 BEAT. Overall: 101-102/144 BEAT GMM. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add three balance modes for grouped GEMM benchmarks: - balanced: uniform token distribution (unchanged) - mild: random distribution where all experts get tokens (from main branch) - extreme: only topk experts get tokens, rest get 0 (previous "unbalanced") Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Increase autotune warmup from 1→10 and bench iterations from 5→30 for more stable algorithm selection across runs. - Fix hipStreamSynchronize in grouped GEMM: was unconditionally called on every run() regardless of pre_sync flag. Now only syncs when pre_sync=true, removing an unnecessary CPU-GPU barrier. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When group_lens tensor is on CPU, use direct memcpy instead of hipMemcpy(DeviceToHost), avoiding a blocking GPU-CPU sync on every grouped GEMM call. This removes the most impactful latency bottleneck identified by profiling. The benchmark now passes CPU group_lens to hipBLASLt, matching how GMM receives its batch_sizes (always on CPU). Result: 153-160/216 BEAT GMM (was 145-149/216). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…rhead Cache matrix layouts and matmul descriptors keyed by shape inside hipblaslt_gemm_impl. Same-shape calls skip create/set/destroy cycle (11 API calls saved per cache hit). dispatch_serial unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The algo cache key now excludes M-dependent dimensions (rows_a, lda, rows_d, ldd). hipBLASLt kernel tile selection depends on N/K/dtype/trans, not M — M only affects grid launch size. Before: every new per-expert token count (which changes every training iteration due to dynamic MoE routing) triggered a full autotune (warmup=10 + bench=30 × 8 candidates = 320 matmuls). This made the first call to each unique M prohibitively slow in training. After: autotune runs once per (N, K, dtype, trans) combination. Different M values reuse the cached algo immediately. Descriptor cache remains M-dependent (layouts must match exact dims). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Revert group_offs parameter removal from hipblaslt_grouped_gemm to maintain API consistency (param is accepted but unused internally; CPU-side group_lens optimization is preserved) - Delete temporary bench_hl_vs_gmm.py benchmark script - Restore original model configs in config.py deleted during development; add LFM2-8B-A1B and gpt_oss_20B as commented-out entries Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Support balanced/unbalanced routing via command line instead of relying on test case config. Default: balanced. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Re-enable hipBLASLt as an autotune candidate for BF16 grouped GEMM (both fwd/dgrad and wgrad paths). This allows the autotune framework to pick the best backend per shape — hipBLASLt wins on large M and dgrad, Triton wins on wgrad with small B. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Re-enable hipBLASLt as autotune candidate for BF16 grouped GEMM - Move group_lens D2H copy from Python .cpu() to C++ hipMemcpy, reducing per-call overhead from ~19us to ~9us - C++ compute_args now handles both host and device group_lens via group_lens_on_host flag Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When algo cache misses, hipblaslt_gemm_impl benchmarks 8 algo candidates using the current call's M. In unbalanced MoE routing, the first expert dispatched may have a very small or large M, biasing the algo selection. Fix: before dispatching individual expert GEMMs, call hipblaslt_gemm_impl once with balanced M (total_tokens / num_experts) to seed the algo cache with a representative workload. Subsequent calls hit the cache (M-invariant key) and skip tuning. The warm call only happens once per (N,K,trans) shape. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Retry BF16 hipBLASLt matmuls with an exact-shape autotuned algo when the shared M-invariant cache returns an invalid configuration. This keeps grouped GEMM on the fast shared-cache path while preserving correctness on gfx942 edge cases. Made-with: Cursor
cdbc4ee to
b17b585
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 9 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Warm the hipBLASLt algo cache with balanced M (total_M / group_num) | ||
| // so the tuned algo is representative of the average expert size, | ||
| // not biased by whichever expert happens to be dispatched first. | ||
| warm_algo_cache(params, num_gemms); | ||
|
|
There was a problem hiding this comment.
warm_algo_cache() is always called before deciding serial vs parallel dispatch. However, the warm call uses handle_, while the parallel path uses par_handles_[s], and hipblaslt_gemm_impl’s algo cache key includes the handle. This means the warmup autotune work may not benefit the subsequent parallel dispatch and can add unnecessary synchronization/overhead. Consider moving the warm call after dispatch-strategy selection and warming the handle(s) that will actually be used.
| // Compute balanced M = total_tokens / group_num. | ||
| int64_t total_tokens = 0; | ||
| for (size_t i = 0; i < num_gemms; ++i) { | ||
| total_tokens += params.transA ? cols_b_[i] : cols_c_[i]; | ||
| } | ||
| const int64_t balanced_m = total_tokens / params.group_num; |
There was a problem hiding this comment.
balanced_m is computed as total_tokens / params.group_num, but num_gemms only counts non-zero-length groups (you skip len==0). If there are many zero-length experts, dividing by group_num can produce a much smaller (or even zero) balanced_m, which undermines the intent of warming with a representative active-expert M. Consider dividing by the number of active groups (e.g., num_gemms) instead.
| // Compute balanced M = total_tokens / group_num. | |
| int64_t total_tokens = 0; | |
| for (size_t i = 0; i < num_gemms; ++i) { | |
| total_tokens += params.transA ? cols_b_[i] : cols_c_[i]; | |
| } | |
| const int64_t balanced_m = total_tokens / params.group_num; | |
| // Compute balanced M over active groups only. | |
| if (num_gemms == 0) | |
| return; | |
| int64_t total_tokens = 0; | |
| for (size_t i = 0; i < num_gemms; ++i) { | |
| total_tokens += params.transA ? cols_b_[i] : cols_c_[i]; | |
| } | |
| const int64_t balanced_m = total_tokens / static_cast<int64_t>(num_gemms); |
| float best_ms = 1e30f; | ||
| int best_idx = 0; | ||
| for (int c = 0; c < returnedAlgoCount; ++c) { | ||
| // warm-up | ||
| for (int w = 0; w < kWarmupIters; ++w) { | ||
| (void) hipblasLtMatmul(handle, operation_desc, &a1, A, A_desc, B, B_desc, &b0, | ||
| D, D_desc, D, D_desc, &candidates[c].algo, workspace, | ||
| workspace_size, stream); | ||
| } | ||
|
|
||
| PRIMUS_TURBO_CHECK_HIP(hipEventRecord(ev_start, stream)); | ||
| for (int i = 0; i < kBenchIters; ++i) { | ||
| (void) hipblasLtMatmul(handle, operation_desc, &a1, A, A_desc, B, B_desc, &b0, | ||
| D, D_desc, D, D_desc, &candidates[c].algo, workspace, | ||
| workspace_size, stream); | ||
| } | ||
| PRIMUS_TURBO_CHECK_HIP(hipEventRecord(ev_stop, stream)); | ||
| PRIMUS_TURBO_CHECK_HIP(hipEventSynchronize(ev_stop)); | ||
|
|
||
| float ms = 0; | ||
| PRIMUS_TURBO_CHECK_HIP(hipEventElapsedTime(&ms, ev_start, ev_stop)); | ||
| if (ms < best_ms) { | ||
| best_ms = ms; | ||
| best_idx = c; |
There was a problem hiding this comment.
In the autotune loop, the return status of hipblasLtMatmul is ignored during warmup/benchmark iterations. If an algo candidate intermittently fails, it can still be selected as “fastest” and later cause runtime failures. Capture the hipblasStatus_t for each call and skip/penalize candidates that don’t return HIPBLAS_STATUS_SUCCESS (and consider checking errors during warmup too).
| float best_ms = 1e30f; | |
| int best_idx = 0; | |
| for (int c = 0; c < returnedAlgoCount; ++c) { | |
| // warm-up | |
| for (int w = 0; w < kWarmupIters; ++w) { | |
| (void) hipblasLtMatmul(handle, operation_desc, &a1, A, A_desc, B, B_desc, &b0, | |
| D, D_desc, D, D_desc, &candidates[c].algo, workspace, | |
| workspace_size, stream); | |
| } | |
| PRIMUS_TURBO_CHECK_HIP(hipEventRecord(ev_start, stream)); | |
| for (int i = 0; i < kBenchIters; ++i) { | |
| (void) hipblasLtMatmul(handle, operation_desc, &a1, A, A_desc, B, B_desc, &b0, | |
| D, D_desc, D, D_desc, &candidates[c].algo, workspace, | |
| workspace_size, stream); | |
| } | |
| PRIMUS_TURBO_CHECK_HIP(hipEventRecord(ev_stop, stream)); | |
| PRIMUS_TURBO_CHECK_HIP(hipEventSynchronize(ev_stop)); | |
| float ms = 0; | |
| PRIMUS_TURBO_CHECK_HIP(hipEventElapsedTime(&ms, ev_start, ev_stop)); | |
| if (ms < best_ms) { | |
| best_ms = ms; | |
| best_idx = c; | |
| float best_ms = 1e30f; | |
| int best_idx = 0; | |
| bool found_valid_algo = false; | |
| for (int c = 0; c < returnedAlgoCount; ++c) { | |
| bool candidate_ok = true; | |
| hipblasStatus_t matmul_status = HIPBLAS_STATUS_SUCCESS; | |
| // warm-up | |
| for (int w = 0; w < kWarmupIters; ++w) { | |
| matmul_status = hipblasLtMatmul(handle, operation_desc, &a1, A, A_desc, B, B_desc, | |
| &b0, D, D_desc, D, D_desc, &candidates[c].algo, | |
| workspace, workspace_size, stream); | |
| if (matmul_status != HIPBLAS_STATUS_SUCCESS) { | |
| candidate_ok = false; | |
| break; | |
| } | |
| } | |
| if (!candidate_ok) { | |
| continue; | |
| } | |
| PRIMUS_TURBO_CHECK_HIP(hipEventRecord(ev_start, stream)); | |
| for (int i = 0; i < kBenchIters; ++i) { | |
| matmul_status = hipblasLtMatmul(handle, operation_desc, &a1, A, A_desc, B, B_desc, | |
| &b0, D, D_desc, D, D_desc, &candidates[c].algo, | |
| workspace, workspace_size, stream); | |
| if (matmul_status != HIPBLAS_STATUS_SUCCESS) { | |
| candidate_ok = false; | |
| break; | |
| } | |
| } | |
| if (!candidate_ok) { | |
| continue; | |
| } | |
| PRIMUS_TURBO_CHECK_HIP(hipEventRecord(ev_stop, stream)); | |
| PRIMUS_TURBO_CHECK_HIP(hipEventSynchronize(ev_stop)); | |
| float ms = 0; | |
| PRIMUS_TURBO_CHECK_HIP(hipEventElapsedTime(&ms, ev_start, ev_stop)); | |
| if (!found_valid_algo || ms < best_ms) { | |
| best_ms = ms; | |
| best_idx = c; | |
| found_valid_algo = true; |
Release cached hipBLASLt descriptors with the thread-local cache and skip autotune candidates that fail during warmup or benchmarking. This avoids leaking descriptor handles and prevents the cache from selecting unusable algos. Made-with: Cursor
The serial grouped dispatch reused one shared hipBLASLt workspace across experts, which caused intermittent GPU stalls and the recurring low-TFLOPS outliers in balanced dgrad. Allocate a dedicated workspace slice per expert and size the PyTorch wrapper allocation to cover the serial path. Made-with: Cursor
Clear cached hipBLASLt GEMM/grouped GEMM runtime state during backend resets and avoid timing-based float32 autotune picks under shared CI load. This reduces flaky small-shape failures when xdist workers reuse stale state. Made-with: Cursor
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 11 out of 12 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| void dispatch_serial(const HipblasltGroupedGemmParams ¶ms, size_t num_gemms) { | ||
| char *workspace_base = static_cast<char *>(params.workspace); | ||
| for (size_t idx = 0; idx < num_gemms; ++idx) { | ||
| void *ws = workspace_base + idx * get_hipblaslt_workspace_size_in_byte(); | ||
| // clang-format off | ||
| hipblaslt_gemm_impl( | ||
| gemm_ptrs_[idx].b_ptr, params.b_type, rows_b_[idx], cols_b_[idx], ld_b_[idx], | ||
| gemm_ptrs_[idx].b_scale_ptr, | ||
| params.transB ? HIPBLAS_OP_T : HIPBLAS_OP_N, | ||
| gemm_ptrs_[idx].a_ptr, params.a_type, rows_a_[idx], cols_a_[idx], ld_a_[idx], | ||
| gemm_ptrs_[idx].a_scale_ptr, | ||
| params.transA ? HIPBLAS_OP_T : HIPBLAS_OP_N, | ||
| gemm_ptrs_[idx].c_ptr, params.c_type, rows_c_[idx], cols_c_[idx], ld_c_[idx], | ||
| ws, get_hipblaslt_workspace_size_in_byte(), | ||
| params.use_low_precision, |
There was a problem hiding this comment.
dispatch_serial() offsets the workspace by idx * get_hipblaslt_workspace_size_in_byte(), which forces callers to allocate one full hipBLASLt workspace per expert. For serial execution on a single stream, the workspace can be reused for every GEMM (or at most one per stream), which drastically reduces memory pressure and avoids OOM.
| params.stream | ||
| ); | ||
| // clang-format on | ||
| PRIMUS_TURBO_CHECK_HIP(hipStreamSynchronize(params.stream)); |
There was a problem hiding this comment.
warm_algo_cache() does a host-side hipStreamSynchronize(params.stream) after enqueueing the warmup GEMM. This introduces a global stall and can break CUDA/HIP graph capture. The warmup call is already ordered with subsequent work on the same stream, so the explicit host synchronize should be removed (or gated off during capture).
| PRIMUS_TURBO_CHECK_HIP(hipStreamSynchronize(params.stream)); | |
| // No explicit host-side synchronize here: the warmup GEMM is already | |
| // ordered before subsequent work submitted to params.stream, and | |
| // synchronizing would introduce an unnecessary stall and can break | |
| // HIP graph capture. |
| runtime = importlib.import_module("primus_turbo.pytorch._C.runtime") | ||
|
|
||
| runtime.clear_hipblaslt_gemm_runtime_caches() | ||
| runtime.clear_hipblaslt_grouped_gemm_runtime_state() |
There was a problem hiding this comment.
GlobalBackendManager.reset() now unconditionally imports the compiled extension submodule primus_turbo.pytorch._C.runtime. In environments where the extension isn’t built/available (e.g., CPU-only tooling, docs, type-checking), this will raise and break reset. Consider wrapping the import and cache-clears in a try/except ImportError (or checking for module availability) so reset still works without the extension.
| runtime = importlib.import_module("primus_turbo.pytorch._C.runtime") | |
| runtime.clear_hipblaslt_gemm_runtime_caches() | |
| runtime.clear_hipblaslt_grouped_gemm_runtime_state() | |
| try: | |
| runtime = importlib.import_module("primus_turbo.pytorch._C.runtime") | |
| except ImportError: | |
| runtime = None | |
| if runtime is not None: | |
| runtime.clear_hipblaslt_gemm_runtime_caches() | |
| runtime.clear_hipblaslt_grouped_gemm_runtime_state() |
hipBLASLt Grouped GEMM: 4-Way Performance Comparison
BF16, MI355X — Balanced & Unbalanced Routing
Backends: hipBLASLt, Triton, Autotune (best of hipBLASLt + Triton + CK per shape), grouped_gemm (GMM)
DeepSeek-V3 — Balanced
DeepSeek-V3 — Unbalanced
gpt_oss_20B — Balanced
gpt_oss_20B — Unbalanced
Overall Summary
Per-op breakdown (all models, all routing):