From 89f64de2f2c4a52526cea0de11edc1f05366a6ab Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Sat, 7 Feb 2026 03:32:28 +0100 Subject: [PATCH] perf: optimize from_str_radix --- CHANGELOG.md | 6 ++ src/string.rs | 205 +++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 183 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4cd83e55..b4289589 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- Optimize `from_str_radix` ([#557]) + +[#557]: https://github.com/recmo/uint/pull/557 + ## [1.17.2] - 2025-12-28 ### Fixed diff --git a/src/string.rs b/src/string.rs index cddfbcf2..d4c046e4 100644 --- a/src/string.rs +++ b/src/string.rs @@ -1,6 +1,6 @@ #![allow(clippy::missing_inline_in_public_items)] // allow format functions -use crate::{Uint, base_convert::BaseConvertError}; +use crate::{Uint, algorithms::DoubleWord, base_convert::BaseConvertError}; use core::{fmt, str::FromStr}; /// Error for [`from_str_radix`](Uint::from_str_radix). @@ -44,6 +44,42 @@ impl fmt::Display for ParseError { } } +/// Returns `(base, power)` where `base = radix^power` is the largest power of +/// `radix` that fits in a `u64`. +const fn radix_base(radix: u64) -> (u64, usize) { + debug_assert!(radix >= 2); + let mut power: usize = 1; + let mut base = radix; + loop { + match base.checked_mul(radix) { + Some(n) => { + base = n; + power += 1; + } + None => return (base, power), + } + } +} + +/// Decode an ASCII byte as a digit for radix <= 36. +/// Case-insensitive 0-9, a-z. Underscores are skipped. +#[inline(always)] +fn decode_digit(b: u8, radix: u64) -> Result, ParseError> { + let digit = match b { + b'0'..=b'9' => b - b'0', + b'a'..=b'z' => b - b'a' + 10, + b'A'..=b'Z' => b - b'A' + 10, + b'_' => return Ok(None), + _ => return Err(ParseError::InvalidDigit(b as char)), + }; + let digit = u64::from(digit); + if digit < radix { + Ok(Some(digit)) + } else { + Err(ParseError::InvalidDigit(b as char)) + } +} + impl Uint { /// Parse a string into a [`Uint`]. /// @@ -59,40 +95,40 @@ impl Uint { /// * [`ParseError::BaseConvertError`] if [`Uint::from_base_be`] fails. // FEATURE: Support proper unicode. Ignore zero-width spaces, joiners, etc. // Recognize digits from other alphabets. + #[inline] pub fn from_str_radix(src: &str, radix: u64) -> Result { - if radix > 64 { - return Err(ParseError::InvalidRadix(radix)); + match radix { + // Specialize for the common cases. + 2 => Self::from_str_radix_pow2(src, 2), + 8 => Self::from_str_radix_pow2(src, 8), + 10 => Self::from_str_radix_chunked(src, 10), + 16 => Self::from_str_radix_pow2(src, 16), + + 65.. => Err(ParseError::InvalidRadix(radix)), + 37.. => Self::from_str_radix_slow(src, radix), + r if r.is_power_of_two() => Self::from_str_radix_pow2(src, radix), + _ => Self::from_str_radix_chunked(src, radix), } + } + + /// Fallback for radix > 36 (base-64 alphabet). Not perf-critical. + #[cold] + fn from_str_radix_slow(src: &str, radix: u64) -> Result { let mut err = None; let digits = src.chars().filter_map(|c| { if err.is_some() { return None; } - let digit = if radix <= 36 { - // Case insensitive 0—9, a—z. - match c { - '0'..='9' => u64::from(c) - u64::from('0'), - 'a'..='z' => u64::from(c) - u64::from('a') + 10, - 'A'..='Z' => u64::from(c) - u64::from('A') + 10, - '_' => return None, // Ignored character. - _ => { - err = Some(ParseError::InvalidDigit(c)); - return None; - } - } - } else { - // The Base-64 alphabets - match c { - 'A'..='Z' => u64::from(c) - u64::from('A'), - 'a'..='f' => u64::from(c) - u64::from('a') + 26, - '0'..='9' => u64::from(c) - u64::from('0') + 52, - '+' | '-' => 62, - '/' | ',' | '_' => 63, - '=' | '\r' | '\n' => return None, // Ignored characters. - _ => { - err = Some(ParseError::InvalidDigit(c)); - return None; - } + let digit = match c { + 'A'..='Z' => u64::from(c) - u64::from('A'), + 'a'..='f' => u64::from(c) - u64::from('a') + 26, + '0'..='9' => u64::from(c) - u64::from('0') + 52, + '+' | '-' => 62, + '/' | ',' | '_' => 63, + '=' | '\r' | '\n' => return None, + _ => { + err = Some(ParseError::InvalidDigit(c)); + return None; } }; Some(digit) @@ -100,6 +136,90 @@ impl Uint { let value = Self::from_base_be(radix, digits)?; err.map_or(Ok(value), Err) } + + /// Power-of-2 radix: shift digits directly into limbs, no multiplication. + #[inline] + fn from_str_radix_pow2(src: &str, radix: u64) -> Result { + debug_assert!(radix.is_power_of_two()); + let bits_per_digit = radix.trailing_zeros() as usize; + let mut result = Self::ZERO; + let mut total_bits = 0usize; + for &b in src.as_bytes().iter().rev() { + let digit = match decode_digit(b, radix) { + Ok(None) => continue, + Ok(Some(d)) => d, + Err(e) => return Err(e), + }; + if total_bits >= BITS { + if digit != 0 { + return Err(BaseConvertError::Overflow.into()); + } + continue; + } + let limb_idx = total_bits / 64; + let bit_idx = total_bits % 64; + result.limbs[limb_idx] |= digit << bit_idx; + if bit_idx + bits_per_digit > 64 { + let hi = digit >> (64 - bit_idx); + if limb_idx + 1 < LIMBS { + result.limbs[limb_idx + 1] |= hi; + } else if hi != 0 { + return Err(BaseConvertError::Overflow.into()); + } + } + total_bits += bits_per_digit; + } + if LIMBS > 0 && result.limbs[LIMBS - 1] > Self::MASK { + return Err(BaseConvertError::Overflow.into()); + } + Ok(result) + } + + /// Non-power-of-2 radix: accumulate chunks of digits into a u64, then do + /// one widening multiply per chunk instead of per digit. + #[allow(clippy::cast_possible_truncation)] + #[inline] + fn from_str_radix_chunked(src: &str, radix: u64) -> Result { + let (base, power) = radix_base(radix); + let mut result = Self::ZERO; + let mut chunk_val: u64 = 0; + let mut chunk_digits: usize = 0; + for &b in src.as_bytes() { + let digit = match decode_digit(b, radix) { + Ok(None) => continue, + Ok(Some(d)) => d, + Err(e) => return Err(e), + }; + chunk_val = chunk_val * radix + digit; + chunk_digits += 1; + if chunk_digits == power { + Self::muladd_limbs(&mut result.limbs, base, chunk_val)?; + chunk_val = 0; + chunk_digits = 0; + } + } + if chunk_digits > 0 { + let mut tail_base = radix; + for _ in 1..chunk_digits { + tail_base *= radix; + } + Self::muladd_limbs(&mut result.limbs, tail_base, chunk_val)?; + } + Ok(result) + } + + /// `limbs = limbs * factor + addend`, returning overflow error. + #[inline(always)] + fn muladd_limbs(limbs: &mut [u64; LIMBS], factor: u64, addend: u64) -> Result<(), ParseError> { + let mut carry = addend; + for limb in limbs.iter_mut() { + (*limb, carry) = u128::muladd(*limb, factor, carry).split(); + } + if carry > 0 || (LIMBS != 0 && limbs[LIMBS - 1] > Self::MASK) { + return Err(BaseConvertError::Overflow.into()); + } + Ok(()) + } } impl FromStr for Uint { @@ -125,6 +245,35 @@ mod tests { use super::*; use proptest::{prop_assert_eq, proptest}; + #[test] + fn test_pow2_overflow() { + type U8 = Uint<8, 1>; + assert_eq!(U8::from_str("0xff"), Ok(U8::from(255))); + assert_eq!( + U8::from_str("0x1ff"), + Err(ParseError::BaseConvertError(BaseConvertError::Overflow)) + ); + assert_eq!( + U8::from_str("0x100"), + Err(ParseError::BaseConvertError(BaseConvertError::Overflow)) + ); + + type U7 = Uint<7, 1>; + assert_eq!(U7::from_str("0x7f"), Ok(U7::from(127))); + assert_eq!( + U7::from_str("0xff"), + Err(ParseError::BaseConvertError(BaseConvertError::Overflow)) + ); + + // Octal: 0o777 = 511, which overflows U8 (max 255). + assert_eq!( + U8::from_str("0o777"), + Err(ParseError::BaseConvertError(BaseConvertError::Overflow)) + ); + // Octal: 0o377 = 255, fits U8. + assert_eq!(U8::from_str("0o377"), Ok(U8::from(255))); + } + #[test] fn test_parse() { proptest!(|(value: u128)| {