Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions crates/braintrust-llm-router/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::error::Error as StdError;
use std::io::ErrorKind;
use std::net::SocketAddr;
use std::time::Duration;

use dashmap::DashMap;
use once_cell::sync::Lazy;
use parking_lot::RwLock;
use reqwest::{Client, ClientBuilder};
use reqwest::{redirect::Policy, Client, ClientBuilder};
use reqwest_middleware::ClientWithMiddleware;
use reqwest_retry::{
default_on_request_failure, policies::ExponentialBackoff, RetryTransientMiddleware, Retryable,
Expand All @@ -17,6 +18,12 @@ use crate::error::{Error, Result};
// The default number of retries for transient transport failures.
const DEFAULT_MAX_RETRIES: u32 = 2;

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct DnsOverride {
pub domain: String,
pub addrs: Vec<SocketAddr>,
}

// Shared reqwest clients are cached by these settings. Keep this key
// low-cardinality and effectively process-wide; request-scoped values do not
// belong here because they fragment client reuse and connection pooling.
Expand All @@ -27,6 +34,8 @@ pub struct ClientSettings {
pub pool_idle_timeout: Duration,
pub pool_max_idle_per_host: usize,
pub user_agent: String,
pub dns_overrides: Vec<DnsOverride>,
pub follow_redirects: bool,
}

impl Default for ClientSettings {
Expand All @@ -37,19 +46,29 @@ impl Default for ClientSettings {
pool_idle_timeout: Duration::from_secs(90),
pool_max_idle_per_host: 16,
user_agent: format!("braintrust-llm-router/{}", env!("CARGO_PKG_VERSION")),
dns_overrides: Vec::new(),
follow_redirects: true,
}
}
}

pub fn build_client(settings: &ClientSettings) -> Result<Client> {
ClientBuilder::new()
let mut builder = ClientBuilder::new()
.connect_timeout(settings.connect_timeout)
.timeout(settings.request_timeout)
.pool_idle_timeout(settings.pool_idle_timeout)
.pool_max_idle_per_host(settings.pool_max_idle_per_host)
.user_agent(&settings.user_agent)
.build()
.map_err(Error::from)
.user_agent(&settings.user_agent);

if !settings.follow_redirects {
builder = builder.redirect(Policy::none());
}

for override_entry in &settings.dns_overrides {
builder = builder.resolve_to_addrs(&override_entry.domain, &override_entry.addrs);
}

builder.build().map_err(Error::from)
}

pub fn build_middleware_client(settings: &ClientSettings) -> Result<ClientWithMiddleware> {
Expand Down Expand Up @@ -279,6 +298,27 @@ mod tests {
assert_eq!(cached_client_count(), 2);
}

#[test]
#[serial]
fn build_middleware_client_reuses_cached_client_for_same_dns_overrides() {
clear_override_client();
clear_cached_clients();

let settings = ClientSettings {
dns_overrides: vec![DnsOverride {
domain: "example.com".to_string(),
addrs: vec!["127.0.0.1:443".parse().expect("socket addr")],
}],
..ClientSettings::default()
};

let first = build_middleware_client(&settings).expect("first client");
let second = build_middleware_client(&settings).expect("second client");

assert_eq!(cached_client_count(), 1);
assert_eq!(format!("{first:?}"), format!("{second:?}"));
}

#[test]
fn default_request_timeout_is_600_seconds() {
assert_eq!(
Expand Down
2 changes: 1 addition & 1 deletion crates/braintrust-llm-router/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod auth;
mod catalog;
mod client;
pub use client::{clear_override_client, set_override_client};
pub use client::{clear_override_client, set_override_client, ClientSettings, DnsOverride};
pub use reqwest_middleware::ClientWithMiddleware;
mod error;
mod providers;
Expand Down
11 changes: 9 additions & 2 deletions crates/braintrust-llm-router/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ pub struct AnthropicProvider {

impl AnthropicProvider {
pub fn new(config: AnthropicConfig) -> Result<Self> {
let mut settings = ClientSettings::default();
Self::new_with_client_settings(config, ClientSettings::default())
}

pub fn new_with_client_settings(
config: AnthropicConfig,
mut settings: ClientSettings,
) -> Result<Self> {
if let Some(timeout) = config.timeout {
settings.request_timeout = timeout;
}
Expand All @@ -60,6 +66,7 @@ impl AnthropicProvider {
endpoint: Option<&Url>,
timeout: Option<Duration>,
metadata: &std::collections::HashMap<String, lingua::serde_json::Value>,
client_settings: Option<ClientSettings>,
) -> Result<Self> {
use lingua::serde_json::Value;
let mut config = AnthropicConfig::default();
Expand All @@ -74,7 +81,7 @@ impl AnthropicProvider {
config.version = version.to_string();
}

Self::new(config)
Self::new_with_client_settings(config, client_settings.unwrap_or_default())
}

fn messages_url(&self) -> Url {
Expand Down
13 changes: 10 additions & 3 deletions crates/braintrust-llm-router/src/providers/azure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ pub struct AzureProvider {

impl AzureProvider {
pub fn new(config: AzureConfig) -> Result<Self> {
let mut settings = ClientSettings::default();
Self::new_with_client_settings(config, ClientSettings::default())
}

pub fn new_with_client_settings(
config: AzureConfig,
mut settings: ClientSettings,
) -> Result<Self> {
if let Some(timeout) = config.timeout {
settings.request_timeout = timeout;
}
Expand All @@ -64,6 +70,7 @@ impl AzureProvider {
endpoint: Option<&Url>,
timeout: Option<Duration>,
metadata: &std::collections::HashMap<String, MetadataValue>,
client_settings: Option<ClientSettings>,
) -> Result<Self> {
let endpoint = endpoint
.cloned()
Expand Down Expand Up @@ -94,7 +101,7 @@ impl AzureProvider {
config.no_named_deployment = no_named;
}

Self::new(config)
Self::new_with_client_settings(config, client_settings.unwrap_or_default())
}

fn deployment_for_request(&self, model: &str) -> Result<String> {
Expand Down Expand Up @@ -368,7 +375,7 @@ mod tests {
}

fn make_provider(metadata: HashMap<String, MetadataValue>) -> AzureProvider {
AzureProvider::from_config(Some(&endpoint()), None, &metadata).unwrap()
AzureProvider::from_config(Some(&endpoint()), None, &metadata, None).unwrap()
}

#[test]
Expand Down
11 changes: 9 additions & 2 deletions crates/braintrust-llm-router/src/providers/bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,13 @@ pub struct BedrockProvider {

impl BedrockProvider {
pub fn new(config: BedrockConfig) -> Result<Self> {
let mut settings = ClientSettings::default();
Self::new_with_client_settings(config, ClientSettings::default())
}

pub fn new_with_client_settings(
config: BedrockConfig,
mut settings: ClientSettings,
) -> Result<Self> {
if let Some(timeout) = config.timeout {
settings.request_timeout = timeout;
}
Expand All @@ -202,6 +208,7 @@ impl BedrockProvider {
endpoint: Option<&Url>,
timeout: Option<Duration>,
metadata: &std::collections::HashMap<String, Value>,
client_settings: Option<ClientSettings>,
) -> Result<Self> {
let mut config = BedrockConfig::default();

Expand All @@ -221,7 +228,7 @@ impl BedrockProvider {
config.timeout = Some(t);
}

Self::new(config)
Self::new_with_client_settings(config, client_settings.unwrap_or_default())
}

fn converse_url(&self, model: &str, stream: bool) -> Result<Url> {
Expand Down
21 changes: 17 additions & 4 deletions crates/braintrust-llm-router/src/providers/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,32 @@ pub struct DatabricksProvider {

impl DatabricksProvider {
pub fn new(config: DatabricksConfig) -> Result<Self> {
let mut settings = ClientSettings::default();
Self::new_with_client_settings(config, ClientSettings::default())
}

pub fn new_with_client_settings(
config: DatabricksConfig,
mut settings: ClientSettings,
) -> Result<Self> {
if let Some(timeout) = config.timeout {
settings.request_timeout = timeout;
}
let client = build_middleware_client(&settings)?;
Ok(Self { client, config })
}

pub fn from_config(api_base: Option<&Url>, timeout: Option<Duration>) -> Result<Self> {
pub fn from_config(
api_base: Option<&Url>,
timeout: Option<Duration>,
client_settings: Option<ClientSettings>,
) -> Result<Self> {
let api_base = api_base
.cloned()
.ok_or_else(|| Error::InvalidRequest("databricks provider requires api_base".into()))?;
Self::new(DatabricksConfig { api_base, timeout })
Self::new_with_client_settings(
DatabricksConfig { api_base, timeout },
client_settings.unwrap_or_default(),
)
}

// This does not support Databrick's new AI gateway URL format yet, only
Expand Down Expand Up @@ -230,7 +243,7 @@ mod tests {

#[test]
fn from_config_requires_api_base() {
let err = DatabricksProvider::from_config(None, None).unwrap_err();
let err = DatabricksProvider::from_config(None, None, None).unwrap_err();
assert!(matches!(err, Error::InvalidRequest(_)));
}
}
16 changes: 13 additions & 3 deletions crates/braintrust-llm-router/src/providers/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ pub struct GoogleProvider {

impl GoogleProvider {
pub fn new(config: GoogleConfig) -> Result<Self> {
let mut settings = ClientSettings::default();
Self::new_with_client_settings(config, ClientSettings::default())
}

pub fn new_with_client_settings(
config: GoogleConfig,
mut settings: ClientSettings,
) -> Result<Self> {
if let Some(timeout) = config.timeout {
settings.request_timeout = timeout;
}
Expand All @@ -46,7 +52,11 @@ impl GoogleProvider {
}

/// Create a Google provider from configuration parameters.
pub fn from_config(endpoint: Option<&Url>, timeout: Option<Duration>) -> Result<Self> {
pub fn from_config(
endpoint: Option<&Url>,
timeout: Option<Duration>,
client_settings: Option<ClientSettings>,
) -> Result<Self> {
let mut config = GoogleConfig::default();

if let Some(ep) = endpoint {
Expand All @@ -56,7 +66,7 @@ impl GoogleProvider {
config.timeout = Some(t);
}

Self::new(config)
Self::new_with_client_settings(config, client_settings.unwrap_or_default())
}

fn generate_url(&self, model: &str, stream: bool) -> Result<Url> {
Expand Down
16 changes: 13 additions & 3 deletions crates/braintrust-llm-router/src/providers/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@ pub struct MistralProvider {

impl MistralProvider {
pub fn new(config: MistralConfig) -> Result<Self> {
let mut settings = ClientSettings::default();
Self::new_with_client_settings(config, ClientSettings::default())
}

pub fn new_with_client_settings(
config: MistralConfig,
mut settings: ClientSettings,
) -> Result<Self> {
if let Some(timeout) = config.timeout {
settings.request_timeout = timeout;
}
Expand All @@ -45,7 +51,11 @@ impl MistralProvider {
}

/// Create a Mistral provider from configuration parameters.
pub fn from_config(endpoint: Option<&Url>, timeout: Option<Duration>) -> Result<Self> {
pub fn from_config(
endpoint: Option<&Url>,
timeout: Option<Duration>,
client_settings: Option<ClientSettings>,
) -> Result<Self> {
let mut config = MistralConfig::default();

if let Some(ep) = endpoint {
Expand All @@ -55,7 +65,7 @@ impl MistralProvider {
config.timeout = Some(t);
}

Self::new(config)
Self::new_with_client_settings(config, client_settings.unwrap_or_default())
}

fn chat_url(&self) -> Result<Url> {
Expand Down
11 changes: 9 additions & 2 deletions crates/braintrust-llm-router/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@ pub struct OpenAIProvider {

impl OpenAIProvider {
pub fn new(config: OpenAIConfig) -> Result<Self> {
let mut settings = ClientSettings::default();
Self::new_with_client_settings(config, ClientSettings::default())
}

pub fn new_with_client_settings(
config: OpenAIConfig,
mut settings: ClientSettings,
) -> Result<Self> {
if let Some(timeout) = config.timeout {
settings.request_timeout = timeout;
}
Expand All @@ -71,6 +77,7 @@ impl OpenAIProvider {
endpoint_template: Option<&str>,
timeout: Option<Duration>,
metadata: &std::collections::HashMap<String, Value>,
client_settings: Option<ClientSettings>,
) -> Result<Self> {
let mut config = OpenAIConfig::default();

Expand Down Expand Up @@ -101,7 +108,7 @@ impl OpenAIProvider {
config.api_version = Some(version.to_string());
}

Self::new(config)
Self::new_with_client_settings(config, client_settings.unwrap_or_default())
}

fn resolve_base(&self, model: Option<&str>) -> Result<Url> {
Expand Down
Loading
Loading