From 341216dbd71a087f75ca7ede70a2ac7e40fa569e Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 21 Jun 2026 13:09:07 +0000 Subject: [PATCH] refactor(hpc): bf16_tile_gemm fallback delegates to the polyfill (dedup) PR #222 added ndarray::simd::bf16_tile_gemm_16x16 by copying the F32x16 kernel out of hpc::bf16_tile_gemm::fallback_path, leaving the same kernel in two places. Collapse it: the polyfill fn is the single source of truth; the hpc AMX wrapper's fallback now calls crate::simd::bf16_tile_gemm_16x16, with the AMX TDPBF16PS tile path still layered on top. Drops the now-unused F32x16 / bf16_to_f32_batch import. Both suites pass (hpc fallback + simd_ops parity); clippy -D warnings + fmt clean. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_01GJ4NVBSjq1w5h7RmTbVafb --- src/hpc/bf16_tile_gemm.rs | 38 +++++--------------------------------- 1 file changed, 5 insertions(+), 33 deletions(-) diff --git a/src/hpc/bf16_tile_gemm.rs b/src/hpc/bf16_tile_gemm.rs index c84c8bec..cc951852 100644 --- a/src/hpc/bf16_tile_gemm.rs +++ b/src/hpc/bf16_tile_gemm.rs @@ -21,7 +21,6 @@ use crate::hpc::amx_matmul::{ amx_available, tile_dpbf16ps, tile_load, tile_loadconfig, tile_release, tile_store, tile_zero, vnni_pack_bf16, TileConfig, }; -use crate::simd::{bf16_to_f32_batch, F32x16}; // ═════════════════════════════════════════════════════════════════════ // Public API — safe dispatching wrapper @@ -104,39 +103,12 @@ unsafe fn amx_path(a_bf16: &[u16], b_vnni: &[u16], c: &mut [f32], k: usize) { // AVX-512 fallback (F32x16 + mul_add FMA) // ═════════════════════════════════════════════════════════════════════ -/// Fallback: decode BF16→f32 and run a tight F32x16 GEMM with mul_add FMA. -/// When AVX-512 is the compile-time baseline, this uses native __m512 FMA; -/// on AVX2 it uses the emulated F32x16 = (F32x8, F32x8) pair — same logic. +/// Fallback: delegate to the single source-of-truth SIMD-polyfill kernel +/// [`crate::simd::bf16_tile_gemm_16x16`] (BF16→f32 decode + `F32x16` FMA). The +/// `F32x16` wrapper owns the AVX-512 / AVX2 / NEON / scalar dispatch, so this +/// AMX wrapper only adds the TDPBF16PS tile path on top of the same kernel. fn fallback_path(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: usize) { - // Decode BF16 → f32 (batch via SIMD when avx512bf16 / avx2 available) - let mut a_f32 = vec![0.0f32; a_bf16.len()]; - let mut b_f32 = vec![0.0f32; b_bf16.len()]; - bf16_to_f32_batch(a_bf16, &mut a_f32); - bf16_to_f32_batch(b_bf16, &mut b_f32); - - // Tight GEMM: for each output (i,j), dot row-of-A with col-of-B via F32x16+FMA. - // B is row-major [K, 16]; j-th column is b_f32[kk*16 + j] over kk=0..K. - // We gather the column into a stack-sized buffer once per (i,j) pair to hit - // the chunks_exact(16) + mul_add fast path on contiguous memory. - for i in 0..16 { - let a_row = &a_f32[i * k..i * k + k]; - for j in 0..16 { - // Stream the column into a contiguous buffer - let mut col = vec![0.0f32; k]; - for kk in 0..k { - col[kk] = b_f32[kk * 16 + j]; - } - - // Accumulate via F32x16::mul_add (FMA) - let mut acc = F32x16::splat(0.0); - for (ra, rb) in a_row.chunks_exact(16).zip(col.chunks_exact(16)) { - let av = F32x16::from_slice(ra); - let bv = F32x16::from_slice(rb); - acc = av.mul_add(bv, acc); - } - c[i * 16 + j] += acc.reduce_sum(); - } - } + crate::simd::bf16_tile_gemm_16x16(a_bf16, b_bf16, c, k); } // ═════════════════════════════════════════════════════════════════════