Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ use tokio::runtime::Runtime;
fn main() {
let config = client::Config::builder()
.provider("deepinfra")
.routing_mode(client::RoutingMode::WRR)
.routing_mode(client::RouterMode::WRR)
.model(
client::ModelConfig::builder()
.name("deepseek-ai/DeepSeek-V3.2")
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/amrs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from amrs.config import ModelConfig, Config, RoutingMode
from amrs.config import ModelConfig, Config, RouterMode

__all__ = [
"ModelConfig",
"Config",
"RoutingMode",
"RouterMode",
]
6 changes: 3 additions & 3 deletions bindings/python/amrs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ class Message(BaseModel):
content: str = Field(description="Content of the message.")


class RoutingMode(str, Enum):
class RouterMode(str, Enum):
RANDOM = "random"
WEIGHTED = "weighted"


class Config(BasicModelConfig):
models: List[ModelConfig] = Field(description="List of model configurations")
routing_mode: RoutingMode = Field(
default=RoutingMode.RANDOM,
routing_mode: RouterMode = Field(
default=RouterMode.RANDOM,
description="Routing mode for the model, default is random.",
)
callback_funcs: Optional[List[Callable]] = Field(
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/amrs/router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def __init__(self, model_list: list[config.ModelName]):
def sample(self, content: str) -> config.ModelName:
pass

def new_router(model_cfgs: list[config.ModelConfig], mode: config.RoutingMode) -> Router:
def new_router(model_cfgs: list[config.ModelConfig], mode: config.RouterMode) -> Router:
model_list = [f"{model_cfg.provider}/{model_cfg.id}" for model_cfg in model_cfgs]

if mode == config.RoutingMode.RANDOM:
if mode == config.RouterMode.RANDOM:
from amrs.router.random import RandomRouter
return RandomRouter(model_list)
else:
Expand Down
2 changes: 1 addition & 1 deletion examples/wrr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use tokio::runtime::Runtime;
fn main() {
let config = client::Config::builder()
.provider("deepinfra")
.routing_mode(client::RoutingMode::WRR)
.routing_mode(client::RouterMode::WRR)
.model(
client::ModelConfig::builder()
.name("deepseek-ai/DeepSeek-V3.2")
Expand Down
4 changes: 2 additions & 2 deletions src/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl Client {
#[cfg(test)]
mod tests {
use super::*;
use crate::client::config::{Config, ModelConfig, RoutingMode};
use crate::client::config::{Config, ModelConfig, RouterMode};
use dotenvy::from_filename;

#[test]
Expand Down Expand Up @@ -81,7 +81,7 @@ mod tests {
TestCase {
name: "weighted round-robin router",
config: Config::builder()
.routing_mode(RoutingMode::WRR)
.routing_mode(RouterMode::WRR)
.models(vec![
crate::client::config::ModelConfig::builder()
.name("model_a".to_string())
Expand Down
10 changes: 5 additions & 5 deletions src/client/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ lazy_static! {

// ------------------ Routing Mode ------------------
#[derive(Debug, Clone, PartialEq)]
pub enum RoutingMode {
pub enum RouterMode {
Random,
WRR, // WeightedRoundRobin,
}
Expand Down Expand Up @@ -83,8 +83,8 @@ pub struct Config {
#[builder(default = "DEFAULT_PROVIDER.to_string()", setter(custom))]
pub(crate) provider: String,

#[builder(default = "RoutingMode::Random")]
pub(crate) routing_mode: RoutingMode,
#[builder(default = "RouterMode::Random")]
pub(crate) routing_mode: RouterMode,
#[builder(default = "vec![]")]
pub(crate) models: Vec<ModelConfig>,
}
Expand Down Expand Up @@ -152,7 +152,7 @@ impl ConfigBuilder {

for model in self.models.as_ref().unwrap() {
if self.routing_mode.is_some()
&& self.routing_mode.as_ref().unwrap() == &RoutingMode::WRR
&& self.routing_mode.as_ref().unwrap() == &RouterMode::WRR
&& model.weight <= 0
{
return Err(format!(
Expand Down Expand Up @@ -218,7 +218,7 @@ mod tests {
assert!(valid_simplest_models_cfg.is_ok());
assert!(valid_simplest_models_cfg.as_ref().unwrap().provider == DEFAULT_PROVIDER);
assert!(valid_simplest_models_cfg.as_ref().unwrap().base_url == None);
assert!(valid_simplest_models_cfg.as_ref().unwrap().routing_mode == RoutingMode::Random);
assert!(valid_simplest_models_cfg.as_ref().unwrap().routing_mode == RouterMode::Random);
assert!(valid_simplest_models_cfg.as_ref().unwrap().models.len() == 1);
assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].base_url == None);
assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].provider == None);
Expand Down
2 changes: 1 addition & 1 deletion src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ pub mod client;
pub mod config;

pub use client::Client;
pub use config::{Config, ModelConfig, ModelName, RoutingMode};
pub use config::{Config, ModelConfig, ModelName, RouterMode};
12 changes: 6 additions & 6 deletions src/router/router.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::client::config::{ModelConfig, ModelName, RoutingMode};
use crate::client::config::{ModelConfig, ModelName, RouterMode};
use crate::router::random::RandomRouter;
use crate::router::wrr::WeightedRoundRobinRouter;

Expand All @@ -8,7 +8,7 @@ pub struct ModelInfo {
pub weight: i32,
}

pub fn construct_router(mode: RoutingMode, models: Vec<ModelConfig>) -> Box<dyn Router> {
pub fn construct_router(mode: RouterMode, models: Vec<ModelConfig>) -> Box<dyn Router> {
let model_infos: Vec<ModelInfo> = models
.iter()
.map(|m| ModelInfo {
Expand All @@ -17,8 +17,8 @@ pub fn construct_router(mode: RoutingMode, models: Vec<ModelConfig>) -> Box<dyn
})
.collect();
match mode {
RoutingMode::Random => Box::new(RandomRouter::new(model_infos)),
RoutingMode::WRR => Box::new(WeightedRoundRobinRouter::new(model_infos)),
RouterMode::Random => Box::new(RandomRouter::new(model_infos)),
RouterMode::WRR => Box::new(WeightedRoundRobinRouter::new(model_infos)),
}
}

Expand Down Expand Up @@ -47,10 +47,10 @@ mod tests {
.unwrap(),
];

let random_router = construct_router(RoutingMode::Random, model_configs.clone());
let random_router = construct_router(RouterMode::Random, model_configs.clone());
assert_eq!(random_router.name(), "RandomRouter");

let weighted_router = construct_router(RoutingMode::WRR, model_configs.clone());
let weighted_router = construct_router(RouterMode::WRR, model_configs.clone());
assert_eq!(weighted_router.name(), "WeightedRoundRobinRouter");
}
}
2 changes: 1 addition & 1 deletion tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ mod tests {
// case 3: multiple models with router.
let config = client::Config::builder()
.provider("faker")
.routing_mode(client::RoutingMode::WRR)
.routing_mode(client::RouterMode::WRR)
.model(
client::ModelConfig::builder()
.name("gpt-3.5-turbo")
Expand Down
Loading