From afb53c28c88ef5e70b8a5ae788fa3bc34b7177cb Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 20 Jun 2026 09:31:08 +0000 Subject: [PATCH] feat(simd): add ndarray::simd::bf16_tile_gemm_16x16 polyfill primitive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A 16x16 BF16 tile GEMM (`C[16,16] += A[16,K]·B[K,16]`, K multiple of 32) built purely from the SIMD polyfill: BF16->f32 decode + `F32x16::mul_add`. The `F32x16` wrapper owns the per-arch dispatch (AVX-512 VFMADD231PS where available -> AVX2 pair -> NEON -> scalar), so the kernel rides AMX/AVX-512 hosts automatically. No `hpc` reference, no AMX intrinsic, no external BLAS. Lives in src/simd_ops.rs, re-exported via `ndarray::simd`. Parity test vs an f32-accumulated scalar reference + a `+=` accumulation test + doctest all pass on AVX-512; clippy -D warnings + fmt clean. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_01GJ4NVBSjq1w5h7RmTbVafb --- src/simd.rs | 5 +- src/simd_ops.rs | 132 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 134 insertions(+), 3 deletions(-) diff --git a/src/simd.rs b/src/simd.rs index 6541fd85..78ae65c9 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -587,8 +587,9 @@ pub use crate::simd_amx::{amx_report, cpu_model, CpuModel}; // Elementwise slice ops — polyfill-dispatched (F32x16/F64x8 chunks + scalar tail). #[cfg(feature = "std")] pub use crate::simd_ops::{ - add_f32, add_f32_inplace, add_f64, add_f64_inplace, add_mul_f32, add_mul_f64, add_scalar_f32, div_f32, - div_f32_inplace, mul_f32, mul_f32_inplace, mul_f64, scale_f32, scale_f32_inplace, sub_f32, sub_f32_inplace, + add_f32, add_f32_inplace, add_f64, add_f64_inplace, add_mul_f32, add_mul_f64, add_scalar_f32, bf16_tile_gemm_16x16, + div_f32, div_f32_inplace, mul_f32, mul_f32_inplace, mul_f64, scale_f32, scale_f32_inplace, sub_f32, + sub_f32_inplace, }; // ============================================================================ diff --git a/src/simd_ops.rs b/src/simd_ops.rs index 6c223b61..455f915e 100644 --- a/src/simd_ops.rs +++ b/src/simd_ops.rs @@ -20,7 +20,7 @@ //! delete them without verifying every BLAS-graph kernel still compiles //! AND that the JIT alternative has been re-evaluated.** -use crate::simd::{F32x16, F64x8}; +use crate::simd::{bf16_to_f32_batch, F32x16, F64x8}; // ═══════════════════════════════════════════════════════════════════ // f32 binary ops (out-of-place) @@ -556,6 +556,136 @@ pub fn array_windows_checked(data: &[T]) -> Result(data)) } +/// BF16 16×16 tile GEMM: `C[16,16] += A[16,K] · B[K,16]` where `A`, `B` are +/// BF16 row-major (`u16` bit patterns) and `C` is f32 row-major. `K` must be a +/// multiple of 32. +/// +/// Pure SIMD-polyfill kernel: decodes BF16→f32 and accumulates with +/// [`F32x16::mul_add`] (FMA), so it rides the polyfill's per-arch escalation — +/// AVX-512 `VFMADD231PS` where available, the emulated `(F32x8, F32x8)` pair on +/// AVX2, NEON on aarch64, scalar otherwise. The `F32x16` wrapper owns the +/// dispatch; there is no AMX intrinsic and no external BLAS backend. +/// +/// Accumulates into `c` (`+=`), matching BLAS `C := A·B + C`. Inputs are +/// decoded BF16→f32, so the result matches an f32-accumulated scalar reference +/// up to BF16 input precision (~2⁻⁸ per multiply, O(√K) accumulated). +/// +/// # Panics +/// Panics if `k % 32 != 0`, `a_bf16.len() != 16 * k`, `b_bf16.len() != k * 16`, +/// or `c.len() != 256`. +/// +/// # Examples +/// ``` +/// use ndarray::simd::{bf16_tile_gemm_16x16, f32_to_bf16_scalar}; +/// let k = 32; +/// let a = vec![f32_to_bf16_scalar(1.0); 16 * k]; +/// let b = vec![f32_to_bf16_scalar(1.0); k * 16]; +/// let mut c = vec![0.0f32; 256]; +/// bf16_tile_gemm_16x16(&a, &b, &mut c, k); +/// assert!((c[0] - k as f32).abs() < 1e-3); // Σ_{p f32 once; the batch decode rides the polyfill too. + let mut a_f32 = vec![0.0f32; a_bf16.len()]; + let mut b_f32 = vec![0.0f32; b_bf16.len()]; + bf16_to_f32_batch(a_bf16, &mut a_f32); + bf16_to_f32_batch(b_bf16, &mut b_f32); + + // C[i,j] += dot(A row i, B col j) via F32x16 + FMA on 16-wide chunks. + // K is a multiple of 32, so chunks_exact(16) tiles A's row / B's column + // with no remainder. + for i in 0..16 { + let a_row = &a_f32[i * k..i * k + k]; + for j in 0..16 { + // Gather B column j (row-major B is strided by 16) into a buffer so + // the dot product hits the contiguous chunks_exact(16) fast path. + let mut col = vec![0.0f32; k]; + for (kk, slot) in col.iter_mut().enumerate() { + *slot = b_f32[kk * 16 + j]; + } + let mut acc = F32x16::splat(0.0); + for (ra, rb) in a_row.chunks_exact(16).zip(col.chunks_exact(16)) { + acc = F32x16::from_slice(ra).mul_add(F32x16::from_slice(rb), acc); + } + c[i * 16 + j] += acc.reduce_sum(); + } + } +} + +#[cfg(test)] +mod bf16_tile_gemm_tests { + use super::*; + use crate::simd::{bf16_to_f32_batch, f32_to_bf16_batch, f32_to_bf16_scalar}; + + /// Scalar BF16 reference (f32-accumulated, `+=`) — the correctness anchor. + fn ref_gemm(a: &[f32], b: &[f32], c: &mut [f32], k: usize) { + for i in 0..16 { + for j in 0..16 { + let mut s = 0.0f32; + for kk in 0..k { + s += a[i * k + kk] * b[kk * 16 + j]; + } + c[i * 16 + j] += s; + } + } + } + + #[test] + fn matches_scalar_reference_k64() { + let k = 64; + let mut a_f32 = vec![0.0f32; 16 * k]; + let mut b_f32 = vec![0.0f32; k * 16]; + for (i, v) in a_f32.iter_mut().enumerate() { + *v = + (((i as i32).wrapping_mul(1103515245).wrapping_add(12345) >> 8) as f32 / 2147483648.0).clamp(-1.0, 1.0); + } + for (i, v) in b_f32.iter_mut().enumerate() { + *v = (((i as i32).wrapping_mul(69069).wrapping_add(1) >> 8) as f32 / 2147483648.0).clamp(-1.0, 1.0); + } + let mut a_bf16 = vec![0u16; a_f32.len()]; + let mut b_bf16 = vec![0u16; b_f32.len()]; + f32_to_bf16_batch(&a_f32, &mut a_bf16); + f32_to_bf16_batch(&b_f32, &mut b_bf16); + + // Reference consumes the same BF16-truncated inputs the kernel sees. + let mut a_back = vec![0.0f32; a_f32.len()]; + let mut b_back = vec![0.0f32; b_f32.len()]; + bf16_to_f32_batch(&a_bf16, &mut a_back); + bf16_to_f32_batch(&b_bf16, &mut b_back); + let mut c_ref = vec![0.0f32; 256]; + ref_gemm(&a_back, &b_back, &mut c_ref, k); + + let mut c = vec![0.0f32; 256]; + bf16_tile_gemm_16x16(&a_bf16, &b_bf16, &mut c, k); + + let max_err = c + .iter() + .zip(&c_ref) + .map(|(x, y)| (x - y).abs()) + .fold(0.0f32, f32::max); + assert!(max_err < 1e-3, "polyfill vs scalar ref max_err = {max_err}"); + } + + #[test] + fn accumulates_into_c() { + // C is preloaded with 5.0; the kernel must ADD, not overwrite. + let k = 32; + let a = vec![f32_to_bf16_scalar(1.0); 16 * k]; + let b = vec![f32_to_bf16_scalar(1.0); k * 16]; + let mut c = vec![5.0f32; 256]; + bf16_tile_gemm_16x16(&a, &b, &mut c, k); + // 5 (preload) + Σ_{p<32} 1·1 = 37 + for v in &c { + assert!((v - 37.0).abs() < 1e-3, "got {v}"); + } + } +} + #[cfg(test)] mod array_chunks_tests { use super::*;