Skip to content

Commit 0555731

Browse files
pjotsclaudememmett
authored
Add FP8 subtile unit tests and generalize GPU test helpers (ROCm#7318)
## Motivation FP8 support requires verifying that the subtile-based kernel's GR/LR swizzled tile assignment and MFMA computation produce correct results on hardware. Previously there were no GPU unit tests covering FP8 data types in the subtile kernel path. ## Technical Details **New FP8 unit tests:** - `test_graTileAssignment_fp8.py` — GPU tests verifying GRA (Global Read Assignment) tile layout for FP8 (AB_B8, inst_k=128, bpe=1) with swizzled addressing - `test_lraTileAssignment_fp8.py` — GPU tests verifying LRA (Local Read Assignment) tile layout for FP8 - `test_gr_lr_roundtrip_fp8.py` — End-to-end GR→LDS→LR roundtrip tests confirming data survives the FP8 swizzled store/load pipeline - `test_mfma_fp8.py` — GPU tests executing `v_mfma_f32_16x16x128_fp8_fp8` and verifying results against a Python reference dot-product **Subtile component updates** (to support FP8 layouts): - `SubtileGREmit.py`, `SubtileLREmit.py` — FP8 swizzled layout emission - `SubtileGeometry.py`, `Kernel.py` — geometry and kernel-level FP8 support **Refactoring of shared test infrastructure** (`gpu_test_helpers.py`): - Generalized `_create_kernel` / `create_writer` to accept `geometry` / `inst_k` / `bpe` parameters so FP8 and FP16 tests share one code path - Added `collect_tile_vgprs`, `compute_expected_subtile`, `setup_roundtrip_writer`, `build_roundtrip_inner_asm`, `alloc_export_vgprs`, `generate_srd_setup` as common helpers - Eliminated ~550 lines of duplicated boilerplate across `test_graTileAssignment.py`, `test_lraTileAssignment.py`, `test_gr_lr_roundtrip.py`, `test_mfma_fp8.py`, and `test_storeD_roundtrip.py` ## Test Plan Run the full unit test suite from the tensilelite root: ```bash PYTHONPATH=<rocisa_lib_path>:. python3 -m pytest Tensile/Tests/unit/ -v -s Or run individual FP8 test files: PYTHONPATH=<rocisa_lib_path>:. python3 -m pytest Tensile/Tests/unit/test_graTileAssignment_fp8.py -v -s PYTHONPATH=<rocisa_lib_path>:. python3 -m pytest Tensile/Tests/unit/test_lraTileAssignment_fp8.py -v -s PYTHONPATH=<rocisa_lib_path>:. python3 -m pytest Tensile/Tests/unit/test_gr_lr_roundtrip_fp8.py -v -s PYTHONPATH=<rocisa_lib_path>:. python3 -m pytest Tensile/Tests/unit/test_mfma_fp8.py -v -s Test Result 124 unit tests passing, including all new FP8 tests and all pre-existing FP16 tests. No regressions in existing test files after the shared-helper refactoring. ``` ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: Matthew Emmett <matthew.emmett@amd.com>
1 parent d265f8c commit 0555731

12 files changed

Lines changed: 1837 additions & 305 deletions

projects/hipblaslt/tensilelite/Tensile/Components/Subtile/Kernel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@
7575
ABInputGeometry,
7676
ABGRGeometry,
7777
ABLRGeometry,
78-
GRTag_1x2, GRTag_2x2, GRTag_TLU1,
79-
LRTag_1x2, LRTag_TLU1,
78+
GRTag_1x1, GRTag_1x2, GRTag_2x2, GRTag_TLU1,
79+
LRTag_1x1, LRTag_1x2, LRTag_TLU1,
8080
ABTilePair,
8181
CDTileGeometry,
8282
MXScaleInputGeometry,
@@ -289,8 +289,8 @@ def emitLoadC(self, ti: 'TileInfo', writer, kernel): pass
289289
lr=ABLRGeometry(tag=LRTag_1x2(), **_B4, subtileShape=(1, 2), loadShape=LoadShape(m=1, k=32)), # 128-bit LR: 32 fp4 along K
290290
)
291291
AB_B8 = ABTilePair(
292-
gr=ABGRGeometry(tag=GRTag_1x2(), **_B8, subtileShape=(1, 2), loadShape=LoadShape(m=1, k=16)), # 128-bit GR: 16 fp8 along K
293-
lr=ABLRGeometry(tag=LRTag_1x2(), **_B8, subtileShape=(1, 2), loadShape=LoadShape(m=1, k=32), loadWidth=32), # 256-bit LR: 32 fp8 along K
292+
gr=ABGRGeometry(tag=GRTag_1x1(), **_B8, subtileShape=(1, 1), loadShape=LoadShape(m=1, k=16)), # 128-bit GR: 16 fp8 along K
293+
lr=ABLRGeometry(tag=LRTag_1x1(), **_B8, subtileShape=(1, 1), loadShape=LoadShape(m=1, k=16)), # 128-bit LR: 16 fp8 along K
294294
)
295295

296296
AB_B4_2x2 = ABTilePair(

projects/hipblaslt/tensilelite/Tensile/Components/Subtile/SubtileGREmit.py

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727
SAddCU32, SAddU32, SMovB32, SMovB64, SMulI32, SNop, SXorB32,
2828
VAddU32, VAndB32, VCmpXEqU32,
2929
VLShiftLeftB32, VLShiftRightB32, VMovB32,
30-
VMulLOU32, VReadfirstlaneB32, VSubU32,
30+
VMulLOU32, VReadfirstlaneB32, VSubU32, VXorB32,
3131
)
3232

3333
from .SubtileGeometry import (
3434
RegList,
35-
GRTag_1x2, GRTag_2x2, GRTag_TLU1,
35+
GRTag_1x1, GRTag_1x2, GRTag_2x2, GRTag_TLU1,
3636
)
3737
from .SubtileScaleEmit import emitScaleGRLDSSwap
3838

@@ -82,14 +82,15 @@ def _emitGRPtrUpdate(tag, tile, ti, writer, kernel):
8282
_emitDTLInit.register(GRTag_TLU1)(_stub)
8383
_emitGRLDSBufferSwap.register(GRTag_TLU1)(_stub)
8484
_emitGRPtrUpdate.register(GRTag_TLU1)(_stub)
85-
for _tag in (GRTag_1x2, GRTag_2x2, GRTag_TLU1):
85+
for _tag in (GRTag_1x1, GRTag_1x2, GRTag_2x2, GRTag_TLU1):
8686
_emitLocalWrite.register(_tag)(_stub)
8787

8888

8989
################################################################################
9090
# 2. Implementations — TLU=0 (shared by GRTag_1x2 and GRTag_2x2)
9191
################################################################################
9292

93+
@_emitGlobalReadOffset.register(GRTag_1x1)
9394
@_emitGlobalReadOffset.register(GRTag_1x2)
9495
@_emitGlobalReadOffset.register(GRTag_2x2)
9596
def _emitGROffset_TLU0(tag, tile, ti, writer, kernel):
@@ -262,6 +263,7 @@ def _emitGROffset_TLU0(tag, tile, ti, writer, kernel):
262263
return module
263264

264265

266+
@_allocGROffsetRegisters.register(GRTag_1x1)
265267
@_allocGROffsetRegisters.register(GRTag_1x2)
266268
@_allocGROffsetRegisters.register(GRTag_2x2)
267269
def _allocGROffsetRegs_TLU0(tag, tile, ti, writer, kernel):
@@ -323,6 +325,7 @@ def _allocGROffsetRegs_TLU0(tag, tile, ti, writer, kernel):
323325
rl.alloc(preventOverflow=False)
324326

325327

328+
@_deallocGROffsetRegisters.register(GRTag_1x1)
326329
@_deallocGROffsetRegisters.register(GRTag_1x2)
327330
@_deallocGROffsetRegisters.register(GRTag_2x2)
328331
def _deallocGROffsetRegs_TLU0(tag, tile, ti, writer, kernel):
@@ -339,6 +342,7 @@ def _deallocGROffsetRegs_TLU0(tag, tile, ti, writer, kernel):
339342

340343
# --- GR load emit (TLU=0) ---------------------------------------------------
341344

345+
@_emitGlobalRead.register(GRTag_1x1)
342346
@_emitGlobalRead.register(GRTag_1x2)
343347
@_emitGlobalRead.register(GRTag_2x2)
344348
def _emitGR_TLU0(tag, tile, ti, writer, kernel):
@@ -409,6 +413,7 @@ def _emitGR_TLU0(tag, tile, ti, writer, kernel):
409413

410414
# --- DTL init (TLU=0) -------------------------------------------------------
411415

416+
@_emitDTLInit.register(GRTag_1x1)
412417
@_emitDTLInit.register(GRTag_1x2)
413418
@_emitDTLInit.register(GRTag_2x2)
414419
def _emitDTLInit_TLU0(tag, tile, ti, writer, kernel):
@@ -478,6 +483,7 @@ def _emitDTLInit_TLU0(tag, tile, ti, writer, kernel):
478483

479484
# --- GR LDS buffer swap (TLU=0) ---------------------------------------------
480485

486+
@_emitGRLDSBufferSwap.register(GRTag_1x1)
481487
@_emitGRLDSBufferSwap.register(GRTag_1x2)
482488
@_emitGRLDSBufferSwap.register(GRTag_2x2)
483489
def _emitGRLDSSwap_TLU0(tag, tile, ti, writer, kernel):
@@ -496,6 +502,7 @@ def _emitGRLDSSwap_TLU0(tag, tile, ti, writer, kernel):
496502

497503
# --- GR pointer update (TLU=0) ----------------------------------------------
498504

505+
@_emitGRPtrUpdate.register(GRTag_1x1)
499506
@_emitGRPtrUpdate.register(GRTag_1x2)
500507
@_emitGRPtrUpdate.register(GRTag_2x2)
501508
def _emitGRPtrUpdate_TLU0(tag, tile, ti, writer, kernel):
@@ -722,48 +729,72 @@ def _grComputeAllOffsets_legacy(module, writer, tileInfo, colId, rowId, rowOffse
722729
rotatedcolId = writer.vgprPool.checkOut(1)
723730
loadWidth = tileInfo.loadWidthGR
724731
if tileInfo.loadRatioGR == 0.5:
725-
blockSize = tileInfo.subIterKBytes // loadWidth
726-
colRotation = blockSize // 2
727-
module.add(VAddU32(dst=vgpr(rotatedcolId), src0=colRotation, src1=vgpr(colId), comment="%s: rotate col for GR offset %u"%(tileInfo.tc, i)))
728-
module.add(VAndB32(dst=vgpr(rotatedcolId), src0=vgpr(rotatedcolId), src1=hex(blockSize-1), comment="(col + %d) %% block_size"%colRotation))
732+
if tileInfo.bpe == 1: # FP8: intra-block K_group +2 rotation, preserving block bit
733+
tmpBlock = writer.vgprPool.checkOut(1)
734+
module.add(VAndB32(dst=vgpr(tmpBlock), src0=vgpr(colId), src1=hex(4), comment="%s: block_bit = colId & 4"%tileInfo.tc))
735+
module.add(VAndB32(dst=vgpr(rotatedcolId), src0=vgpr(colId), src1=hex(3), comment="%s: K_group = colId & 3"%tileInfo.tc))
736+
module.add(VAddU32(dst=vgpr(rotatedcolId), src0=vgpr(rotatedcolId), src1=hex(2), comment="%s: K_group + 2"%tileInfo.tc))
737+
module.add(VAndB32(dst=vgpr(rotatedcolId), src0=vgpr(rotatedcolId), src1=hex(3), comment="%s: (K_group+2) %% 4"%tileInfo.tc))
738+
module.add(VAddU32(dst=vgpr(rotatedcolId), src0=vgpr(rotatedcolId), src1=vgpr(tmpBlock), comment="%s: K_group_rot + block_bit"%tileInfo.tc))
739+
writer.vgprPool.checkIn(tmpBlock)
740+
else: # FP4/FP16: half-block rotation
741+
blockSize = tileInfo.subIterKBytes // loadWidth
742+
colRotation = blockSize // 2
743+
module.add(VAddU32(dst=vgpr(rotatedcolId), src0=colRotation, src1=vgpr(colId), comment="%s: rotate col for GR offset %u"%(tileInfo.tc, i)))
744+
module.add(VAndB32(dst=vgpr(rotatedcolId), src0=vgpr(rotatedcolId), src1=hex(blockSize-1), comment="(col + %d) %% block_size"%colRotation))
729745
else:
730746
module.add(VMovB32(dst=vgpr(rotatedcolId), src=vgpr(colId), comment=""))
731747
_grComputeOffset_legacy(module, writer, tileInfo, rotatedcolId, rowOffset, tileInfo.sharedVgprGROffset[i])
732748
writer.vgprPool.checkIn(rotatedcolId)
733749

734750
def _grSwizzleColIds_legacy(module, writer, tileInfoA, tileInfoB, blockSize, numRowsPerLDSBanks,
735751
laneId, colIdA, colIdB, waveId):
736-
tmpVgpr = writer.vgprPool.checkOut(2)
752+
tmpVgpr = writer.vgprPool.checkOut(3)
737753
ldsRowId = tmpVgpr
738754
tmp = tmpVgpr + 1
755+
waveRotation = tmpVgpr + 2
756+
half = blockSize // 2
739757
module.addComment0("Swizzling")
740758
module.add(VLShiftRightB32(dst=vgpr(ldsRowId), shiftHex=hex(blockSize.bit_length()-1), src=vgpr(laneId), comment="row id within wave"))
741759
module.add(VLShiftRightB32(dst=vgpr(ldsRowId), shiftHex=hex(numRowsPerLDSBanks.bit_length()-1), src=vgpr(ldsRowId), comment="lds row id"))
742-
module.add(VAndB32(dst=vgpr(tmp), src0=vgpr(ldsRowId), src1=hex(1), comment="lds row id % 2"))
743-
module.add(VCmpXEqU32(dst=VCC(), src0=0, src1=vgpr(tmp), comment="lds row id % 2 == 0 ?"))
744-
module.add(VMovB32(dst=vgpr(colIdA), src=vgpr(colIdA), dpp=DPPModifiers(quad_perm=[1,0,3,2]), comment="swap colId pairs for swizzling"))
745-
module.add(SMovB64(dst=EXEC(), src=-1))
746-
module.add(VMovB32(dst=vgpr(colIdB), src=vgpr(colIdA), comment=""))
747-
module.addComment0("Rotation within a single wave")
748-
module.add(VLShiftRightB32(dst=vgpr(tmp), shiftHex=hex(1), src=vgpr(ldsRowId), comment=""))
749-
module.add(VLShiftLeftB32(dst=vgpr(tmp), shiftHex=hex(1), src=vgpr(tmp), comment="(ldsRowId //2) * 2"))
750-
module.add(VSubU32(dst=vgpr(tmp), src0=hex(blockSize), src1=vgpr(tmp), comment="rotation offset : blockSize - (ldsRowId//2)*2"))
751-
needWaveRotation = any(t.loadRatioGR != 0.5 for t, _ in [(tileInfoA, colIdA), (tileInfoB, colIdB)])
752-
if needWaveRotation:
753-
waveRotation = writer.vgprPool.checkOut(1)
754-
for tInfo, cId in [(tileInfoA, colIdA), (tileInfoB, colIdB)]:
755-
if tInfo.loadRatioGR != 0.5:
756-
module.addComment0("Rotation per wave")
757-
module.add(VAndB32(dst=vgpr(waveRotation), src0=vgpr(waveId), src1=hex(1), comment=""))
758-
module.add(VLShiftLeftB32(dst=vgpr(waveRotation), shiftHex=hex((2*numRowsPerLDSBanks).bit_length() - 1), src=vgpr(waveRotation), comment=""))
759-
module.add(VSubU32(dst=vgpr(waveRotation), src0=vgpr(tmp), src1=vgpr(waveRotation), comment=""))
760-
module.add(VAddU32(dst=vgpr(cId), src0=vgpr(waveRotation), src1=vgpr(cId), comment=""))
761-
else:
762-
module.add(VAddU32(dst=vgpr(cId), src0=vgpr(tmp), src1=vgpr(cId), comment=""))
763-
if needWaveRotation:
764-
writer.vgprPool.checkIn(waveRotation)
765-
module.add(VAndB32(dst=vgpr(colIdA), src0=vgpr(colIdA), src1=hex(blockSize-1), comment="(col + offset) % block_size"))
766-
module.add(VAndB32(dst=vgpr(colIdB), src0=vgpr(colIdB), src1=hex(blockSize-1), comment="(col + offset) % block_size"))
760+
module.add(VAndB32(dst=vgpr(tmp), src0=vgpr(ldsRowId), src1=hex(1), comment="swap_bit = ldsRowId & 1"))
761+
if tileInfoA.bpe == 1: # FP8: step1=block-swap, step2=wave K_group rotation
762+
# Step 1: block-swap (XOR blockSize//2 for odd ldsRowId)
763+
module.add(VLShiftLeftB32(dst=vgpr(tmp), shiftHex=hex(int(math.log2(half))), src=vgpr(tmp),
764+
comment=f"swap_bit * {half}"))
765+
module.add(VXorB32(dst=vgpr(colIdA), src0=vgpr(colIdA), src1=vgpr(tmp),
766+
comment="FP8 step1: block-swap colIdA"))
767+
module.add(VMovB32(dst=vgpr(colIdB), src=vgpr(colIdA), comment="colIdB = colIdA"))
768+
# Step 2: K_group rotation = (waveId & 1) * 2 (only for loadRatioGR != 0.5)
769+
module.add(VAndB32(dst=vgpr(tmp), src0=vgpr(waveId), src1=hex(1), comment="wave_half = waveId & 1"))
770+
module.add(VLShiftLeftB32(dst=vgpr(tmp), shiftHex=hex(1), src=vgpr(tmp), comment="rotation = wave_half * 2"))
771+
for tInfo, cId in [(tileInfoA, colIdA), (tileInfoB, colIdB)]:
772+
if tInfo.loadRatioGR != 0.5:
773+
module.add(VAndB32(dst=vgpr(waveRotation), src0=vgpr(cId), src1=hex(4), comment="FP8 step2: block_bit = colId & 4"))
774+
module.add(VAndB32(dst=vgpr(cId), src0=vgpr(cId), src1=hex(3), comment="K_group = colId & 3"))
775+
module.add(VAddU32(dst=vgpr(cId), src0=vgpr(cId), src1=vgpr(tmp), comment="K_group + rotation"))
776+
module.add(VAndB32(dst=vgpr(cId), src0=vgpr(cId), src1=hex(3), comment="(K_group+rotation) % 4"))
777+
module.add(VAddU32(dst=vgpr(cId), src0=vgpr(cId), src1=vgpr(waveRotation), comment="K_group_rot + block_bit"))
778+
else: # FP4/FP16: pair-swap (even ldsRowId) + intra/inter-wave rotation
779+
module.add(VCmpXEqU32(dst=VCC(), src0=0, src1=vgpr(tmp), comment="lds row id % 2 == 0 ?"))
780+
module.add(VMovB32(dst=vgpr(colIdA), src=vgpr(colIdA), dpp=DPPModifiers(quad_perm=[1,0,3,2]), comment="swap colId pairs for swizzling"))
781+
module.add(SMovB64(dst=EXEC(), src=-1))
782+
module.add(VMovB32(dst=vgpr(colIdB), src=vgpr(colIdA), comment=""))
783+
module.addComment0("Rotation within a single wave")
784+
module.add(VLShiftRightB32(dst=vgpr(tmp), shiftHex=hex(1), src=vgpr(ldsRowId), comment=""))
785+
module.add(VLShiftLeftB32(dst=vgpr(tmp), shiftHex=hex(1), src=vgpr(tmp), comment="(ldsRowId //2) * 2"))
786+
module.add(VSubU32(dst=vgpr(tmp), src0=hex(blockSize), src1=vgpr(tmp), comment="rotation offset : blockSize - (ldsRowId//2)*2"))
787+
for tInfo, cId in [(tileInfoA, colIdA), (tileInfoB, colIdB)]:
788+
if tInfo.loadRatioGR != 0.5:
789+
module.addComment0("Rotation per wave")
790+
module.add(VAndB32(dst=vgpr(waveRotation), src0=vgpr(waveId), src1=hex(1), comment=""))
791+
module.add(VLShiftLeftB32(dst=vgpr(waveRotation), shiftHex=hex((2*numRowsPerLDSBanks).bit_length() - 1), src=vgpr(waveRotation), comment=""))
792+
module.add(VSubU32(dst=vgpr(waveRotation), src0=vgpr(tmp), src1=vgpr(waveRotation), comment=""))
793+
module.add(VAddU32(dst=vgpr(cId), src0=vgpr(waveRotation), src1=vgpr(cId), comment=""))
794+
else:
795+
module.add(VAddU32(dst=vgpr(cId), src0=vgpr(tmp), src1=vgpr(cId), comment=""))
796+
module.add(VAndB32(dst=vgpr(colIdA), src0=vgpr(colIdA), src1=hex(blockSize-1), comment="(col + offset) % block_size"))
797+
module.add(VAndB32(dst=vgpr(colIdB), src0=vgpr(colIdB), src1=hex(blockSize-1), comment="(col + offset) % block_size"))
767798
writer.vgprPool.checkIn(tmpVgpr)
768799

769800
def _graTileAssignment_legacy(writer, kernel, useSwizzling=True):

projects/hipblaslt/tensilelite/Tensile/Components/Subtile/SubtileGeometry.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,10 @@ def mmaTileRegCount(self): return self.gr.mmaTileRegCount
679679
# The ABGRGeometry / ABLRGeometry classes store one tag instance as `self.tag`.
680680
################################################################################
681681

682+
@dataclass(frozen=True)
683+
class GRTag_1x1:
684+
"""GR emit strategy: row-major (TLU=0), 1×1 block shape."""
685+
682686
@dataclass(frozen=True)
683687
class GRTag_1x2:
684688
"""GR emit strategy: row-major (TLU=0), 1×2 block shape."""
@@ -691,6 +695,10 @@ class GRTag_2x2:
691695
class GRTag_TLU1:
692696
"""GR emit strategy: column-major (TLU=1), 8×1 block shape."""
693697

698+
@dataclass(frozen=True)
699+
class LRTag_1x1:
700+
"""LR emit strategy: row-major (TLU=0), 1×1 subtile shape."""
701+
694702
@dataclass(frozen=True)
695703
class LRTag_1x2:
696704
"""LR emit strategy: row-major (TLU=0), 1×2 subtile shape."""

0 commit comments

Comments
 (0)