diff --git a/crates/braintrust-llm-router/src/client.rs b/crates/braintrust-llm-router/src/client.rs index 4c155ec5..0004c52f 100644 --- a/crates/braintrust-llm-router/src/client.rs +++ b/crates/braintrust-llm-router/src/client.rs @@ -1,11 +1,12 @@ use std::error::Error as StdError; use std::io::ErrorKind; +use std::net::SocketAddr; use std::time::Duration; use dashmap::DashMap; use once_cell::sync::Lazy; use parking_lot::RwLock; -use reqwest::{Client, ClientBuilder}; +use reqwest::{redirect::Policy, Client, ClientBuilder}; use reqwest_middleware::ClientWithMiddleware; use reqwest_retry::{ default_on_request_failure, policies::ExponentialBackoff, RetryTransientMiddleware, Retryable, @@ -17,6 +18,12 @@ use crate::error::{Error, Result}; // The default number of retries for transient transport failures. const DEFAULT_MAX_RETRIES: u32 = 2; +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct DnsOverride { + pub domain: String, + pub addrs: Vec, +} + // Shared reqwest clients are cached by these settings. Keep this key // low-cardinality and effectively process-wide; request-scoped values do not // belong here because they fragment client reuse and connection pooling. @@ -27,6 +34,8 @@ pub struct ClientSettings { pub pool_idle_timeout: Duration, pub pool_max_idle_per_host: usize, pub user_agent: String, + pub dns_overrides: Vec, + pub follow_redirects: bool, } impl Default for ClientSettings { @@ -37,19 +46,29 @@ impl Default for ClientSettings { pool_idle_timeout: Duration::from_secs(90), pool_max_idle_per_host: 16, user_agent: format!("braintrust-llm-router/{}", env!("CARGO_PKG_VERSION")), + dns_overrides: Vec::new(), + follow_redirects: true, } } } pub fn build_client(settings: &ClientSettings) -> Result { - ClientBuilder::new() + let mut builder = ClientBuilder::new() .connect_timeout(settings.connect_timeout) .timeout(settings.request_timeout) .pool_idle_timeout(settings.pool_idle_timeout) .pool_max_idle_per_host(settings.pool_max_idle_per_host) - .user_agent(&settings.user_agent) - .build() - .map_err(Error::from) + .user_agent(&settings.user_agent); + + if !settings.follow_redirects { + builder = builder.redirect(Policy::none()); + } + + for override_entry in &settings.dns_overrides { + builder = builder.resolve_to_addrs(&override_entry.domain, &override_entry.addrs); + } + + builder.build().map_err(Error::from) } pub fn build_middleware_client(settings: &ClientSettings) -> Result { @@ -279,6 +298,27 @@ mod tests { assert_eq!(cached_client_count(), 2); } + #[test] + #[serial] + fn build_middleware_client_reuses_cached_client_for_same_dns_overrides() { + clear_override_client(); + clear_cached_clients(); + + let settings = ClientSettings { + dns_overrides: vec![DnsOverride { + domain: "example.com".to_string(), + addrs: vec!["127.0.0.1:443".parse().expect("socket addr")], + }], + ..ClientSettings::default() + }; + + let first = build_middleware_client(&settings).expect("first client"); + let second = build_middleware_client(&settings).expect("second client"); + + assert_eq!(cached_client_count(), 1); + assert_eq!(format!("{first:?}"), format!("{second:?}")); + } + #[test] fn default_request_timeout_is_600_seconds() { assert_eq!( diff --git a/crates/braintrust-llm-router/src/lib.rs b/crates/braintrust-llm-router/src/lib.rs index 81475a6e..5df809c8 100644 --- a/crates/braintrust-llm-router/src/lib.rs +++ b/crates/braintrust-llm-router/src/lib.rs @@ -1,7 +1,7 @@ mod auth; mod catalog; mod client; -pub use client::{clear_override_client, set_override_client}; +pub use client::{clear_override_client, set_override_client, ClientSettings, DnsOverride}; pub use reqwest_middleware::ClientWithMiddleware; mod error; mod providers; diff --git a/crates/braintrust-llm-router/src/providers/anthropic.rs b/crates/braintrust-llm-router/src/providers/anthropic.rs index d7db328e..88635901 100644 --- a/crates/braintrust-llm-router/src/providers/anthropic.rs +++ b/crates/braintrust-llm-router/src/providers/anthropic.rs @@ -44,7 +44,13 @@ pub struct AnthropicProvider { impl AnthropicProvider { pub fn new(config: AnthropicConfig) -> Result { - let mut settings = ClientSettings::default(); + Self::new_with_client_settings(config, ClientSettings::default()) + } + + pub fn new_with_client_settings( + config: AnthropicConfig, + mut settings: ClientSettings, + ) -> Result { if let Some(timeout) = config.timeout { settings.request_timeout = timeout; } @@ -60,6 +66,7 @@ impl AnthropicProvider { endpoint: Option<&Url>, timeout: Option, metadata: &std::collections::HashMap, + client_settings: Option, ) -> Result { use lingua::serde_json::Value; let mut config = AnthropicConfig::default(); @@ -74,7 +81,7 @@ impl AnthropicProvider { config.version = version.to_string(); } - Self::new(config) + Self::new_with_client_settings(config, client_settings.unwrap_or_default()) } fn messages_url(&self) -> Url { diff --git a/crates/braintrust-llm-router/src/providers/azure.rs b/crates/braintrust-llm-router/src/providers/azure.rs index 54af47e7..4c8b2320 100644 --- a/crates/braintrust-llm-router/src/providers/azure.rs +++ b/crates/braintrust-llm-router/src/providers/azure.rs @@ -46,7 +46,13 @@ pub struct AzureProvider { impl AzureProvider { pub fn new(config: AzureConfig) -> Result { - let mut settings = ClientSettings::default(); + Self::new_with_client_settings(config, ClientSettings::default()) + } + + pub fn new_with_client_settings( + config: AzureConfig, + mut settings: ClientSettings, + ) -> Result { if let Some(timeout) = config.timeout { settings.request_timeout = timeout; } @@ -64,6 +70,7 @@ impl AzureProvider { endpoint: Option<&Url>, timeout: Option, metadata: &std::collections::HashMap, + client_settings: Option, ) -> Result { let endpoint = endpoint .cloned() @@ -94,7 +101,7 @@ impl AzureProvider { config.no_named_deployment = no_named; } - Self::new(config) + Self::new_with_client_settings(config, client_settings.unwrap_or_default()) } fn deployment_for_request(&self, model: &str) -> Result { @@ -368,7 +375,7 @@ mod tests { } fn make_provider(metadata: HashMap) -> AzureProvider { - AzureProvider::from_config(Some(&endpoint()), None, &metadata).unwrap() + AzureProvider::from_config(Some(&endpoint()), None, &metadata, None).unwrap() } #[test] diff --git a/crates/braintrust-llm-router/src/providers/bedrock.rs b/crates/braintrust-llm-router/src/providers/bedrock.rs index b5534a1d..55e9825a 100644 --- a/crates/braintrust-llm-router/src/providers/bedrock.rs +++ b/crates/braintrust-llm-router/src/providers/bedrock.rs @@ -185,7 +185,13 @@ pub struct BedrockProvider { impl BedrockProvider { pub fn new(config: BedrockConfig) -> Result { - let mut settings = ClientSettings::default(); + Self::new_with_client_settings(config, ClientSettings::default()) + } + + pub fn new_with_client_settings( + config: BedrockConfig, + mut settings: ClientSettings, + ) -> Result { if let Some(timeout) = config.timeout { settings.request_timeout = timeout; } @@ -202,6 +208,7 @@ impl BedrockProvider { endpoint: Option<&Url>, timeout: Option, metadata: &std::collections::HashMap, + client_settings: Option, ) -> Result { let mut config = BedrockConfig::default(); @@ -221,7 +228,7 @@ impl BedrockProvider { config.timeout = Some(t); } - Self::new(config) + Self::new_with_client_settings(config, client_settings.unwrap_or_default()) } fn converse_url(&self, model: &str, stream: bool) -> Result { diff --git a/crates/braintrust-llm-router/src/providers/databricks.rs b/crates/braintrust-llm-router/src/providers/databricks.rs index 3604cfe5..08a0b8bc 100644 --- a/crates/braintrust-llm-router/src/providers/databricks.rs +++ b/crates/braintrust-llm-router/src/providers/databricks.rs @@ -27,7 +27,13 @@ pub struct DatabricksProvider { impl DatabricksProvider { pub fn new(config: DatabricksConfig) -> Result { - let mut settings = ClientSettings::default(); + Self::new_with_client_settings(config, ClientSettings::default()) + } + + pub fn new_with_client_settings( + config: DatabricksConfig, + mut settings: ClientSettings, + ) -> Result { if let Some(timeout) = config.timeout { settings.request_timeout = timeout; } @@ -35,11 +41,18 @@ impl DatabricksProvider { Ok(Self { client, config }) } - pub fn from_config(api_base: Option<&Url>, timeout: Option) -> Result { + pub fn from_config( + api_base: Option<&Url>, + timeout: Option, + client_settings: Option, + ) -> Result { let api_base = api_base .cloned() .ok_or_else(|| Error::InvalidRequest("databricks provider requires api_base".into()))?; - Self::new(DatabricksConfig { api_base, timeout }) + Self::new_with_client_settings( + DatabricksConfig { api_base, timeout }, + client_settings.unwrap_or_default(), + ) } // This does not support Databrick's new AI gateway URL format yet, only @@ -230,7 +243,7 @@ mod tests { #[test] fn from_config_requires_api_base() { - let err = DatabricksProvider::from_config(None, None).unwrap_err(); + let err = DatabricksProvider::from_config(None, None, None).unwrap_err(); assert!(matches!(err, Error::InvalidRequest(_))); } } diff --git a/crates/braintrust-llm-router/src/providers/google.rs b/crates/braintrust-llm-router/src/providers/google.rs index c6be383c..3850e0a6 100644 --- a/crates/braintrust-llm-router/src/providers/google.rs +++ b/crates/braintrust-llm-router/src/providers/google.rs @@ -37,7 +37,13 @@ pub struct GoogleProvider { impl GoogleProvider { pub fn new(config: GoogleConfig) -> Result { - let mut settings = ClientSettings::default(); + Self::new_with_client_settings(config, ClientSettings::default()) + } + + pub fn new_with_client_settings( + config: GoogleConfig, + mut settings: ClientSettings, + ) -> Result { if let Some(timeout) = config.timeout { settings.request_timeout = timeout; } @@ -46,7 +52,11 @@ impl GoogleProvider { } /// Create a Google provider from configuration parameters. - pub fn from_config(endpoint: Option<&Url>, timeout: Option) -> Result { + pub fn from_config( + endpoint: Option<&Url>, + timeout: Option, + client_settings: Option, + ) -> Result { let mut config = GoogleConfig::default(); if let Some(ep) = endpoint { @@ -56,7 +66,7 @@ impl GoogleProvider { config.timeout = Some(t); } - Self::new(config) + Self::new_with_client_settings(config, client_settings.unwrap_or_default()) } fn generate_url(&self, model: &str, stream: bool) -> Result { diff --git a/crates/braintrust-llm-router/src/providers/mistral.rs b/crates/braintrust-llm-router/src/providers/mistral.rs index 5c9c6f89..d1038fd1 100644 --- a/crates/braintrust-llm-router/src/providers/mistral.rs +++ b/crates/braintrust-llm-router/src/providers/mistral.rs @@ -36,7 +36,13 @@ pub struct MistralProvider { impl MistralProvider { pub fn new(config: MistralConfig) -> Result { - let mut settings = ClientSettings::default(); + Self::new_with_client_settings(config, ClientSettings::default()) + } + + pub fn new_with_client_settings( + config: MistralConfig, + mut settings: ClientSettings, + ) -> Result { if let Some(timeout) = config.timeout { settings.request_timeout = timeout; } @@ -45,7 +51,11 @@ impl MistralProvider { } /// Create a Mistral provider from configuration parameters. - pub fn from_config(endpoint: Option<&Url>, timeout: Option) -> Result { + pub fn from_config( + endpoint: Option<&Url>, + timeout: Option, + client_settings: Option, + ) -> Result { let mut config = MistralConfig::default(); if let Some(ep) = endpoint { @@ -55,7 +65,7 @@ impl MistralProvider { config.timeout = Some(t); } - Self::new(config) + Self::new_with_client_settings(config, client_settings.unwrap_or_default()) } fn chat_url(&self) -> Result { diff --git a/crates/braintrust-llm-router/src/providers/openai.rs b/crates/braintrust-llm-router/src/providers/openai.rs index 6c8818c4..21c3fba2 100644 --- a/crates/braintrust-llm-router/src/providers/openai.rs +++ b/crates/braintrust-llm-router/src/providers/openai.rs @@ -47,7 +47,13 @@ pub struct OpenAIProvider { impl OpenAIProvider { pub fn new(config: OpenAIConfig) -> Result { - let mut settings = ClientSettings::default(); + Self::new_with_client_settings(config, ClientSettings::default()) + } + + pub fn new_with_client_settings( + config: OpenAIConfig, + mut settings: ClientSettings, + ) -> Result { if let Some(timeout) = config.timeout { settings.request_timeout = timeout; } @@ -71,6 +77,7 @@ impl OpenAIProvider { endpoint_template: Option<&str>, timeout: Option, metadata: &std::collections::HashMap, + client_settings: Option, ) -> Result { let mut config = OpenAIConfig::default(); @@ -101,7 +108,7 @@ impl OpenAIProvider { config.api_version = Some(version.to_string()); } - Self::new(config) + Self::new_with_client_settings(config, client_settings.unwrap_or_default()) } fn resolve_base(&self, model: Option<&str>) -> Result { diff --git a/crates/braintrust-llm-router/src/providers/vertex.rs b/crates/braintrust-llm-router/src/providers/vertex.rs index a484f5b2..a233ab7d 100644 --- a/crates/braintrust-llm-router/src/providers/vertex.rs +++ b/crates/braintrust-llm-router/src/providers/vertex.rs @@ -52,7 +52,13 @@ struct VertexModelExtra { impl VertexProvider { pub fn new(config: VertexConfig) -> Result { - let mut settings = ClientSettings::default(); + Self::new_with_client_settings(config, ClientSettings::default()) + } + + pub fn new_with_client_settings( + config: VertexConfig, + mut settings: ClientSettings, + ) -> Result { if let Some(timeout) = config.timeout { settings.request_timeout = timeout; } @@ -70,6 +76,7 @@ impl VertexProvider { endpoint: Option<&Url>, timeout: Option, metadata: &std::collections::HashMap, + client_settings: Option, ) -> Result { let project = metadata .get("project") @@ -108,7 +115,7 @@ impl VertexProvider { timeout, }; - Self::new(config) + Self::new_with_client_settings(config, client_settings.unwrap_or_default()) } fn resolve_location(&self, spec: &ModelSpec) -> String { @@ -611,7 +618,7 @@ mod tests { metadata.insert("location".into(), Value::String("us-east5".into())); let provider = - VertexProvider::from_config(Some(&global_endpoint), None, &metadata).unwrap(); + VertexProvider::from_config(Some(&global_endpoint), None, &metadata, None).unwrap(); assert!(provider.config.api_base.is_none()); assert_eq!(provider.config.location, "us-east5"); } @@ -624,7 +631,7 @@ mod tests { metadata.insert("location".into(), Value::String("us-east5".into())); let provider = - VertexProvider::from_config(Some(&custom_endpoint), None, &metadata).unwrap(); + VertexProvider::from_config(Some(&custom_endpoint), None, &metadata, None).unwrap(); assert_eq!( provider.config.api_base.as_deref(), Some("https://my-proxy.example.com/") @@ -640,7 +647,7 @@ mod tests { Value::String("https://custom-proxy.example.com".into()), ); - let provider = VertexProvider::from_config(None, None, &metadata).unwrap(); + let provider = VertexProvider::from_config(None, None, &metadata, None).unwrap(); assert_eq!( provider.config.api_base.as_deref(), Some("https://custom-proxy.example.com") @@ -653,7 +660,7 @@ mod tests { metadata.insert("project".into(), Value::String("my-project".into())); metadata.insert("api_base".into(), Value::String("".into())); - let provider = VertexProvider::from_config(None, None, &metadata).unwrap(); + let provider = VertexProvider::from_config(None, None, &metadata, None).unwrap(); assert!(provider.config.api_base.is_none()); } @@ -663,7 +670,7 @@ mod tests { metadata.insert("project".into(), Value::String("my-project".into())); metadata.insert("location".into(), Value::String(" ".into())); - let provider = VertexProvider::from_config(None, None, &metadata).unwrap(); + let provider = VertexProvider::from_config(None, None, &metadata, None).unwrap(); assert_eq!(provider.config.location, "global"); } } diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index f2a9fedc..a3516176 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -10,6 +10,7 @@ use bytes::Bytes; use crate::auth::AuthConfig; use crate::catalog::{load_catalog_from_disk, ModelCatalog, ModelResolver, ModelSpec}; +use crate::client::ClientSettings; use crate::error::{Error, Result}; use crate::providers::{ enable_streaming_payload, prepare_bedrock_request, requires_bedrock_request_preparation, @@ -42,6 +43,7 @@ use crate::providers::{ /// * `endpoint_template` - Endpoint template with `` placeholder (optional, OpenAI only) /// * `timeout` - Request timeout (optional) /// * `metadata` - Provider-specific options (organization_id, project, api_version, etc.) +/// * `client_settings` - HTTP client settings, including optional DNS pins. /// /// # Example /// @@ -50,7 +52,7 @@ use crate::providers::{ /// use std::collections::HashMap; /// /// let metadata = HashMap::new(); -/// let provider = create_provider("openai", None, None, None, &metadata)?; +/// let provider = create_provider("openai", None, None, None, &metadata, None)?; /// let router = Router::builder() /// .with_catalog(catalog) /// .add_provider("openai", provider) @@ -63,6 +65,7 @@ pub fn create_provider( endpoint_template: Option<&str>, timeout: Option, metadata: &HashMap, + client_settings: Option, ) -> Result> { match kind { "openai" => Ok(Arc::new(OpenAIProvider::from_config( @@ -70,29 +73,53 @@ pub fn create_provider( endpoint_template, timeout, metadata, + client_settings, )?)), "anthropic" => Ok(Arc::new(AnthropicProvider::from_config( - endpoint, timeout, metadata, + endpoint, + timeout, + metadata, + client_settings, )?)), "azure" => Ok(Arc::new(AzureProvider::from_config( - endpoint, timeout, metadata, + endpoint, + timeout, + metadata, + client_settings, + )?)), + "google" => Ok(Arc::new(GoogleProvider::from_config( + endpoint, + timeout, + client_settings, )?)), - "google" => Ok(Arc::new(GoogleProvider::from_config(endpoint, timeout)?)), "vertex" => Ok(Arc::new(VertexProvider::from_config( - endpoint, timeout, metadata, + endpoint, + timeout, + metadata, + client_settings, )?)), "bedrock" => Ok(Arc::new(BedrockProvider::from_config( - endpoint, timeout, metadata, + endpoint, + timeout, + metadata, + client_settings, )?)), "databricks" => Ok(Arc::new(DatabricksProvider::from_config( - endpoint, timeout, + endpoint, + timeout, + client_settings, + )?)), + "mistral" => Ok(Arc::new(MistralProvider::from_config( + endpoint, + timeout, + client_settings, )?)), - "mistral" => Ok(Arc::new(MistralProvider::from_config(endpoint, timeout)?)), kind if is_openai_compatible(kind) => Ok(Arc::new(OpenAIProvider::from_config( endpoint, endpoint_template, timeout, metadata, + client_settings, )?)), other => Err(Error::InvalidRequest(format!( "unsupported provider kind: {other}"