Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 5 additions & 33 deletions src/hpc/bf16_tile_gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use crate::hpc::amx_matmul::{
amx_available, tile_dpbf16ps, tile_load, tile_loadconfig, tile_release, tile_store, tile_zero, vnni_pack_bf16,
TileConfig,
};
use crate::simd::{bf16_to_f32_batch, F32x16};

// ═════════════════════════════════════════════════════════════════════
// Public API — safe dispatching wrapper
Expand Down Expand Up @@ -104,39 +103,12 @@ unsafe fn amx_path(a_bf16: &[u16], b_vnni: &[u16], c: &mut [f32], k: usize) {
// AVX-512 fallback (F32x16 + mul_add FMA)
// ═════════════════════════════════════════════════════════════════════

/// Fallback: decode BF16→f32 and run a tight F32x16 GEMM with mul_add FMA.
/// When AVX-512 is the compile-time baseline, this uses native __m512 FMA;
/// on AVX2 it uses the emulated F32x16 = (F32x8, F32x8) pair — same logic.
/// Fallback: delegate to the single source-of-truth SIMD-polyfill kernel
/// [`crate::simd::bf16_tile_gemm_16x16`] (BF16→f32 decode + `F32x16` FMA). The
/// `F32x16` wrapper owns the AVX-512 / AVX2 / NEON / scalar dispatch, so this
/// AMX wrapper only adds the TDPBF16PS tile path on top of the same kernel.
fn fallback_path(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: usize) {
// Decode BF16 → f32 (batch via SIMD when avx512bf16 / avx2 available)
let mut a_f32 = vec![0.0f32; a_bf16.len()];
let mut b_f32 = vec![0.0f32; b_bf16.len()];
bf16_to_f32_batch(a_bf16, &mut a_f32);
bf16_to_f32_batch(b_bf16, &mut b_f32);

// Tight GEMM: for each output (i,j), dot row-of-A with col-of-B via F32x16+FMA.
// B is row-major [K, 16]; j-th column is b_f32[kk*16 + j] over kk=0..K.
// We gather the column into a stack-sized buffer once per (i,j) pair to hit
// the chunks_exact(16) + mul_add fast path on contiguous memory.
for i in 0..16 {
let a_row = &a_f32[i * k..i * k + k];
for j in 0..16 {
// Stream the column into a contiguous buffer
let mut col = vec![0.0f32; k];
for kk in 0..k {
col[kk] = b_f32[kk * 16 + j];
}

// Accumulate via F32x16::mul_add (FMA)
let mut acc = F32x16::splat(0.0);
for (ra, rb) in a_row.chunks_exact(16).zip(col.chunks_exact(16)) {
let av = F32x16::from_slice(ra);
let bv = F32x16::from_slice(rb);
acc = av.mul_add(bv, acc);
}
c[i * 16 + j] += acc.reduce_sum();
}
}
crate::simd::bf16_tile_gemm_16x16(a_bf16, b_bf16, c, k);
}

// ═════════════════════════════════════════════════════════════════════
Expand Down
Loading