Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
346 changes: 332 additions & 14 deletions crates/lingua/src/util/media.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,210 @@ mod wasm_fetch {
#[cfg(not(target_arch = "wasm32"))]
mod native_fetch {
use super::*;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
use std::time::Duration;
use url::{Host, Url};

const MAX_REDIRECTS: usize = 3;
const MEDIA_FETCH_TIMEOUT: Duration = Duration::from_secs(30);

fn ipv4_in_cidr(address: Ipv4Addr, base: Ipv4Addr, prefix_len: u32) -> bool {
let address = u32::from(address);
let base = u32::from(base);
let mask = if prefix_len == 0 {
0
} else {
u32::MAX << (32 - prefix_len)
};

(address & mask) == (base & mask)
}

fn is_blocked_ipv4(address: Ipv4Addr) -> bool {
address.is_loopback()
|| address.is_private()
|| address.is_link_local()
|| address.is_multicast()
|| address.is_unspecified()
|| ipv4_in_cidr(address, Ipv4Addr::new(0, 0, 0, 0), 8)
|| ipv4_in_cidr(address, Ipv4Addr::new(100, 64, 0, 0), 10)
|| ipv4_in_cidr(address, Ipv4Addr::new(192, 0, 0, 0), 24)
|| ipv4_in_cidr(address, Ipv4Addr::new(198, 18, 0, 0), 15)
|| ipv4_in_cidr(address, Ipv4Addr::new(224, 0, 0, 0), 4)
|| ipv4_in_cidr(address, Ipv4Addr::new(240, 0, 0, 0), 4)
}

fn is_blocked_ipv6(address: Ipv6Addr) -> bool {
address.is_loopback()
|| address.is_unspecified()
|| address.is_multicast()
|| address.is_unique_local()
|| address.is_unicast_link_local()
|| address.to_ipv4_mapped().is_some_and(is_blocked_ipv4)
}

fn is_blocked_ip(address: IpAddr) -> bool {
match address {
IpAddr::V4(address) => is_blocked_ipv4(address),
IpAddr::V6(address) => is_blocked_ipv6(address),
}
}

fn is_blocked_hostname(hostname: &str) -> bool {
hostname.eq_ignore_ascii_case("localhost")
|| hostname.eq_ignore_ascii_case("metadata.amazonaws.com")
|| hostname.eq_ignore_ascii_case("metadata.google.internal")
}

struct ValidatedMediaUrl {
hostname: Option<String>,
addresses: Vec<SocketAddr>,
}

fn validate_media_url(url: &Url) -> Result<ValidatedMediaUrl, MediaError> {
if url.scheme() != "http" && url.scheme() != "https" {
return Err(MediaError::FetchError(
"media URL must use http or https".to_string(),
));
}

let host = url
.host()
.ok_or_else(|| MediaError::FetchError("media URL is missing a host".to_string()))?;
match host {
Host::Ipv4(address) => {
if is_blocked_ipv4(address) {
return Err(MediaError::FetchError(
"media URL resolves to a blocked address".to_string(),
));
}
return Ok(ValidatedMediaUrl {
hostname: None,
addresses: Vec::new(),
});
}
Host::Ipv6(address) => {
if is_blocked_ipv6(address) {
return Err(MediaError::FetchError(
"media URL resolves to a blocked address".to_string(),
));
}
return Ok(ValidatedMediaUrl {
hostname: None,
addresses: Vec::new(),
});
}
Host::Domain(host) => {
if is_blocked_hostname(host) {
return Err(MediaError::FetchError(
"media URL resolves to a blocked address".to_string(),
));
}
}
}

let hostname = url
.host_str()
.ok_or_else(|| MediaError::FetchError("media URL is missing a host".to_string()))?;
let port = url.port_or_known_default().ok_or_else(|| {
MediaError::FetchError("media URL is missing a valid port".to_string())
})?;
let addresses = (hostname, port)
.to_socket_addrs()
.map_err(|e| MediaError::FetchError(format!("failed to resolve media URL: {e}")))?;

let mut resolved_addresses = Vec::new();
for address in addresses {
if is_blocked_ip(address.ip()) {
return Err(MediaError::FetchError(
"media URL resolves to a blocked address".to_string(),
));
}
resolved_addresses.push(address);
}

if resolved_addresses.is_empty() {
return Err(MediaError::FetchError(
"media URL did not resolve to any addresses".to_string(),
));
}

Ok(ValidatedMediaUrl {
hostname: Some(hostname.to_string()),
addresses: resolved_addresses,
})
}

async fn fetch_validated_url(url: &str) -> Result<reqwest::Response, MediaError> {
let mut current_url = Url::parse(url)
.map_err(|e| MediaError::FetchError(format!("invalid media URL: {e}")))?;

for redirect_count in 0..=MAX_REDIRECTS {
let validated_url = validate_media_url(&current_url)?;
let mut client_builder = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.timeout(MEDIA_FETCH_TIMEOUT);
if let Some(hostname) = validated_url.hostname {
client_builder =
client_builder.resolve_to_addrs(&hostname, &validated_url.addresses);
}
let client = client_builder
.build()
.map_err(|e| MediaError::FetchError(e.to_string()))?;
let response = client
.get(current_url.clone())
.send()
.await
.map_err(|e| MediaError::FetchError(e.to_string()))?;

if !response.status().is_redirection() {
return Ok(response);
}

if redirect_count >= MAX_REDIRECTS {
return Err(MediaError::FetchError(
"media URL exceeded redirect limit".to_string(),
));
}

let location = response
.headers()
.get(reqwest::header::LOCATION)
.and_then(|v| v.to_str().ok())
.ok_or_else(|| {
MediaError::FetchError("media URL redirect missing location header".to_string())
})?;
current_url = current_url
.join(location)
.map_err(|e| MediaError::FetchError(format!("invalid media redirect URL: {e}")))?;
validate_media_url(&current_url)?;
}

Err(MediaError::FetchError(
"media URL exceeded redirect limit".to_string(),
))
}

async fn response_bytes_with_limit(
response: &mut reqwest::Response,
max_bytes: Option<usize>,
) -> Result<Vec<u8>, MediaError> {
let mut bytes = Vec::new();
while let Some(chunk) = response
.chunk()
.await
.map_err(|e| MediaError::FetchError(e.to_string()))?
{
bytes.extend_from_slice(&chunk);
if let Some(max) = max_bytes {
if bytes.len() > max {
return Err(MediaError::SizeExceeded(max / 1024 / 1024));
}
}
}

Ok(bytes)
}

/// Fetch a URL and return its content as a MediaBlock.
///
Expand All @@ -326,9 +530,7 @@ mod native_fetch {
allowed_types: Option<&[&str]>,
max_bytes: Option<usize>,
) -> Result<MediaBlock, MediaError> {
let response = reqwest::get(url)
.await
.map_err(|e| MediaError::FetchError(e.to_string()))?;
let mut response = fetch_validated_url(url).await?;

if !response.status().is_success() {
return Err(MediaError::FetchError(format!(
Expand Down Expand Up @@ -361,17 +563,7 @@ mod native_fetch {
}

// Get bytes
let bytes = response
.bytes()
.await
.map_err(|e| MediaError::FetchError(e.to_string()))?;

// Check size
if let Some(max) = max_bytes {
if bytes.len() > max {
return Err(MediaError::SizeExceeded(max / 1024 / 1024));
}
}
let bytes = response_bytes_with_limit(&mut response, max_bytes).await?;

// Encode to base64
use base64::Engine;
Expand All @@ -382,6 +574,132 @@ mod native_fetch {
data,
})
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn validate_media_url_rejects_non_http_schemes() {
let url = Url::parse("file:///etc/passwd").unwrap();
assert!(matches!(
validate_media_url(&url),
Err(MediaError::FetchError(message))
if message.contains("http or https")
));
}

#[test]
fn validate_media_url_rejects_localhost() {
let url = Url::parse("http://localhost/image.png").unwrap();
assert!(matches!(
validate_media_url(&url),
Err(MediaError::FetchError(message))
if message.contains("blocked address")
));
}

#[test]
fn validate_media_url_rejects_dns_resolved_localhost() {
let url = Url::parse("http://localhost./image.png").unwrap();
assert!(matches!(
validate_media_url(&url),
Err(MediaError::FetchError(message))
if message.contains("blocked address")
));
}

#[test]
fn validate_media_url_rejects_metadata_ip() {
let url = Url::parse("http://169.254.169.254/latest/meta-data").unwrap();
assert!(matches!(
validate_media_url(&url),
Err(MediaError::FetchError(message))
if message.contains("blocked address")
));
}

#[test]
fn validate_media_url_rejects_cloud_metadata_hostnames() {
for url in [
"http://metadata.amazonaws.com/latest/meta-data",
"http://metadata.google.internal/computeMetadata/v1",
] {
let url = Url::parse(url).unwrap();
assert!(matches!(
validate_media_url(&url),
Err(MediaError::FetchError(message))
if message.contains("blocked address")
));
}
}

#[test]
fn redirect_from_public_url_to_localhost_or_private_ip_is_rejected() {
let current_url = Url::parse("https://example.com/image.png").unwrap();

for location in [
"http://127.0.0.1:8080/private.png",
"http://192.168.0.10/private.png",
] {
let redirect_url = current_url.join(location).unwrap();

assert!(matches!(
validate_media_url(&redirect_url),
Err(MediaError::FetchError(message))
if message.contains("blocked address")
));
}
}

#[test]
fn validate_media_url_rejects_ipv4_mapped_ipv6_localhost() {
let dotted_url = Url::parse("http://[::ffff:127.0.0.1]/image.png").unwrap();
assert!(matches!(
validate_media_url(&dotted_url),
Err(MediaError::FetchError(message))
if message.contains("blocked address")
));

let hex_url = Url::parse("http://[::ffff:7f00:1]/image.png").unwrap();
assert!(matches!(
validate_media_url(&hex_url),
Err(MediaError::FetchError(message))
if message.contains("blocked address")
));
}

#[test]
fn blocked_ip_ranges_include_metadata_local_private_and_multicast_addresses() {
for address in [
IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)),
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(127, 255, 255, 255)),
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
IpAddr::V4(Ipv4Addr::new(0, 255, 255, 255)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1)),
IpAddr::V4(Ipv4Addr::new(172, 31, 255, 255)),
IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)),
IpAddr::V4(Ipv4Addr::new(224, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(239, 255, 255, 255)),
IpAddr::V6(Ipv6Addr::LOCALHOST),
IpAddr::V6("ff00::1".parse().unwrap()),
IpAddr::V6("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff".parse().unwrap()),
IpAddr::V6("::ffff:127.0.0.1".parse().unwrap()),
] {
assert!(is_blocked_ip(address), "{address} should be blocked");
}

for address in [
IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)),
IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
IpAddr::V6("2606:2800:220:1:248:1893:25c8:1946".parse().unwrap()),
] {
assert!(!is_blocked_ip(address), "{address} should be allowed");
}
}
}
}

// Re-export the appropriate implementation
Expand Down
Loading