From c31dd4019df5fddd82a054578a11a6d91880b800 Mon Sep 17 00:00:00 2001 From: Rohit Date: Sun, 17 May 2026 19:04:45 +0530 Subject: [PATCH] add validations for connect and tests --- src/algorithms/least_connections.rs | 3 +- src/config/default.rs | 3 +- src/config/types.rs | 1 - src/config/validator.rs | 8 +++++ src/proxy/tcp.rs | 54 ++++++++++++++++------------- tests/connect_timeout.rs | 36 +++++++++++++++++++ 6 files changed, 76 insertions(+), 29 deletions(-) create mode 100644 tests/connect_timeout.rs diff --git a/src/algorithms/least_connections.rs b/src/algorithms/least_connections.rs index e5a1482..fa58f7a 100644 --- a/src/algorithms/least_connections.rs +++ b/src/algorithms/least_connections.rs @@ -1,6 +1,5 @@ -use std::sync::{Arc, atomic::Ordering}; - use crate::state::backend::BackendState; +use std::sync::{Arc, atomic::Ordering}; pub fn select_backend(backends: &[Arc]) -> Option> { // prev and curr approach diff --git a/src/config/default.rs b/src/config/default.rs index 2bcc3e2..e7af80b 100644 --- a/src/config/default.rs +++ b/src/config/default.rs @@ -9,7 +9,8 @@ load_balancer: retry_attempts: 2 sticky_sessions: false health_check_interval_secs: 5 - + connect_timeout_secs: 3 + idle_timeout_secs: 30 upstreams: - id: "main" algorithm: "round_robin" diff --git a/src/config/types.rs b/src/config/types.rs index 54e6ca8..795ef80 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -26,7 +26,6 @@ pub struct ServerConfig { pub struct LoadBalancerConfig { pub retry_attempts: usize, pub sticky_sessions: bool, - pub health_check_interval_secs: u64, pub connect_timeout_secs: u64, pub idle_timeout_secs: u64, diff --git a/src/config/validator.rs b/src/config/validator.rs index 3dc5702..6b765d0 100644 --- a/src/config/validator.rs +++ b/src/config/validator.rs @@ -7,6 +7,14 @@ use std::collections::HashSet; // rather than causing runtime instability. pub fn validate_config(config: &Config) -> Result<()> { + if config.load_balancer.connect_timeout_secs == 0 { + bail!("connect_timeout_secs must be greater than 0"); + } + + if config.load_balancer.idle_timeout_secs == 0 { + bail!("idle_timeout_secs must be greater than 0"); + } + let mut upstream_ids = HashSet::new(); for upstream in &config.upstreams { if upstream.servers.is_empty() { diff --git a/src/proxy/tcp.rs b/src/proxy/tcp.rs index 11074f6..b5bea43 100644 --- a/src/proxy/tcp.rs +++ b/src/proxy/tcp.rs @@ -1,3 +1,6 @@ +use crate::state::app::SharedAppState; +use crate::state::backend::ConnectionGuard; +use std::time::Duration; use tokio::{ io::copy_bidirectional, net::{TcpListener, TcpStream}, @@ -5,9 +8,6 @@ use tokio::{ }; use tracing::{error, info}; -use crate::state::app::SharedAppState; -use crate::state::backend::ConnectionGuard; - pub async fn start_tcp_proxy(address: &str, state: SharedAppState) -> anyhow::Result<()> { let listener = TcpListener::bind(address).await?; @@ -18,7 +18,6 @@ pub async fn start_tcp_proxy(address: &str, state: SharedAppState) -> anyhow::Re info!("new client connected {}", client_address); let state = state.clone(); - tokio::spawn(async move { if let Err(error) = handle_connection(client_stream, state).await { error!("connection handling failed {:?}", error) @@ -59,30 +58,13 @@ async fn handle_connection(mut stream: TcpStream, state: SharedAppState) -> anyh info!("forwarding traffic to {}", backend_address); - match timeout(connect_timeout, TcpStream::connect(&backend_address)).await { - Ok(Ok(mut backend_stream)) => { - if timeout(idle_timeout, copy_bidirectional(&mut stream, &mut backend_stream)) - .await - .is_err() - { - error!("connection with {} timed out (idle)", backend_address); - } + match proxy_connection(&mut stream, &backend_address, connect_timeout, idle_timeout).await { + Ok(_) => { return Ok(()); } - Ok(Err(error)) => { - // this is very important .. - guard.mark_backend_unhealthy(); - error!("failed to connect to backend {}: {:?}", backend_address, error); - - // retry another - continue; - } - Err(_) => { - // connection attempt timed out + Err(error) => { guard.mark_backend_unhealthy(); - error!("connection attempt to {} timed out", backend_address); - - // retry another + error!("backend {} failed: {:?}", backend_address, error); continue; } } @@ -90,3 +72,25 @@ async fn handle_connection(mut stream: TcpStream, state: SharedAppState) -> anyh error!("all backend retry attempts failed"); Ok(()) } + +async fn proxy_connection( + client_stream: &mut TcpStream, + backend_address: &str, + connect_timeout: Duration, + idle_timeout: Duration, +) -> anyhow::Result<()> { + let mut backend_stream = + timeout(connect_timeout, TcpStream::connect(backend_address)).await??; + + match timeout(idle_timeout, copy_bidirectional(client_stream, &mut backend_stream)).await { + Ok(Ok(_)) => Ok(()), + + Ok(Err(error)) => { + anyhow::bail!("proxy IO error: {error}"); + } + + Err(_) => { + anyhow::bail!("connection idle timeout"); + } + } +} diff --git a/tests/connect_timeout.rs b/tests/connect_timeout.rs new file mode 100644 index 0000000..6f28774 --- /dev/null +++ b/tests/connect_timeout.rs @@ -0,0 +1,36 @@ +use laminar::{ + config::types::BackendServerConfig, + state::backend::{BackendState, ConnectionGuard}, +}; +use std::{ + sync::{ + Arc, + atomic::{AtomicBool, AtomicUsize, Ordering}, + }, + time::Duration, +}; +use tokio::{net::TcpStream, time::timeout}; + +#[tokio::test] +async fn marks_backend_unhealthy_on_connect_timeout() { + let backend = Arc::new(BackendState { + config: BackendServerConfig { + id: "dead-backend".into(), + host: "10.255.255.1".into(), + port: 1234, + weight: 1, + }, + + healthy: AtomicBool::new(true), + active_connections: AtomicUsize::new(0), + failed_health_checks: 0, + }); + + let guard = ConnectionGuard::new(backend.clone()); + let address = guard.address(); + let result = timeout(Duration::from_millis(100), TcpStream::connect(address)).await; + assert!(result.is_err()); + + guard.mark_backend_unhealthy(); + assert!(!backend.healthy.load(Ordering::Relaxed)); +}