Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,9 @@ pub use crate::simd_amx::{amx_report, cpu_model, CpuModel};
// Elementwise slice ops — polyfill-dispatched (F32x16/F64x8 chunks + scalar tail).
#[cfg(feature = "std")]
pub use crate::simd_ops::{
add_f32, add_f32_inplace, add_f64, add_f64_inplace, add_mul_f32, add_mul_f64, add_scalar_f32, div_f32,
div_f32_inplace, mul_f32, mul_f32_inplace, mul_f64, scale_f32, scale_f32_inplace, sub_f32, sub_f32_inplace,
add_f32, add_f32_inplace, add_f64, add_f64_inplace, add_mul_f32, add_mul_f64, add_scalar_f32, bf16_tile_gemm_16x16,
div_f32, div_f32_inplace, mul_f32, mul_f32_inplace, mul_f64, scale_f32, scale_f32_inplace, sub_f32,
sub_f32_inplace,
};

// ============================================================================
Expand Down
132 changes: 131 additions & 1 deletion src/simd_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
//! delete them without verifying every BLAS-graph kernel still compiles
//! AND that the JIT alternative has been re-evaluated.**

use crate::simd::{F32x16, F64x8};
use crate::simd::{bf16_to_f32_batch, F32x16, F64x8};

// ═══════════════════════════════════════════════════════════════════
// f32 binary ops (out-of-place)
Expand Down Expand Up @@ -556,6 +556,136 @@ pub fn array_windows_checked<T, const N: usize>(data: &[T]) -> Result<impl Itera
Ok(array_windows::<T, N>(data))
}

/// BF16 16×16 tile GEMM: `C[16,16] += A[16,K] · B[K,16]` where `A`, `B` are
/// BF16 row-major (`u16` bit patterns) and `C` is f32 row-major. `K` must be a
/// multiple of 32.
///
/// Pure SIMD-polyfill kernel: decodes BF16→f32 and accumulates with
/// [`F32x16::mul_add`] (FMA), so it rides the polyfill's per-arch escalation —
/// AVX-512 `VFMADD231PS` where available, the emulated `(F32x8, F32x8)` pair on
/// AVX2, NEON on aarch64, scalar otherwise. The `F32x16` wrapper owns the
/// dispatch; there is no AMX intrinsic and no external BLAS backend.
///
/// Accumulates into `c` (`+=`), matching BLAS `C := A·B + C`. Inputs are
/// decoded BF16→f32, so the result matches an f32-accumulated scalar reference
/// up to BF16 input precision (~2⁻⁸ per multiply, O(√K) accumulated).
///
/// # Panics
/// Panics if `k % 32 != 0`, `a_bf16.len() != 16 * k`, `b_bf16.len() != k * 16`,
/// or `c.len() != 256`.
///
/// # Examples
/// ```
/// use ndarray::simd::{bf16_tile_gemm_16x16, f32_to_bf16_scalar};
/// let k = 32;
/// let a = vec![f32_to_bf16_scalar(1.0); 16 * k];
/// let b = vec![f32_to_bf16_scalar(1.0); k * 16];
/// let mut c = vec![0.0f32; 256];
/// bf16_tile_gemm_16x16(&a, &b, &mut c, k);
/// assert!((c[0] - k as f32).abs() < 1e-3); // Σ_{p<k} 1·1 = k
/// ```
pub fn bf16_tile_gemm_16x16(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: usize) {
assert_eq!(k % 32, 0, "K must be a multiple of 32");
assert_eq!(a_bf16.len(), 16 * k, "A must be 16xK BF16 row-major");
assert_eq!(b_bf16.len(), k * 16, "B must be Kx16 BF16 row-major");
assert_eq!(c.len(), 16 * 16, "C must be 16x16 f32 row-major");

// Decode BF16 -> f32 once; the batch decode rides the polyfill too.
let mut a_f32 = vec![0.0f32; a_bf16.len()];
let mut b_f32 = vec![0.0f32; b_bf16.len()];
bf16_to_f32_batch(a_bf16, &mut a_f32);
bf16_to_f32_batch(b_bf16, &mut b_f32);

// C[i,j] += dot(A row i, B col j) via F32x16 + FMA on 16-wide chunks.
// K is a multiple of 32, so chunks_exact(16) tiles A's row / B's column
// with no remainder.
for i in 0..16 {
let a_row = &a_f32[i * k..i * k + k];
for j in 0..16 {
// Gather B column j (row-major B is strided by 16) into a buffer so
// the dot product hits the contiguous chunks_exact(16) fast path.
let mut col = vec![0.0f32; k];
for (kk, slot) in col.iter_mut().enumerate() {
*slot = b_f32[kk * 16 + j];
}
Comment on lines +607 to +610

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reuse B column buffers outside the row loop

For each output element (i, j), this allocates and fills a k-element Vec, so one 16×16 tile does 256 heap allocations and repeats the same B-column gather once for every row. In hot tiled GEMM use, especially for small or moderate k, that allocator and memory traffic can dominate the SIMD FMA work; pretranspose/gather the 16 B columns once per call or at least once per j and reuse them across all 16 rows.

Useful? React with 👍 / 👎.

let mut acc = F32x16::splat(0.0);
for (ra, rb) in a_row.chunks_exact(16).zip(col.chunks_exact(16)) {
acc = F32x16::from_slice(ra).mul_add(F32x16::from_slice(rb), acc);
}
c[i * 16 + j] += acc.reduce_sum();
}
}
}

#[cfg(test)]
mod bf16_tile_gemm_tests {
use super::*;
use crate::simd::{bf16_to_f32_batch, f32_to_bf16_batch, f32_to_bf16_scalar};

/// Scalar BF16 reference (f32-accumulated, `+=`) — the correctness anchor.
fn ref_gemm(a: &[f32], b: &[f32], c: &mut [f32], k: usize) {
for i in 0..16 {
for j in 0..16 {
let mut s = 0.0f32;
for kk in 0..k {
s += a[i * k + kk] * b[kk * 16 + j];
}
c[i * 16 + j] += s;
}
}
}

#[test]
fn matches_scalar_reference_k64() {
let k = 64;
let mut a_f32 = vec![0.0f32; 16 * k];
let mut b_f32 = vec![0.0f32; k * 16];
for (i, v) in a_f32.iter_mut().enumerate() {
*v =
(((i as i32).wrapping_mul(1103515245).wrapping_add(12345) >> 8) as f32 / 2147483648.0).clamp(-1.0, 1.0);
}
for (i, v) in b_f32.iter_mut().enumerate() {
*v = (((i as i32).wrapping_mul(69069).wrapping_add(1) >> 8) as f32 / 2147483648.0).clamp(-1.0, 1.0);
}
let mut a_bf16 = vec![0u16; a_f32.len()];
let mut b_bf16 = vec![0u16; b_f32.len()];
f32_to_bf16_batch(&a_f32, &mut a_bf16);
f32_to_bf16_batch(&b_f32, &mut b_bf16);

// Reference consumes the same BF16-truncated inputs the kernel sees.
let mut a_back = vec![0.0f32; a_f32.len()];
let mut b_back = vec![0.0f32; b_f32.len()];
bf16_to_f32_batch(&a_bf16, &mut a_back);
bf16_to_f32_batch(&b_bf16, &mut b_back);
let mut c_ref = vec![0.0f32; 256];
ref_gemm(&a_back, &b_back, &mut c_ref, k);

let mut c = vec![0.0f32; 256];
bf16_tile_gemm_16x16(&a_bf16, &b_bf16, &mut c, k);

let max_err = c
.iter()
.zip(&c_ref)
.map(|(x, y)| (x - y).abs())
.fold(0.0f32, f32::max);
assert!(max_err < 1e-3, "polyfill vs scalar ref max_err = {max_err}");
}

#[test]
fn accumulates_into_c() {
// C is preloaded with 5.0; the kernel must ADD, not overwrite.
let k = 32;
let a = vec![f32_to_bf16_scalar(1.0); 16 * k];
let b = vec![f32_to_bf16_scalar(1.0); k * 16];
let mut c = vec![5.0f32; 256];
bf16_tile_gemm_16x16(&a, &b, &mut c, k);
// 5 (preload) + Σ_{p<32} 1·1 = 37
for v in &c {
assert!((v - 37.0).abs() < 1e-3, "got {v}");
}
}
}

#[cfg(test)]
mod array_chunks_tests {
use super::*;
Expand Down
Loading