diff --git a/src/cmdline.rs b/src/cmdline.rs index e2cadc9..342ac00 100644 --- a/src/cmdline.rs +++ b/src/cmdline.rs @@ -20,8 +20,9 @@ pub struct Args { #[arg(long, short)] pub cpus: u8, - /// Amount of RAM available to VM, in MiB. - #[arg(long, short)] + /// Amount of RAM for the VM. Optionally specify a unit suffix: M (MiB, default), G (GiB), + /// T (TiB), P (PiB). Examples: 2048, 2048M, 2G. + #[arg(long, short, value_parser = parse_memory)] pub memory: u32, /// Bootloader configuration. @@ -76,6 +77,42 @@ fn parse_timesync(s: &str) -> Result { .map_err(|e| format!("invalid timesync port: {e}")) } +/// Parse a memory size string with an optional unit suffix into MiB. +/// +/// Supported suffixes (powers of 1024): M (MiB, default), G (GiB), T (TiB), P (PiB). +/// Case-insensitive. With no suffix, the value is treated as MiB. +pub fn parse_memory(s: &str) -> Result { + let (num_str, unit) = if s.ends_with(|c: char| c.is_ascii_alphabetic()) { + s.split_at(s.len() - 1) + } else { + (s, "M") + }; + + let n: u64 = num_str + .parse() + .map_err(|_| format!("invalid memory value: '{s}'"))?; + + let mib: u64 = match unit.to_uppercase().as_str() { + "M" => n, + "G" => n + .checked_mul(1024) + .ok_or_else(|| format!("memory value '{s}' is too large"))?, + "T" => n + .checked_mul(1024 * 1024) + .ok_or_else(|| format!("memory value '{s}' is too large"))?, + "P" => n + .checked_mul(1024 * 1024 * 1024) + .ok_or_else(|| format!("memory value '{s}' is too large"))?, + _ => { + return Err(format!( + "unknown memory suffix '{unit}', valid suffixes are MGTP" + )) + } + }; + + u32::try_from(mib).map_err(|_| format!("memory value '{s}' is too large")) +} + /// Parse the input string into a hash map of key value pairs, associating the argument with its /// respective value. pub fn parse_args(s: String) -> Result, anyhow::Error> { @@ -250,6 +287,34 @@ mod bootloader { } mod tests { + #[test] + fn memory_parse_no_suffix() { + assert_eq!(super::parse_memory("512"), Ok(512)); + assert_eq!(super::parse_memory("2048"), Ok(2048)); + } + + #[test] + fn memory_parse_mib() { + assert_eq!(super::parse_memory("512M"), Ok(512)); + assert_eq!(super::parse_memory("512m"), Ok(512)); + } + + #[test] + fn memory_parse_gib() { + assert_eq!(super::parse_memory("2G"), Ok(2048)); + assert_eq!(super::parse_memory("2g"), Ok(2048)); + } + + #[test] + fn memory_parse_invalid() { + assert!(super::parse_memory("abc").is_err()); + assert!(super::parse_memory("1Z").is_err()); + assert!(super::parse_memory("1K").is_err()); + assert!(super::parse_memory("1B").is_err()); + assert!(super::parse_memory("1E").is_err()); + assert!(super::parse_memory("").is_err()); + } + #[test] fn virtio_blk_argument_ordering() { let in_order =