Skip to content

Commit afb53c2

Browse files
committed
feat(simd): add ndarray::simd::bf16_tile_gemm_16x16 polyfill primitive
A 16x16 BF16 tile GEMM (`C[16,16] += A[16,K]·B[K,16]`, K multiple of 32) built purely from the SIMD polyfill: BF16->f32 decode + `F32x16::mul_add`. The `F32x16` wrapper owns the per-arch dispatch (AVX-512 VFMADD231PS where available -> AVX2 pair -> NEON -> scalar), so the kernel rides AMX/AVX-512 hosts automatically. No `hpc` reference, no AMX intrinsic, no external BLAS. Lives in src/simd_ops.rs, re-exported via `ndarray::simd`. Parity test vs an f32-accumulated scalar reference + a `+=` accumulation test + doctest all pass on AVX-512; clippy -D warnings + fmt clean. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01GJ4NVBSjq1w5h7RmTbVafb
1 parent 2d5c9bb commit afb53c2

2 files changed

Lines changed: 134 additions & 3 deletions

File tree

src/simd.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,8 +587,9 @@ pub use crate::simd_amx::{amx_report, cpu_model, CpuModel};
587587
// Elementwise slice ops — polyfill-dispatched (F32x16/F64x8 chunks + scalar tail).
588588
#[cfg(feature = "std")]
589589
pub use crate::simd_ops::{
590-
add_f32, add_f32_inplace, add_f64, add_f64_inplace, add_mul_f32, add_mul_f64, add_scalar_f32, div_f32,
591-
div_f32_inplace, mul_f32, mul_f32_inplace, mul_f64, scale_f32, scale_f32_inplace, sub_f32, sub_f32_inplace,
590+
add_f32, add_f32_inplace, add_f64, add_f64_inplace, add_mul_f32, add_mul_f64, add_scalar_f32, bf16_tile_gemm_16x16,
591+
div_f32, div_f32_inplace, mul_f32, mul_f32_inplace, mul_f64, scale_f32, scale_f32_inplace, sub_f32,
592+
sub_f32_inplace,
592593
};
593594

594595
// ============================================================================

src/simd_ops.rs

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
//! delete them without verifying every BLAS-graph kernel still compiles
2121
//! AND that the JIT alternative has been re-evaluated.**
2222
23-
use crate::simd::{F32x16, F64x8};
23+
use crate::simd::{bf16_to_f32_batch, F32x16, F64x8};
2424

2525
// ═══════════════════════════════════════════════════════════════════
2626
// f32 binary ops (out-of-place)
@@ -556,6 +556,136 @@ pub fn array_windows_checked<T, const N: usize>(data: &[T]) -> Result<impl Itera
556556
Ok(array_windows::<T, N>(data))
557557
}
558558

559+
/// BF16 16×16 tile GEMM: `C[16,16] += A[16,K] · B[K,16]` where `A`, `B` are
560+
/// BF16 row-major (`u16` bit patterns) and `C` is f32 row-major. `K` must be a
561+
/// multiple of 32.
562+
///
563+
/// Pure SIMD-polyfill kernel: decodes BF16→f32 and accumulates with
564+
/// [`F32x16::mul_add`] (FMA), so it rides the polyfill's per-arch escalation —
565+
/// AVX-512 `VFMADD231PS` where available, the emulated `(F32x8, F32x8)` pair on
566+
/// AVX2, NEON on aarch64, scalar otherwise. The `F32x16` wrapper owns the
567+
/// dispatch; there is no AMX intrinsic and no external BLAS backend.
568+
///
569+
/// Accumulates into `c` (`+=`), matching BLAS `C := A·B + C`. Inputs are
570+
/// decoded BF16→f32, so the result matches an f32-accumulated scalar reference
571+
/// up to BF16 input precision (~2⁻⁸ per multiply, O(√K) accumulated).
572+
///
573+
/// # Panics
574+
/// Panics if `k % 32 != 0`, `a_bf16.len() != 16 * k`, `b_bf16.len() != k * 16`,
575+
/// or `c.len() != 256`.
576+
///
577+
/// # Examples
578+
/// ```
579+
/// use ndarray::simd::{bf16_tile_gemm_16x16, f32_to_bf16_scalar};
580+
/// let k = 32;
581+
/// let a = vec![f32_to_bf16_scalar(1.0); 16 * k];
582+
/// let b = vec![f32_to_bf16_scalar(1.0); k * 16];
583+
/// let mut c = vec![0.0f32; 256];
584+
/// bf16_tile_gemm_16x16(&a, &b, &mut c, k);
585+
/// assert!((c[0] - k as f32).abs() < 1e-3); // Σ_{p<k} 1·1 = k
586+
/// ```
587+
pub fn bf16_tile_gemm_16x16(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: usize) {
588+
assert_eq!(k % 32, 0, "K must be a multiple of 32");
589+
assert_eq!(a_bf16.len(), 16 * k, "A must be 16xK BF16 row-major");
590+
assert_eq!(b_bf16.len(), k * 16, "B must be Kx16 BF16 row-major");
591+
assert_eq!(c.len(), 16 * 16, "C must be 16x16 f32 row-major");
592+
593+
// Decode BF16 -> f32 once; the batch decode rides the polyfill too.
594+
let mut a_f32 = vec![0.0f32; a_bf16.len()];
595+
let mut b_f32 = vec![0.0f32; b_bf16.len()];
596+
bf16_to_f32_batch(a_bf16, &mut a_f32);
597+
bf16_to_f32_batch(b_bf16, &mut b_f32);
598+
599+
// C[i,j] += dot(A row i, B col j) via F32x16 + FMA on 16-wide chunks.
600+
// K is a multiple of 32, so chunks_exact(16) tiles A's row / B's column
601+
// with no remainder.
602+
for i in 0..16 {
603+
let a_row = &a_f32[i * k..i * k + k];
604+
for j in 0..16 {
605+
// Gather B column j (row-major B is strided by 16) into a buffer so
606+
// the dot product hits the contiguous chunks_exact(16) fast path.
607+
let mut col = vec![0.0f32; k];
608+
for (kk, slot) in col.iter_mut().enumerate() {
609+
*slot = b_f32[kk * 16 + j];
610+
}
611+
let mut acc = F32x16::splat(0.0);
612+
for (ra, rb) in a_row.chunks_exact(16).zip(col.chunks_exact(16)) {
613+
acc = F32x16::from_slice(ra).mul_add(F32x16::from_slice(rb), acc);
614+
}
615+
c[i * 16 + j] += acc.reduce_sum();
616+
}
617+
}
618+
}
619+
620+
#[cfg(test)]
621+
mod bf16_tile_gemm_tests {
622+
use super::*;
623+
use crate::simd::{bf16_to_f32_batch, f32_to_bf16_batch, f32_to_bf16_scalar};
624+
625+
/// Scalar BF16 reference (f32-accumulated, `+=`) — the correctness anchor.
626+
fn ref_gemm(a: &[f32], b: &[f32], c: &mut [f32], k: usize) {
627+
for i in 0..16 {
628+
for j in 0..16 {
629+
let mut s = 0.0f32;
630+
for kk in 0..k {
631+
s += a[i * k + kk] * b[kk * 16 + j];
632+
}
633+
c[i * 16 + j] += s;
634+
}
635+
}
636+
}
637+
638+
#[test]
639+
fn matches_scalar_reference_k64() {
640+
let k = 64;
641+
let mut a_f32 = vec![0.0f32; 16 * k];
642+
let mut b_f32 = vec![0.0f32; k * 16];
643+
for (i, v) in a_f32.iter_mut().enumerate() {
644+
*v =
645+
(((i as i32).wrapping_mul(1103515245).wrapping_add(12345) >> 8) as f32 / 2147483648.0).clamp(-1.0, 1.0);
646+
}
647+
for (i, v) in b_f32.iter_mut().enumerate() {
648+
*v = (((i as i32).wrapping_mul(69069).wrapping_add(1) >> 8) as f32 / 2147483648.0).clamp(-1.0, 1.0);
649+
}
650+
let mut a_bf16 = vec![0u16; a_f32.len()];
651+
let mut b_bf16 = vec![0u16; b_f32.len()];
652+
f32_to_bf16_batch(&a_f32, &mut a_bf16);
653+
f32_to_bf16_batch(&b_f32, &mut b_bf16);
654+
655+
// Reference consumes the same BF16-truncated inputs the kernel sees.
656+
let mut a_back = vec![0.0f32; a_f32.len()];
657+
let mut b_back = vec![0.0f32; b_f32.len()];
658+
bf16_to_f32_batch(&a_bf16, &mut a_back);
659+
bf16_to_f32_batch(&b_bf16, &mut b_back);
660+
let mut c_ref = vec![0.0f32; 256];
661+
ref_gemm(&a_back, &b_back, &mut c_ref, k);
662+
663+
let mut c = vec![0.0f32; 256];
664+
bf16_tile_gemm_16x16(&a_bf16, &b_bf16, &mut c, k);
665+
666+
let max_err = c
667+
.iter()
668+
.zip(&c_ref)
669+
.map(|(x, y)| (x - y).abs())
670+
.fold(0.0f32, f32::max);
671+
assert!(max_err < 1e-3, "polyfill vs scalar ref max_err = {max_err}");
672+
}
673+
674+
#[test]
675+
fn accumulates_into_c() {
676+
// C is preloaded with 5.0; the kernel must ADD, not overwrite.
677+
let k = 32;
678+
let a = vec![f32_to_bf16_scalar(1.0); 16 * k];
679+
let b = vec![f32_to_bf16_scalar(1.0); k * 16];
680+
let mut c = vec![5.0f32; 256];
681+
bf16_tile_gemm_16x16(&a, &b, &mut c, k);
682+
// 5 (preload) + Σ_{p<32} 1·1 = 37
683+
for v in &c {
684+
assert!((v - 37.0).abs() < 1e-3, "got {v}");
685+
}
686+
}
687+
}
688+
559689
#[cfg(test)]
560690
mod array_chunks_tests {
561691
use super::*;

0 commit comments

Comments
 (0)