diff --git a/CHANGELOG.md b/CHANGELOG.md index 4cd83e55..b4289589 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- Optimize `from_str_radix` ([#557]) + +[#557]: https://github.com/recmo/uint/pull/557 + ## [1.17.2] - 2025-12-28 ### Fixed diff --git a/src/algorithms/mul_redc.rs b/src/algorithms/mul_redc.rs index 05ba4a67..69fcfa33 100644 --- a/src/algorithms/mul_redc.rs +++ b/src/algorithms/mul_redc.rs @@ -21,36 +21,33 @@ pub fn mul_redc(a: [u64; N], b: [u64; N], modulus: [u64; N], inv // See let mut result = [0; N]; let mut carry = false; + let has_top_carry = modulus[N - 1] >= 0x7fff_ffff_ffff_ffff; for b in b { - let mut m = 0; - let mut carry_1 = 0; - let mut carry_2 = 0; - for i in 0..N { - // Add limb product + let mut carry_1; + let mut carry_2; + + // i = 0: compute initial value and reduction factor. + let (value, next_carry) = carrying_mul_add(a[0], b, result[0], 0); + carry_1 = next_carry; + let m = value.wrapping_mul(inv); + let (value, next_carry) = carrying_mul_add(modulus[0], m, value, 0); + carry_2 = next_carry; + debug_assert_eq!(value, 0); + + // i = 1..N + for i in 1..N { let (value, next_carry) = carrying_mul_add(a[i], b, result[i], carry_1); carry_1 = next_carry; - if i == 0 { - // Compute reduction factor - m = value.wrapping_mul(inv); - } - - // Add m * modulus to acc to clear next_result[0] let (value, next_carry) = carrying_mul_add(modulus[i], m, value, carry_2); carry_2 = next_carry; - // Shift result - if i > 0 { - result[i - 1] = value; - } else { - debug_assert_eq!(value, 0); - } + result[i - 1] = value; } - // Add carries let (value, next_carry) = carrying_add(carry_1, carry_2, carry); result[N - 1] = value; - if modulus[N - 1] >= 0x7fff_ffff_ffff_ffff { + if has_top_carry { carry = next_carry; } else { debug_assert!(!next_carry); @@ -74,8 +71,8 @@ pub fn square_redc(a: [u64; N], modulus: [u64; N], inv: u64) -> let mut result = [0; N]; let mut carry_outer = 0; + let has_top_carry = modulus[N - 1] >= 0x3fff_ffff_ffff_ffff; for i in 0..N { - // Add limb product let (value, mut carry_lo) = carrying_mul_add(a[i], a[i], result[i], 0); let mut carry_hi = false; result[i] = value; @@ -87,7 +84,6 @@ pub fn square_redc(a: [u64; N], modulus: [u64; N], inv: u64) -> carry_hi = next_carry_hi; } - // Add m times modulus to result and shift one limb let m = result[0].wrapping_mul(inv); let (value, mut carry) = carrying_mul_add(m, modulus[0], result[0], 0); debug_assert_eq!(value, 0); @@ -97,19 +93,16 @@ pub fn square_redc(a: [u64; N], modulus: [u64; N], inv: u64) -> carry = next_carry; } - // Add carries - if modulus[N - 1] >= 0x3fff_ffff_ffff_ffff { + if has_top_carry { let wide = (carry_outer as u128) .wrapping_add(carry_lo as u128) .wrapping_add((carry_hi as u128) << 64) .wrapping_add(carry as u128); result[N - 1] = wide as u64; - // Note carry_outer can be {0, 1, 2}. carry_outer = (wide >> 64) as u64; debug_assert!(carry_outer <= 2); } else { - // `carry_outer` and `carry_hi` are always zero. debug_assert!(!carry_hi); debug_assert_eq!(carry_outer, 0); let (value, carry) = carry_lo.overflowing_add(carry); diff --git a/src/base_convert.rs b/src/base_convert.rs index f88af451..a615e014 100644 --- a/src/base_convert.rs +++ b/src/base_convert.rs @@ -193,35 +193,64 @@ impl Uint { base: u64, digits: I, ) -> Result { - // OPT: Special handling of bases that divide 2^64, and bases that are - // powers of 2. - // OPT: Same trick as with `to_base_le`, find the largest power of base - // that fits `u64` and accumulate there first. if base < 2 { return Err(BaseConvertError::InvalidBase(base)); } + let chunk_base = crate::utils::max_pow_u64(base); + let mut chunk_power: usize = 1; + { + let mut p = base; + while p != chunk_base { + p *= base; + chunk_power += 1; + } + } + let mut result = Self::ZERO; + let mut chunk_val: u64 = 0; + let mut chunk_digits: usize = 0; for digit in digits { if digit >= base { return Err(BaseConvertError::InvalidDigit(digit, base)); } - // Multiply by base. - // OPT: keep track of non-zero limbs and mul the minimum. - let mut carry = u128::from(digit); - #[allow(clippy::cast_possible_truncation)] - for limb in &mut result.limbs { - carry += u128::from(*limb) * u128::from(base); - *limb = carry as u64; - carry >>= 64; + chunk_val = chunk_val * base + digit; + chunk_digits += 1; + if chunk_digits == chunk_power { + Self::from_base_muladd(&mut result, chunk_base, chunk_val)?; + chunk_val = 0; + chunk_digits = 0; } - if carry > 0 || (LIMBS != 0 && result.limbs[LIMBS - 1] > Self::MASK) { - return Err(BaseConvertError::Overflow); + } + if chunk_digits > 0 { + let mut tail_base = base; + for _ in 1..chunk_digits { + tail_base *= base; } + Self::from_base_muladd(&mut result, tail_base, chunk_val)?; } Ok(result) } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn from_base_muladd( + result: &mut Self, + factor: u64, + addend: u64, + ) -> Result<(), BaseConvertError> { + let mut carry = u128::from(addend); + for limb in &mut result.limbs { + carry += u128::from(*limb) * u128::from(factor); + *limb = carry as u64; + carry >>= 64; + } + if carry > 0 || (LIMBS != 0 && result.limbs[LIMBS - 1] > Self::MASK) { + return Err(BaseConvertError::Overflow); + } + Ok(()) + } } struct SpigotLittle { diff --git a/src/bits.rs b/src/bits.rs index fef5179d..d0c1adef 100644 --- a/src/bits.rs +++ b/src/bits.rs @@ -327,6 +327,38 @@ impl Uint { #[inline] #[must_use] pub const fn overflowing_shl(self, rhs: usize) -> (Self, bool) { + if LIMBS == 1 { + let (limbs, bits) = (rhs / 64, rhs % 64); + if limbs >= 1 { + return (Self::ZERO, self.limbs[0] != 0); + } + let x = self.limbs[0]; + let carry = (x >> (63 - bits)) >> 1; + let mut r = Self::ZERO; + r.limbs[0] = (x << bits) & Self::MASK; + return (r, carry != 0); + } + if LIMBS == 2 { + let (limbs, bits) = (rhs / 64, rhs % 64); + if limbs >= 2 { + return (Self::ZERO, !self.const_is_zero()); + } + let val = self.as_double_words()[0].get(); + let shifted = val << bits; + if limbs == 0 { + let carry = (val >> (127 - bits)) >> 1; + let mut r = Self::ZERO; + r.limbs[0] = shifted as u64; + r.limbs[1] = (shifted >> 64) as u64 & Self::MASK; + return (r, carry != 0); + } + let x = self.limbs[0] as u128; + let carry = (x >> (63 - bits)) >> 1; + let mut r = Self::ZERO; + r.limbs[1] = (x << bits) as u64 & Self::MASK; + return (r, carry != 0); + } + let (limbs, bits) = (rhs / 64, rhs % 64); if limbs >= LIMBS { return (Self::ZERO, !self.const_is_zero()); @@ -410,6 +442,38 @@ impl Uint { #[inline] #[must_use] pub const fn overflowing_shr(self, rhs: usize) -> (Self, bool) { + if LIMBS == 1 { + let (limbs, bits) = (rhs / 64, rhs % 64); + if limbs >= 1 { + return (Self::ZERO, self.limbs[0] != 0); + } + let x = self.limbs[0]; + let carry = (x << (63 - bits)) << 1; + let mut r = Self::ZERO; + r.limbs[0] = x >> bits; + return (r, carry != 0); + } + if LIMBS == 2 { + let (limbs, bits) = (rhs / 64, rhs % 64); + if limbs >= 2 { + return (Self::ZERO, !self.const_is_zero()); + } + let val = self.as_double_words()[0].get(); + if limbs == 0 { + let carry = (val << (127 - bits)) << 1; + let shifted = val >> bits; + let mut r = Self::ZERO; + r.limbs[0] = shifted as u64; + r.limbs[1] = (shifted >> 64) as u64; + return (r, carry != 0); + } + let x = self.limbs[1]; + let carry = (x << (63 - bits)) << 1; + let mut r = Self::ZERO; + r.limbs[0] = x >> bits; + return (r, carry != 0); + } + let (limbs, bits) = (rhs / 64, rhs % 64); if limbs >= LIMBS { return (Self::ZERO, !self.const_is_zero()); diff --git a/src/cmp.rs b/src/cmp.rs index 47ede059..5f605364 100644 --- a/src/cmp.rs +++ b/src/cmp.rs @@ -11,6 +11,14 @@ impl PartialOrd for Uint { impl Ord for Uint { #[inline] fn cmp(&self, rhs: &Self) -> Ordering { + if LIMBS == 1 { + return self.limbs[0].cmp(&rhs.limbs[0]); + } + if LIMBS == 2 { + return self.as_double_words()[0] + .get() + .cmp(&rhs.as_double_words()[0].get()); + } crate::algorithms::cmp(self.as_limbs(), rhs.as_limbs()) } } diff --git a/src/div.rs b/src/div.rs index 1f1f77b4..d4c930c9 100644 --- a/src/div.rs +++ b/src/div.rs @@ -50,6 +50,8 @@ impl Uint { let q = &mut self.limbs[0]; let r = &mut rhs.limbs[0]; (*q, *r) = algorithms::div::div_1x1(*q, *r); + } else if LIMBS <= 4 { + algorithms::div::div_inlined(&mut self.limbs, &mut rhs.limbs); } else { Self::div_rem_by_ref(&mut self, &mut rhs); } diff --git a/src/fmt.rs b/src/fmt.rs index ae267651..ce8b084b 100644 --- a/src/fmt.rs +++ b/src/fmt.rs @@ -83,10 +83,51 @@ impl fmt::Debug for Uint { } impl_fmt!(fmt::Display; base::Decimal, ""); -impl_fmt!(fmt::Binary; base::Binary, "b"); -impl_fmt!(fmt::Octal; base::Octal, "o"); -impl_fmt!(fmt::LowerHex; base::Hexadecimal, "x"); -impl_fmt!(fmt::UpperHex; base::Hexadecimal, "X"); + +macro_rules! impl_fmt_pow2 { + ($tr:path; $base:ty, $bits_per_digit:literal, $upper:literal) => { + impl $tr for Uint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Ok(small) = u64::try_from(self) { + return ::fmt(&small, f); + } + if let Ok(small) = u128::try_from(self) { + return ::fmt(&small, f); + } + + let alphabet: &[u8; 16] = if $upper { + b"0123456789ABCDEF" + } else { + b"0123456789abcdef" + }; + let mask: u64 = (1 << $bits_per_digit) - 1; + + let bit_len = self.bit_len(); + let total_digits = bit_len.div_ceil($bits_per_digit); + + let mut s = StackString::::new(); + let mut i = total_digits; + while i > 0 { + i -= 1; + let bit_offset = i * $bits_per_digit; + let limb_idx = bit_offset / 64; + let bit_idx = bit_offset % 64; + let mut digit = (self.limbs[limb_idx] >> bit_idx) & mask; + if bit_idx + $bits_per_digit > 64 && limb_idx + 1 < LIMBS { + digit |= (self.limbs[limb_idx + 1] << (64 - bit_idx)) & mask; + } + s.push_byte(alphabet[digit as usize]); + } + f.pad_integral(true, <$base>::PREFIX, s.as_str()) + } + } + }; +} + +impl_fmt_pow2!(fmt::Binary; base::Binary, 1, false); +impl_fmt_pow2!(fmt::Octal; base::Octal, 3, false); +impl_fmt_pow2!(fmt::LowerHex; base::Hexadecimal, 4, false); +impl_fmt_pow2!(fmt::UpperHex; base::Hexadecimal, 4, true); /// A stack-allocated buffer that implements [`fmt::Write`]. pub(crate) struct StackString { @@ -115,6 +156,13 @@ impl StackString { const fn as_bytes(&self) -> &[u8] { unsafe { core::slice::from_raw_parts(self.buf.as_ptr().cast(), self.len) } } + + #[inline] + fn push_byte(&mut self, b: u8) { + debug_assert!(self.len < SIZE); + unsafe { self.buf.as_mut_ptr().add(self.len).cast::().write(b) }; + self.len += 1; + } } impl fmt::Write for StackString { diff --git a/src/modular.rs b/src/modular.rs index 1beccaca..eff4bcbb 100644 --- a/src/modular.rs +++ b/src/modular.rs @@ -58,7 +58,11 @@ impl Uint { // Reuse `div_rem` if we don't need an extra limb. if const { crate::nlimbs(BITS + 1) == LIMBS } { let numerator = unsafe { &mut *numerator.as_mut_ptr().cast::() }; - Self::div_rem_by_ref(numerator, &mut modulus); + if LIMBS <= 4 { + algorithms::div::div_inlined(&mut numerator.limbs, &mut modulus.limbs); + } else { + Self::div_rem_by_ref(numerator, &mut modulus); + } } else { Self::div_rem_bits_plus_one(numerator.as_mut_ptr(), &mut modulus); } diff --git a/src/string.rs b/src/string.rs index cddfbcf2..d4c046e4 100644 --- a/src/string.rs +++ b/src/string.rs @@ -1,6 +1,6 @@ #![allow(clippy::missing_inline_in_public_items)] // allow format functions -use crate::{Uint, base_convert::BaseConvertError}; +use crate::{Uint, algorithms::DoubleWord, base_convert::BaseConvertError}; use core::{fmt, str::FromStr}; /// Error for [`from_str_radix`](Uint::from_str_radix). @@ -44,6 +44,42 @@ impl fmt::Display for ParseError { } } +/// Returns `(base, power)` where `base = radix^power` is the largest power of +/// `radix` that fits in a `u64`. +const fn radix_base(radix: u64) -> (u64, usize) { + debug_assert!(radix >= 2); + let mut power: usize = 1; + let mut base = radix; + loop { + match base.checked_mul(radix) { + Some(n) => { + base = n; + power += 1; + } + None => return (base, power), + } + } +} + +/// Decode an ASCII byte as a digit for radix <= 36. +/// Case-insensitive 0-9, a-z. Underscores are skipped. +#[inline(always)] +fn decode_digit(b: u8, radix: u64) -> Result, ParseError> { + let digit = match b { + b'0'..=b'9' => b - b'0', + b'a'..=b'z' => b - b'a' + 10, + b'A'..=b'Z' => b - b'A' + 10, + b'_' => return Ok(None), + _ => return Err(ParseError::InvalidDigit(b as char)), + }; + let digit = u64::from(digit); + if digit < radix { + Ok(Some(digit)) + } else { + Err(ParseError::InvalidDigit(b as char)) + } +} + impl Uint { /// Parse a string into a [`Uint`]. /// @@ -59,40 +95,40 @@ impl Uint { /// * [`ParseError::BaseConvertError`] if [`Uint::from_base_be`] fails. // FEATURE: Support proper unicode. Ignore zero-width spaces, joiners, etc. // Recognize digits from other alphabets. + #[inline] pub fn from_str_radix(src: &str, radix: u64) -> Result { - if radix > 64 { - return Err(ParseError::InvalidRadix(radix)); + match radix { + // Specialize for the common cases. + 2 => Self::from_str_radix_pow2(src, 2), + 8 => Self::from_str_radix_pow2(src, 8), + 10 => Self::from_str_radix_chunked(src, 10), + 16 => Self::from_str_radix_pow2(src, 16), + + 65.. => Err(ParseError::InvalidRadix(radix)), + 37.. => Self::from_str_radix_slow(src, radix), + r if r.is_power_of_two() => Self::from_str_radix_pow2(src, radix), + _ => Self::from_str_radix_chunked(src, radix), } + } + + /// Fallback for radix > 36 (base-64 alphabet). Not perf-critical. + #[cold] + fn from_str_radix_slow(src: &str, radix: u64) -> Result { let mut err = None; let digits = src.chars().filter_map(|c| { if err.is_some() { return None; } - let digit = if radix <= 36 { - // Case insensitive 0—9, a—z. - match c { - '0'..='9' => u64::from(c) - u64::from('0'), - 'a'..='z' => u64::from(c) - u64::from('a') + 10, - 'A'..='Z' => u64::from(c) - u64::from('A') + 10, - '_' => return None, // Ignored character. - _ => { - err = Some(ParseError::InvalidDigit(c)); - return None; - } - } - } else { - // The Base-64 alphabets - match c { - 'A'..='Z' => u64::from(c) - u64::from('A'), - 'a'..='f' => u64::from(c) - u64::from('a') + 26, - '0'..='9' => u64::from(c) - u64::from('0') + 52, - '+' | '-' => 62, - '/' | ',' | '_' => 63, - '=' | '\r' | '\n' => return None, // Ignored characters. - _ => { - err = Some(ParseError::InvalidDigit(c)); - return None; - } + let digit = match c { + 'A'..='Z' => u64::from(c) - u64::from('A'), + 'a'..='f' => u64::from(c) - u64::from('a') + 26, + '0'..='9' => u64::from(c) - u64::from('0') + 52, + '+' | '-' => 62, + '/' | ',' | '_' => 63, + '=' | '\r' | '\n' => return None, + _ => { + err = Some(ParseError::InvalidDigit(c)); + return None; } }; Some(digit) @@ -100,6 +136,90 @@ impl Uint { let value = Self::from_base_be(radix, digits)?; err.map_or(Ok(value), Err) } + + /// Power-of-2 radix: shift digits directly into limbs, no multiplication. + #[inline] + fn from_str_radix_pow2(src: &str, radix: u64) -> Result { + debug_assert!(radix.is_power_of_two()); + let bits_per_digit = radix.trailing_zeros() as usize; + let mut result = Self::ZERO; + let mut total_bits = 0usize; + for &b in src.as_bytes().iter().rev() { + let digit = match decode_digit(b, radix) { + Ok(None) => continue, + Ok(Some(d)) => d, + Err(e) => return Err(e), + }; + if total_bits >= BITS { + if digit != 0 { + return Err(BaseConvertError::Overflow.into()); + } + continue; + } + let limb_idx = total_bits / 64; + let bit_idx = total_bits % 64; + result.limbs[limb_idx] |= digit << bit_idx; + if bit_idx + bits_per_digit > 64 { + let hi = digit >> (64 - bit_idx); + if limb_idx + 1 < LIMBS { + result.limbs[limb_idx + 1] |= hi; + } else if hi != 0 { + return Err(BaseConvertError::Overflow.into()); + } + } + total_bits += bits_per_digit; + } + if LIMBS > 0 && result.limbs[LIMBS - 1] > Self::MASK { + return Err(BaseConvertError::Overflow.into()); + } + Ok(result) + } + + /// Non-power-of-2 radix: accumulate chunks of digits into a u64, then do + /// one widening multiply per chunk instead of per digit. + #[allow(clippy::cast_possible_truncation)] + #[inline] + fn from_str_radix_chunked(src: &str, radix: u64) -> Result { + let (base, power) = radix_base(radix); + let mut result = Self::ZERO; + let mut chunk_val: u64 = 0; + let mut chunk_digits: usize = 0; + for &b in src.as_bytes() { + let digit = match decode_digit(b, radix) { + Ok(None) => continue, + Ok(Some(d)) => d, + Err(e) => return Err(e), + }; + chunk_val = chunk_val * radix + digit; + chunk_digits += 1; + if chunk_digits == power { + Self::muladd_limbs(&mut result.limbs, base, chunk_val)?; + chunk_val = 0; + chunk_digits = 0; + } + } + if chunk_digits > 0 { + let mut tail_base = radix; + for _ in 1..chunk_digits { + tail_base *= radix; + } + Self::muladd_limbs(&mut result.limbs, tail_base, chunk_val)?; + } + Ok(result) + } + + /// `limbs = limbs * factor + addend`, returning overflow error. + #[inline(always)] + fn muladd_limbs(limbs: &mut [u64; LIMBS], factor: u64, addend: u64) -> Result<(), ParseError> { + let mut carry = addend; + for limb in limbs.iter_mut() { + (*limb, carry) = u128::muladd(*limb, factor, carry).split(); + } + if carry > 0 || (LIMBS != 0 && limbs[LIMBS - 1] > Self::MASK) { + return Err(BaseConvertError::Overflow.into()); + } + Ok(()) + } } impl FromStr for Uint { @@ -125,6 +245,35 @@ mod tests { use super::*; use proptest::{prop_assert_eq, proptest}; + #[test] + fn test_pow2_overflow() { + type U8 = Uint<8, 1>; + assert_eq!(U8::from_str("0xff"), Ok(U8::from(255))); + assert_eq!( + U8::from_str("0x1ff"), + Err(ParseError::BaseConvertError(BaseConvertError::Overflow)) + ); + assert_eq!( + U8::from_str("0x100"), + Err(ParseError::BaseConvertError(BaseConvertError::Overflow)) + ); + + type U7 = Uint<7, 1>; + assert_eq!(U7::from_str("0x7f"), Ok(U7::from(127))); + assert_eq!( + U7::from_str("0xff"), + Err(ParseError::BaseConvertError(BaseConvertError::Overflow)) + ); + + // Octal: 0o777 = 511, which overflows U8 (max 255). + assert_eq!( + U8::from_str("0o777"), + Err(ParseError::BaseConvertError(BaseConvertError::Overflow)) + ); + // Octal: 0o377 = 255, fits U8. + assert_eq!(U8::from_str("0o377"), Ok(U8::from(255))); + } + #[test] fn test_parse() { proptest!(|(value: u128)| {