|
20 | 20 | //! delete them without verifying every BLAS-graph kernel still compiles |
21 | 21 | //! AND that the JIT alternative has been re-evaluated.** |
22 | 22 |
|
23 | | -use crate::simd::{F32x16, F64x8}; |
| 23 | +use crate::simd::{bf16_to_f32_batch, F32x16, F64x8}; |
24 | 24 |
|
25 | 25 | // ═══════════════════════════════════════════════════════════════════ |
26 | 26 | // f32 binary ops (out-of-place) |
@@ -556,6 +556,136 @@ pub fn array_windows_checked<T, const N: usize>(data: &[T]) -> Result<impl Itera |
556 | 556 | Ok(array_windows::<T, N>(data)) |
557 | 557 | } |
558 | 558 |
|
| 559 | +/// BF16 16×16 tile GEMM: `C[16,16] += A[16,K] · B[K,16]` where `A`, `B` are |
| 560 | +/// BF16 row-major (`u16` bit patterns) and `C` is f32 row-major. `K` must be a |
| 561 | +/// multiple of 32. |
| 562 | +/// |
| 563 | +/// Pure SIMD-polyfill kernel: decodes BF16→f32 and accumulates with |
| 564 | +/// [`F32x16::mul_add`] (FMA), so it rides the polyfill's per-arch escalation — |
| 565 | +/// AVX-512 `VFMADD231PS` where available, the emulated `(F32x8, F32x8)` pair on |
| 566 | +/// AVX2, NEON on aarch64, scalar otherwise. The `F32x16` wrapper owns the |
| 567 | +/// dispatch; there is no AMX intrinsic and no external BLAS backend. |
| 568 | +/// |
| 569 | +/// Accumulates into `c` (`+=`), matching BLAS `C := A·B + C`. Inputs are |
| 570 | +/// decoded BF16→f32, so the result matches an f32-accumulated scalar reference |
| 571 | +/// up to BF16 input precision (~2⁻⁸ per multiply, O(√K) accumulated). |
| 572 | +/// |
| 573 | +/// # Panics |
| 574 | +/// Panics if `k % 32 != 0`, `a_bf16.len() != 16 * k`, `b_bf16.len() != k * 16`, |
| 575 | +/// or `c.len() != 256`. |
| 576 | +/// |
| 577 | +/// # Examples |
| 578 | +/// ``` |
| 579 | +/// use ndarray::simd::{bf16_tile_gemm_16x16, f32_to_bf16_scalar}; |
| 580 | +/// let k = 32; |
| 581 | +/// let a = vec![f32_to_bf16_scalar(1.0); 16 * k]; |
| 582 | +/// let b = vec![f32_to_bf16_scalar(1.0); k * 16]; |
| 583 | +/// let mut c = vec![0.0f32; 256]; |
| 584 | +/// bf16_tile_gemm_16x16(&a, &b, &mut c, k); |
| 585 | +/// assert!((c[0] - k as f32).abs() < 1e-3); // Σ_{p<k} 1·1 = k |
| 586 | +/// ``` |
| 587 | +pub fn bf16_tile_gemm_16x16(a_bf16: &[u16], b_bf16: &[u16], c: &mut [f32], k: usize) { |
| 588 | + assert_eq!(k % 32, 0, "K must be a multiple of 32"); |
| 589 | + assert_eq!(a_bf16.len(), 16 * k, "A must be 16xK BF16 row-major"); |
| 590 | + assert_eq!(b_bf16.len(), k * 16, "B must be Kx16 BF16 row-major"); |
| 591 | + assert_eq!(c.len(), 16 * 16, "C must be 16x16 f32 row-major"); |
| 592 | + |
| 593 | + // Decode BF16 -> f32 once; the batch decode rides the polyfill too. |
| 594 | + let mut a_f32 = vec![0.0f32; a_bf16.len()]; |
| 595 | + let mut b_f32 = vec![0.0f32; b_bf16.len()]; |
| 596 | + bf16_to_f32_batch(a_bf16, &mut a_f32); |
| 597 | + bf16_to_f32_batch(b_bf16, &mut b_f32); |
| 598 | + |
| 599 | + // C[i,j] += dot(A row i, B col j) via F32x16 + FMA on 16-wide chunks. |
| 600 | + // K is a multiple of 32, so chunks_exact(16) tiles A's row / B's column |
| 601 | + // with no remainder. |
| 602 | + for i in 0..16 { |
| 603 | + let a_row = &a_f32[i * k..i * k + k]; |
| 604 | + for j in 0..16 { |
| 605 | + // Gather B column j (row-major B is strided by 16) into a buffer so |
| 606 | + // the dot product hits the contiguous chunks_exact(16) fast path. |
| 607 | + let mut col = vec![0.0f32; k]; |
| 608 | + for (kk, slot) in col.iter_mut().enumerate() { |
| 609 | + *slot = b_f32[kk * 16 + j]; |
| 610 | + } |
| 611 | + let mut acc = F32x16::splat(0.0); |
| 612 | + for (ra, rb) in a_row.chunks_exact(16).zip(col.chunks_exact(16)) { |
| 613 | + acc = F32x16::from_slice(ra).mul_add(F32x16::from_slice(rb), acc); |
| 614 | + } |
| 615 | + c[i * 16 + j] += acc.reduce_sum(); |
| 616 | + } |
| 617 | + } |
| 618 | +} |
| 619 | + |
| 620 | +#[cfg(test)] |
| 621 | +mod bf16_tile_gemm_tests { |
| 622 | + use super::*; |
| 623 | + use crate::simd::{bf16_to_f32_batch, f32_to_bf16_batch, f32_to_bf16_scalar}; |
| 624 | + |
| 625 | + /// Scalar BF16 reference (f32-accumulated, `+=`) — the correctness anchor. |
| 626 | + fn ref_gemm(a: &[f32], b: &[f32], c: &mut [f32], k: usize) { |
| 627 | + for i in 0..16 { |
| 628 | + for j in 0..16 { |
| 629 | + let mut s = 0.0f32; |
| 630 | + for kk in 0..k { |
| 631 | + s += a[i * k + kk] * b[kk * 16 + j]; |
| 632 | + } |
| 633 | + c[i * 16 + j] += s; |
| 634 | + } |
| 635 | + } |
| 636 | + } |
| 637 | + |
| 638 | + #[test] |
| 639 | + fn matches_scalar_reference_k64() { |
| 640 | + let k = 64; |
| 641 | + let mut a_f32 = vec![0.0f32; 16 * k]; |
| 642 | + let mut b_f32 = vec![0.0f32; k * 16]; |
| 643 | + for (i, v) in a_f32.iter_mut().enumerate() { |
| 644 | + *v = |
| 645 | + (((i as i32).wrapping_mul(1103515245).wrapping_add(12345) >> 8) as f32 / 2147483648.0).clamp(-1.0, 1.0); |
| 646 | + } |
| 647 | + for (i, v) in b_f32.iter_mut().enumerate() { |
| 648 | + *v = (((i as i32).wrapping_mul(69069).wrapping_add(1) >> 8) as f32 / 2147483648.0).clamp(-1.0, 1.0); |
| 649 | + } |
| 650 | + let mut a_bf16 = vec![0u16; a_f32.len()]; |
| 651 | + let mut b_bf16 = vec![0u16; b_f32.len()]; |
| 652 | + f32_to_bf16_batch(&a_f32, &mut a_bf16); |
| 653 | + f32_to_bf16_batch(&b_f32, &mut b_bf16); |
| 654 | + |
| 655 | + // Reference consumes the same BF16-truncated inputs the kernel sees. |
| 656 | + let mut a_back = vec![0.0f32; a_f32.len()]; |
| 657 | + let mut b_back = vec![0.0f32; b_f32.len()]; |
| 658 | + bf16_to_f32_batch(&a_bf16, &mut a_back); |
| 659 | + bf16_to_f32_batch(&b_bf16, &mut b_back); |
| 660 | + let mut c_ref = vec![0.0f32; 256]; |
| 661 | + ref_gemm(&a_back, &b_back, &mut c_ref, k); |
| 662 | + |
| 663 | + let mut c = vec![0.0f32; 256]; |
| 664 | + bf16_tile_gemm_16x16(&a_bf16, &b_bf16, &mut c, k); |
| 665 | + |
| 666 | + let max_err = c |
| 667 | + .iter() |
| 668 | + .zip(&c_ref) |
| 669 | + .map(|(x, y)| (x - y).abs()) |
| 670 | + .fold(0.0f32, f32::max); |
| 671 | + assert!(max_err < 1e-3, "polyfill vs scalar ref max_err = {max_err}"); |
| 672 | + } |
| 673 | + |
| 674 | + #[test] |
| 675 | + fn accumulates_into_c() { |
| 676 | + // C is preloaded with 5.0; the kernel must ADD, not overwrite. |
| 677 | + let k = 32; |
| 678 | + let a = vec![f32_to_bf16_scalar(1.0); 16 * k]; |
| 679 | + let b = vec![f32_to_bf16_scalar(1.0); k * 16]; |
| 680 | + let mut c = vec![5.0f32; 256]; |
| 681 | + bf16_tile_gemm_16x16(&a, &b, &mut c, k); |
| 682 | + // 5 (preload) + Σ_{p<32} 1·1 = 37 |
| 683 | + for v in &c { |
| 684 | + assert!((v - 37.0).abs() < 1e-3, "got {v}"); |
| 685 | + } |
| 686 | + } |
| 687 | +} |
| 688 | + |
559 | 689 | #[cfg(test)] |
560 | 690 | mod array_chunks_tests { |
561 | 691 | use super::*; |
|
0 commit comments