Add FP8/BF8 support for LDS transpose load#2210
Add FP8/BF8 support for LDS transpose load#2210stefankoncarevic wants to merge 15 commits intodevelopfrom
Conversation
84c9425 to
59a3f2f
Compare
f3176a8 to
a75ab7a
Compare
aaf7a7b to
e0bb0cd
Compare
ff1d1c9 to
14d13f0
Compare
24d9bf6 to
076a998
Compare
eecc935 to
4f657da
Compare
185a1c7 to
1587b86
Compare
196438c to
0aad9ec
Compare
076a998 to
ef8c1cd
Compare
b8674ba to
b54ad4c
Compare
b54ad4c to
fe7c46f
Compare
There was a problem hiding this comment.
Pull request overview
Adds FP8/BF8 enablement for Rock’s LDS transpose-load path on gfx950, including support for scaled FP8 MFMA geometries that use ds_read_tr8_b64, and expands testing coverage accordingly.
Changes:
- Extend LDS transpose-load utility/lowering to handle FP8/BF8 element types, including quad-rate scaled FP8 geometries (16x128, 32x64).
- Update LDS transpose MFMA-geometry validation and
rock.lds_transpose_loadop verification for 8-bit types (vector length = 8). - Add new e2e + MLIR tests covering FP8/BF8 and mixed fp8/bf8 combinations.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp | Implements FP8/BF8 lane→offset mapping, quad-rate load sequencing, and type-compatibility updates. |
| mlir/lib/Dialect/Rock/IR/RockDialect.cpp | Extends rock.lds_transpose_load verifier to accept FP8/BF8 and enforce expected vector length. |
| mlir/include/mlir/Dialect/Rock/IR/RockOps.td | Broadens op type constraints to allow FP8/BF8 memrefs and non-fixed vector results (checked in verifier). |
| mlir/lib/Dialect/Rock/IR/RockDialect.cpp | Updates LDS transpose config geometry validation error message. |
| mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h | Updates public header docs to reflect new types/geometries. |
| mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td | Updates KDim documentation to include 64/128. |
| mlir/lib/Dialect/Rock/IR/RockDialect.cpp | Updates LDS transpose config verification to allow (16,128) and (32,64). |
| mlir/test/Dialect/Rock/ops.mlir | Adds FP8 rock.lds_transpose_load + new config-attr test cases. |
| mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir | Adds lowering checks for FP8 transpose-load legalization. |
| mlir/test/Dialect/Rock/lds_transpose_error.mlir | Updates expected error text for expanded valid geometry set. |
| mlir/test/e2e/PrLdsTransposeLoadFp8.toml | New e2e suite for FP8/BF8 LDS transpose-load (including mixed-type cases). |
| mlir/test/e2e/PrLdsTransposeLoadFp8.cfg | Gating for the new e2e suite based on lds_transpose_load feature. |
| mlir/test/e2e/CMakeLists.txt | Wires the new e2e suite into CMake. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
43f0c7e to
b4f76ac
Compare
There was a problem hiding this comment.
Pull request overview
Adds gfx950 FP8/BF8 enablement for the LDS transpose-load optimization path (ds_read_tr8_b64), including support for scaled FP8 MFMA geometries and mixed fp8/bf8 GEMM type combinations, plus accompanying IR + e2e tests.
Changes:
- Extend LDS transpose load utility to support FP8/BF8 (vector length 8) and scaled FP8 MFMA geometries (16x128, 32x64) with quad-rate load sequencing.
- Update Rock dialect verification/constraints and tests to cover new supported geometries and element types.
- Add new FP8/BF8 e2e regression suite and hook it into the e2e CMake target list.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp | Core implementation: FP8/BF8 support, new geometry handling, quad-rate offset/load logic |
| mlir/lib/Dialect/Rock/IR/RockDialect.cpp | Verifier updates: LDS transpose config geometry list; LDS transpose load result vector length checks |
| mlir/include/mlir/Dialect/Rock/IR/RockOps.td | Expand operand element types; loosen result type to be enforced by verifier |
| mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h | Header/docs updated for new types and geometries |
| mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td | Doc update: KDim now includes 64/128 |
| mlir/test/Dialect/Rock/ops.mlir | New IR tests covering fp8 lds_transpose_load and new config attrs |
| mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir | Lowering checks for fp8 transpose_load generation |
| mlir/test/Dialect/Rock/lds_transpose_error.mlir | Update expected error messages for expanded valid geometry list |
| mlir/test/e2e/PrLdsTransposeLoadFp8.toml | New e2e suite covering fp8/bf8 + mixed combos |
| mlir/test/e2e/PrLdsTransposeLoadFp8.cfg | Feature-gate for the new e2e suite |
| mlir/test/e2e/CMakeLists.txt | Register new e2e suite in the PR test list |
Comments suppressed due to low confidence (1)
mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp:170
isValidMfmaGeometry()now treats scaled FP8 geometries (16x128, 32x64) as generally valid, but the non-FP8 lane mapping path ingetBasePanelOffsets()does not handle these and will hitllvm_unreachableif they ever occur with f16/bf16 types. To avoid potential compiler crashes on inconsistent inputs, add an element-type-dependent geometry check inmakeDecision()(or before callinggetBasePanelOffsets) so 16x128/32x64 are only accepted when the element type is FP8/BF8.
// Validate MFMA geometry
if (!isValidMfmaGeometry(shape.mnMfma, shape.kMfma)) {
return dec;
}
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
80e126e to
7a60515
Compare
- Remove duplicate entries in getMfmaInsnInfoMap - Clarify neutral scale creation comment in AccelEmitter.cpp - Rename zeroAttr to neutralScaleAttr for clarity
Implement ds_read_tr8_b64 offset formulas for FP8/BF8 MFMA (16x32, 32x16). Enable mixed fp8/bf8 type combinations for GEMM operations on gfx950.
Disable LDS transpose for FP8 GEMM when K >= 1280 or small square matrices (K == N < 512) to avoid performance regressions while preserving compile time benefits.Add FP8 GEMM heuristic to selectively disable LDS transpose
- Add (32,64) geometry support in LdsTransposeLoad.cpp: - New getBasePanelOffsets() branch for 32x64 quad-rate formula - k_block = block_id / 2, m_block = block_id % 2 - kOffsetBase = k_local + k_block * 32 - mOffsetBase = m_parity * 8 + m_block * 16 - Update isQuadRate detection to include 32x64 - Add validation for (32,64) in RockDialect.cpp - Extend tuning ranges for scaled FP8 testing: - kPackPerBlock: added 64 - kPack: added 32 (for k_base=32) Co-authored-by: Cursor <cursoragent@cursor.com>
… check in LDSTransposeLoadOp::verify, update suite 4 test count comment
exact set of valid vector types and drop the now-redundant verifier checks. Add ldsTransposeConfig structural checks to ThreadwiseReadIntoOp::verify (rank-1 + static dest, supported element type, (geometry, type) consistency). Replace the assert + .value() in emitThreadwiseHWTranspose with emitOpError to avoid UB in release builds and reject non-rank-1 / dynamic destinations up-front. Use AmdArchInfo::hasLdsTransposeLoad for arch gating, share a single isValidLdsTransposeMfmaGeometry helper, align the numWaves formula with computeWaveGridLayout, drop the dead (tuning only emits power-of-2 wave-tile factors), and refresh doc comments. Add four negative ODS-coverage tests for the result-type constraint.
* Refactor computeWaveGridLayout to a closed-form algorithm:
iterate power-of-2 factorizations and pick the most balanced one
that fits the available wave tiles. Replaces ~110 lines of hand-
rolled switch cases (with dead fallback branches) with ~25 lines.
* Collapse the four FP8 geometry branches in getBasePanelOffsets
into a single parameterized branch.
* Extract isFp8OnlyLdsTransposeGeometry / isF16OnlyLdsTransposeGeometry
helpers into LdsTransposeLoad.h and reuse them in decideLDSTranspose-
ForOperands and ThreadwiseReadIntoOp::verify.
* Unify isHighHalf and readIdx into a single extraKOffset parameter
in computePanelFinalOffset.
* Add a shared computeKBlockTimesStride helper used by both
getDoubleRateKOffsetBase and the FP8 branch in getBasePanelOffsets.
* Tighten the numWaves check in decideLDSTransposeForOperands to an
exact {1, 2, 4, 8, 16} match and move it before result population.
* Add 4 missing negative Lit tests for ThreadwiseReadIntoOp::verify
(FP8 + F16-only geom, F16 + quad-rate geom, unsupported dest elem
type, rank-2 dest).
* Add a 32x16 unscaled FP8 e2e config and a new FileCheck test
(lds_transpose_load_fp8_panels.mlir) verifying the number of
amdgpu.transpose_load ops emitted for all 4 FP8/BF8 paths.
* AccelEmitter: add defense-in-depth asserts around the numBlksInD =
(waveSize / inputSpanLen) / numBlksInK computation and move the
Site 2 computation inside the branch where the values are used.
No functional change for supported configurations.
* Add UNSCALED_16x32 stanza to lds_transpose_load_fp8_panels.mlir so all four FP8/BF8 paths have a fast Lit panel-count guard. * Trim the Decision struct down to and drop the now-unused mPerWave / nPerWave / doubleBuffering parameters from makeDecision. The public API of decideLDSTransposeForOperands is unchanged. * Extract isSupportedLdsTransposeNumWaves and isFp8Type into the public header and reuse them in decideLDSTransposeForOperands, computeWaveGridLayout, and ThreadwiseReadIntoOp::verify. * Replace the inline FP8 geometry guard in getBasePanelOffsets with The verifier already enforces this on well-formed IR. * Smaller cleanups: move c4 into the F16/BF16 branch of getBasePanelOffsets, assert extraKOffset >= 0 and switch the guard to , and replace with . No functional change for supported configurations.
f394fde to
024e85a
Compare
Resolves:
Implement ds_read_tr8_b64 offset formulas for FP8/BF8 MFMA (16x32, 32x16, 16x128, 32x64). Enable mixed fp8/bf8 type combinations for GEMM operations on gfx950.
Motivation
Add FP8 and BF8 data type support for LDS transpose load optimization on gfx950.
This enables efficient matrix loads using
ds_read_tr8_b64hardware instructionTechnical Details
LdsTransposeLoad.cpp: Implemented FP8/BF8 offset formulas ingetBasePanelOffsets()LdsTransposeLoad.cpp: Updated type compatibility check inmakeDecision()areBothFp8Types()check to allow mixed fp8/bf8 combinationsTest Plan
Test Result