diff --git a/Cargo.lock b/Cargo.lock index 8e6d5b6..3e10361 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -106,6 +106,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "fastrand" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" + [[package]] name = "fnv" version = "1.0.7" @@ -376,6 +382,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "tempfile", "thiserror 2.0.18", "tokio", "tracing", @@ -401,6 +408,12 @@ version = "0.2.186" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + [[package]] name = "lock_api" version = "0.4.14" @@ -573,6 +586,19 @@ dependencies = [ "bitflags", ] +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -734,6 +760,19 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +[[package]] +name = "tempfile" +version = "3.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" +dependencies = [ + "fastrand", + "getrandom", + "once_cell", + "rustix", + "windows-sys", +] + [[package]] name = "thiserror" version = "1.0.69" diff --git a/Cargo.toml b/Cargo.toml index c6b794f..e5a5c52 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,5 +18,5 @@ tracing = "0.1.44" tracing-subscriber = {version="0.3.23", features = ["json"]} uuid = {version="1.23.2",features=["v4"]} - - +[dev-dependencies] +tempfile = "3.27.0" diff --git a/ROADMAP.md b/ROADMAP.md index bba7ab2..0799373 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -11,7 +11,7 @@ - [x] Define configuration format - [x] Create runtime AppState - [x] Define backend configuration model -- [ ] Setup graceful shutdown handling +- [x] Setup graceful shutdown handling --- @@ -140,8 +140,8 @@ ## Weighted Balancing -- [ ] Weighted Round Robin -- [ ] Backend weights in config +- [x] Weighted Round Robin +- [x] Backend weights in config - [ ] Dynamic weight updates --- @@ -180,16 +180,16 @@ - [x] Add draining backend state - [x] Stop routing new connections - [x] Wait for active connections -- [ ] Graceful backend removal +- [x] Graceful backend removal --- ## Dynamic Config Reloading -- [ ] Watch configuration file -- [ ] Reload backend configuration -- [ ] Preserve active connections -- [ ] Runtime backend updates +- [x] Watch configuration file (manual reload api semantics implemented) +- [x] Reload backend configuration +- [x] Preserve active connections +- [x] Runtime backend updates --- diff --git a/src/admin/http.rs b/src/admin/http.rs index 86aff08..8c07354 100644 --- a/src/admin/http.rs +++ b/src/admin/http.rs @@ -1,4 +1,6 @@ -use crate::{metrics::registry::gather_metrics, state::app::SharedAppState}; +use crate::{ + admin::reload::reload_config, metrics::registry::gather_metrics, state::app::SharedAppState, +}; use axum::{ Json, Router, extract::{Path, State}, @@ -31,6 +33,31 @@ struct MetricsResponse { upstreams: Vec, } +#[derive(Serialize)] +struct BackendStatus { + id: String, + healthy: bool, + draining: bool, + weight: usize, + active_connections: usize, + total_requests: usize, + failed_requests: usize, +} + +#[derive(Serialize)] +struct UpstreamStatus { + id: String, + algorithm: String, + backend_count: usize, + weighted_backend_count: usize, + backends: Vec, +} + +#[derive(Serialize)] +struct StatusResponse { + upstreams: Vec, +} + async fn prometheus_handler() -> String { gather_metrics() } @@ -76,7 +103,6 @@ async fn drain_backend_handler( for backend in &upstream.backends { if backend.config.id == id { backend.mark_draining(); - tracing::info!( backend_id = %id, "backend marked as draining" @@ -90,11 +116,62 @@ async fn drain_backend_handler( format!("backend '{id}' not found") } +async fn reload_handler(State(state): State) -> String { + match reload_config(state).await { + Ok(_) => "config reloaded".into(), + + Err(error) => { + tracing::error!( + error = %error, + "config reload failed" + ); + + format!("reload failed: {error}") + } + } +} + +async fn status_handler(State(state): State) -> Json { + let state = state.read().await; + + let upstreams = state + .upstreams + .iter() + .map(|upstream| { + let backends = upstream + .backends + .iter() + .map(|backend| BackendStatus { + id: backend.config.id.clone(), + healthy: backend.healthy.load(Ordering::Relaxed), + draining: backend.draining.load(Ordering::Relaxed), + weight: backend.config.weight, + active_connections: backend.active_connections.load(Ordering::Relaxed), + total_requests: backend.total_requests.load(Ordering::Relaxed), + failed_requests: backend.failed_requests.load(Ordering::Relaxed), + }) + .collect(); + + UpstreamStatus { + id: upstream.id.clone(), + algorithm: format!("{:?}", upstream.algorithm), + backend_count: upstream.backends.len(), + weighted_backend_count: upstream.weighted_backends.len(), + backends, + } + }) + .collect(); + + Json(StatusResponse { upstreams }) +} + pub async fn start_admin_server(address: &str, state: SharedAppState) -> anyhow::Result<()> { let app = Router::new() .route("/metrics", get(metrics_handler)) .route("/backend/{id}/drain", post(drain_backend_handler)) .route("/prometheus", get(prometheus_handler)) + .route("/reload", post(reload_handler)) + .route("/status", get(status_handler)) .with_state(state); let listener = TcpListener::bind(address).await?; axum::serve(listener, app).await?; diff --git a/src/admin/mod.rs b/src/admin/mod.rs index 3883215..2b3459e 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -1 +1,2 @@ pub mod http; +pub mod reload; diff --git a/src/admin/reload.rs b/src/admin/reload.rs new file mode 100644 index 0000000..97b57e5 --- /dev/null +++ b/src/admin/reload.rs @@ -0,0 +1,89 @@ +use std::sync::{Arc, atomic::AtomicUsize}; + +use anyhow::Result; + +use crate::{ + config::{loader::load_config, validator::validate_config}, + state::{ + app::{SharedAppState, UpstreamPool}, + backend::BackendState, + }, +}; + +pub async fn reload_config(state: SharedAppState) -> Result<()> { + let config_path = { + let state = state.read().await; + state.config_path.clone() + }; + + let config = load_config(&config_path)?; + validate_config(&config)?; + let mut state = state.write().await; + + for new_upstream in config.upstreams { + let existing_upstream = state.upstreams.iter_mut().find(|u| u.id == new_upstream.id); + match existing_upstream { + Some(upstream) => { + for server in &new_upstream.servers { + let exists = upstream.backends.iter().any(|b| b.config.id == server.id); + + if !exists { + tracing::info!( + backend_id = %server.id, + "adding new backend" + ); + upstream.backends.push(Arc::new(BackendState::new(server.clone()))); + } + } + + for backend in &upstream.backends { + let still_exists = + new_upstream.servers.iter().any(|s| s.id == backend.config.id); + + if !still_exists { + backend.mark_draining(); + tracing::info!( + backend_id = + %backend.config.id, + "backend marked draining during reload" + ); + } + } + upstream.rebuild_weighted_backends(); + } + + None => { + tracing::info!( + upstream_id = %new_upstream.id, + "adding new upstream" + ); + + let backends = new_upstream + .servers + .into_iter() + .map(|server| Arc::new(BackendState::new(server))) + .collect(); + + let mut upstream_pool = UpstreamPool { + id: new_upstream.id, + + current_index: AtomicUsize::new(0), + + algorithm: new_upstream.algorithm, + + backends, + + weighted_backends: Vec::new(), + }; + + upstream_pool.rebuild_weighted_backends(); + + state.upstreams.push(upstream_pool); + } + } + } + + tracing::info!("runtime config reloaded"); + + Ok(()) +} diff --git a/src/algorithms/least_connections.rs b/src/algorithms/least_connections.rs index 7221d80..2cedbeb 100644 --- a/src/algorithms/least_connections.rs +++ b/src/algorithms/least_connections.rs @@ -8,7 +8,8 @@ pub fn select_backend(backends: &[Arc]) -> Option], + current_index: &AtomicUsize, +) -> Option> { + let routable = weighted_backends + .iter() + .filter(|backend| backend.is_routable()) + .cloned() + .collect::>(); + + if routable.is_empty() { + return None; + } + + let index = current_index.fetch_add(1, Ordering::Relaxed); + + Some(routable[index % routable.len()].clone()) +} diff --git a/src/config/types.rs b/src/config/types.rs index f7cde8d..d3c53df 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -34,7 +34,7 @@ pub struct LoadBalancerConfig { // Static backend server definition loaded from configuration. // This only contains immutable backend metadata. // Live runtime information is tracked separately in "BackendState". -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Clone)] pub struct BackendServerConfig { pub id: String, pub host: String, diff --git a/src/health/tcp.rs b/src/health/tcp.rs index 20942f8..3ba6470 100644 --- a/src/health/tcp.rs +++ b/src/health/tcp.rs @@ -1,4 +1,4 @@ -use std::{sync::atomic::Ordering, time::Duration}; +use std::{collections::HashSet, sync::atomic::Ordering, time::Duration}; use crate::state::{app::SharedAppState, backend::BackendState}; use anyhow::Result; @@ -40,10 +40,27 @@ pub async fn start_health_checker(state: SharedAppState, interval_secs: u64) { .flat_map(|upstream| upstream.backends.clone()) .collect::>() }; + let mut removable_backend_ids = HashSet::new(); + for backend in backends { let _ = check_backend_status(&backend).await; if backend.is_draining() && backend.active_connections.load(Ordering::Relaxed) == 0 { info!(backend_id =%backend.config.id,"backend safe to remove"); + removable_backend_ids.insert(backend.config.id.clone()); + } + } + + if !removable_backend_ids.is_empty() { + let mut state = state.write().await; + + for upstream in &mut state.upstreams { + upstream.backends.retain(|backend| { + let should_remove = removable_backend_ids.contains(&backend.config.id); + if should_remove { + info!(backend_id =%backend.config.id,"backend removed from runtime"); + } + !should_remove + }); } } diff --git a/src/main.rs b/src/main.rs index 59828bc..99a16b9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,7 +31,7 @@ async fn main() -> Result<()> { let health_interval = config.load_balancer.health_check_interval_secs; - let state = AppState::build(config); + let state = AppState::build(config, path.clone()); info!("initialized {} upstream pools", state.upstreams.len()); if state.upstreams.is_empty() { bail!("no upstreams configured"); diff --git a/src/metrics/registry.rs b/src/metrics/registry.rs index 9e711ae..a0cb3fd 100644 --- a/src/metrics/registry.rs +++ b/src/metrics/registry.rs @@ -1,10 +1,17 @@ -use prometheus::{Encoder, IntCounterVec, IntGaugeVec, Registry, TextEncoder}; +use prometheus::{ + Encoder, HistogramOpts, HistogramVec, IntCounterVec, IntGaugeVec, Opts, Registry, TextEncoder, +}; use std::sync::OnceLock; pub static REGISTRY: OnceLock = OnceLock::new(); pub static TOTAL_REQUESTS: OnceLock = OnceLock::new(); pub static FAILED_REQUESTS: OnceLock = OnceLock::new(); pub static ACTIVE_CONNECTIONS: OnceLock = OnceLock::new(); +pub static REQUEST_DURATION: OnceLock = OnceLock::new(); +pub static BACKEND_CONNECT_DURATION: OnceLock = OnceLock::new(); +pub static BYTES_IN: OnceLock = OnceLock::new(); +pub static BYTES_OUT: OnceLock = OnceLock::new(); + pub fn initialize_metrics() { let registry = Registry::new(); @@ -26,14 +33,46 @@ pub fn initialize_metrics() { ) .unwrap(); + let request_duration = HistogramVec::new( + HistogramOpts::new("laminar_request_duration_seconds", "Request duration in seconds"), + &["backend"], + ) + .unwrap(); + + let backend_connect_duration = HistogramVec::new( + HistogramOpts::new("laminar_backend_connect_duration_seconds", "Backend connect duration"), + &["backend"], + ) + .unwrap(); + + let bytes_in = IntCounterVec::new( + Opts::new("laminar_bytes_in_total", "Total inbound bytes"), + &["backend"], + ) + .unwrap(); + + let bytes_out = IntCounterVec::new( + Opts::new("laminar_bytes_out_total", "Total outbound bytes"), + &["backend"], + ) + .unwrap(); + registry.register(Box::new(total_requests.clone())).unwrap(); registry.register(Box::new(failed_requests.clone())).unwrap(); registry.register(Box::new(active_connections.clone())).unwrap(); + registry.register(Box::new(request_duration.clone())).unwrap(); + registry.register(Box::new(backend_connect_duration.clone())).unwrap(); + registry.register(Box::new(bytes_in.clone())).unwrap(); + registry.register(Box::new(bytes_out.clone())).unwrap(); REGISTRY.set(registry).unwrap(); TOTAL_REQUESTS.set(total_requests).unwrap(); FAILED_REQUESTS.set(failed_requests).unwrap(); ACTIVE_CONNECTIONS.set(active_connections).unwrap(); + REQUEST_DURATION.set(request_duration).unwrap(); + BACKEND_CONNECT_DURATION.set(backend_connect_duration).unwrap(); + BYTES_IN.set(bytes_in).unwrap(); + BYTES_OUT.set(bytes_out).unwrap(); } pub fn gather_metrics() -> String { diff --git a/src/proxy/tcp.rs b/src/proxy/tcp.rs index 785cb24..23eef57 100644 --- a/src/proxy/tcp.rs +++ b/src/proxy/tcp.rs @@ -1,7 +1,11 @@ use crate::common::shutdown::shutdown_signal; -use crate::metrics::registry::{FAILED_REQUESTS, TOTAL_REQUESTS}; +use crate::metrics::registry::{ + BACKEND_CONNECT_DURATION, BYTES_IN, BYTES_OUT, FAILED_REQUESTS, REQUEST_DURATION, + TOTAL_REQUESTS, +}; use crate::state::app::SharedAppState; use crate::state::backend::ConnectionGuard; +use std::time::Instant; use std::{collections::HashSet, time::Duration}; use tokio::{ io::copy_bidirectional, @@ -49,6 +53,7 @@ pub async fn start_tcp_proxy(address: &str, state: SharedAppState) -> anyhow::Re pub async fn handle_connection(mut stream: TcpStream, state: SharedAppState) -> anyhow::Result<()> { let request_id = Uuid::new_v4(); + let request_start = Instant::now(); let (retry_attempt, connect_timeout, idle_timeout) = { let state = state.read().await; (state.retry_attempts, state.connect_timeout, state.idle_timeout) @@ -65,11 +70,12 @@ pub async fn handle_connection(mut stream: TcpStream, state: SharedAppState) -> // if connection fails: // mark that backend unhealthy so future selections skip it for _ in 0..retry_attempt { - let backend_arc = { + let (backend_arc, algorithm) = { let state = state.read().await; let upstream = &state.upstreams[0]; + let algorithm = upstream.algorithm.clone(); match upstream.next_backend() { - Some(backend) => backend, + Some(backend) => (backend, algorithm), None => { error!( request_id = %request_id, @@ -91,10 +97,19 @@ pub async fn handle_connection(mut stream: TcpStream, state: SharedAppState) -> request_id = %request_id, backend_id = %guard.backend_id(), backend = %backend_address, + algorithm = ?algorithm, "proxy connection started" ); - match proxy_connection(&mut stream, &backend_address, connect_timeout, idle_timeout).await { + match proxy_connection( + &mut stream, + guard.backend_id(), + &backend_address, + connect_timeout, + idle_timeout, + ) + .await + { Ok(_) => { info!( request_id = %request_id, @@ -105,6 +120,13 @@ pub async fn handle_connection(mut stream: TcpStream, state: SharedAppState) -> metrics.with_label_values(&[guard.backend_id()]).inc(); } guard.backend().increment_total_requests(); + + if let Some(histogram) = REQUEST_DURATION.get() { + histogram + .with_label_values(&[guard.backend_id()]) + .observe(request_start.elapsed().as_secs_f64()); + } + return Ok(()); } Err(error) => { @@ -135,15 +157,31 @@ pub async fn handle_connection(mut stream: TcpStream, state: SharedAppState) -> async fn proxy_connection( client_stream: &mut TcpStream, + backend_id: &str, backend_address: &str, connect_timeout: Duration, idle_timeout: Duration, ) -> anyhow::Result<()> { + let connect_start = Instant::now(); let mut backend_stream = timeout(connect_timeout, TcpStream::connect(backend_address)).await??; + if let Some(histogram) = BACKEND_CONNECT_DURATION.get() { + histogram.with_label_values(&[backend_id]).observe(connect_start.elapsed().as_secs_f64()); + } + match timeout(idle_timeout, copy_bidirectional(client_stream, &mut backend_stream)).await { - Ok(Ok(_)) => Ok(()), + Ok(Ok((from_client, from_backend))) => { + if let Some(counter) = BYTES_IN.get() { + counter.with_label_values(&[backend_id]).inc_by(from_client); + } + + if let Some(counter) = BYTES_OUT.get() { + counter.with_label_values(&[backend_id]).inc_by(from_backend); + } + + Ok(()) + } Ok(Err(error)) => { anyhow::bail!("proxy IO error: {error}"); diff --git a/src/state/app.rs b/src/state/app.rs index 62c3a6c..76263eb 100644 --- a/src/state/app.rs +++ b/src/state/app.rs @@ -1,4 +1,4 @@ -use crate::algorithms::{least_connections, round_robin}; +use crate::algorithms::{least_connections, round_robin, weighted_round_robin}; use crate::config::LoadBalancingAlgorithm; use crate::{config::types::Config, state::backend::BackendState}; use std::sync::Arc; @@ -12,6 +12,7 @@ pub struct UpstreamPool { pub current_index: AtomicUsize, pub algorithm: LoadBalancingAlgorithm, pub backends: Vec>, + pub weighted_backends: Vec>, } impl UpstreamPool { @@ -24,11 +25,24 @@ impl UpstreamPool { LoadBalancingAlgorithm::LeastConnections => { least_connections::select_backend(&self.backends) } + LoadBalancingAlgorithm::WeightedRoundRobin => { + weighted_round_robin::select_backend(&self.weighted_backends, &self.current_index) + } _ => { unimplemented!("algorithm not implemented yet") } } } + + pub fn rebuild_weighted_backends(&mut self) { + self.weighted_backends.clear(); + + for backend in &self.backends { + for _ in 0..backend.config.weight { + self.weighted_backends.push(backend.clone()); + } + } + } } // Central shared runtime state for the entire load balancer. // Most subsystems eventually interact with this: @@ -43,12 +57,13 @@ pub struct AppState { pub upstreams: Vec, pub connect_timeout: Duration, pub idle_timeout: Duration, + pub config_path: String, } pub type SharedAppState = Arc>; impl AppState { - pub fn build(config: Config) -> Self { + pub fn build(config: Config, config_path: String) -> Self { // config.upstreams is a grouped collection of upstreams // each upstream has an id, algorithm and servers( yes group of servers) // each server has id, host, port, weight @@ -63,13 +78,16 @@ impl AppState { let backends = upstream.servers.into_iter().map(|s| Arc::new(BackendState::new(s))).collect(); - UpstreamPool { + let mut upstream_pool = UpstreamPool { id: upstream.id, current_index: AtomicUsize::new(0), algorithm: upstream.algorithm, + backends, + weighted_backends: Vec::new(), + }; - backends, // all backends belonging to a single upstream type ( single logical service) - } + upstream_pool.rebuild_weighted_backends(); + upstream_pool }) .collect(); @@ -78,6 +96,7 @@ impl AppState { retry_attempts: config.load_balancer.retry_attempts, connect_timeout: Duration::from_secs(config.load_balancer.connect_timeout_secs), idle_timeout: Duration::from_secs(config.load_balancer.idle_timeout_secs), + config_path, } } } diff --git a/src/state/backend.rs b/src/state/backend.rs index e58a5e4..9143ae0 100644 --- a/src/state/backend.rs +++ b/src/state/backend.rs @@ -103,4 +103,8 @@ impl BackendState { pub fn is_healthy(&self) -> bool { self.healthy.load(Ordering::Relaxed) } + + pub fn is_routable(&self) -> bool { + self.healthy.load(Ordering::Relaxed) && !self.draining.load(Ordering::Relaxed) + } } diff --git a/tests/health_selection.rs b/tests/health_selection.rs index f08b830..ef8a0b0 100644 --- a/tests/health_selection.rs +++ b/tests/health_selection.rs @@ -33,6 +33,7 @@ fn unhealthy_backend_is_skipped() { create_backend("dead", 9001, false).into(), create_backend("healthy", 9002, true).into(), ], + weighted_backends: Vec::new(), }; let backend = upstream.next_backend().unwrap(); @@ -50,6 +51,7 @@ fn returns_none_when_all_backends_dead() { create_backend("dead-1", 9001, false).into(), create_backend("dead-2", 9002, false).into(), ], + weighted_backends: Vec::new(), }; let backend = upstream.next_backend(); diff --git a/tests/reload_runtime.rs b/tests/reload_runtime.rs new file mode 100644 index 0000000..a792626 --- /dev/null +++ b/tests/reload_runtime.rs @@ -0,0 +1,279 @@ +use std::{fs, sync::Arc}; + +use tempfile::NamedTempFile; + +use tokio::sync::RwLock; + +use laminar::{admin::reload::reload_config, config::loader::load_config, state::app::AppState}; + +#[tokio::test] +async fn reload_adds_new_backend() { + let initial_config = r#" +server: + host: "127.0.0.1" + port: 8080 + +load_balancer: + retry_attempts: 2 + sticky_sessions: false + health_check_interval_secs: 5 + connect_timeout_secs: 5 + idle_timeout_secs: 30 + +upstreams: + - id: "main" + + algorithm: "round_robin" + + servers: + - id: "server-1" + host: "127.0.0.1" + port: 9001 + weight: 1 +"#; + + let updated_config = r#" +server: + host: "127.0.0.1" + port: 8080 + +load_balancer: + retry_attempts: 2 + sticky_sessions: false + health_check_interval_secs: 5 + connect_timeout_secs: 5 + idle_timeout_secs: 30 + +upstreams: + - id: "main" + + algorithm: "round_robin" + + servers: + - id: "server-1" + host: "127.0.0.1" + port: 9001 + weight: 1 + + - id: "server-2" + host: "127.0.0.1" + port: 9002 + weight: 1 +"#; + + let temp_file = NamedTempFile::new().unwrap(); + fs::write(temp_file.path(), initial_config).unwrap(); + + let config = load_config(temp_file.path().to_str().unwrap()).unwrap(); + let state = Arc::new(RwLock::new(AppState::build( + config, + temp_file.path().to_str().unwrap().to_string(), + ))); + + { + let state = state.read().await; + assert_eq!(state.upstreams[0].backends.len(), 1); + } + + fs::write(temp_file.path(), updated_config).unwrap(); + reload_config(state.clone()).await.unwrap(); + + { + let state = state.read().await; + assert_eq!(state.upstreams[0].backends.len(), 2); + assert!(state.upstreams[0].backends.iter().any(|b| { b.config.id == "server-2" })); + } +} + +#[tokio::test] +async fn reload_marks_removed_backend_draining() { + let initial_config = r#" +server: + host: "127.0.0.1" + port: 8080 + +load_balancer: + retry_attempts: 2 + sticky_sessions: false + health_check_interval_secs: 5 + connect_timeout_secs: 5 + idle_timeout_secs: 30 + +upstreams: + - id: "main" + + algorithm: "round_robin" + + servers: + - id: "server-1" + host: "127.0.0.1" + port: 9001 + weight: 1 + + - id: "server-2" + host: "127.0.0.1" + port: 9002 + weight: 1 +"#; + + let updated_config = r#" +server: + host: "127.0.0.1" + port: 8080 + +load_balancer: + retry_attempts: 2 + sticky_sessions: false + health_check_interval_secs: 5 + connect_timeout_secs: 5 + idle_timeout_secs: 30 + +upstreams: + - id: "main" + + algorithm: "round_robin" + + servers: + - id: "server-1" + host: "127.0.0.1" + port: 9001 + weight: 1 +"#; + + let temp_file = NamedTempFile::new().unwrap(); + fs::write(temp_file.path(), initial_config).unwrap(); + let config = load_config(temp_file.path().to_str().unwrap()).unwrap(); + let state = Arc::new(RwLock::new(AppState::build( + config, + temp_file.path().to_str().unwrap().to_string(), + ))); + + fs::write(temp_file.path(), updated_config).unwrap(); + reload_config(state.clone()).await.unwrap(); + let state = state.read().await; + let backend = state.upstreams[0].backends.iter().find(|b| b.config.id == "server-2").unwrap(); + + assert!(backend.is_draining()); +} + +#[tokio::test] +async fn draining_backend_is_removed_from_runtime() { + let initial_config = r#" +server: + host: "127.0.0.1" + port: 8080 + +load_balancer: + retry_attempts: 2 + sticky_sessions: false + health_check_interval_secs: 1 + connect_timeout_secs: 5 + idle_timeout_secs: 30 + +upstreams: + - id: "main" + + algorithm: "round_robin" + + servers: + - id: "server-1" + host: "127.0.0.1" + port: 9001 + weight: 1 + + - id: "server-2" + host: "127.0.0.1" + port: 9002 + weight: 1 +"#; + + let updated_config = r#" +server: + host: "127.0.0.1" + port: 8080 + +load_balancer: + retry_attempts: 2 + sticky_sessions: false + health_check_interval_secs: 1 + connect_timeout_secs: 5 + idle_timeout_secs: 30 + +upstreams: + - id: "main" + + algorithm: "round_robin" + + servers: + - id: "server-1" + host: "127.0.0.1" + port: 9001 + weight: 1 +"#; + + let temp_file = NamedTempFile::new().unwrap(); + fs::write(temp_file.path(), initial_config).unwrap(); + let config = load_config(temp_file.path().to_str().unwrap()).unwrap(); + let state = Arc::new(RwLock::new(AppState::build( + config, + temp_file.path().to_str().unwrap().to_string(), + ))); + + fs::write(temp_file.path(), updated_config).unwrap(); + reload_config(state.clone()).await.unwrap(); + + { + let state = state.read().await; + let backend = + state.upstreams[0].backends.iter().find(|b| b.config.id == "server-2").unwrap(); + + assert!(backend.is_draining()); + } + let state1 = state.clone(); + tokio::spawn(async move { + laminar::health::tcp::start_health_checker(state1, 1).await; + }); + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + let state = state.read().await; + assert!(state.upstreams[0].backends.iter().all(|b| { b.config.id != "server-2" })); +} + +#[tokio::test] +async fn reload_is_idempotent() { + let config_text = r#" +server: + host: "127.0.0.1" + port: 8080 + +load_balancer: + retry_attempts: 2 + sticky_sessions: false + health_check_interval_secs: 5 + connect_timeout_secs: 5 + idle_timeout_secs: 30 + +upstreams: + - id: "main" + + algorithm: "round_robin" + + servers: + - id: "server-1" + host: "127.0.0.1" + port: 9001 + weight: 1 +"#; + + let temp_file = NamedTempFile::new().unwrap(); + fs::write(temp_file.path(), config_text).unwrap(); + let config = load_config(temp_file.path().to_str().unwrap()).unwrap(); + let state = Arc::new(RwLock::new(AppState::build( + config, + temp_file.path().to_str().unwrap().to_string(), + ))); + reload_config(state.clone()).await.unwrap(); + reload_config(state.clone()).await.unwrap(); + let state = state.read().await; + + assert_eq!(state.upstreams[0].backends.len(), 1); +} diff --git a/tests/round_robin.rs b/tests/round_robin.rs index c2c350b..0fcfda9 100644 --- a/tests/round_robin.rs +++ b/tests/round_robin.rs @@ -33,6 +33,7 @@ fn round_robin_rotates_backends() { create_backend("server-1", 9001).into(), create_backend("server-2", 9002).into(), ], + weighted_backends: Vec::new(), }; let first = upstream.next_backend().unwrap(); diff --git a/tests/routing_invariants.rs b/tests/routing_invariants.rs new file mode 100644 index 0000000..da6bd47 --- /dev/null +++ b/tests/routing_invariants.rs @@ -0,0 +1,45 @@ +use std::sync::{ + Arc, + atomic::{AtomicBool, AtomicUsize}, +}; + +use laminar::{ + algorithms::{least_connections, round_robin, weighted_round_robin}, + config::types::BackendServerConfig, + state::backend::BackendState, +}; + +fn backend(id: &str, healthy: bool, draining: bool) -> Arc { + Arc::new(BackendState { + config: BackendServerConfig { + id: id.into(), + host: "127.0.0.1".into(), + port: 8080, + weight: 5, + }, + + healthy: AtomicBool::new(healthy), + draining: AtomicBool::new(draining), + active_connections: AtomicUsize::new(0), + total_requests: AtomicUsize::new(0), + failed_requests: AtomicUsize::new(0), + failed_health_checks: 0, + }) +} + +#[test] +fn draining_backend_never_routed() { + let draining = backend("draining", true, true); + let healthy = backend("healthy", true, false); + let backends = vec![draining, healthy.clone()]; + let counter = AtomicUsize::new(0); + + for _ in 0..50 { + let rr = round_robin::select_backend(&backends, &counter).unwrap(); + assert_eq!(rr.config.id, "healthy"); + let lc = least_connections::select_backend(&backends).unwrap(); + assert_eq!(lc.config.id, "healthy"); + let wrr = weighted_round_robin::select_backend(&backends, &counter).unwrap(); + assert_eq!(wrr.config.id, "healthy"); + } +} diff --git a/tests/runtime_concurrency.rs b/tests/runtime_concurrency.rs new file mode 100644 index 0000000..ee17355 --- /dev/null +++ b/tests/runtime_concurrency.rs @@ -0,0 +1,149 @@ +use std::{fs, sync::Arc, time::Duration}; + +use tempfile::NamedTempFile; + +use tokio::{ + io::AsyncWriteExt, + net::{TcpListener, TcpStream}, + sync::RwLock, +}; + +use laminar::{ + admin::reload::reload_config, config::loader::load_config, health::tcp::start_health_checker, + proxy::tcp::handle_connection, state::app::AppState, +}; + +#[tokio::test] +async fn reload_during_active_connection_survives() { + let initial_config = r#" +server: + host: "127.0.0.1" + port: 8080 + +load_balancer: + retry_attempts: 1 + sticky_sessions: false + health_check_interval_secs: 1 + connect_timeout_secs: 5 + idle_timeout_secs: 30 + +upstreams: + - id: "main" + + algorithm: "round_robin" + + servers: + - id: "server-1" + host: "127.0.0.1" + port: 9001 + weight: 1 +"#; + + let updated_config = r#" +server: + host: "127.0.0.1" + port: 8080 + +load_balancer: + retry_attempts: 1 + sticky_sessions: false + health_check_interval_secs: 1 + connect_timeout_secs: 5 + idle_timeout_secs: 30 + +upstreams: + - id: "main" + + algorithm: "round_robin" + + servers: + - id: "server-2" + host: "127.0.0.1" + port: 9002 + weight: 1 +"#; + + let temp_file = NamedTempFile::new().unwrap(); + + fs::write(temp_file.path(), initial_config).unwrap(); + + let config = load_config(temp_file.path().to_str().unwrap()).unwrap(); + + let state = Arc::new(RwLock::new(AppState::build( + config, + temp_file.path().to_str().unwrap().to_string(), + ))); + + // START HEALTH CHECKER + { + let health_state = state.clone(); + + tokio::spawn(async move { + start_health_checker(health_state, 1).await; + }); + } + + // BACKEND SERVER + let backend_listener = TcpListener::bind("127.0.0.1:9001").await.unwrap(); + + tokio::spawn(async move { + let (_socket, _) = backend_listener.accept().await.unwrap(); + + // Keep connection alive long enough + // for reload/draining semantics + tokio::time::sleep(Duration::from_secs(2)).await; + }); + + // PROXY ENTRY + let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + + let proxy_addr = proxy_listener.local_addr().unwrap(); + + let proxy_state = state.clone(); + + tokio::spawn(async move { + let (stream, _) = proxy_listener.accept().await.unwrap(); + + handle_connection(stream, proxy_state).await.unwrap(); + }); + + // CLIENT + let mut client = TcpStream::connect(proxy_addr).await.unwrap(); + + // Trigger active traffic + client.write_all(b"ping").await.unwrap(); + + // RELOAD DURING ACTIVE CONNECTION + fs::write(temp_file.path(), updated_config).unwrap(); + + reload_config(state.clone()).await.unwrap(); + + // Give runtime time to mark draining + tokio::time::sleep(Duration::from_millis(200)).await; + + { + let state = state.read().await; + + let exists = state.upstreams[0].backends.iter().any(|b| b.config.id == "server-1"); + + // backend may already be removed + // depending on async timing + + if exists { + let backend = + state.upstreams[0].backends.iter().find(|b| b.config.id == "server-1").unwrap(); + + assert!(backend.is_draining()); + } + } + + // Wait for connection completion + // + health cleanup loop + tokio::time::sleep(Duration::from_secs(4)).await; + + { + let state = state.read().await; + + assert!(state.upstreams[0].backends.iter().all(|b| { b.config.id != "server-1" })); + } +} diff --git a/tests/timeouts.rs b/tests/timeouts.rs index 75f9755..ff33c44 100644 --- a/tests/timeouts.rs +++ b/tests/timeouts.rs @@ -36,7 +36,7 @@ async fn test_connect_timeout() { }], }; - let state = Arc::new(RwLock::new(AppState::build(config))); + let state = Arc::new(RwLock::new(AppState::build(config, "laminar_config.yaml".to_string()))); // Create a local listener to act as the "client" entry point let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -89,7 +89,7 @@ async fn test_idle_timeout() { }], }; - let state = Arc::new(RwLock::new(AppState::build(config))); + let state = Arc::new(RwLock::new(AppState::build(config, "laminar_config.yaml".to_string()))); // 2. Start a local listener to act as the "client" entry point let proxy_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); diff --git a/tests/weighted_round_robin.rs b/tests/weighted_round_robin.rs new file mode 100644 index 0000000..a0fdcd9 --- /dev/null +++ b/tests/weighted_round_robin.rs @@ -0,0 +1,93 @@ +use std::sync::{ + Arc, + atomic::{AtomicBool, AtomicUsize}, +}; + +use laminar::{ + algorithms::weighted_round_robin, config::types::BackendServerConfig, + state::backend::BackendState, +}; + +fn create_backend(id: &str, weight: usize, healthy: bool, draining: bool) -> Arc { + Arc::new(BackendState { + config: BackendServerConfig { id: id.into(), host: "127.0.0.1".into(), port: 8080, weight }, + + healthy: AtomicBool::new(healthy), + draining: AtomicBool::new(draining), + active_connections: AtomicUsize::new(0), + total_requests: AtomicUsize::new(0), + failed_requests: AtomicUsize::new(0), + failed_health_checks: 0, + }) +} + +#[test] +fn weighted_distribution_is_respected() { + let backend_1 = create_backend("server-1", 5, true, false); + let backend_2 = create_backend("server-2", 1, true, false); + let weighted_backends = vec![ + backend_1.clone(), + backend_1.clone(), + backend_1.clone(), + backend_1.clone(), + backend_1.clone(), + backend_2.clone(), + ]; + let counter = AtomicUsize::new(0); + + let mut server_1_hits = 0; + let mut server_2_hits = 0; + + for _ in 0..600 { + let backend = weighted_round_robin::select_backend(&weighted_backends, &counter).unwrap(); + match backend.config.id.as_str() { + "server-1" => { + server_1_hits += 1; + } + "server-2" => { + server_2_hits += 1; + } + _ => {} + } + } + + assert!(server_1_hits > 450); + assert!(server_2_hits < 150); +} + +#[test] +fn unhealthy_backend_is_skipped() { + let backend_1 = create_backend("dead", 5, false, false); + let backend_2 = create_backend("healthy", 1, true, false); + let backends = vec![backend_1, backend_2]; + let counter = AtomicUsize::new(0); + + for _ in 0..20 { + let backend = weighted_round_robin::select_backend(&backends, &counter).unwrap(); + assert_eq!(backend.config.id, "healthy"); + } +} + +#[test] +fn draining_backend_is_skipped() { + let backend_1 = create_backend("draining", 5, true, true); + let backend_2 = create_backend("healthy", 1, true, false); + let backends = vec![backend_1, backend_2]; + let counter = AtomicUsize::new(0); + + for _ in 0..20 { + let backend = weighted_round_robin::select_backend(&backends, &counter).unwrap(); + assert_eq!(backend.config.id, "healthy"); + } +} + +#[test] +fn returns_none_when_all_backends_invalid() { + let backend_1 = create_backend("dead", 5, false, false); + let backend_2 = create_backend("draining", 1, true, true); + let backends = vec![backend_1, backend_2]; + let counter = AtomicUsize::new(0); + let backend = weighted_round_robin::select_backend(&backends, &counter); + + assert!(backend.is_none()); +}