diff --git a/crates/lingua/src/util/media.rs b/crates/lingua/src/util/media.rs index 3fd371fd..dbb8e9ea 100644 --- a/crates/lingua/src/util/media.rs +++ b/crates/lingua/src/util/media.rs @@ -313,6 +313,210 @@ mod wasm_fetch { #[cfg(not(target_arch = "wasm32"))] mod native_fetch { use super::*; + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs}; + use std::time::Duration; + use url::{Host, Url}; + + const MAX_REDIRECTS: usize = 3; + const MEDIA_FETCH_TIMEOUT: Duration = Duration::from_secs(30); + + fn ipv4_in_cidr(address: Ipv4Addr, base: Ipv4Addr, prefix_len: u32) -> bool { + let address = u32::from(address); + let base = u32::from(base); + let mask = if prefix_len == 0 { + 0 + } else { + u32::MAX << (32 - prefix_len) + }; + + (address & mask) == (base & mask) + } + + fn is_blocked_ipv4(address: Ipv4Addr) -> bool { + address.is_loopback() + || address.is_private() + || address.is_link_local() + || address.is_multicast() + || address.is_unspecified() + || ipv4_in_cidr(address, Ipv4Addr::new(0, 0, 0, 0), 8) + || ipv4_in_cidr(address, Ipv4Addr::new(100, 64, 0, 0), 10) + || ipv4_in_cidr(address, Ipv4Addr::new(192, 0, 0, 0), 24) + || ipv4_in_cidr(address, Ipv4Addr::new(198, 18, 0, 0), 15) + || ipv4_in_cidr(address, Ipv4Addr::new(224, 0, 0, 0), 4) + || ipv4_in_cidr(address, Ipv4Addr::new(240, 0, 0, 0), 4) + } + + fn is_blocked_ipv6(address: Ipv6Addr) -> bool { + address.is_loopback() + || address.is_unspecified() + || address.is_multicast() + || address.is_unique_local() + || address.is_unicast_link_local() + || address.to_ipv4_mapped().is_some_and(is_blocked_ipv4) + } + + fn is_blocked_ip(address: IpAddr) -> bool { + match address { + IpAddr::V4(address) => is_blocked_ipv4(address), + IpAddr::V6(address) => is_blocked_ipv6(address), + } + } + + fn is_blocked_hostname(hostname: &str) -> bool { + hostname.eq_ignore_ascii_case("localhost") + || hostname.eq_ignore_ascii_case("metadata.amazonaws.com") + || hostname.eq_ignore_ascii_case("metadata.google.internal") + } + + struct ValidatedMediaUrl { + hostname: Option, + addresses: Vec, + } + + fn validate_media_url(url: &Url) -> Result { + if url.scheme() != "http" && url.scheme() != "https" { + return Err(MediaError::FetchError( + "media URL must use http or https".to_string(), + )); + } + + let host = url + .host() + .ok_or_else(|| MediaError::FetchError("media URL is missing a host".to_string()))?; + match host { + Host::Ipv4(address) => { + if is_blocked_ipv4(address) { + return Err(MediaError::FetchError( + "media URL resolves to a blocked address".to_string(), + )); + } + return Ok(ValidatedMediaUrl { + hostname: None, + addresses: Vec::new(), + }); + } + Host::Ipv6(address) => { + if is_blocked_ipv6(address) { + return Err(MediaError::FetchError( + "media URL resolves to a blocked address".to_string(), + )); + } + return Ok(ValidatedMediaUrl { + hostname: None, + addresses: Vec::new(), + }); + } + Host::Domain(host) => { + if is_blocked_hostname(host) { + return Err(MediaError::FetchError( + "media URL resolves to a blocked address".to_string(), + )); + } + } + } + + let hostname = url + .host_str() + .ok_or_else(|| MediaError::FetchError("media URL is missing a host".to_string()))?; + let port = url.port_or_known_default().ok_or_else(|| { + MediaError::FetchError("media URL is missing a valid port".to_string()) + })?; + let addresses = (hostname, port) + .to_socket_addrs() + .map_err(|e| MediaError::FetchError(format!("failed to resolve media URL: {e}")))?; + + let mut resolved_addresses = Vec::new(); + for address in addresses { + if is_blocked_ip(address.ip()) { + return Err(MediaError::FetchError( + "media URL resolves to a blocked address".to_string(), + )); + } + resolved_addresses.push(address); + } + + if resolved_addresses.is_empty() { + return Err(MediaError::FetchError( + "media URL did not resolve to any addresses".to_string(), + )); + } + + Ok(ValidatedMediaUrl { + hostname: Some(hostname.to_string()), + addresses: resolved_addresses, + }) + } + + async fn fetch_validated_url(url: &str) -> Result { + let mut current_url = Url::parse(url) + .map_err(|e| MediaError::FetchError(format!("invalid media URL: {e}")))?; + + for redirect_count in 0..=MAX_REDIRECTS { + let validated_url = validate_media_url(¤t_url)?; + let mut client_builder = reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .timeout(MEDIA_FETCH_TIMEOUT); + if let Some(hostname) = validated_url.hostname { + client_builder = + client_builder.resolve_to_addrs(&hostname, &validated_url.addresses); + } + let client = client_builder + .build() + .map_err(|e| MediaError::FetchError(e.to_string()))?; + let response = client + .get(current_url.clone()) + .send() + .await + .map_err(|e| MediaError::FetchError(e.to_string()))?; + + if !response.status().is_redirection() { + return Ok(response); + } + + if redirect_count >= MAX_REDIRECTS { + return Err(MediaError::FetchError( + "media URL exceeded redirect limit".to_string(), + )); + } + + let location = response + .headers() + .get(reqwest::header::LOCATION) + .and_then(|v| v.to_str().ok()) + .ok_or_else(|| { + MediaError::FetchError("media URL redirect missing location header".to_string()) + })?; + current_url = current_url + .join(location) + .map_err(|e| MediaError::FetchError(format!("invalid media redirect URL: {e}")))?; + validate_media_url(¤t_url)?; + } + + Err(MediaError::FetchError( + "media URL exceeded redirect limit".to_string(), + )) + } + + async fn response_bytes_with_limit( + response: &mut reqwest::Response, + max_bytes: Option, + ) -> Result, MediaError> { + let mut bytes = Vec::new(); + while let Some(chunk) = response + .chunk() + .await + .map_err(|e| MediaError::FetchError(e.to_string()))? + { + bytes.extend_from_slice(&chunk); + if let Some(max) = max_bytes { + if bytes.len() > max { + return Err(MediaError::SizeExceeded(max / 1024 / 1024)); + } + } + } + + Ok(bytes) + } /// Fetch a URL and return its content as a MediaBlock. /// @@ -326,9 +530,7 @@ mod native_fetch { allowed_types: Option<&[&str]>, max_bytes: Option, ) -> Result { - let response = reqwest::get(url) - .await - .map_err(|e| MediaError::FetchError(e.to_string()))?; + let mut response = fetch_validated_url(url).await?; if !response.status().is_success() { return Err(MediaError::FetchError(format!( @@ -361,17 +563,7 @@ mod native_fetch { } // Get bytes - let bytes = response - .bytes() - .await - .map_err(|e| MediaError::FetchError(e.to_string()))?; - - // Check size - if let Some(max) = max_bytes { - if bytes.len() > max { - return Err(MediaError::SizeExceeded(max / 1024 / 1024)); - } - } + let bytes = response_bytes_with_limit(&mut response, max_bytes).await?; // Encode to base64 use base64::Engine; @@ -382,6 +574,132 @@ mod native_fetch { data, }) } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn validate_media_url_rejects_non_http_schemes() { + let url = Url::parse("file:///etc/passwd").unwrap(); + assert!(matches!( + validate_media_url(&url), + Err(MediaError::FetchError(message)) + if message.contains("http or https") + )); + } + + #[test] + fn validate_media_url_rejects_localhost() { + let url = Url::parse("http://localhost/image.png").unwrap(); + assert!(matches!( + validate_media_url(&url), + Err(MediaError::FetchError(message)) + if message.contains("blocked address") + )); + } + + #[test] + fn validate_media_url_rejects_dns_resolved_localhost() { + let url = Url::parse("http://localhost./image.png").unwrap(); + assert!(matches!( + validate_media_url(&url), + Err(MediaError::FetchError(message)) + if message.contains("blocked address") + )); + } + + #[test] + fn validate_media_url_rejects_metadata_ip() { + let url = Url::parse("http://169.254.169.254/latest/meta-data").unwrap(); + assert!(matches!( + validate_media_url(&url), + Err(MediaError::FetchError(message)) + if message.contains("blocked address") + )); + } + + #[test] + fn validate_media_url_rejects_cloud_metadata_hostnames() { + for url in [ + "http://metadata.amazonaws.com/latest/meta-data", + "http://metadata.google.internal/computeMetadata/v1", + ] { + let url = Url::parse(url).unwrap(); + assert!(matches!( + validate_media_url(&url), + Err(MediaError::FetchError(message)) + if message.contains("blocked address") + )); + } + } + + #[test] + fn redirect_from_public_url_to_localhost_or_private_ip_is_rejected() { + let current_url = Url::parse("https://example.com/image.png").unwrap(); + + for location in [ + "http://127.0.0.1:8080/private.png", + "http://192.168.0.10/private.png", + ] { + let redirect_url = current_url.join(location).unwrap(); + + assert!(matches!( + validate_media_url(&redirect_url), + Err(MediaError::FetchError(message)) + if message.contains("blocked address") + )); + } + } + + #[test] + fn validate_media_url_rejects_ipv4_mapped_ipv6_localhost() { + let dotted_url = Url::parse("http://[::ffff:127.0.0.1]/image.png").unwrap(); + assert!(matches!( + validate_media_url(&dotted_url), + Err(MediaError::FetchError(message)) + if message.contains("blocked address") + )); + + let hex_url = Url::parse("http://[::ffff:7f00:1]/image.png").unwrap(); + assert!(matches!( + validate_media_url(&hex_url), + Err(MediaError::FetchError(message)) + if message.contains("blocked address") + )); + } + + #[test] + fn blocked_ip_ranges_include_metadata_local_private_and_multicast_addresses() { + for address in [ + IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)), + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + IpAddr::V4(Ipv4Addr::new(127, 255, 255, 255)), + IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), + IpAddr::V4(Ipv4Addr::new(0, 255, 255, 255)), + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), + IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1)), + IpAddr::V4(Ipv4Addr::new(172, 31, 255, 255)), + IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)), + IpAddr::V4(Ipv4Addr::new(224, 0, 0, 1)), + IpAddr::V4(Ipv4Addr::new(239, 255, 255, 255)), + IpAddr::V6(Ipv6Addr::LOCALHOST), + IpAddr::V6("ff00::1".parse().unwrap()), + IpAddr::V6("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff".parse().unwrap()), + IpAddr::V6("::ffff:127.0.0.1".parse().unwrap()), + ] { + assert!(is_blocked_ip(address), "{address} should be blocked"); + } + + for address in [ + IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)), + IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), + IpAddr::V6("2606:2800:220:1:248:1893:25c8:1946".parse().unwrap()), + ] { + assert!(!is_blocked_ip(address), "{address} should be allowed"); + } + } + } } // Re-export the appropriate implementation