diff --git a/src/hpc/bf16_tile_gemm.rs b/src/hpc/bf16_tile_gemm.rs index c84c8bec..cc951852 100644 --- a/src/hpc/bf16_tile_gemm.rs +++ b/src/hpc/bf16_tile_gemm.rs @@ -21,7 +21,6 @@ use crate::hpc::amx_matmul::{ amx_available, tile_dpbf16ps, tile_load, tile_loadconfig, tile_release, tile_store, tile_zero, vnni_pack_bf16, TileConfig, }; -use crate::simd::{bf16_to_f32_batch, F32x16}; // ═════════════════════════════════════════════════════════════════════ // Public API — safe dispatching wrapper @@ -104,39 +103,12 @@ unsafe fn amx_path(a_bf16: &[u16], b_vnni: &[u16], c: &mut [f32], k: usize) { // AVX-512 fallback (F32x16 + mul_add FMA) // ═════════════════════════════════════════════════════════════════════ -/// Fallback: decode BF16→f32 and run a tight F32x16 GEMM with mul_add FMA. -/// When AVX-512 is the compile-time baseline, this uses native __m512 FMA; -/// on AVX2 it uses the emulated F32x16 = (F32x8, F32x8) pair — same logic. +/// Fallback: delegate to the single source-of-truth SIMD-polyfill kernel +/// [`crate::simd::bf16_tile_gemm_16x16`] (BF16→f32 decode + `F32x16` FMA). The +/// `F32x16` wrapper owns the AVX-512 / AVX2 / NEON / scalar dispatch, so this +/// AMX wrapper only adds the TDPBF16PS tile path on top of the same kernel. fn fallback_path(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: usize) { - // Decode BF16 → f32 (batch via SIMD when avx512bf16 / avx2 available) - let mut a_f32 = vec![0.0f32; a_bf16.len()]; - let mut b_f32 = vec![0.0f32; b_bf16.len()]; - bf16_to_f32_batch(a_bf16, &mut a_f32); - bf16_to_f32_batch(b_bf16, &mut b_f32); - - // Tight GEMM: for each output (i,j), dot row-of-A with col-of-B via F32x16+FMA. - // B is row-major [K, 16]; j-th column is b_f32[kk*16 + j] over kk=0..K. - // We gather the column into a stack-sized buffer once per (i,j) pair to hit - // the chunks_exact(16) + mul_add fast path on contiguous memory. - for i in 0..16 { - let a_row = &a_f32[i * k..i * k + k]; - for j in 0..16 { - // Stream the column into a contiguous buffer - let mut col = vec![0.0f32; k]; - for kk in 0..k { - col[kk] = b_f32[kk * 16 + j]; - } - - // Accumulate via F32x16::mul_add (FMA) - let mut acc = F32x16::splat(0.0); - for (ra, rb) in a_row.chunks_exact(16).zip(col.chunks_exact(16)) { - let av = F32x16::from_slice(ra); - let bv = F32x16::from_slice(rb); - acc = av.mul_add(bv, acc); - } - c[i * 16 + j] += acc.reduce_sum(); - } - } + crate::simd::bf16_tile_gemm_16x16(a_bf16, b_bf16, c, k); } // ═════════════════════════════════════════════════════════════════════