Skip to content

Commit e563fdc

Browse files
committed
fix(pr217): address review (fmt + codex/coderabbit) — v3-clean
CI fmt: ran `cargo fmt --all` (edge_residue_probe / golden_helix_probe were committed unformatted). Correctness (codex / coderabbit): * simd_int_ops::gemm_u8_i8 — VNNI dispatch was compile-time `#[cfg(target_feature)]`, so the default x86-64-v3 GitHub build stripped both VNNI arms → scalar on Ice Lake / SPR / Zen 4 silicon (codex P2 regression). Now RUNTIME `is_x86_feature_detected!` (avx512vnni → avxvnni → scalar); compiles + reaches VNNI under v3, and removes the pre-existing `needless_return` clippy warning. * simd_avx2.rs U16x16 `shr`/`shl` — returned ZERO for any shift ∉{1,2,4,8}; now `_mm256_srl_epi16`/`_mm256_sll_epi16` with a runtime lane count (all shifts). * amx_matmul::for_dpbusd — tile 1/2 shapes now match the operand contract (tmm1 = VNNI kb/4×64, tmm2 = plain 16×kb); identical at kb=64 (tests unaffected), correct for kb<64. * backend::native gemv_f32/f64 — early-return on m==0 (don't slice `x[..n]` when there are no rows; matches the scalar reference no-op). * test_tile_zero_and_release — minimal config rewritten on the corrected XTILECFG offsets (colsb=4 @16 / rows=1 @48), with an explanatory note. Probes / docs: * amx_probe matmul_f32 validator — true relative-L2 + max-abs (the old `|e|.max(1.0)` denominator was an absolute test for |e|<1). * amx_rb_probe rb_32 — assert K % 64 == 0 (was silently truncating the tail). * doc `# Examples` (ignore) on the new public APIs: TileConfig::for_dpbusd_8, tile_dpbusd_2x2, F32x8::mul_add, F32x8::cmp_gt_mask. Validated under x86-64-v3 (GitHub target): clippy clean, `cargo build --examples` Finished; native AMX probes still all CORRECT. https://claude.ai/code/session_01D2WSmezQBNC3bUdHuGfGmo
1 parent d33b983 commit e563fdc

9 files changed

Lines changed: 124 additions & 62 deletions

File tree

examples/amx_probe.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,21 @@ fn test_matmul_f32(m: usize, n: usize, k: usize) {
101101
ArrayViewMut2::from_shape((m, n), &mut got[..]).unwrap(),
102102
)
103103
.unwrap();
104-
let mut max_rel = 0.0f32;
104+
// True relative metric: L2 relative error ‖got−exp‖ / ‖exp‖ (robust to
105+
// small individual outputs — the previous `|e|.max(1.0)` denominator turned
106+
// every |e|<1 cell into an absolute-error test) plus the max absolute error.
107+
let mut sq_err = 0.0f64;
108+
let mut sq_ref = 0.0f64;
109+
let mut max_abs = 0.0f32;
105110
for (g, e) in got.iter().zip(&exp) {
106-
let denom = e.abs().max(1.0);
107-
max_rel = max_rel.max((g - e).abs() / denom);
111+
let d = g - e;
112+
sq_err += (d as f64) * (d as f64);
113+
sq_ref += (*e as f64) * (*e as f64);
114+
max_abs = max_abs.max(d.abs());
108115
}
109-
let verdict = if max_rel < 0.05 { "CORRECT" } else { "WRONG " };
110-
println!(" matmul_f32 {m:>4}x{k:>4}x{n:>4} {verdict} max_rel_err = {max_rel:.4}");
116+
let rel_l2 = (sq_err.sqrt() / sq_ref.sqrt().max(1e-12)) as f32;
117+
let verdict = if rel_l2 < 0.02 { "CORRECT" } else { "WRONG " };
118+
println!(" matmul_f32 {m:>4}x{k:>4}x{n:>4} {verdict} rel-L2 {rel_l2:.4} max-abs {max_abs:.4}");
111119
}
112120

113121
fn main() {

examples/amx_rb_probe.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ fn ref_32(a: &[u8], b: &[i8], k: usize) -> Vec<i32> {
2727

2828
/// 32×32 = A(32×k u8) · B(k×32 i8) via the 2×2 register-blocked AMX kernel.
2929
fn rb_32(a: &[u8], b: &[i8], k: usize) -> Vec<i32> {
30+
assert_eq!(k % 64, 0, "rb_32: K must be a multiple of 64 (TDPBUSD tile depth)");
3031
// Pack the two 16-wide B column bands into VNNI quads.
3132
let mut b0 = vec![0i8; k * 16];
3233
let mut b1 = vec![0i8; k * 16];

examples/edge_residue_probe.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,20 @@ fn splitmix(s: &mut u64) -> f32 {
3737
fn quantize_i8(x: &[f32]) -> (Vec<i8>, f32) {
3838
let amax = x.iter().fold(0.0f32, |a, &v| a.max(v.abs())).max(1e-12);
3939
let scale = 127.0 / amax;
40-
(x.iter().map(|&v| (v * scale).round().clamp(-127.0, 127.0) as i8).collect(), scale)
40+
(
41+
x.iter()
42+
.map(|&v| (v * scale).round().clamp(-127.0, 127.0) as i8)
43+
.collect(),
44+
scale,
45+
)
4146
}
4247

4348
fn l2(a: &[f32], b: &[f32]) -> f32 {
44-
a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum::<f32>().sqrt()
49+
a.iter()
50+
.zip(b)
51+
.map(|(x, y)| (x - y) * (x - y))
52+
.sum::<f32>()
53+
.sqrt()
4554
}
4655

4756
fn run(n: usize, d: usize, k: usize, noise: f32) {
@@ -78,7 +87,9 @@ fn run(n: usize, d: usize, k: usize, noise: f32) {
7887
let amx_ns = t0.elapsed().as_nanos() as f64;
7988

8089
// ||c_j||² in the i8 domain (same scale as v) for the argmin.
81-
let cnorm: Vec<i32> = (0..k).map(|c| (0..d).map(|j| (cb_i8[c * d + j] as i32).pow(2)).sum()).collect();
90+
let cnorm: Vec<i32> = (0..k)
91+
.map(|c| (0..d).map(|j| (cb_i8[c * d + j] as i32).pow(2)).sum())
92+
.collect();
8293
// idx[i] = argmax_j (2·G[i][j] − ||c_j||²) ≡ argmin_j ||v_i − c_j||².
8394
let mut idx = vec![0u32; n];
8495
for i in 0..n {

examples/golden_helix_probe.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,17 @@ fn main() {
119119

120120
println!("\n[2] Fisher-z percentile rank as a no-cosine normalised key:");
121121
// A deterministic spread of cosine similarities in (−1, 1).
122-
let mut sims: Vec<f64> = (0..1000).map(|i| -0.999 + 1.998 * (i as f64 + 0.5) / 1000.0).collect();
122+
let mut sims: Vec<f64> = (0..1000)
123+
.map(|i| -0.999 + 1.998 * (i as f64 + 0.5) / 1000.0)
124+
.collect();
123125
// Percentile rank of fisher_z(s). Both fisher_z and ranking are monotone in s,
124126
// so the rank order must equal the cosine order — verify (Spearman == 1).
125127
let mut idx: Vec<usize> = (0..sims.len()).collect();
126128
idx.sort_by(|&a, &b| fisher_z(sims[a]).partial_cmp(&fisher_z(sims[b])).unwrap());
127129
let inversions = idx.windows(2).filter(|w| sims[w[0]] > sims[w[1]]).count();
128-
println!(" rank-order vs cosine-order inversions: {inversions} (0 ⇒ ordering fully preserved, no cosine needed)");
130+
println!(
131+
" rank-order vs cosine-order inversions: {inversions} (0 ⇒ ordering fully preserved, no cosine needed)"
132+
);
129133

130134
// Rim-stretch: resolution (Δz per unit Δs) near the rim vs the centre.
131135
sims.sort_by(|a, b| a.partial_cmp(b).unwrap());

src/backend/native.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,9 @@ pub fn gemm_f64(
287287
/// SIMD tiers compute each row via [`dot_f32`]; the scalar tier uses
288288
/// the byte-stable [`scalar::gemv_f32`] reference.
289289
pub fn gemv_f32(m: usize, n: usize, alpha: f32, a: &[f32], lda: usize, x: &[f32], beta: f32, y: &mut [f32]) {
290+
if m == 0 {
291+
return; // no rows ⇒ no-op; must not slice `x[..n]` (scalar ref returns too)
292+
}
290293
match tier() {
291294
Tier::Scalar => scalar::gemv_f32(m, n, alpha, a, lda, x, beta, y),
292295
// Avx512 + Avx2: per-row SIMD dot product. `dot_f32` itself
@@ -307,6 +310,9 @@ pub fn gemv_f32(m: usize, n: usize, alpha: f32, a: &[f32], lda: usize, x: &[f32]
307310
/// SIMD tiers compute each row via [`dot_f64`]; the scalar tier uses
308311
/// the byte-stable [`scalar::gemv_f64`] reference.
309312
pub fn gemv_f64(m: usize, n: usize, alpha: f64, a: &[f64], lda: usize, x: &[f64], beta: f64, y: &mut [f64]) {
313+
if m == 0 {
314+
return; // no rows ⇒ no-op; must not slice `x[..n]` (scalar ref returns too)
315+
}
310316
match tier() {
311317
Tier::Scalar => scalar::gemv_f64(m, n, alpha, a, lda, x, beta, y),
312318
_ => {

src/hpc/amx_matmul.rs

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,16 @@ impl TileConfig {
6363
cfg.data[16] = 64; // colsb[0] low (u16 @ 16); high byte @17 stays 0
6464
cfg.data[48] = 16; // rows[0] (u8 @ 48)
6565

66-
// Tile 1 (A): 16 rows × kb colbytes (u8 left operand).
67-
cfg.data[18] = kb as u8; // colsb[1] low (u16 @ 18); high byte @19 stays 0
68-
cfg.data[49] = 16; // rows[1] (u8 @ 49)
66+
// Tile 1 (B, VNNI K×N → VEX.vvvv): kb/4 rows × 64 colbytes. The kernel
67+
// loads the VNNI operand into tmm1, so tile 1 must carry the VNNI shape.
68+
// (Was the plain 16×kb shape — equal to this only at kb=64; backwards
69+
// for kb<64, which would mis-shape a tail kernel / external caller.)
70+
cfg.data[18] = 64; // colsb[1] low (u16 @ 18); high byte @19 stays 0
71+
cfg.data[49] = (kb / 4) as u8; // rows[1] (u8 @ 49)
6972

70-
// Tile 2 (B): kb/4 rows × 64 colbytes (VNNI-packed right operand).
71-
cfg.data[20] = 64; // colsb[2] low (u16 @ 20); high byte @21 stays 0
72-
cfg.data[50] = (kb / 4) as u8; // rows[2] (u8 @ 50)
73+
// Tile 2 (A, plain M×K → ModRM.rm): 16 rows × kb colbytes.
74+
cfg.data[20] = kb as u8; // colsb[2] low (u16 @ 20); high byte @21 stays 0
75+
cfg.data[50] = 16; // rows[2] (u8 @ 50)
7376

7477
cfg
7578
}
@@ -80,6 +83,17 @@ impl TileConfig {
8083
/// (vvvv/signed). Every tile is 16×64 so one config serves all roles. Same
8184
/// XTILECFG layout as [`Self::for_dpbusd`]: colsb[t] u16 @ 16+2t, rows[t]
8285
/// u8 @ 48+t.
86+
///
87+
/// # Examples
88+
/// ```ignore
89+
/// use ndarray::hpc::amx_matmul::{tile_loadconfig, tile_release, TileConfig};
90+
/// // SAFETY: requires AMX (gate on `amx_available()`); all 8 tiles are 16×64.
91+
/// unsafe {
92+
/// tile_loadconfig(&TileConfig::for_dpbusd_8());
93+
/// // load A→tmm4/tmm5, B-VNNI→tmm6/tmm7, zero tmm0-3, then tile_dpbusd_2x2()
94+
/// tile_release();
95+
/// }
96+
/// ```
8397
pub fn for_dpbusd_8() -> Self {
8498
let mut cfg = TileConfig { data: [0u8; 64] };
8599
cfg.data[0] = 1; // palette 1
@@ -291,6 +305,21 @@ pub unsafe fn tile_dpbusd() {
291305
/// C10 dst2 rm5 vvvv6 → C4 E2 49 5E D5 C11 dst3 rm5 vvvv7 → C4 E2 41 5E DD
292306
/// All eight operand tiles (0/1/2/3 dst, 4/5 A, 6/7 B) are distinct → no #UD.
293307
///
308+
/// # Examples
309+
/// ```ignore
310+
/// use ndarray::hpc::amx_matmul::*;
311+
/// // SAFETY: requires AMX; full 32×32 register-blocked tile contract.
312+
/// unsafe {
313+
/// tile_loadconfig(&TileConfig::for_dpbusd_8());
314+
/// tile_zero(0); tile_zero(1); tile_zero(2); tile_zero(3); // C accumulators
315+
/// tile_load(4, a0_ptr, k); tile_load(5, a1_ptr, k); // A rows (rm)
316+
/// tile_load(6, b0_vnni, 64); tile_load(7, b1_vnni, 64); // B cols (vvvv)
317+
/// tile_dpbusd_2x2(); // 4 TDPBUSDs
318+
/// tile_store(0, c00, n * 4); /* … tmm1/2/3 → other quadrants … */
319+
/// tile_release();
320+
/// }
321+
/// ```
322+
///
294323
/// # Safety
295324
/// Tiles 0-7 configured (`TileConfig::for_dpbusd_8`) and 4/5/6/7 loaded.
296325
#[inline]
@@ -842,11 +871,14 @@ mod tests {
842871
return;
843872
}
844873
unsafe {
845-
// Minimal config: just tile 0, 1 row × 4 bytes
874+
// Minimal valid tile 0: 1 row × 4 colbytes, using the CORRECTED
875+
// XTILECFG offsets (colsb[t] u16 @ 16+2t, rows[t] u8 @ 48+t). The
876+
// old code wrote data[16]=1/data[48]=4 which under the fixed layout
877+
// means colsb=1/rows=4 — still valid, but mislabeled; now explicit.
846878
let mut cfg = TileConfig { data: [0u8; 64] };
847879
cfg.data[0] = 1; // palette 1
848-
cfg.data[16] = 1; // tile 0: 1 row
849-
cfg.data[48] = 4; // tile 0: 4 colbytes
880+
cfg.data[16] = 4; // colsb[0] = 4 bytes (u16 @ 16)
881+
cfg.data[48] = 1; // rows[0] = 1 row (u8 @ 48)
850882

851883
tile_loadconfig(&cfg);
852884
// TILEZERO tmm0

src/simd_avx2.rs

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,29 +1603,18 @@ impl U16x16 {
16031603
/// Logical right shift each 16-bit lane by `imm` (matches `U16x32::shr`).
16041604
#[inline(always)]
16051605
pub fn shr(self, imm: u32) -> Self {
1606-
Self(unsafe {
1607-
match imm {
1608-
1 => _mm256_srli_epi16(self.0, 1),
1609-
2 => _mm256_srli_epi16(self.0, 2),
1610-
4 => _mm256_srli_epi16(self.0, 4),
1611-
8 => _mm256_srli_epi16(self.0, 8),
1612-
_ => _mm256_setzero_si256(),
1613-
}
1614-
})
1606+
// SAFETY: AVX2 baseline; `_mm256_srl_epi16` takes a runtime lane count
1607+
// from the low 64 bits of an xmm, so every shift amount works (the
1608+
// earlier `match {1,2,4,8}` returned zero for all other amounts).
1609+
Self(unsafe { _mm256_srl_epi16(self.0, _mm_cvtsi32_si128(imm as i32)) })
16151610
}
16161611

16171612
/// Logical left shift each 16-bit lane by `imm` (matches `U16x32::shl`).
16181613
#[inline(always)]
16191614
pub fn shl(self, imm: u32) -> Self {
1620-
Self(unsafe {
1621-
match imm {
1622-
1 => _mm256_slli_epi16(self.0, 1),
1623-
2 => _mm256_slli_epi16(self.0, 2),
1624-
4 => _mm256_slli_epi16(self.0, 4),
1625-
8 => _mm256_slli_epi16(self.0, 8),
1626-
_ => _mm256_setzero_si256(),
1627-
}
1628-
})
1615+
// SAFETY: AVX2 baseline; `_mm256_sll_epi16` takes a runtime lane count
1616+
// (same fix as `shr` — the `match {1,2,4,8}` zeroed all other amounts).
1617+
Self(unsafe { _mm256_sll_epi16(self.0, _mm_cvtsi32_si128(imm as i32)) })
16291618
}
16301619

16311620
/// Multiply, keep low 16 bits (wrapping) — `_mm256_mullo_epi16`.

src/simd_avx512.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,14 @@ impl PartialEq for U16x32 {
13511351
// reduction needs an 8-wide FMA.
13521352
impl F32x8 {
13531353
/// Fused multiply-add: `self * a + b`, single rounding (`_mm256_fmadd_ps`).
1354+
///
1355+
/// # Examples
1356+
/// ```ignore
1357+
/// let a = F32x8::splat(0.5);
1358+
/// let b = F32x8::splat(2.0);
1359+
/// let c = F32x8::splat(1.0);
1360+
/// assert_eq!(a.mul_add(b, c).to_array(), [2.0; 8]); // 0.5*2.0 + 1.0
1361+
/// ```
13541362
#[inline(always)]
13551363
pub fn mul_add(self, a: Self, b: Self) -> Self {
13561364
// SAFETY: FMA3 intrinsic; reached only on FMA-capable targets via the
@@ -1363,6 +1371,14 @@ impl F32x8 {
13631371
/// + `_mm256_movemask_ps`. The FastScan heap threshold-prune uses it to skip
13641372
/// an 8-lane score chunk that holds no candidate above the current heap-min
13651373
/// in a single instruction — the SIMD early-out the scalar `>hmin` scan loses.
1374+
///
1375+
/// # Examples
1376+
/// ```ignore
1377+
/// let a = F32x8::from_array([3.0, 0.0, 5.0, 0.0, 3.0, 0.0, 5.0, 0.0]);
1378+
/// let b = F32x8::splat(1.0);
1379+
/// // lanes 0,2,4,6 are > 1.0 ⇒ bits 0,2,4,6 set = 0b0101_0101 = 0x55.
1380+
/// assert_eq!(a.cmp_gt_mask(b), 0x55);
1381+
/// ```
13661382
#[inline(always)]
13671383
pub fn cmp_gt_mask(self, other: Self) -> u32 {
13681384
// SAFETY: AVX `vcmpps` + `vmovmskps`; available wherever this 256-bit

src/simd_int_ops.rs

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -277,32 +277,27 @@ pub fn gemm_u8_i8(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usiz
277277
}
278278
}
279279

280-
// Compile-time dispatch chain (tiers 1-3). Exactly one arm survives
281-
// per build; the others are stripped by `#[cfg]` so the compiler
282-
// emits a direct call to the chosen kernel with no runtime branch.
283-
284-
#[cfg(all(target_arch = "x86_64", target_feature = "avx512vnni"))]
285-
{
286-
// SAFETY: `target_feature = "avx512vnni"` at this site guarantees
287-
// AVX-512F + VNNI + BW (the kernel's `#[target_feature(enable)]`
288-
// set). The dispatcher is the safety invariant the kernel relies on.
289-
unsafe { crate::hpc::vnni_gemm::int8_gemm_vnni_avx512(a, b, c, m, n, k) };
290-
return;
291-
}
292-
293-
#[cfg(all(
294-
target_arch = "x86_64",
295-
target_feature = "avxvnni",
296-
not(target_feature = "avx512vnni"),
297-
))]
280+
// RUNTIME VNNI dispatch (tiers 1-2, after the AMX check above). This MUST
281+
// be runtime `is_x86_feature_detected!`, NOT compile-time
282+
// `#[cfg(target_feature)]`: the default x86-64-v3 build has neither
283+
// avx512vnni nor avxvnni as a *compile* feature, so a cfg chain would strip
284+
// both arms and fall through to scalar even on Ice Lake / Sapphire Rapids /
285+
// Zen 4 silicon that supports VNNI at runtime (the regression codex flagged
286+
// on PR #217). Runtime detection keeps the VNNI kernels reachable on the
287+
// baseline build, matching the pre-consolidation `simd_caps()` behaviour.
288+
#[cfg(target_arch = "x86_64")]
298289
{
299-
// SAFETY: `target_feature = "avxvnni"` at this site guarantees
300-
// AVX + AVX2 + AVX-VNNI (the kernel's `#[target_feature(enable)]`
301-
// set). Arm only fires when AVX-512 VNNI is *not* present —
302-
// Alder Lake / Arrow Lake without AVX-512, or Zen 4 builds that
303-
// pinned a ymm-only target. The dispatcher is the safety invariant.
304-
unsafe { crate::hpc::vnni_gemm::int8_gemm_avxvnni_ymm(a, b, c, m, n, k) };
305-
return;
290+
if std::is_x86_feature_detected!("avx512vnni") {
291+
// SAFETY: avx512vnni detected ⇒ AVX-512F + VNNI + BW present, the
292+
// kernel's `#[target_feature(enable)]` set.
293+
unsafe { crate::hpc::vnni_gemm::int8_gemm_vnni_avx512(a, b, c, m, n, k) };
294+
return;
295+
}
296+
if std::is_x86_feature_detected!("avxvnni") {
297+
// SAFETY: avxvnni detected ⇒ AVX + AVX2 + AVX-VNNI present.
298+
unsafe { crate::hpc::vnni_gemm::int8_gemm_avxvnni_ymm(a, b, c, m, n, k) };
299+
return;
300+
}
306301
}
307302

308303
// Fallback: scalar reference kernel. Always correct; same result the

0 commit comments

Comments
 (0)