diff --git a/crates/primitives/benches/m31_sbox.rs b/crates/primitives/benches/m31_sbox.rs index fb10512..7a6d214 100644 --- a/crates/primitives/benches/m31_sbox.rs +++ b/crates/primitives/benches/m31_sbox.rs @@ -4,7 +4,7 @@ //! difference between M31 (which requires x^5 or higher) and fields like //! Koalabear/BabyBear that could use cheaper sboxes like x^3. -use criterion::{black_box, criterion_group, criterion_main, Criterion, BenchmarkId}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; use zp1_primitives::M31; /// Compute x^3 (simulating cheaper sbox, e.g. Koalabear if valid) @@ -37,21 +37,13 @@ fn bench_field_ops(c: &mut Criterion) { let a = M31::new(0x12345678); let b = M31::new(0x7654321); - group.bench_function("add", |bench| { - bench.iter(|| black_box(a) + black_box(b)) - }); + group.bench_function("add", |bench| bench.iter(|| black_box(a) + black_box(b))); - group.bench_function("mul", |bench| { - bench.iter(|| black_box(a) * black_box(b)) - }); + group.bench_function("mul", |bench| bench.iter(|| black_box(a) * black_box(b))); - group.bench_function("square", |bench| { - bench.iter(|| black_box(a).square()) - }); + group.bench_function("square", |bench| bench.iter(|| black_box(a).square())); - group.bench_function("inverse", |bench| { - bench.iter(|| black_box(a).inv()) - }); + group.bench_function("inverse", |bench| bench.iter(|| black_box(a).inv())); group.finish(); } @@ -89,7 +81,9 @@ fn bench_batch_sbox(c: &mut Criterion) { ] { group.bench_with_input(BenchmarkId::new(name, "1000"), &batch, |bench, data| { bench.iter(|| { - data.iter().map(|&x| sbox_fn(black_box(x))).collect::>() + data.iter() + .map(|&x| sbox_fn(black_box(x))) + .collect::>() }) }); } diff --git a/crates/primitives/src/circle.rs b/crates/primitives/src/circle.rs index 3286a46..3380ab4 100644 --- a/crates/primitives/src/circle.rs +++ b/crates/primitives/src/circle.rs @@ -49,10 +49,10 @@ pub fn sqrt_m31(a: M31) -> Option { if a.is_zero() { return Some(M31::ZERO); } - + // For p ≡ 3 (mod 4): sqrt(a) = a^((p+1)/4) = a^(2^29) let r = a.pow_u64(1u64 << 29); - + // Verify: r² = a if r * r == a { Some(r) @@ -78,13 +78,22 @@ pub struct CirclePoint { impl CirclePoint { /// The identity element (1, 0) - corresponds to angle 0. - pub const IDENTITY: Self = Self { x: M31::ONE, y: M31::ZERO }; + pub const IDENTITY: Self = Self { + x: M31::ONE, + y: M31::ZERO, + }; /// The point (0, 1) - corresponds to angle π/2, has order 4. - pub const I: Self = Self { x: M31::ZERO, y: M31::ONE }; + pub const I: Self = Self { + x: M31::ZERO, + y: M31::ONE, + }; /// The point (-1, 0) - corresponds to angle π, has order 2. - pub const NEG_ONE: Self = Self { x: M31(M31::P - 1), y: M31::ZERO }; + pub const NEG_ONE: Self = Self { + x: M31(M31::P - 1), + y: M31::ZERO, + }; /// Create a new circle point (does not verify it's on the circle). #[inline] @@ -135,7 +144,10 @@ impl CirclePoint { /// (x, y)⁻¹ = (x, -y) #[inline] pub fn inv(self) -> Self { - Self { x: self.x, y: -self.y } + Self { + x: self.x, + y: -self.y, + } } /// Conjugate - same as inverse for unit circle. @@ -148,14 +160,17 @@ impl CirclePoint { /// This is the "other" point at the same angle + π. #[inline] pub fn antipodal(self) -> Self { - Self { x: -self.x, y: -self.y } + Self { + x: -self.x, + y: -self.y, + } } /// Compute self^n using repeated squaring. pub fn pow(self, mut n: u64) -> Self { let mut result = Self::IDENTITY; let mut base = self; - + while n > 0 { if n & 1 == 1 { result = result.mul(base); @@ -163,7 +178,7 @@ impl CirclePoint { base = base.double(); n >>= 1; } - + result } @@ -189,17 +204,17 @@ impl CirclePoint { /// This returns a generator for the unique subgroup of order 2^log_order. pub fn generator(log_order: usize) -> Self { assert!(log_order <= 31, "Maximum subgroup order is 2^31"); - + // Start with the generator of order 2^31 let g = Self::generator_order_2_31(); - + // Square (31 - log_order) times to get generator of order 2^log_order // g^(2^(31-k)) has order 2^k let mut result = g; for _ in log_order..31 { result = result.double(); } - + result } @@ -224,9 +239,9 @@ impl CirclePoint { let x = M31::new(2); let y = M31::new(1268011823); - + debug_assert!(x * x + y * y == M31::ONE, "Generator not on circle"); - + Self { x, y } } @@ -234,7 +249,7 @@ impl CirclePoint { #[allow(dead_code)] fn generator_order_2_31_computed() -> Self { let x = M31::new(2); - let y_squared = M31::ONE - x * x; // 1 - 4 = -3 = p - 3 + let y_squared = M31::ONE - x * x; // 1 - 4 = -3 = p - 3 let y = sqrt_m31(y_squared).expect("y² should be a QR"); Self { x, y } } @@ -270,10 +285,10 @@ impl CircleDomain { /// Create a circle domain of size 2^log_size. pub fn new(log_size: usize) -> Self { assert!(log_size <= 31, "Domain size exceeds circle group order"); - + let size = 1usize << log_size; let generator = CirclePoint::generator(log_size); - + // Precompute all domain points: [g^0, g^1, ..., g^(n-1)] let mut points = Vec::with_capacity(size); let mut current = CirclePoint::IDENTITY; @@ -281,11 +296,16 @@ impl CircleDomain { points.push(current); current = current.mul(generator); } - + // Verify: the last multiplication should give identity debug_assert!(current.is_identity(), "Domain points don't form a cycle"); - - Self { log_size, size, generator, points } + + Self { + log_size, + size, + generator, + points, + } } /// Get the i-th domain point (g^i). @@ -313,14 +333,14 @@ impl CircleDomain { pub fn verify(&self) -> bool { self.points.iter().all(|p| p.is_valid()) } - + /// Get unique x-coordinates (for polynomial evaluation). /// Returns (unique_xs, mapping) where mapping[i] gives the index in unique_xs /// for domain point i. pub fn unique_x_coords(&self) -> (Vec, Vec) { let mut unique_xs = Vec::new(); let mut mapping = Vec::with_capacity(self.size); - + for p in &self.points { if let Some(idx) = unique_xs.iter().position(|&x| x == p.x) { mapping.push(idx); @@ -329,7 +349,7 @@ impl CircleDomain { unique_xs.push(p.x); } } - + (unique_xs, mapping) } } @@ -355,11 +375,13 @@ pub struct Coset { impl Coset { /// Create a coset by shifting a domain. pub fn new(domain: CircleDomain, shift: CirclePoint) -> Self { - let shifted_points = domain.points.iter() - .map(|p| shift.mul(*p)) - .collect(); - - Self { domain, shift, shifted_points } + let shifted_points = domain.points.iter().map(|p| shift.mul(*p)).collect(); + + Self { + domain, + shift, + shifted_points, + } } /// Create the standard LDE coset. @@ -369,11 +391,11 @@ impl Coset { /// disjoint from the original domain. pub fn lde_coset(log_size: usize) -> Self { let domain = CircleDomain::new(log_size); - + // Shift by generator of order 2n (one step up in the subgroup chain) // This gives a coset h·D that is disjoint from D let shift = CirclePoint::generator(log_size + 1); - + Self::new(domain, shift) } @@ -439,23 +461,23 @@ impl CircleFFT { pub fn fft(&self, coeffs: &[M31]) -> Vec { let n = self.domain.size; let half = n / 2; - + // Pad coefficients to half domain size (max useful degree) let mut padded = coeffs.to_vec(); if padded.len() > half { padded.truncate(half); } padded.resize(half, M31::ZERO); - + // Evaluate at each domain point's x-coordinate let mut evals = Vec::with_capacity(n); - + for i in 0..n { let x = self.domain.get_point(i).x; let val = evaluate_poly(&padded, x); evals.push(val); } - + evals } @@ -469,13 +491,13 @@ impl CircleFFT { pub fn ifft(&self, evals: &[M31]) -> Vec { let n = self.domain.size; let half = n / 2; - + assert_eq!(evals.len(), n, "Evaluation count must match domain size"); - + // Get x-coordinates of first half (should be unique) let xs: Vec = (0..half).map(|i| self.domain.get_point(i).x).collect(); let ys: Vec = (0..half).map(|i| evals[i]).collect(); - + // Lagrange interpolation on the unique x-coordinates interpolate_lagrange(&xs, &ys) } @@ -487,7 +509,7 @@ impl CircleFFT { pub fn extend(&self, evals: &[M31], log_extension: usize) -> Vec { // Recover coefficients let coeffs = self.ifft(evals); - + // Evaluate on larger domain let extended_fft = CircleFFT::new(self.domain.log_size + log_extension); extended_fft.fft(&coeffs) @@ -520,7 +542,7 @@ impl CircleFFT { // Based on Stwo's proven implementation (Apache 2.0 licensed) /// Butterfly operation for forward FFT. -/// +/// /// Given v0, v1 and twiddle factor t, computes: /// - v0_new = v0 + v1 * t /// - v1_new = v0 - v1 * t @@ -532,11 +554,11 @@ pub fn butterfly(v0: &mut M31, v1: &mut M31, twid: M31) { } /// Inverse butterfly operation for inverse FFT. -/// +/// /// Given v0, v1 and inverse twiddle factor it, computes: /// - v0_new = v0 + v1 /// - v1_new = (v0 - v1) * it -#[inline] +#[inline] pub fn ibutterfly(v0: &mut M31, v1: &mut M31, itwid: M31) { let tmp = *v0; *v0 = tmp + *v1; @@ -544,7 +566,7 @@ pub fn ibutterfly(v0: &mut M31, v1: &mut M31, itwid: M31) { } /// Precomputed twiddle factors for efficient FFT. -/// +/// /// Twiddles are the x-coordinates of domain points, bit-reversed for /// efficient access during the butterfly passes. #[derive(Clone, Debug)] @@ -559,7 +581,7 @@ pub struct CircleTwiddles { impl CircleTwiddles { /// Precompute twiddle factors for a domain of size 2^log_size. - /// + /// /// Follows Stwo's algorithm: for each layer, store x-coordinates /// of coset points in bit-reversed order. pub fn new(log_size: usize) -> Self { @@ -570,7 +592,7 @@ impl CircleTwiddles { log_size, }; } - + if log_size == 1 { // For size 2, we just need the y-coordinate of the generator let gen = CirclePoint::generator(1); @@ -580,20 +602,20 @@ impl CircleTwiddles { log_size, }; } - + // Start with a coset that generates the domain // Use generator of order 2^log_size let mut coset = CirclePoint::generator(log_size); let mut coset_size = 1usize << log_size; - + let mut twiddles = Vec::with_capacity(coset_size); - + // For each layer, compute and store twiddles // The twiddles are the x-coordinates of coset points for layer in 0..log_size { let start_idx = twiddles.len(); let half_size = coset_size / 2; - + // For each layer, collect x-coordinates of the first half of coset points // Start from identity and step by generator let mut point = CirclePoint::IDENTITY; @@ -601,16 +623,16 @@ impl CircleTwiddles { twiddles.push(point.x); point = point.mul(coset); } - + // Bit-reverse this layer's twiddles if half_size > 1 { bit_reverse_permutation(&mut twiddles[start_idx..]); } - + // Double the coset generator for next layer coset = coset.double(); coset_size /= 2; - + // After first layer, x-coordinates should all be non-zero // The identity point has x=1, and we step by a generator that // produces points with different x-coords @@ -618,28 +640,35 @@ impl CircleTwiddles { // First layer contains identity (x=1), which is fine } } - + // Pad to power of 2 for alignment twiddles.push(M31::ONE); - + // Compute inverse twiddles with safe fallback for any zeros - let itwiddles: Vec = twiddles.iter().map(|t| { - if t.is_zero() { - M31::ONE // Fallback for zero (should not happen in well-formed domains) - } else { - t.inv() - } - }).collect(); - - Self { twiddles, itwiddles, log_size } + let itwiddles: Vec = twiddles + .iter() + .map(|t| { + if t.is_zero() { + M31::ONE // Fallback for zero (should not happen in well-formed domains) + } else { + t.inv() + } + }) + .collect(); + + Self { + twiddles, + itwiddles, + log_size, + } } - + /// Get twiddles for a specific layer. fn layer_twiddles(&self, layer: usize) -> &[M31] { if layer >= self.log_size { return &[]; } - + // Calculate start index for this layer let mut start = 0; let mut layer_size = 1 << (self.log_size - 1); @@ -647,16 +676,16 @@ impl CircleTwiddles { start += layer_size; layer_size /= 2; } - + &self.twiddles[start..(start + layer_size.max(1))] } } /// Fast Circle FFT implementation. -/// +/// /// NOTE: Currently delegates to the O(n²) CircleFFT for correctness. /// The butterfly operations above are ready for O(n log n) implementation. -/// +/// /// TODO: Implement proper O(n log n) butterfly-based Circle FFT. /// See: https://github.com/starkware-libs/stwo #[derive(Clone, Debug)] @@ -670,37 +699,37 @@ pub struct FastCircleFFT { impl FastCircleFFT { /// Create a Fast Circle FFT for domain size 2^log_size. pub fn new(log_size: usize) -> Self { - Self { + Self { inner: CircleFFT::new(log_size), twiddles: CircleTwiddles::new(log_size), } } - + /// Forward FFT: polynomial coefficients → evaluations. pub fn fft(&self, coeffs: &[M31]) -> Vec { self.inner.fft(coeffs) } - + /// Inverse FFT: evaluations → polynomial coefficients. pub fn ifft(&self, evals: &[M31]) -> Vec { self.inner.ifft(evals) } - + /// Low-degree extension using FFT. pub fn extend(&self, evals: &[M31], log_extension: usize) -> Vec { self.inner.extend(evals, log_extension) } - + /// Get domain size. pub fn size(&self) -> usize { self.inner.size() } - + /// Get log domain size. pub fn log_size(&self) -> usize { self.inner.log_size() } - + /// Get the domain. pub fn domain(&self) -> &CircleDomain { self.inner.domain() @@ -708,16 +737,13 @@ impl FastCircleFFT { } /// Execute one layer of the FFT butterfly algorithm. -/// +/// /// This processes all butterflies at a given layer with the same twiddle factor. #[inline] -fn fft_layer_loop( - values: &mut [M31], - layer: usize, - h: usize, - twid: M31, - butterfly_fn: F -) where F: Fn(&mut M31, &mut M31, M31) { +fn fft_layer_loop(values: &mut [M31], layer: usize, h: usize, twid: M31, butterfly_fn: F) +where + F: Fn(&mut M31, &mut M31, M31), +{ let layer_size = 1 << layer; for l in 0..layer_size { let idx0 = (h << (layer + 1)) + l; @@ -732,7 +758,7 @@ fn fft_layer_loop( } /// Compute circle twiddles (layer 0) from line twiddles (layer 1). -/// +/// /// The relationship between consecutive domain points allows us to derive /// the y-coordinate twiddles from the x-coordinate twiddles. fn circle_twiddles_from_line(line_twiddles: &[M31]) -> impl Iterator + '_ { @@ -768,27 +794,32 @@ pub fn evaluate_poly(coeffs: &[M31], point: M31) -> M31 { pub fn interpolate_lagrange(xs: &[M31], ys: &[M31]) -> Vec { let n = xs.len(); assert_eq!(n, ys.len(), "xs and ys must have same length"); - + if n == 0 { return vec![]; } if n == 1 { return vec![ys[0]]; } - + // Check for duplicates for i in 0..n { - for j in (i+1)..n { - assert!(xs[i] != xs[j], "Duplicate x values in interpolation: x[{}] = x[{}] = {}", - i, j, xs[i].value()); + for j in (i + 1)..n { + assert!( + xs[i] != xs[j], + "Duplicate x values in interpolation: x[{}] = x[{}] = {}", + i, + j, + xs[i].value() + ); } } - + let mut coeffs = vec![M31::ZERO; n]; - + for i in 0..n { // Compute Lagrange basis polynomial Lᵢ(x) = ∏_{j≠i} (x - xⱼ)/(xᵢ - xⱼ) - + // First compute denominator: ∏_{j≠i} (xᵢ - xⱼ) let mut denom = M31::ONE; for j in 0..n { @@ -796,10 +827,10 @@ pub fn interpolate_lagrange(xs: &[M31], ys: &[M31]) -> Vec { denom = denom * (xs[i] - xs[j]); } } - + // Scale factor: yᵢ / denom let scale = ys[i] * denom.inv(); - + // Build numerator polynomial: ∏_{j≠i} (x - xⱼ) let mut basis = vec![M31::ONE]; for j in 0..n { @@ -807,13 +838,13 @@ pub fn interpolate_lagrange(xs: &[M31], ys: &[M31]) -> Vec { // Multiply by (x - xⱼ) let mut new_basis = vec![M31::ZERO; basis.len() + 1]; for (k, &b) in basis.iter().enumerate() { - new_basis[k + 1] = new_basis[k + 1] + b; // +b·x - new_basis[k] = new_basis[k] - b * xs[j]; // -b·xⱼ + new_basis[k + 1] = new_basis[k + 1] + b; // +b·x + new_basis[k] = new_basis[k] - b * xs[j]; // -b·xⱼ } basis = new_basis; } } - + // Add scaled basis to result for (k, &b) in basis.iter().enumerate() { if k < n { @@ -821,7 +852,7 @@ pub fn interpolate_lagrange(xs: &[M31], ys: &[M31]) -> Vec { } } } - + coeffs } @@ -833,15 +864,15 @@ pub fn poly_mul(f: &[M31], g: &[M31]) -> Vec { if f.is_empty() || g.is_empty() { return vec![]; } - + let mut result = vec![M31::ZERO; f.len() + g.len() - 1]; - + for (i, &fi) in f.iter().enumerate() { for (j, &gj) in g.iter().enumerate() { result[i + j] = result[i + j] + fi * gj; } } - + result } @@ -849,14 +880,14 @@ pub fn poly_mul(f: &[M31], g: &[M31]) -> Vec { pub fn poly_add(f: &[M31], g: &[M31]) -> Vec { let max_len = f.len().max(g.len()); let mut result = vec![M31::ZERO; max_len]; - + for (i, &fi) in f.iter().enumerate() { result[i] = result[i] + fi; } for (i, &gi) in g.iter().enumerate() { result[i] = result[i] + gi; } - + result } @@ -864,14 +895,14 @@ pub fn poly_add(f: &[M31], g: &[M31]) -> Vec { pub fn poly_sub(f: &[M31], g: &[M31]) -> Vec { let max_len = f.len().max(g.len()); let mut result = vec![M31::ZERO; max_len]; - + for (i, &fi) in f.iter().enumerate() { result[i] = result[i] + fi; } for (i, &gi) in g.iter().enumerate() { result[i] = result[i] - gi; } - + result } @@ -898,35 +929,35 @@ pub fn poly_divmod(f: &[M31], g: &[M31]) -> (Vec, Vec) { Some(d) => d, None => panic!("Division by zero polynomial"), }; - + let f_deg = match poly_degree(f) { Some(d) => d, None => return (vec![], vec![]), // 0 / g = 0 remainder 0 }; - + if f_deg < g_deg { return (vec![], f.to_vec()); } - + let mut remainder = f.to_vec(); let mut quotient = vec![M31::ZERO; f_deg - g_deg + 1]; - + let lead_g_inv = g[g_deg].inv(); - + for i in (0..=f_deg - g_deg).rev() { let coeff = remainder[i + g_deg] * lead_g_inv; quotient[i] = coeff; - + for j in 0..=g_deg { remainder[i + j] = remainder[i + j] - coeff * g[j]; } } - + // Trim trailing zeros from remainder while remainder.len() > 1 && remainder.last() == Some(&M31::ZERO) { remainder.pop(); } - + (quotient, remainder) } @@ -941,9 +972,9 @@ pub fn bit_reverse_permutation(data: &mut [T]) { if n <= 1 { return; } - + let log_n = n.trailing_zeros() as usize; - + for i in 0..n { let j = bit_reverse(i, log_n); if i < j { @@ -972,10 +1003,10 @@ mod tests { let four = M31::new(4); let r = sqrt_m31(four).unwrap(); assert!(r * r == four); - + // sqrt(1) = 1 assert!(sqrt_m31(M31::ONE).unwrap() * sqrt_m31(M31::ONE).unwrap() == M31::ONE); - + // sqrt(0) = 0 assert_eq!(sqrt_m31(M31::ZERO), Some(M31::ZERO)); } @@ -992,13 +1023,13 @@ mod tests { #[test] fn test_circle_point_mul() { let p = CirclePoint::generator(4); - + // p * identity = p assert_eq!(p.mul(CirclePoint::IDENTITY), p); - - // identity * p = p + + // identity * p = p assert_eq!(CirclePoint::IDENTITY.mul(p), p); - + // p * p^(-1) = identity let p_inv = p.inv(); assert!(p.mul(p_inv).is_identity()); @@ -1008,10 +1039,10 @@ mod tests { fn test_circle_point_double() { let g = CirclePoint::generator(4); assert!(g.is_valid()); - + let g2 = g.double(); assert!(g2.is_valid()); - + // g.double() should equal g.mul(g) assert_eq!(g2, g.mul(g)); } @@ -1021,7 +1052,7 @@ mod tests { // Check generator is on circle let g = CirclePoint::generator_order_2_31(); assert!(g.is_valid(), "Generator must satisfy x² + y² = 1"); - + // Verify the specific values assert_eq!(g.x, M31::new(2)); let y_sq = g.y * g.y; @@ -1033,15 +1064,15 @@ mod tests { fn test_circle_point_order() { // Generator of order 2^4 = 16 let g = CirclePoint::generator(4); - + // g^16 should be identity let g16 = g.pow(16); assert!(g16.is_identity(), "g^16 should be identity"); - + // g^8 should NOT be identity let g8 = g.pow(8); assert!(!g8.is_identity(), "g^8 should not be identity"); - + // g^4 should NOT be identity let g4 = g.pow(4); assert!(!g4.is_identity(), "g^4 should not be identity"); @@ -1052,13 +1083,13 @@ mod tests { let domain = CircleDomain::new(3); assert_eq!(domain.size, 8); assert_eq!(domain.log_size, 3); - + // First point is identity assert!(domain.get_point(0).is_identity()); - + // All points are valid assert!(domain.verify()); - + // All points are distinct for i in 0..domain.size { for j in (i + 1)..domain.size { @@ -1083,16 +1114,16 @@ mod tests { fn test_evaluate_poly() { // f(x) = 1 + 2x + 3x² let coeffs = vec![M31::new(1), M31::new(2), M31::new(3)]; - + // f(0) = 1 assert_eq!(evaluate_poly(&coeffs, M31::ZERO), M31::new(1)); - + // f(1) = 1 + 2 + 3 = 6 assert_eq!(evaluate_poly(&coeffs, M31::ONE), M31::new(6)); - + // f(2) = 1 + 4 + 12 = 17 assert_eq!(evaluate_poly(&coeffs, M31::new(2)), M31::new(17)); - + // f(10) = 1 + 20 + 300 = 321 assert_eq!(evaluate_poly(&coeffs, M31::new(10)), M31::new(321)); } @@ -1103,9 +1134,9 @@ mod tests { // These lie on f(x) = 1 + 2x + 3x² let xs = vec![M31::new(0), M31::new(1), M31::new(2), M31::new(3)]; let ys = vec![M31::new(1), M31::new(6), M31::new(17), M31::new(34)]; - + let coeffs = interpolate_lagrange(&xs, &ys); - + // Verify interpolation for i in 0..4 { let y = evaluate_poly(&coeffs, xs[i]); @@ -1118,9 +1149,9 @@ mod tests { // FFT of constant polynomial f(x) = 42 let fft = CircleFFT::new(3); let coeffs = vec![M31::new(42)]; - + let evals = fft.fft(&coeffs); - + // All evaluations should be 42 for (i, &e) in evals.iter().enumerate() { assert_eq!(e, M31::new(42), "Eval at {} should be 42", i); @@ -1132,9 +1163,9 @@ mod tests { // FFT of f(x) = 1 + 2x let fft = CircleFFT::new(2); let coeffs = vec![M31::new(1), M31::new(2)]; - + let evals = fft.fft(&coeffs); - + // Verify each evaluation manually for i in 0..4 { let x = fft.get_domain_point(i).x; @@ -1147,21 +1178,24 @@ mod tests { fn test_fft_ifft_roundtrip() { let fft = CircleFFT::new(3); let half = fft.size() / 2; - + // Original polynomial of degree < n/2 - let original = vec![ - M31::new(1), M31::new(2), M31::new(3), M31::new(4), - ]; + let original = vec![M31::new(1), M31::new(2), M31::new(3), M31::new(4)]; assert!(original.len() <= half); - + let evals = fft.fft(&original); let recovered = fft.ifft(&evals); - + // Should recover original coefficients for i in 0..original.len() { - assert_eq!(recovered[i], original[i], - "Roundtrip failed at {}: got {}, expected {}", - i, recovered[i].value(), original[i].value()); + assert_eq!( + recovered[i], + original[i], + "Roundtrip failed at {}: got {}, expected {}", + i, + recovered[i].value(), + original[i].value() + ); } } @@ -1171,7 +1205,7 @@ mod tests { let f = vec![M31::new(1), M31::new(1)]; let g = vec![M31::new(1), M31::new(2)]; let h = poly_mul(&f, &g); - + assert_eq!(h.len(), 3); assert_eq!(h[0], M31::new(1)); assert_eq!(h[1], M31::new(3)); @@ -1183,13 +1217,13 @@ mod tests { // (2x² + 3x + 1) / (x + 1) = (2x + 1) remainder 0 let f = vec![M31::new(1), M31::new(3), M31::new(2)]; let g = vec![M31::new(1), M31::new(1)]; - + let (q, r) = poly_divmod(&f, &g); - + // Verify: f = q*g + r let qg = poly_mul(&q, &g); let reconstructed = poly_add(&qg, &r); - + for i in 0..f.len() { assert_eq!(reconstructed[i], f[i], "Division check failed at {}", i); } @@ -1198,16 +1232,16 @@ mod tests { #[test] fn test_lde_extension() { let fft = CircleFFT::new(2); - + // Polynomial: f(x) = 1 + 2x (degree 1, fits in domain of size 4) let coeffs = vec![M31::new(1), M31::new(2)]; let evals = fft.fft(&coeffs); - + // Extend to domain of size 8 let extended = fft.extend(&evals, 1); - + assert_eq!(extended.len(), 8); - + // Verify extended evaluations are correct let extended_fft = CircleFFT::new(3); for i in 0..8 { @@ -1216,57 +1250,57 @@ mod tests { assert_eq!(extended[i], expected, "LDE mismatch at {}", i); } } - + #[test] fn test_domain_x_coords_first_half_unique() { // For a domain of size n, verify the first n/2 x-coordinates are unique - let domain = CircleDomain::new(4); // size 16 + let domain = CircleDomain::new(4); // size 16 let half = domain.size / 2; - + let xs: Vec = (0..half).map(|i| domain.get_point(i).x).collect(); - + // Check uniqueness for i in 0..half { - for j in (i+1)..half { + for j in (i + 1)..half { assert_ne!(xs[i], xs[j], "Duplicate x at positions {} and {}", i, j); } } } - + // ======================================================================== // FastCircleFFT Tests (O(n log n) butterfly algorithm) // ======================================================================== - + #[test] fn test_fast_fft_butterfly_basic() { // Test the basic butterfly operation let mut v0 = M31::new(3); let mut v1 = M31::new(5); let twid = M31::new(2); - + // v0_new = v0 + v1*t = 3 + 5*2 = 13 // v1_new = v0 - v1*t = 3 - 5*2 = 3 - 10 = -7 mod p butterfly(&mut v0, &mut v1, twid); - + assert_eq!(v0, M31::new(13)); - assert_eq!(v1, M31::ZERO - M31::new(7)); // -7 mod p + assert_eq!(v1, M31::ZERO - M31::new(7)); // -7 mod p } - + #[test] fn test_fast_fft_ibutterfly_basic() { // Test the inverse butterfly operation let mut v0 = M31::new(8); let mut v1 = M31::new(4); let itwid = M31::new(2); - + // v0_new = v0 + v1 = 8 + 4 = 12 // v1_new = (v0 - v1) * it = (8 - 4) * 2 = 8 ibutterfly(&mut v0, &mut v1, itwid); - + assert_eq!(v0, M31::new(12)); assert_eq!(v1, M31::new(8)); } - + #[test] fn test_fast_fft_small_sizes() { // Test size 4 (log_size 2) - smallest working size @@ -1275,7 +1309,7 @@ mod tests { let coeffs4 = vec![M31::new(1), M31::new(2)]; // f(x) = 1 + 2x let evals4 = fft4.fft(&coeffs4); assert_eq!(evals4.len(), 4, "FFT should produce 4 evaluations"); - + // Test size 8 (log_size 3) let fft8 = FastCircleFFT::new(3); // For size 8, we provide 4 coefficients (n/2 = 4) @@ -1283,7 +1317,7 @@ mod tests { let evals8 = fft8.fft(&coeffs8); assert_eq!(evals8.len(), 8, "FFT should produce 8 evaluations"); } - + #[test] fn test_fast_fft_roundtrip() { // Test that fft followed by ifft preserves coefficients @@ -1292,22 +1326,34 @@ mod tests { let fast_fft = FastCircleFFT::new(log_size); let n = 1 << log_size; let half = n / 2; - + // Create test coefficients (only n/2 meaningful for degree < n/2) - let coeffs: Vec = (0..half).map(|i| M31::new((i * 7 + 13) as u32 % 1000)).collect(); - + let coeffs: Vec = (0..half) + .map(|i| M31::new((i * 7 + 13) as u32 % 1000)) + .collect(); + // Forward FFT: n/2 coeffs -> n evals let evals = fast_fft.fft(&coeffs); - assert_eq!(evals.len(), n, "FFT output size mismatch for log_size {}", log_size); - + assert_eq!( + evals.len(), + n, + "FFT output size mismatch for log_size {}", + log_size + ); + // Inverse FFT: n evals -> n/2 coeffs let recovered = fast_fft.ifft(&evals); - assert_eq!(recovered.len(), half, "IFFT output size mismatch for log_size {}", log_size); - + assert_eq!( + recovered.len(), + half, + "IFFT output size mismatch for log_size {}", + log_size + ); + // Check roundtrip for the meaningful coefficients for i in 0..half { assert_eq!( - recovered[i], coeffs[i], + recovered[i], coeffs[i], "Roundtrip failed at index {} for log_size {}: got {:?}, expected {:?}", i, log_size, recovered[i], coeffs[i] ); @@ -1317,12 +1363,12 @@ mod tests { #[test] fn test_fast_fft_extend() { - let fft = FastCircleFFT::new(3); // size 8 - + let fft = FastCircleFFT::new(3); // size 8 + // Create coefficients (n/2 = 4 meaningful coefficients) let coeffs: Vec = (0..4).map(|i| M31::new(i as u32)).collect(); let evals = fft.fft(&coeffs); - + // Extend to size 16 let extended = fft.extend(&evals, 1); assert_eq!(extended.len(), 16); diff --git a/crates/primitives/src/extension.rs b/crates/primitives/src/extension.rs index 6228247..78be774 100644 --- a/crates/primitives/src/extension.rs +++ b/crates/primitives/src/extension.rs @@ -29,9 +29,9 @@ //! - i² = -1 //! - u² = 2 + i +use crate::field::M31; use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use serde::{Deserialize, Serialize}; -use crate::field::M31; // ============================================================================ // CM31: Complex Extension M31[i]/(i² + 1) @@ -50,13 +50,22 @@ pub struct CM31 { impl CM31 { /// The additive identity. - pub const ZERO: Self = Self { a: M31::ZERO, b: M31::ZERO }; + pub const ZERO: Self = Self { + a: M31::ZERO, + b: M31::ZERO, + }; /// The multiplicative identity. - pub const ONE: Self = Self { a: M31::ONE, b: M31::ZERO }; + pub const ONE: Self = Self { + a: M31::ONE, + b: M31::ZERO, + }; /// The imaginary unit i where i² = -1. - pub const I: Self = Self { a: M31::ZERO, b: M31::ONE }; + pub const I: Self = Self { + a: M31::ZERO, + b: M31::ONE, + }; /// Create a new CM31 element. #[inline] @@ -67,7 +76,10 @@ impl CM31 { /// Embed an M31 element into CM31. #[inline] pub const fn from_base(val: M31) -> Self { - Self { a: val, b: M31::ZERO } + Self { + a: val, + b: M31::ZERO, + } } /// Check if this is a real element (b = 0). @@ -85,7 +97,10 @@ impl CM31 { /// Complex conjugate: conj(a + bi) = a - bi. #[inline] pub fn conjugate(self) -> Self { - Self { a: self.a, b: -self.b } + Self { + a: self.a, + b: -self.b, + } } /// Norm: N(z) = z * conj(z) = a² + b² (result is in M31). @@ -125,26 +140,36 @@ impl Add for CM31 { type Output = Self; #[inline] fn add(self, rhs: Self) -> Self { - Self { a: self.a + rhs.a, b: self.b + rhs.b } + Self { + a: self.a + rhs.a, + b: self.b + rhs.b, + } } } impl AddAssign for CM31 { #[inline] - fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; } + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } } impl Sub for CM31 { type Output = Self; #[inline] fn sub(self, rhs: Self) -> Self { - Self { a: self.a - rhs.a, b: self.b - rhs.b } + Self { + a: self.a - rhs.a, + b: self.b - rhs.b, + } } } impl SubAssign for CM31 { #[inline] - fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; } + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } } impl Mul for CM31 { @@ -161,20 +186,27 @@ impl Mul for CM31 { impl MulAssign for CM31 { #[inline] - fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } } impl Neg for CM31 { type Output = Self; #[inline] fn neg(self) -> Self { - Self { a: -self.a, b: -self.b } + Self { + a: -self.a, + b: -self.b, + } } } impl From for CM31 { #[inline] - fn from(val: M31) -> Self { Self::from_base(val) } + fn from(val: M31) -> Self { + Self::from_base(val) + } } // ============================================================================ @@ -182,7 +214,10 @@ impl From for CM31 { // ============================================================================ /// The non-residue in CM31 used to define the extension: u² = 2 + i. -pub const U_SQUARED: CM31 = CM31 { a: M31(2), b: M31(1) }; +pub const U_SQUARED: CM31 = CM31 { + a: M31(2), + b: M31(1), +}; /// An element of QM31 = CM31[u]/(u² - (2+i)). /// @@ -201,14 +236,18 @@ pub struct QM31 { impl QM31 { /// The additive identity. pub const ZERO: Self = Self { - c0: M31::ZERO, c1: M31::ZERO, - c2: M31::ZERO, c3: M31::ZERO, + c0: M31::ZERO, + c1: M31::ZERO, + c2: M31::ZERO, + c3: M31::ZERO, }; /// The multiplicative identity. pub const ONE: Self = Self { - c0: M31::ONE, c1: M31::ZERO, - c2: M31::ZERO, c3: M31::ZERO, + c0: M31::ONE, + c1: M31::ZERO, + c2: M31::ZERO, + c3: M31::ZERO, }; /// Create from four M31 coefficients: a + bi + cu + diu. @@ -220,34 +259,52 @@ impl QM31 { /// Create from two CM31 elements: z₀ + z₁u. #[inline] pub const fn from_cm31(z0: CM31, z1: CM31) -> Self { - Self { c0: z0.a, c1: z0.b, c2: z1.a, c3: z1.b } + Self { + c0: z0.a, + c1: z0.b, + c2: z1.a, + c3: z1.b, + } } /// Get the z₀ component (constant part). #[inline] pub const fn z0(&self) -> CM31 { - CM31 { a: self.c0, b: self.c1 } + CM31 { + a: self.c0, + b: self.c1, + } } /// Get the z₁ component (coefficient of u). #[inline] pub const fn z1(&self) -> CM31 { - CM31 { a: self.c2, b: self.c3 } + CM31 { + a: self.c2, + b: self.c3, + } } /// Embed an M31 element into QM31. #[inline] pub const fn from_base(val: M31) -> Self { Self { - c0: val, c1: M31::ZERO, - c2: M31::ZERO, c3: M31::ZERO, + c0: val, + c1: M31::ZERO, + c2: M31::ZERO, + c3: M31::ZERO, } } /// Embed a CM31 element into QM31. #[inline] pub const fn from_cm31_base(val: CM31) -> Self { - Self { c0: val.a, c1: val.b, c2: M31::ZERO, c3: M31::ZERO } + Self { + c0: val.a, + c1: val.b, + c2: M31::ZERO, + c3: M31::ZERO, + } } /// Check if this element is in the base field M31. @@ -272,8 +329,10 @@ impl QM31 { #[inline] pub fn conjugate(self) -> Self { Self { - c0: self.c0, c1: self.c1, - c2: -self.c2, c3: -self.c3, + c0: self.c0, + c1: self.c1, + c2: -self.c2, + c3: -self.c3, } } @@ -344,19 +403,27 @@ impl QM31 { /// Get the c0 coefficient. #[inline] - pub const fn c0(&self) -> M31 { self.c0 } + pub const fn c0(&self) -> M31 { + self.c0 + } /// Get the c1 coefficient. #[inline] - pub const fn c1(&self) -> M31 { self.c1 } + pub const fn c1(&self) -> M31 { + self.c1 + } /// Get the c2 coefficient. #[inline] - pub const fn c2(&self) -> M31 { self.c2 } + pub const fn c2(&self) -> M31 { + self.c2 + } /// Get the c3 coefficient. #[inline] - pub const fn c3(&self) -> M31 { self.c3 } + pub const fn c3(&self) -> M31 { + self.c3 + } } // ============================================================================ @@ -379,7 +446,9 @@ impl Add for QM31 { impl AddAssign for QM31 { #[inline] - fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; } + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } } impl Sub for QM31 { @@ -398,7 +467,9 @@ impl Sub for QM31 { impl SubAssign for QM31 { #[inline] - fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; } + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } } impl Mul for QM31 { @@ -426,7 +497,9 @@ impl Mul for QM31 { impl MulAssign for QM31 { #[inline] - fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } } impl Neg for QM31 { @@ -454,17 +527,23 @@ impl Div for QM31 { impl DivAssign for QM31 { #[inline] - fn div_assign(&mut self, rhs: Self) { *self = *self / rhs; } + fn div_assign(&mut self, rhs: Self) { + *self = *self / rhs; + } } impl From for QM31 { #[inline] - fn from(val: M31) -> Self { Self::from_base(val) } + fn from(val: M31) -> Self { + Self::from_base(val) + } } impl From for QM31 { #[inline] - fn from(val: CM31) -> Self { Self::from_cm31_base(val) } + fn from(val: CM31) -> Self { + Self::from_cm31_base(val) + } } // ============================================================================ @@ -485,7 +564,11 @@ impl core::fmt::Display for CM31 { impl core::fmt::Display for QM31 { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "({} + {}i) + ({} + {}i)u", self.c0, self.c1, self.c2, self.c3) + write!( + f, + "({} + {}i) + ({} + {}i)u", + self.c0, self.c1, self.c2, self.c3 + ) } } @@ -573,7 +656,12 @@ mod tests { #[test] fn test_qm31_mul_identity() { - let a = QM31::new(M31::new(123), M31::new(456), M31::new(789), M31::new(101112)); + let a = QM31::new( + M31::new(123), + M31::new(456), + M31::new(789), + M31::new(101112), + ); assert_eq!(a * QM31::ONE, a); assert_eq!(QM31::ONE * a, a); } diff --git a/crates/primitives/src/field.rs b/crates/primitives/src/field.rs index bcb331b..8fda71a 100644 --- a/crates/primitives/src/field.rs +++ b/crates/primitives/src/field.rs @@ -4,8 +4,8 @@ //! This field is efficient for STARK proving due to fast reduction //! and good NTT-friendly properties via Circle STARKs or extension towers. -use core::ops::{Add, AddAssign, Sub, SubAssign, Mul, MulAssign, Neg}; use bytemuck::{Pod, Zeroable}; +use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use serde::{Deserialize, Serialize}; /// The Mersenne31 prime: 2^31 - 1 @@ -15,7 +15,9 @@ pub const P: u32 = (1 << 31) - 1; /// /// Internally stored as a u32 in the range [0, P). /// All arithmetic operations maintain this invariant. -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, Pod, Zeroable, Serialize, Deserialize)] +#[derive( + Clone, Copy, Debug, Default, PartialEq, Eq, Hash, Pod, Zeroable, Serialize, Deserialize, +)] #[repr(transparent)] pub struct M31(pub u32); @@ -46,7 +48,11 @@ impl M31 { #[inline] const fn reduce(val: u32) -> u32 { let reduced = val.wrapping_sub(P); - if reduced < P { reduced } else { val } + if reduced < P { + reduced + } else { + val + } } /// Reduce a u64 product to M31. diff --git a/crates/primitives/src/lib.rs b/crates/primitives/src/lib.rs index a08dce8..976844d 100644 --- a/crates/primitives/src/lib.rs +++ b/crates/primitives/src/lib.rs @@ -8,14 +8,14 @@ //! - Range-check helpers //! - Plonky3 interoperability for SIMD-optimized operations -pub mod field; +pub mod circle; pub mod extension; +pub mod field; pub mod limbs; -pub mod circle; pub mod p3_interop; -pub use field::M31; +pub use circle::{CircleDomain, CircleFFT, CirclePoint, Coset, FastCircleFFT}; pub use extension::{CM31, QM31, U_SQUARED}; -pub use limbs::{to_limbs, from_limbs}; -pub use circle::{CirclePoint, CircleDomain, CircleFFT, Coset, FastCircleFFT}; -pub use p3_interop::{to_p3, from_p3, P3M31}; +pub use field::M31; +pub use limbs::{from_limbs, to_limbs}; +pub use p3_interop::{from_p3, to_p3, P3M31}; diff --git a/crates/primitives/src/p3_interop.rs b/crates/primitives/src/p3_interop.rs index 7208609..39863ec 100644 --- a/crates/primitives/src/p3_interop.rs +++ b/crates/primitives/src/p3_interop.rs @@ -13,13 +13,13 @@ //! Using these can provide 2-8x speedups for field-heavy operations. //! //! # DFT Support -//! +//! //! The `p3_fast_dft` function provides O(n log n) FFT using Plonky3's Radix2Dit. use crate::field::M31; -use p3_field::{PrimeCharacteristicRing, PrimeField32}; -use p3_field::extension::Complex; pub use p3_dft::TwoAdicSubgroupDft; +use p3_field::extension::Complex; +use p3_field::{PrimeCharacteristicRing, PrimeField32}; pub use p3_mersenne_31::{Mersenne31 as P3M31, Mersenne31ComplexRadix2Dit}; /// Convert ZP1 M31 to Plonky3 Mersenne31. @@ -48,7 +48,7 @@ pub fn from_p3_vec(v: &[P3M31]) -> Vec { pub type P3Complex = Complex; /// Perform O(n log n) DFT using Plonky3's optimized Radix-2 DIT. -/// +/// /// This function takes a slice of ZP1 M31 values, converts them to Plonky3's /// Complex, runs the DFT, and converts back. /// @@ -59,27 +59,30 @@ pub type P3Complex = Complex; /// Evaluations on a 2-adic subgroup of the complex extension pub fn p3_dft(coeffs: &[M31]) -> Vec<(M31, M31)> { use p3_matrix::dense::RowMajorMatrix; - + if coeffs.is_empty() { return vec![]; } - + // Pad to power of 2 let len = coeffs.len().next_power_of_two(); - + // Convert to Complex - embed M31 as real part - let mut complex_coeffs: Vec = coeffs.iter() + let mut complex_coeffs: Vec = coeffs + .iter() .map(|&m| P3Complex::new_real(to_p3(m))) .collect(); complex_coeffs.resize(len, P3Complex::ZERO); - + // Create matrix and run DFT let mat = RowMajorMatrix::new_col(complex_coeffs); let dft = Mersenne31ComplexRadix2Dit; let result = dft.dft_batch(mat); - + // Convert back to ZP1 (real, imag) pairs - result.values.iter() + result + .values + .iter() .map(|c| (from_p3(c.real()), from_p3(c.imag()))) .collect() } @@ -87,23 +90,26 @@ pub fn p3_dft(coeffs: &[M31]) -> Vec<(M31, M31)> { /// Perform O(n log n) inverse DFT using Plonky3. pub fn p3_idft(evals: &[(M31, M31)]) -> Vec<(M31, M31)> { use p3_matrix::dense::RowMajorMatrix; - + if evals.is_empty() { return vec![]; } - + // Convert to Complex - let complex_evals: Vec = evals.iter() + let complex_evals: Vec = evals + .iter() .map(|&(r, i)| P3Complex::new_complex(to_p3(r), to_p3(i))) .collect(); - + // Create matrix and run IDFT let mat = RowMajorMatrix::new_col(complex_evals); let dft = Mersenne31ComplexRadix2Dit; let result = dft.idft_batch(mat); - + // Convert back - result.values.iter() + result + .values + .iter() .map(|c| (from_p3(c.real()), from_p3(c.imag()))) .collect() } @@ -122,26 +128,26 @@ mod tests { assert_eq!(zp1, back, "Roundtrip failed for {}", i); } } - + #[test] fn test_arithmetic_compatibility() { let a = M31::new(12345); let b = M31::new(67890); - + // ZP1 arithmetic let zp1_sum = a + b; let zp1_prod = a * b; - + // P3 arithmetic let p3_a = to_p3(a); let p3_b = to_p3(b); let p3_sum = p3_a + p3_b; let p3_prod = p3_a * p3_b; - + assert_eq!(zp1_sum, from_p3(p3_sum), "Sum mismatch"); assert_eq!(zp1_prod, from_p3(p3_prod), "Product mismatch"); } - + #[test] fn test_p3_has_simd() { // This test verifies Plonky3 is correctly configured @@ -150,20 +156,20 @@ mod tests { assert!(!one.is_zero()); assert!(!gen.is_zero()); } - + #[test] fn test_p3_dft_roundtrip() { // Test O(n log n) DFT roundtrip let coeffs: Vec = (0..8).map(|i| M31::new(i * 10 + 1)).collect(); - + // Forward DFT let evals = p3_dft(&coeffs); assert_eq!(evals.len(), 8, "DFT should return 8 evaluations"); - + // Inverse DFT let recovered = p3_idft(&evals); assert_eq!(recovered.len(), 8, "IDFT should return 8 coefficients"); - + // Check roundtrip (real parts should match, imaginary should be ~0) for (i, ((r, _), orig)) in recovered.iter().zip(coeffs.iter()).enumerate() { assert_eq!(*r, *orig, "DFT roundtrip failed at index {}", i); diff --git a/crates/prover/src/gpu/backend.rs b/crates/prover/src/gpu/backend.rs index 2477c63..7731a9d 100644 --- a/crates/prover/src/gpu/backend.rs +++ b/crates/prover/src/gpu/backend.rs @@ -2,8 +2,8 @@ #![allow(dead_code)] -use std::sync::Arc; use crate::gpu::DeviceType; +use std::sync::Arc; /// Error type for GPU operations. #[derive(Debug, Clone)] @@ -26,12 +26,23 @@ impl std::fmt::Display for GpuError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { GpuError::DeviceNotAvailable(msg) => write!(f, "Device not available: {}", msg), - GpuError::OutOfMemory { requested, available } => { - write!(f, "Out of GPU memory: requested {} bytes, {} available", requested, available) + GpuError::OutOfMemory { + requested, + available, + } => { + write!( + f, + "Out of GPU memory: requested {} bytes, {} available", + requested, available + ) } GpuError::KernelError(msg) => write!(f, "Kernel execution failed: {}", msg), GpuError::InvalidBufferSize { expected, actual } => { - write!(f, "Invalid buffer size: expected {}, got {}", expected, actual) + write!( + f, + "Invalid buffer size: expected {}, got {}", + expected, actual + ) } GpuError::SyncError(msg) => write!(f, "Synchronization error: {}", msg), GpuError::NotSupported(msg) => write!(f, "Feature not supported: {}", msg), @@ -45,16 +56,16 @@ impl std::error::Error for GpuError {} pub trait GpuMemory: Send + Sync { /// Get the size of the allocated memory in bytes. fn size(&self) -> usize; - + /// Copy data from host to device. fn copy_from_host(&mut self, data: &[u8]) -> Result<(), GpuError>; - + /// Copy data from device to host. fn copy_to_host(&self, data: &mut [u8]) -> Result<(), GpuError>; - + /// Get raw pointer (for internal use). fn as_ptr(&self) -> *const u8; - + /// Get mutable raw pointer (for internal use). fn as_mut_ptr(&mut self) -> *mut u8; } @@ -63,16 +74,16 @@ pub trait GpuMemory: Send + Sync { pub trait GpuDevice: Send + Sync { /// Get device type. fn device_type(&self) -> DeviceType; - + /// Get device name. fn name(&self) -> &str; - + /// Allocate memory on device. fn allocate(&self, size: usize) -> Result, GpuError>; - + /// Synchronize all pending operations. fn synchronize(&self) -> Result<(), GpuError>; - + /// Get available memory in bytes. fn available_memory(&self) -> usize; } @@ -81,13 +92,13 @@ pub trait GpuDevice: Send + Sync { pub trait GpuBackend: Send + Sync { /// Get the underlying device. fn device(&self) -> &dyn GpuDevice; - + /// Perform Number Theoretic Transform (NTT) on M31 elements. fn ntt_m31(&self, values: &mut [u32], log_n: usize) -> Result<(), GpuError>; - + /// Perform inverse NTT on M31 elements. fn intt_m31(&self, values: &mut [u32], log_n: usize) -> Result<(), GpuError>; - + /// Batch polynomial evaluation at multiple points. fn batch_evaluate( &self, @@ -95,10 +106,10 @@ pub trait GpuBackend: Send + Sync { points: &[u32], results: &mut [u32], ) -> Result<(), GpuError>; - + /// Compute Merkle tree from leaf hashes. fn merkle_tree(&self, leaves: &[[u8; 32]]) -> Result, GpuError>; - + /// Low Degree Extension (LDE) of polynomial. fn lde(&self, coeffs: &[u32], blowup_factor: usize) -> Result, GpuError>; } @@ -120,7 +131,7 @@ impl GpuMemory for CpuMemory { fn size(&self) -> usize { self.data.len() } - + fn copy_from_host(&mut self, data: &[u8]) -> Result<(), GpuError> { if data.len() != self.data.len() { return Err(GpuError::InvalidBufferSize { @@ -131,7 +142,7 @@ impl GpuMemory for CpuMemory { self.data.copy_from_slice(data); Ok(()) } - + fn copy_to_host(&self, data: &mut [u8]) -> Result<(), GpuError> { if data.len() != self.data.len() { return Err(GpuError::InvalidBufferSize { @@ -142,11 +153,11 @@ impl GpuMemory for CpuMemory { data.copy_from_slice(&self.data); Ok(()) } - + fn as_ptr(&self) -> *const u8 { self.data.as_ptr() } - + fn as_mut_ptr(&mut self) -> *mut u8 { self.data.as_mut_ptr() } @@ -175,20 +186,20 @@ impl GpuDevice for CpuDevice { fn device_type(&self) -> DeviceType { DeviceType::Cpu } - + fn name(&self) -> &str { &self.name } - + fn allocate(&self, size: usize) -> Result, GpuError> { Ok(Box::new(CpuMemory::new(size))) } - + fn synchronize(&self) -> Result<(), GpuError> { // CPU operations are synchronous Ok(()) } - + fn available_memory(&self) -> usize { // Return a large value for CPU usize::MAX @@ -218,7 +229,7 @@ impl GpuBackend for CpuBackend { fn device(&self) -> &dyn GpuDevice { self.device.as_ref() } - + fn ntt_m31(&self, values: &mut [u32], log_n: usize) -> Result<(), GpuError> { let n = 1usize << log_n; if values.len() != n { @@ -260,7 +271,7 @@ impl GpuBackend for CpuBackend { Ok(()) } - + fn intt_m31(&self, values: &mut [u32], log_n: usize) -> Result<(), GpuError> { let n = 1usize << log_n; if values.len() != n { @@ -308,7 +319,7 @@ impl GpuBackend for CpuBackend { Ok(()) } - + fn batch_evaluate( &self, coeffs: &[u32], @@ -321,26 +332,26 @@ impl GpuBackend for CpuBackend { actual: results.len(), }); } - + use zp1_primitives::field::M31; - + // Evaluate polynomial at each point for (i, &point) in points.iter().enumerate() { let x = M31::new(point); let mut result = M31::ZERO; let mut x_pow = M31::ONE; - + for &coeff in coeffs { result = result + M31::new(coeff) * x_pow; x_pow = x_pow * x; } - + results[i] = result.value(); } - + Ok(()) } - + fn merkle_tree(&self, leaves: &[[u8; 32]]) -> Result, GpuError> { let n = leaves.len(); if n == 0 || !n.is_power_of_two() { @@ -370,7 +381,7 @@ impl GpuBackend for CpuBackend { Ok(tree) } - + fn lde(&self, coeffs: &[u32], blowup_factor: usize) -> Result, GpuError> { let n = coeffs.len(); let extended_n = n * blowup_factor; @@ -406,12 +417,20 @@ const M31_P: u32 = (1u32 << 31) - 1; #[inline] fn m31_add(a: u32, b: u32) -> u32 { let sum = a.wrapping_add(b); - if sum >= M31_P { sum - M31_P } else { sum } + if sum >= M31_P { + sum - M31_P + } else { + sum + } } #[inline] fn m31_sub(a: u32, b: u32) -> u32 { - if a >= b { a - b } else { M31_P - b + a } + if a >= b { + a - b + } else { + M31_P - b + a + } } #[inline] @@ -420,7 +439,11 @@ fn m31_mul(a: u32, b: u32) -> u32 { let lo = (prod & (M31_P as u64)) as u32; let hi = (prod >> 31) as u32; let sum = lo.wrapping_add(hi); - if sum >= M31_P { sum - M31_P } else { sum } + if sum >= M31_P { + sum - M31_P + } else { + sum + } } #[inline] @@ -433,13 +456,13 @@ fn mod_inverse(a: u32, m: u32) -> u32 { let mut r = a as i64; let mut old_s = 0i64; let mut s = 1i64; - + while r != 0 { let q = old_r / r; (old_r, r) = (r, old_r - q * r); (old_s, s) = (s, old_s - q * s); } - + if old_s < 0 { (old_s + m as i64) as u32 } else { @@ -450,44 +473,46 @@ fn mod_inverse(a: u32, m: u32) -> u32 { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_cpu_memory() { let mut mem = CpuMemory::new(32); assert_eq!(mem.size(), 32); - + let data = vec![1u8; 32]; mem.copy_from_host(&data).unwrap(); - + let mut output = vec![0u8; 32]; mem.copy_to_host(&mut output).unwrap(); - + assert_eq!(data, output); } - + #[test] fn test_cpu_device() { let device = CpuDevice::new(); assert_eq!(device.device_type(), DeviceType::Cpu); assert!(device.name().contains("CPU")); - + let mem = device.allocate(64).unwrap(); assert_eq!(mem.size(), 64); - + device.synchronize().unwrap(); } - + #[test] fn test_cpu_backend_batch_evaluate() { let backend = CpuBackend::new(); - + // Polynomial: 1 + 2x + 3x^2 let coeffs = vec![1, 2, 3]; let points = vec![0, 1, 2]; let mut results = vec![0u32; 3]; - - backend.batch_evaluate(&coeffs, &points, &mut results).unwrap(); - + + backend + .batch_evaluate(&coeffs, &points, &mut results) + .unwrap(); + // At x=0: 1 + 0 + 0 = 1 assert_eq!(results[0], 1); // At x=1: 1 + 2 + 3 = 6 @@ -495,10 +520,13 @@ mod tests { // At x=2: 1 + 4 + 12 = 17 assert_eq!(results[2], 17); } - + #[test] fn test_gpu_error_display() { - let err = GpuError::OutOfMemory { requested: 1000, available: 500 }; + let err = GpuError::OutOfMemory { + requested: 1000, + available: 500, + }; let msg = format!("{}", err); assert!(msg.contains("1000")); assert!(msg.contains("500")); diff --git a/crates/prover/src/gpu/cuda.rs b/crates/prover/src/gpu/cuda.rs index 16d2864..ec98631 100644 --- a/crates/prover/src/gpu/cuda.rs +++ b/crates/prover/src/gpu/cuda.rs @@ -314,21 +314,21 @@ impl CudaBackend { // 2. Query device properties // 3. Compile PTX kernels // 4. Create streams - + // Check if CUDA is available (placeholder) #[cfg(not(feature = "cuda"))] { return Err(GpuError::DeviceNotAvailable( - "CUDA support not compiled. Enable 'cuda' feature.".to_string() + "CUDA support not compiled. Enable 'cuda' feature.".to_string(), )); } - + #[cfg(feature = "cuda")] { Ok(Self { device_index, device_name: format!("NVIDIA GPU {}", device_index), - sm_count: 80, // Would query from device + sm_count: 80, // Would query from device available_memory: 16 * 1024 * 1024 * 1024, // 16GB typical twiddles: Vec::new(), inv_twiddles: Vec::new(), @@ -340,30 +340,31 @@ impl CudaBackend { /// Precompute twiddle factors for given size. pub fn precompute_twiddles(&mut self, log_n: usize) -> Result<(), GpuError> { if log_n > self.max_log_n { - return Err(GpuError::NotSupported( - format!("log_n {} exceeds maximum {}", log_n, self.max_log_n) - )); + return Err(GpuError::NotSupported(format!( + "log_n {} exceeds maximum {}", + log_n, self.max_log_n + ))); } let n = 1usize << log_n; const M31_P: u64 = (1u64 << 31) - 1; - + // Using a simplified generator let generator = 5u64; let order = M31_P - 1; let step = order / (n as u64); - + self.twiddles = Vec::with_capacity(n); self.inv_twiddles = Vec::with_capacity(n); - + let mut w = 1u64; for _ in 0..n { self.twiddles.push(w as u32); w = (w * pow_mod(generator, step, M31_P)) % M31_P; } - + self.inv_twiddles = self.twiddles.iter().rev().cloned().collect(); - + Ok(()) } @@ -383,7 +384,7 @@ fn pow_mod(base: u64, exp: u64, modulus: u64) -> u64 { let mut result = 1u64; let mut base = base % modulus; let mut exp = exp; - + while exp > 0 { if exp & 1 == 1 { result = (result * base) % modulus; @@ -391,7 +392,7 @@ fn pow_mod(base: u64, exp: u64, modulus: u64) -> u64 { exp >>= 1; base = (base * base) % modulus; } - + result } @@ -415,7 +416,7 @@ impl GpuMemory for CudaMemory { fn size(&self) -> usize { self.data.len() } - + fn copy_from_host(&mut self, data: &[u8]) -> Result<(), GpuError> { // In real implementation: cudaMemcpy(device, host, size, cudaMemcpyHostToDevice) if data.len() > self.data.len() { @@ -427,7 +428,7 @@ impl GpuMemory for CudaMemory { self.data[..data.len()].copy_from_slice(data); Ok(()) } - + fn copy_to_host(&self, data: &mut [u8]) -> Result<(), GpuError> { // In real implementation: cudaMemcpy(host, device, size, cudaMemcpyDeviceToHost) if data.len() > self.data.len() { @@ -439,11 +440,11 @@ impl GpuMemory for CudaMemory { data.copy_from_slice(&self.data[..data.len()]); Ok(()) } - + fn as_ptr(&self) -> *const u8 { self.data.as_ptr() } - + fn as_mut_ptr(&mut self) -> *mut u8 { self.data.as_mut_ptr() } @@ -474,20 +475,20 @@ impl GpuDevice for CudaDevice { fn device_type(&self) -> DeviceType { DeviceType::Cuda } - + fn name(&self) -> &str { &self.name } - + fn allocate(&self, size: usize) -> Result, GpuError> { Ok(Box::new(CudaMemory::new(size)?)) } - + fn synchronize(&self) -> Result<(), GpuError> { // In real implementation: cudaDeviceSynchronize Ok(()) } - + fn available_memory(&self) -> usize { self.memory_bytes } @@ -499,7 +500,7 @@ impl GpuBackend for CudaBackend { static DEVICE: std::sync::OnceLock = std::sync::OnceLock::new(); DEVICE.get_or_init(|| CudaDevice::new(0).unwrap()) } - + fn ntt_m31(&self, values: &mut [u32], log_n: usize) -> Result<(), GpuError> { let n = 1usize << log_n; if values.len() != n { @@ -508,9 +509,9 @@ impl GpuBackend for CudaBackend { actual: values.len(), }); } - + // CPU fallback (real impl would launch CUDA kernels) - + // Bit-reversal permutation for i in 0..n { let j = bit_reverse(i, log_n); @@ -518,35 +519,35 @@ impl GpuBackend for CudaBackend { values.swap(i, j); } } - + // Cooley-Tukey butterfly stages for stage in 0..log_n { let half_step = 1usize << stage; let step = half_step << 1; - + for group in (0..n).step_by(step) { for pos in 0..half_step { let i = group + pos; let j = i + half_step; - + let w = if self.twiddles.is_empty() { 1u32 } else { self.twiddles.get(pos * (n / step)).copied().unwrap_or(1) }; - + let u = values[i]; let v = m31_mul(values[j], w); - + values[i] = m31_add(u, v); values[j] = m31_sub(u, v); } } } - + Ok(()) } - + fn intt_m31(&self, values: &mut [u32], log_n: usize) -> Result<(), GpuError> { let n = 1usize << log_n; if values.len() != n { @@ -555,34 +556,37 @@ impl GpuBackend for CudaBackend { actual: values.len(), }); } - + const M31_P: u32 = (1u32 << 31) - 1; - + // Gentleman-Sande butterfly stages for stage in (0..log_n).rev() { let half_step = 1usize << stage; let step = half_step << 1; - + for group in (0..n).step_by(step) { for pos in 0..half_step { let i = group + pos; let j = i + half_step; - + let w = if self.inv_twiddles.is_empty() { 1u32 } else { - self.inv_twiddles.get(pos * (n / step)).copied().unwrap_or(1) + self.inv_twiddles + .get(pos * (n / step)) + .copied() + .unwrap_or(1) }; - + let u = values[i]; let v = values[j]; - + values[i] = m31_add(u, v); values[j] = m31_mul(m31_sub(u, v), w); } } } - + // Bit-reversal permutation for i in 0..n { let j = bit_reverse(i, log_n); @@ -590,16 +594,16 @@ impl GpuBackend for CudaBackend { values.swap(i, j); } } - + // Scale by 1/n let inv_n = mod_inverse(n as u32, M31_P); for v in values.iter_mut() { *v = m31_mul(*v, inv_n); } - + Ok(()) } - + fn batch_evaluate( &self, coeffs: &[u32], @@ -612,7 +616,7 @@ impl GpuBackend for CudaBackend { actual: results.len(), }); } - + // CPU fallback: Horner's method // Real impl would launch poly_eval_batch kernel for (i, &point) in points.iter().enumerate() { @@ -622,10 +626,10 @@ impl GpuBackend for CudaBackend { } results[i] = result; } - + Ok(()) } - + fn merkle_tree(&self, leaves: &[[u8; 32]]) -> Result, GpuError> { let n = leaves.len(); if n == 0 || !n.is_power_of_two() { @@ -634,14 +638,14 @@ impl GpuBackend for CudaBackend { actual: n, }); } - + // Build tree bottom-up // Real impl would launch merkle_layer kernel for each level let tree_size = 2 * n - 1; let mut tree = vec![[0u8; 32]; tree_size]; - + tree[n - 1..].copy_from_slice(leaves); - + for i in (0..n - 1).rev() { // Compute hash of children let mut hash = [0u8; 32]; @@ -654,28 +658,28 @@ impl GpuBackend for CudaBackend { } tree[i] = hash; } - + Ok(tree) } - + fn lde(&self, coeffs: &[u32], blowup_factor: usize) -> Result, GpuError> { let n = coeffs.len(); let extended_n = n * blowup_factor; - + if !n.is_power_of_two() || !blowup_factor.is_power_of_two() { return Err(GpuError::InvalidBufferSize { expected: n.next_power_of_two(), actual: n, }); } - + // CPU fallback // Real impl would launch lde_evaluate kernel let mut results = vec![0u32; extended_n]; - + let generator = 3u32; let mut point = generator; - + for i in 0..extended_n { let mut result = 0u32; for &coeff in coeffs.iter().rev() { @@ -684,7 +688,7 @@ impl GpuBackend for CudaBackend { results[i] = result; point = m31_mul(point, generator); } - + Ok(results) } } @@ -696,12 +700,20 @@ const M31_P: u32 = (1u32 << 31) - 1; #[inline] fn m31_add(a: u32, b: u32) -> u32 { let sum = a.wrapping_add(b); - if sum >= M31_P { sum - M31_P } else { sum } + if sum >= M31_P { + sum - M31_P + } else { + sum + } } #[inline] fn m31_sub(a: u32, b: u32) -> u32 { - if a >= b { a - b } else { M31_P - b + a } + if a >= b { + a - b + } else { + M31_P - b + a + } } #[inline] @@ -710,7 +722,11 @@ fn m31_mul(a: u32, b: u32) -> u32 { let lo = (prod & (M31_P as u64)) as u32; let hi = (prod >> 31) as u32; let sum = lo.wrapping_add(hi); - if sum >= M31_P { sum - M31_P } else { sum } + if sum >= M31_P { + sum - M31_P + } else { + sum + } } #[inline] @@ -723,13 +739,13 @@ fn mod_inverse(a: u32, m: u32) -> u32 { let mut r = a as i64; let mut old_s = 0i64; let mut s = 1i64; - + while r != 0 { let q = old_r / r; (old_r, r) = (r, old_r - q * r); (old_s, s) = (s, old_s - q * s); } - + if old_s < 0 { (old_s + m as i64) as u32 } else { @@ -770,46 +786,46 @@ pub struct CudaDeviceInfo { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_m31_arithmetic() { assert_eq!(m31_add(100, 200), 300); assert_eq!(m31_add(M31_P - 1, 2), 1); - + assert_eq!(m31_sub(200, 100), 100); assert_eq!(m31_sub(100, 200), M31_P - 100); - + assert_eq!(m31_mul(2, 3), 6); assert_eq!(m31_mul(M31_P - 1, 2), M31_P - 2); } - + #[test] fn test_bit_reverse() { assert_eq!(bit_reverse(0b000, 3), 0b000); assert_eq!(bit_reverse(0b001, 3), 0b100); assert_eq!(bit_reverse(0b010, 3), 0b010); } - + #[test] fn test_mod_inverse() { let inv = mod_inverse(3, M31_P); assert_eq!(m31_mul(3, inv), 1); } - + #[test] fn test_query_cuda_devices() { let devices = query_cuda_devices(); assert!(!devices.is_empty()); } - + #[test] fn test_cuda_memory() { let mut mem = CudaMemory::new(1024).unwrap(); assert_eq!(mem.size(), 1024); - + let data = vec![1u8; 512]; mem.copy_from_host(&data).unwrap(); - + let mut output = vec![0u8; 512]; mem.copy_to_host(&mut output).unwrap(); assert_eq!(data, output); diff --git a/crates/prover/src/gpu/metal.rs b/crates/prover/src/gpu/metal.rs index f5047d1..6ccf9c6 100644 --- a/crates/prover/src/gpu/metal.rs +++ b/crates/prover/src/gpu/metal.rs @@ -315,7 +315,7 @@ kernel void merkle_layer( #[cfg(all(target_os = "macos", feature = "gpu-metal"))] mod native { use super::*; - use metal::{Device, CommandQueue, Library, ComputePipelineState, Buffer, MTLResourceOptions}; + use metal::{Buffer, CommandQueue, ComputePipelineState, Device, Library, MTLResourceOptions}; use std::sync::Arc; /// Native Metal backend using metal-rs crate. @@ -344,22 +344,25 @@ mod native { pub fn new() -> Result { let device = Device::system_default() .ok_or_else(|| GpuError::DeviceNotAvailable("No Metal device found".to_string()))?; - + let command_queue = device.new_command_queue(); - + // Compile shaders - let library = device.new_library_with_source(METAL_M31_SHADERS, &metal::CompileOptions::new()) + let library = device + .new_library_with_source(METAL_M31_SHADERS, &metal::CompileOptions::new()) .map_err(|e| GpuError::KernelError(format!("Shader compilation failed: {}", e)))?; - + // Create pipeline states for each kernel let ntt_butterfly_pipeline = Self::create_pipeline(&device, &library, "ntt_butterfly")?; - let intt_butterfly_pipeline = Self::create_pipeline(&device, &library, "intt_butterfly")?; + let intt_butterfly_pipeline = + Self::create_pipeline(&device, &library, "intt_butterfly")?; let intt_scale_pipeline = Self::create_pipeline(&device, &library, "intt_scale")?; - let bit_reverse_pipeline = Self::create_pipeline(&device, &library, "bit_reverse_permute")?; + let bit_reverse_pipeline = + Self::create_pipeline(&device, &library, "bit_reverse_permute")?; let poly_eval_pipeline = Self::create_pipeline(&device, &library, "poly_eval_batch")?; let lde_pipeline = Self::create_pipeline(&device, &library, "lde_evaluate")?; let merkle_layer_pipeline = Self::create_pipeline(&device, &library, "merkle_layer")?; - + Ok(Self { device, command_queue, @@ -378,39 +381,48 @@ mod native { max_log_n: 24, }) } - - fn create_pipeline(device: &Device, library: &Library, name: &str) -> Result { - let function = library.get_function(name, None) - .map_err(|e| GpuError::KernelError(format!("Function '{}' not found: {}", name, e)))?; - - device.new_compute_pipeline_state_with_function(&function) - .map_err(|e| GpuError::KernelError(format!("Pipeline creation failed for '{}': {}", name, e))) + + fn create_pipeline( + device: &Device, + library: &Library, + name: &str, + ) -> Result { + let function = library.get_function(name, None).map_err(|e| { + GpuError::KernelError(format!("Function '{}' not found: {}", name, e)) + })?; + + device + .new_compute_pipeline_state_with_function(&function) + .map_err(|e| { + GpuError::KernelError(format!("Pipeline creation failed for '{}': {}", name, e)) + }) } - + /// Precompute twiddle factors for NTT. pub fn precompute_twiddles(&mut self, log_n: usize) -> Result<(), GpuError> { if log_n > self.max_log_n { - return Err(GpuError::NotSupported( - format!("log_n {} exceeds maximum {}", log_n, self.max_log_n) - )); + return Err(GpuError::NotSupported(format!( + "log_n {} exceeds maximum {}", + log_n, self.max_log_n + ))); } - + let n = 1usize << log_n; const M31_P: u64 = (1u64 << 31) - 1; - + let generator = 5u64; let order = M31_P - 1; let step = order / (n as u64); - + self.twiddles = Vec::with_capacity(n); let mut w = 1u64; for _ in 0..n { self.twiddles.push(w as u32); w = (w * pow_mod(generator, step, M31_P)) % M31_P; } - + self.inv_twiddles = self.twiddles.iter().rev().cloned().collect(); - + // Create GPU buffers for twiddles let twiddle_bytes = bytemuck::cast_slice::(&self.twiddles); self.twiddle_buffer = Some(self.device.new_buffer_with_data( @@ -418,27 +430,33 @@ mod native { twiddle_bytes.len() as u64, MTLResourceOptions::StorageModeShared, )); - + let inv_twiddle_bytes = bytemuck::cast_slice::(&self.inv_twiddles); self.inv_twiddle_buffer = Some(self.device.new_buffer_with_data( inv_twiddle_bytes.as_ptr() as *const _, inv_twiddle_bytes.len() as u64, MTLResourceOptions::StorageModeShared, )); - + Ok(()) } - - fn execute_ntt_gpu(&self, values: &mut [u32], log_n: usize, inverse: bool) -> Result<(), GpuError> { + + fn execute_ntt_gpu( + &self, + values: &mut [u32], + log_n: usize, + inverse: bool, + ) -> Result<(), GpuError> { let n = 1usize << log_n; - + // Ensure twiddles are precomputed let twiddle_buffer = if inverse { self.inv_twiddle_buffer.as_ref() } else { self.twiddle_buffer.as_ref() - }.ok_or_else(|| GpuError::NotSupported("Twiddles not precomputed".to_string()))?; - + } + .ok_or_else(|| GpuError::NotSupported("Twiddles not precomputed".to_string()))?; + // Create data buffer let data_bytes = bytemuck::cast_slice::(values); let data_buffer = self.device.new_buffer_with_data( @@ -446,93 +464,118 @@ mod native { data_bytes.len() as u64, MTLResourceOptions::StorageModeShared, ); - + let command_buffer = self.command_queue.new_command_buffer(); - + // Bit-reverse permutation { let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&self.bit_reverse_pipeline); encoder.set_buffer(0, Some(&data_buffer), 0); - encoder.set_bytes(1, std::mem::size_of::() as u64, &(n as u32) as *const u32 as *const _); - encoder.set_bytes(2, std::mem::size_of::() as u64, &(log_n as u32) as *const u32 as *const _); - + encoder.set_bytes( + 1, + std::mem::size_of::() as u64, + &(n as u32) as *const u32 as *const _, + ); + encoder.set_bytes( + 2, + std::mem::size_of::() as u64, + &(log_n as u32) as *const u32 as *const _, + ); + let thread_group_size = metal::MTLSize::new(256, 1, 1); let grid_size = metal::MTLSize::new(n as u64, 1, 1); encoder.dispatch_threads(grid_size, thread_group_size); encoder.end_encoding(); } - + // Butterfly stages let pipeline = if inverse { &self.intt_butterfly_pipeline } else { &self.ntt_butterfly_pipeline }; - + let stages: Box> = if inverse { Box::new((0..log_n).rev()) } else { Box::new(0..log_n) }; - + for stage in stages { let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(pipeline); encoder.set_buffer(0, Some(&data_buffer), 0); - encoder.set_bytes(1, std::mem::size_of::() as u64, &(n as u32) as *const u32 as *const _); - encoder.set_bytes(2, std::mem::size_of::() as u64, &(stage as u32) as *const u32 as *const _); + encoder.set_bytes( + 1, + std::mem::size_of::() as u64, + &(n as u32) as *const u32 as *const _, + ); + encoder.set_bytes( + 2, + std::mem::size_of::() as u64, + &(stage as u32) as *const u32 as *const _, + ); encoder.set_buffer(3, Some(twiddle_buffer), 0); - + let threads_per_stage = n / 2; - let thread_group_size = metal::MTLSize::new(256.min(threads_per_stage as u64), 1, 1); + let thread_group_size = + metal::MTLSize::new(256.min(threads_per_stage as u64), 1, 1); let grid_size = metal::MTLSize::new(threads_per_stage as u64, 1, 1); encoder.dispatch_threads(grid_size, thread_group_size); encoder.end_encoding(); } - + // Scale for inverse NTT if inverse { let inv_n = mod_inverse(n as u32, M31_P as u32); let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&self.intt_scale_pipeline); encoder.set_buffer(0, Some(&data_buffer), 0); - encoder.set_bytes(1, std::mem::size_of::() as u64, &(n as u32) as *const u32 as *const _); - encoder.set_bytes(2, std::mem::size_of::() as u64, &inv_n as *const u32 as *const _); - + encoder.set_bytes( + 1, + std::mem::size_of::() as u64, + &(n as u32) as *const u32 as *const _, + ); + encoder.set_bytes( + 2, + std::mem::size_of::() as u64, + &inv_n as *const u32 as *const _, + ); + let thread_group_size = metal::MTLSize::new(256, 1, 1); let grid_size = metal::MTLSize::new(n as u64, 1, 1); encoder.dispatch_threads(grid_size, thread_group_size); encoder.end_encoding(); } - + command_buffer.commit(); command_buffer.wait_until_completed(); - + // Copy results back let result_ptr = data_buffer.contents() as *const u32; unsafe { std::ptr::copy_nonoverlapping(result_ptr, values.as_mut_ptr(), n); } - + Ok(()) } } - + impl GpuBackend for MetalBackend { fn device(&self) -> &dyn GpuDevice { static DEVICE: std::sync::OnceLock = std::sync::OnceLock::new(); DEVICE.get_or_init(|| MetalDeviceWrapper::new().unwrap()) } - + fn ntt_m31(&self, values: &mut [u32], log_n: usize) -> Result<(), GpuError> { self.execute_ntt_gpu(values, log_n, false) } - + fn intt_m31(&self, values: &mut [u32], log_n: usize) -> Result<(), GpuError> { self.execute_ntt_gpu(values, log_n, true) } - + fn batch_evaluate( &self, coeffs: &[u32], @@ -541,7 +584,7 @@ mod native { ) -> Result<(), GpuError> { let num_points = points.len(); let num_coeffs = coeffs.len(); - + // Create buffers let coeffs_buffer = self.device.new_buffer_with_data( coeffs.as_ptr() as *const _, @@ -557,42 +600,42 @@ mod native { (num_points * 4) as u64, MTLResourceOptions::StorageModeShared, ); - + let command_buffer = self.command_queue.new_command_buffer(); let encoder = command_buffer.new_compute_command_encoder(); - + encoder.set_compute_pipeline_state(&self.poly_eval_pipeline); encoder.set_buffer(0, Some(&coeffs_buffer), 0); encoder.set_bytes(1, 4, &(num_coeffs as u32) as *const u32 as *const _); encoder.set_buffer(2, Some(&points_buffer), 0); encoder.set_buffer(3, Some(&results_buffer), 0); - + let thread_group_size = metal::MTLSize::new(256.min(num_points as u64), 1, 1); let grid_size = metal::MTLSize::new(num_points as u64, 1, 1); encoder.dispatch_threads(grid_size, thread_group_size); encoder.end_encoding(); - + command_buffer.commit(); command_buffer.wait_until_completed(); - + // Copy results let result_ptr = results_buffer.contents() as *const u32; unsafe { std::ptr::copy_nonoverlapping(result_ptr, results.as_mut_ptr(), num_points); } - + Ok(()) } - + fn merkle_tree(&self, leaves: &[[u8; 32]]) -> Result, GpuError> { // Use CPU fallback for now - Blake3 GPU implementation is complex cpu_merkle_tree(leaves) } - + fn lde(&self, coeffs: &[u32], blowup_factor: usize) -> Result, GpuError> { let n = coeffs.len(); let extended_n = n * blowup_factor; - + // Generate extended domain let mut domain = Vec::with_capacity(extended_n); let generator = 3u32; @@ -601,18 +644,18 @@ mod native { domain.push(point); point = m31_mul(point, generator); } - + let mut results = vec![0u32; extended_n]; self.batch_evaluate(coeffs, &domain, &mut results)?; Ok(results) } } - + pub struct MetalDeviceWrapper { name: String, memory_bytes: usize, } - + impl MetalDeviceWrapper { pub fn new() -> Result { let device = Device::system_default() @@ -623,29 +666,41 @@ mod native { }) } } - + impl GpuDevice for MetalDeviceWrapper { - fn device_type(&self) -> DeviceType { DeviceType::Metal } - fn name(&self) -> &str { &self.name } + fn device_type(&self) -> DeviceType { + DeviceType::Metal + } + fn name(&self) -> &str { + &self.name + } fn allocate(&self, size: usize) -> Result, GpuError> { Ok(Box::new(MetalMemory::new(size))) } - fn synchronize(&self) -> Result<(), GpuError> { Ok(()) } - fn available_memory(&self) -> usize { self.memory_bytes } + fn synchronize(&self) -> Result<(), GpuError> { + Ok(()) + } + fn available_memory(&self) -> usize { + self.memory_bytes + } } - + pub struct MetalMemory { data: Vec, } - + impl MetalMemory { pub fn new(size: usize) -> Self { - Self { data: vec![0u8; size] } + Self { + data: vec![0u8; size], + } } } - + impl GpuMemory for MetalMemory { - fn size(&self) -> usize { self.data.len() } + fn size(&self) -> usize { + self.data.len() + } fn copy_from_host(&mut self, data: &[u8]) -> Result<(), GpuError> { self.data[..data.len()].copy_from_slice(data); Ok(()) @@ -654,8 +709,12 @@ mod native { data.copy_from_slice(&self.data[..data.len()]); Ok(()) } - fn as_ptr(&self) -> *const u8 { self.data.as_ptr() } - fn as_mut_ptr(&mut self) -> *mut u8 { self.data.as_mut_ptr() } + fn as_ptr(&self) -> *const u8 { + self.data.as_ptr() + } + fn as_mut_ptr(&mut self) -> *mut u8 { + self.data.as_mut_ptr() + } } } @@ -684,28 +743,29 @@ mod fallback { max_log_n: 24, }) } - + pub fn precompute_twiddles(&mut self, log_n: usize) -> Result<(), GpuError> { if log_n > self.max_log_n { - return Err(GpuError::NotSupported( - format!("log_n {} exceeds maximum {}", log_n, self.max_log_n) - )); + return Err(GpuError::NotSupported(format!( + "log_n {} exceeds maximum {}", + log_n, self.max_log_n + ))); } - + let n = 1usize << log_n; const M31_P: u64 = (1u64 << 31) - 1; - + let generator = 5u64; let order = M31_P - 1; let step = order / (n as u64); - + self.twiddles = Vec::with_capacity(n); let mut w = 1u64; for _ in 0..n { self.twiddles.push(w as u32); w = (w * pow_mod(generator, step, M31_P)) % M31_P; } - + self.inv_twiddles = self.twiddles.iter().rev().cloned().collect(); Ok(()) } @@ -716,15 +776,15 @@ mod fallback { static DEVICE: std::sync::OnceLock = std::sync::OnceLock::new(); DEVICE.get_or_init(|| MetalDevice::new().unwrap()) } - + fn ntt_m31(&self, values: &mut [u32], log_n: usize) -> Result<(), GpuError> { cpu_ntt(values, log_n, &self.twiddles, false) } - + fn intt_m31(&self, values: &mut [u32], log_n: usize) -> Result<(), GpuError> { cpu_ntt(values, log_n, &self.inv_twiddles, true) } - + fn batch_evaluate( &self, coeffs: &[u32], @@ -733,11 +793,11 @@ mod fallback { ) -> Result<(), GpuError> { cpu_batch_evaluate(coeffs, points, results) } - + fn merkle_tree(&self, leaves: &[[u8; 32]]) -> Result, GpuError> { cpu_merkle_tree(leaves) } - + fn lde(&self, coeffs: &[u32], blowup_factor: usize) -> Result, GpuError> { cpu_lde(coeffs, blowup_factor) } @@ -770,13 +830,21 @@ mod fallback { } impl GpuDevice for MetalDevice { - fn device_type(&self) -> DeviceType { DeviceType::Metal } - fn name(&self) -> &str { &self.name } + fn device_type(&self) -> DeviceType { + DeviceType::Metal + } + fn name(&self) -> &str { + &self.name + } fn allocate(&self, size: usize) -> Result, GpuError> { Ok(Box::new(MetalMemory::new(size))) } - fn synchronize(&self) -> Result<(), GpuError> { Ok(()) } - fn available_memory(&self) -> usize { self.memory_bytes } + fn synchronize(&self) -> Result<(), GpuError> { + Ok(()) + } + fn available_memory(&self) -> usize { + self.memory_bytes + } } pub struct MetalMemory { @@ -785,12 +853,16 @@ mod fallback { impl MetalMemory { pub fn new(size: usize) -> Self { - Self { data: vec![0u8; size] } + Self { + data: vec![0u8; size], + } } } impl GpuMemory for MetalMemory { - fn size(&self) -> usize { self.data.len() } + fn size(&self) -> usize { + self.data.len() + } fn copy_from_host(&mut self, data: &[u8]) -> Result<(), GpuError> { self.data[..data.len()].copy_from_slice(data); Ok(()) @@ -799,8 +871,12 @@ mod fallback { data.copy_from_slice(&self.data[..data.len()]); Ok(()) } - fn as_ptr(&self) -> *const u8 { self.data.as_ptr() } - fn as_mut_ptr(&mut self) -> *mut u8 { self.data.as_mut_ptr() } + fn as_ptr(&self) -> *const u8 { + self.data.as_ptr() + } + fn as_mut_ptr(&mut self) -> *mut u8 { + self.data.as_mut_ptr() + } } } @@ -823,12 +899,20 @@ const M31_P: u32 = (1u32 << 31) - 1; #[inline] fn m31_add(a: u32, b: u32) -> u32 { let sum = a.wrapping_add(b); - if sum >= M31_P { sum - M31_P } else { sum } + if sum >= M31_P { + sum - M31_P + } else { + sum + } } #[inline] fn m31_sub(a: u32, b: u32) -> u32 { - if a >= b { a - b } else { M31_P - b + a } + if a >= b { + a - b + } else { + M31_P - b + a + } } #[inline] @@ -837,7 +921,11 @@ fn m31_mul(a: u32, b: u32) -> u32 { let lo = (prod & (M31_P as u64)) as u32; let hi = (prod >> 31) as u32; let sum = lo.wrapping_add(hi); - if sum >= M31_P { sum - M31_P } else { sum } + if sum >= M31_P { + sum - M31_P + } else { + sum + } } #[inline] @@ -849,7 +937,7 @@ fn pow_mod(base: u64, exp: u64, modulus: u64) -> u64 { let mut result = 1u64; let mut base = base % modulus; let mut exp = exp; - + while exp > 0 { if exp & 1 == 1 { result = (result * base) % modulus; @@ -865,13 +953,13 @@ fn mod_inverse(a: u32, m: u32) -> u32 { let mut r = a as i64; let mut old_s = 0i64; let mut s = 1i64; - + while r != 0 { let q = old_r / r; (old_r, r) = (r, old_r - q * r); (old_s, s) = (s, old_s - q * s); } - + if old_s < 0 { (old_s + m as i64) as u32 } else { @@ -879,7 +967,12 @@ fn mod_inverse(a: u32, m: u32) -> u32 { } } -fn cpu_ntt(values: &mut [u32], log_n: usize, twiddles: &[u32], inverse: bool) -> Result<(), GpuError> { +fn cpu_ntt( + values: &mut [u32], + log_n: usize, + twiddles: &[u32], + inverse: bool, +) -> Result<(), GpuError> { let n = 1usize << log_n; if values.len() != n { return Err(GpuError::InvalidBufferSize { @@ -887,7 +980,7 @@ fn cpu_ntt(values: &mut [u32], log_n: usize, twiddles: &[u32], inverse: bool) -> actual: values.len(), }); } - + // Bit-reversal permutation for i in 0..n { let j = bit_reverse(i, log_n); @@ -895,27 +988,27 @@ fn cpu_ntt(values: &mut [u32], log_n: usize, twiddles: &[u32], inverse: bool) -> values.swap(i, j); } } - + // Butterfly stages for stage in 0..log_n { let half_step = 1usize << stage; let step = half_step << 1; - + for group in (0..n).step_by(step) { for pos in 0..half_step { let i = group + pos; let j = i + half_step; - + let w = twiddles.get(pos * (n / step)).copied().unwrap_or(1); let u = values[i]; let v = m31_mul(values[j], w); - + values[i] = m31_add(u, v); values[j] = m31_sub(u, v); } } } - + // Scale for inverse if inverse { let inv_n = mod_inverse(n as u32, M31_P); @@ -923,7 +1016,7 @@ fn cpu_ntt(values: &mut [u32], log_n: usize, twiddles: &[u32], inverse: bool) -> *v = m31_mul(*v, inv_n); } } - + Ok(()) } @@ -946,11 +1039,11 @@ fn cpu_merkle_tree(leaves: &[[u8; 32]]) -> Result, GpuError> { actual: n, }); } - + let tree_size = 2 * n - 1; let mut tree = vec![[0u8; 32]; tree_size]; tree[n - 1..].copy_from_slice(leaves); - + for i in (0..n - 1).rev() { let mut hash = [0u8; 32]; let left_idx = 2 * i + 1; @@ -962,7 +1055,7 @@ fn cpu_merkle_tree(leaves: &[[u8; 32]]) -> Result, GpuError> { } tree[i] = hash; } - + Ok(tree) } @@ -970,10 +1063,10 @@ fn cpu_lde(coeffs: &[u32], blowup_factor: usize) -> Result, GpuError> { let n = coeffs.len(); let extended_n = n * blowup_factor; let mut results = vec![0u32; extended_n]; - + let generator = 3u32; let mut point = generator; - + for i in 0..extended_n { let mut result = 0u32; for &coeff in coeffs.iter().rev() { @@ -982,7 +1075,7 @@ fn cpu_lde(coeffs: &[u32], blowup_factor: usize) -> Result, GpuError> { results[i] = result; point = m31_mul(point, generator); } - + Ok(results) } @@ -996,49 +1089,51 @@ impl Default for native::MetalBackend { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_m31_arithmetic() { assert_eq!(m31_add(100, 200), 300); assert_eq!(m31_add(M31_P - 1, 2), 1); - + assert_eq!(m31_sub(200, 100), 100); - + assert_eq!(m31_mul(2, 3), 6); } - + #[test] fn test_bit_reverse() { assert_eq!(bit_reverse(0b000, 3), 0b000); assert_eq!(bit_reverse(0b001, 3), 0b100); } - + #[test] fn test_mod_inverse() { let inv = mod_inverse(3, M31_P); assert_eq!(m31_mul(3, inv), 1); } - + #[cfg(target_os = "macos")] #[test] fn test_metal_backend_creation() { let backend = MetalBackend::new(); assert!(backend.is_ok()); } - + #[cfg(target_os = "macos")] #[test] fn test_batch_evaluate() { let backend = MetalBackend::new().unwrap(); - + let coeffs = vec![1, 2, 3]; let points = vec![0, 1, 2]; let mut results = vec![0u32; 3]; - - backend.batch_evaluate(&coeffs, &points, &mut results).unwrap(); - - assert_eq!(results[0], 1); // 1 + 0 + 0 - assert_eq!(results[1], 6); // 1 + 2 + 3 + + backend + .batch_evaluate(&coeffs, &points, &mut results) + .unwrap(); + + assert_eq!(results[0], 1); // 1 + 0 + 0 + assert_eq!(results[1], 6); // 1 + 2 + 3 assert_eq!(results[2], 17); // 1 + 4 + 12 } } diff --git a/crates/prover/src/gpu/mod.rs b/crates/prover/src/gpu/mod.rs index b4cce00..a4e269b 100644 --- a/crates/prover/src/gpu/mod.rs +++ b/crates/prover/src/gpu/mod.rs @@ -33,13 +33,15 @@ pub mod metal; pub mod cuda; -pub use backend::{GpuBackend, GpuDevice, GpuError, GpuMemory, CpuBackend}; -pub use operations::{GpuNtt, GpuPolynomial, GpuMerkle}; +pub use backend::{CpuBackend, GpuBackend, GpuDevice, GpuError, GpuMemory}; +pub use operations::{GpuMerkle, GpuNtt, GpuPolynomial}; #[cfg(target_os = "macos")] pub use metal::{MetalBackend, MetalDevice, MetalMemory, METAL_M31_SHADERS}; -pub use cuda::{CudaBackend, CudaDevice, CudaMemory, CudaDeviceInfo, CUDA_M31_KERNELS, query_cuda_devices}; +pub use cuda::{ + query_cuda_devices, CudaBackend, CudaDevice, CudaDeviceInfo, CudaMemory, CUDA_M31_KERNELS, +}; /// GPU device type enumeration. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -65,7 +67,7 @@ impl std::fmt::Display for DeviceType { /// Detect available GPU devices on the system. pub fn detect_devices() -> Vec { let mut devices = Vec::new(); - + // Check for Metal (macOS) #[cfg(target_os = "macos")] { @@ -77,7 +79,7 @@ pub fn detect_devices() -> Vec { available: true, }); } - + // CPU fallback is always available devices.push(DeviceInfo { device_type: DeviceType::Cpu, @@ -86,7 +88,7 @@ pub fn detect_devices() -> Vec { memory_bytes: 0, available: true, }); - + devices } @@ -112,9 +114,9 @@ fn num_cpus() -> usize { } /// Get the best available GPU backend. -/// +/// /// Priority: Metal (macOS) > CUDA > CPU -/// +/// /// # Returns /// A boxed GpuBackend implementation. pub fn get_backend() -> Result, GpuError> { @@ -126,7 +128,7 @@ pub fn get_backend() -> Result, GpuError> { Err(_) => {} // Fall through to next option } } - + // Try CUDA (would fail without cuda feature) #[cfg(feature = "cuda")] { @@ -135,7 +137,7 @@ pub fn get_backend() -> Result, GpuError> { Err(_) => {} // Fall through to CPU } } - + // CPU fallback Ok(Box::new(CpuBackend::default())) } @@ -145,12 +147,12 @@ pub fn get_backend_for_device(device_type: DeviceType) -> Result Ok(Box::new(MetalBackend::new()?)), - + #[cfg(not(target_os = "macos"))] DeviceType::Metal => Err(GpuError::DeviceNotAvailable( - "Metal is only available on macOS".to_string() + "Metal is only available on macOS".to_string(), )), - + DeviceType::Cuda => { #[cfg(feature = "cuda")] { @@ -159,11 +161,11 @@ pub fn get_backend_for_device(device_type: DeviceType) -> Result Ok(Box::new(CpuBackend::default())), } } @@ -171,33 +173,33 @@ pub fn get_backend_for_device(device_type: DeviceType) -> Result Result<(), GpuError>; - + /// Perform inverse NTT in-place. fn intt_inplace(&self, values: &mut [u32], log_n: usize) -> Result<(), GpuError>; - + /// Perform forward NTT, returning new array. fn ntt(&self, values: &[u32], log_n: usize) -> Result, GpuError> { let mut result = values.to_vec(); self.ntt_inplace(&mut result, log_n)?; Ok(result) } - + /// Perform inverse NTT, returning new array. fn intt(&self, values: &[u32], log_n: usize) -> Result, GpuError> { let mut result = values.to_vec(); self.intt_inplace(&mut result, log_n)?; Ok(result) } - + /// Batch NTT on multiple polynomials. fn batch_ntt(&self, polys: &mut [Vec], log_n: usize) -> Result<(), GpuError> { for poly in polys.iter_mut() { @@ -37,23 +37,19 @@ pub trait GpuNtt { pub trait GpuPolynomial { /// Multiply two polynomials using NTT. fn poly_mul(&self, a: &[u32], b: &[u32]) -> Result, GpuError>; - + /// Add two polynomials. fn poly_add(&self, a: &[u32], b: &[u32]) -> Result, GpuError>; - + /// Evaluate polynomial at a single point. fn poly_eval(&self, coeffs: &[u32], point: u32) -> Result; - + /// Evaluate polynomial at multiple points. - fn poly_eval_batch( - &self, - coeffs: &[u32], - points: &[u32], - ) -> Result, GpuError>; - + fn poly_eval_batch(&self, coeffs: &[u32], points: &[u32]) -> Result, GpuError>; + /// Low Degree Extension. fn lde(&self, coeffs: &[u32], blowup_factor: usize) -> Result, GpuError>; - + /// Interpolate polynomial from evaluations. fn interpolate(&self, evaluations: &[u32]) -> Result, GpuError>; } @@ -62,13 +58,13 @@ pub trait GpuPolynomial { pub trait GpuMerkle { /// Compute Merkle root from leaves. fn merkle_root(&self, leaves: &[[u8; 32]]) -> Result<[u8; 32], GpuError>; - + /// Build full Merkle tree from leaves. fn merkle_tree(&self, leaves: &[[u8; 32]]) -> Result, GpuError>; - + /// Compute Merkle path for a leaf. fn merkle_path(&self, tree: &[[u8; 32]], leaf_index: usize) -> Result, GpuError>; - + /// Verify Merkle path. fn verify_merkle_path( &self, @@ -85,7 +81,7 @@ impl GpuNtt for T { fn ntt_inplace(&self, values: &mut [u32], log_n: usize) -> Result<(), GpuError> { self.ntt_m31(values, log_n) } - + fn intt_inplace(&self, values: &mut [u32], log_n: usize) -> Result<(), GpuError> { self.intt_m31(values, log_n) } @@ -94,67 +90,71 @@ impl GpuNtt for T { impl GpuPolynomial for T { fn poly_mul(&self, a: &[u32], b: &[u32]) -> Result, GpuError> { use zp1_primitives::field::M31; - + let n = (a.len() + b.len() - 1).next_power_of_two(); let log_n = n.trailing_zeros() as usize; - + let mut a_ext = vec![0u32; n]; let mut b_ext = vec![0u32; n]; a_ext[..a.len()].copy_from_slice(a); b_ext[..b.len()].copy_from_slice(b); - + // Transform to NTT domain self.ntt_m31(&mut a_ext, log_n)?; self.ntt_m31(&mut b_ext, log_n)?; - + // Pointwise multiply for i in 0..n { let av = M31::new(a_ext[i]); let bv = M31::new(b_ext[i]); a_ext[i] = (av * bv).value(); } - + // Transform back self.intt_m31(&mut a_ext, log_n)?; - + Ok(a_ext) } - + fn poly_add(&self, a: &[u32], b: &[u32]) -> Result, GpuError> { use zp1_primitives::field::M31; - + let n = a.len().max(b.len()); let mut result = vec![0u32; n]; - + for i in 0..n { - let av = if i < a.len() { M31::new(a[i]) } else { M31::ZERO }; - let bv = if i < b.len() { M31::new(b[i]) } else { M31::ZERO }; + let av = if i < a.len() { + M31::new(a[i]) + } else { + M31::ZERO + }; + let bv = if i < b.len() { + M31::new(b[i]) + } else { + M31::ZERO + }; result[i] = (av + bv).value(); } - + Ok(result) } - + fn poly_eval(&self, coeffs: &[u32], point: u32) -> Result { let mut results = vec![0u32; 1]; self.batch_evaluate(coeffs, &[point], &mut results)?; Ok(results[0]) } - - fn poly_eval_batch( - &self, - coeffs: &[u32], - points: &[u32], - ) -> Result, GpuError> { + + fn poly_eval_batch(&self, coeffs: &[u32], points: &[u32]) -> Result, GpuError> { let mut results = vec![0u32; points.len()]; self.batch_evaluate(coeffs, points, &mut results)?; Ok(results) } - + fn lde(&self, coeffs: &[u32], blowup_factor: usize) -> Result, GpuError> { GpuBackend::lde(self, coeffs, blowup_factor) } - + fn interpolate(&self, evaluations: &[u32]) -> Result, GpuError> { // For now, use the INTT as interpolation let log_n = evaluations.len().trailing_zeros() as usize; @@ -168,15 +168,18 @@ impl GpuMerkle for T { fn merkle_root(&self, leaves: &[[u8; 32]]) -> Result<[u8; 32], GpuError> { let tree = self.merkle_tree(leaves)?; if tree.is_empty() { - return Err(GpuError::InvalidBufferSize { expected: 1, actual: 0 }); + return Err(GpuError::InvalidBufferSize { + expected: 1, + actual: 0, + }); } Ok(tree[0]) } - + fn merkle_tree(&self, leaves: &[[u8; 32]]) -> Result, GpuError> { GpuBackend::merkle_tree(self, leaves) } - + fn merkle_path(&self, tree: &[[u8; 32]], leaf_index: usize) -> Result, GpuError> { let n = (tree.len() + 1) / 2; // Number of leaves if leaf_index >= n { @@ -185,10 +188,10 @@ impl GpuMerkle for T { actual: leaf_index, }); } - + let mut path = Vec::new(); let mut idx = tree.len() - n + leaf_index; - + while idx > 0 { let sibling = if idx % 2 == 0 { idx - 1 } else { idx + 1 }; if sibling < tree.len() { @@ -196,10 +199,10 @@ impl GpuMerkle for T { } idx = (idx - 1) / 2; } - + Ok(path) } - + fn verify_merkle_path( &self, root: &[u8; 32], @@ -207,11 +210,11 @@ impl GpuMerkle for T { path: &[[u8; 32]], leaf_index: usize, ) -> Result { - use sha2::{Sha256, Digest}; - + use sha2::{Digest, Sha256}; + let mut current = *leaf; let mut idx = leaf_index; - + for sibling in path { let mut hasher = Sha256::new(); if idx % 2 == 0 { @@ -224,7 +227,7 @@ impl GpuMerkle for T { current.copy_from_slice(&hasher.finalize()); idx /= 2; } - + Ok(current == *root) } } @@ -233,32 +236,32 @@ impl GpuMerkle for T { mod tests { use super::*; use crate::gpu::backend::CpuBackend; - + #[test] fn test_poly_add() { let backend = CpuBackend::new(); - + let a = vec![1, 2, 3]; let b = vec![4, 5]; - + let result = backend.poly_add(&a, &b).unwrap(); assert_eq!(result, vec![5, 7, 3]); } - + #[test] fn test_poly_eval() { let backend = CpuBackend::new(); - + // Polynomial: 1 + 2x let coeffs = vec![1, 2]; let result = GpuPolynomial::poly_eval(&backend, &coeffs, 5).unwrap(); assert_eq!(result, 11); // 1 + 2*5 = 11 } - + #[test] fn test_merkle_operations() { let backend = CpuBackend::new(); - + // Create some leaves let leaves: Vec<[u8; 32]> = (0..4) .map(|i| { @@ -267,10 +270,10 @@ mod tests { leaf }) .collect(); - + let tree = GpuBackend::merkle_tree(&backend, &leaves).unwrap(); assert!(!tree.is_empty()); - + let root = GpuMerkle::merkle_root(&backend, &leaves).unwrap(); assert_eq!(root, tree[0]); }