diff --git a/Cargo.lock b/Cargo.lock index fe01e11720..2c9dd8bb85 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3596,6 +3596,7 @@ dependencies = [ "sha2", "smallvec", "smol", + "socket2", "sqlx", "thiserror 2.0.17", "time", diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 9ccb441e45..0deeff16f2 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -21,12 +21,11 @@ json = ["serde", "serde_json"] # for conditional compilation _rt-async-global-executor = ["async-global-executor", "_rt-async-io", "_rt-async-task"] -_rt-async-io = ["async-io", "async-fs"] # see note at async-fs declaration +_rt-async-io = ["async-io", "async-fs", "socket2"] # see note at async-fs declaration _rt-async-std = ["async-std", "_rt-async-io"] _rt-async-task = ["async-task"] _rt-smol = ["smol", "_rt-async-io", "_rt-async-task"] -_rt-tokio = ["tokio", "tokio-stream"] - +_rt-tokio = ["tokio", "tokio-stream", "socket2"] _tls-native-tls = ["native-tls"] _tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"] _tls-rustls-ring-webpki = ["_tls-rustls", "rustls/ring", "webpki-roots"] @@ -102,6 +101,7 @@ hashlink = "0.11.0" indexmap = "2.0" event-listener = "5.2.0" hashbrown = "0.16.0" +socket2 = { version = "0.5", features = ["all"], optional = true } thiserror.workspace = true diff --git a/sqlx-core/src/net/mod.rs b/sqlx-core/src/net/mod.rs index f9c43668ab..265f59cadf 100644 --- a/sqlx-core/src/net/mod.rs +++ b/sqlx-core/src/net/mod.rs @@ -2,5 +2,6 @@ mod socket; pub mod tls; pub use socket::{ - connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer, + connect_tcp, connect_uds, BufferedSocket, KeepaliveConfig, Socket, SocketIntoBox, WithSocket, + WriteBuffer, }; diff --git a/sqlx-core/src/net/socket/mod.rs b/sqlx-core/src/net/socket/mod.rs index 0f9aae61b4..06e23966c7 100644 --- a/sqlx-core/src/net/socket/mod.rs +++ b/sqlx-core/src/net/socket/mod.rs @@ -3,6 +3,7 @@ use std::io; use std::path::Path; use std::pin::Pin; use std::task::{ready, Context, Poll}; +use std::time::Duration; pub use buffered::{BufferedSocket, WriteBuffer}; use bytes::BufMut; @@ -12,6 +13,25 @@ use crate::io::ReadBuf; mod buffered; +/// Configuration for TCP keepalive probes on a connection. +/// +/// All fields default to `None`, meaning the OS default is used. +/// Constructing a `KeepaliveConfig::default()` and passing it enables keepalive +/// with OS defaults for all parameters. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct KeepaliveConfig { + /// Time the connection must be idle before keepalive probes begin. + /// `None` means the OS default. + pub idle: Option, + /// Interval between keepalive probes. + /// `None` means the OS default. + pub interval: Option, + /// Maximum number of failed probes before the connection is dropped. + /// Only supported on Unix; ignored on other platforms. + /// `None` means the OS default. + pub retries: Option, +} + pub trait Socket: Send + Sync + Unpin + 'static { fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result; @@ -181,23 +201,63 @@ impl Socket for Box { } } +#[cfg(any(feature = "_rt-tokio", feature = "_rt-async-io"))] +fn build_tcp_keepalive(config: &KeepaliveConfig) -> socket2::TcpKeepalive { + let mut ka = socket2::TcpKeepalive::new(); + + if let Some(idle) = config.idle { + ka = ka.with_time(idle); + } + + // socket2's `with_interval` is unavailable on these platforms. + #[cfg(not(any( + target_os = "haiku", + target_os = "openbsd", + target_os = "redox", + target_os = "solaris", + )))] + if let Some(interval) = config.interval { + ka = ka.with_interval(interval); + } + + // socket2's `with_retries` is unavailable on these platforms. + #[cfg(not(any( + target_os = "haiku", + target_os = "openbsd", + target_os = "redox", + target_os = "solaris", + target_os = "windows", + )))] + if let Some(retries) = config.retries { + ka = ka.with_retries(retries); + } + + ka +} + pub async fn connect_tcp( host: &str, port: u16, + keepalive: Option<&KeepaliveConfig>, with_socket: Ws, ) -> crate::Result { #[cfg(feature = "_rt-tokio")] if crate::rt::rt_tokio::available() { - return Ok(with_socket - .with_socket(tokio::net::TcpStream::connect((host, port)).await?) - .await); + let stream = tokio::net::TcpStream::connect((host, port)).await?; + + if let Some(ka) = keepalive { + let sock = socket2::SockRef::from(&stream); + sock.set_tcp_keepalive(&build_tcp_keepalive(ka))?; + } + + return Ok(with_socket.with_socket(stream).await); } cfg_if! { if #[cfg(feature = "_rt-async-io")] { - Ok(with_socket.with_socket(connect_tcp_async_io(host, port).await?).await) + Ok(with_socket.with_socket(connect_tcp_async_io(host, port, keepalive).await?).await) } else { - crate::rt::missing_rt((host, port, with_socket)) + crate::rt::missing_rt((host, port, keepalive, with_socket)) } } } @@ -208,7 +268,11 @@ pub async fn connect_tcp( /// /// This implements the same behavior as [`tokio::net::TcpStream::connect()`]. #[cfg(feature = "_rt-async-io")] -async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result { +async fn connect_tcp_async_io( + host: &str, + port: u16, + keepalive: Option<&KeepaliveConfig>, +) -> crate::Result { use async_io::Async; use std::net::{IpAddr, TcpStream, ToSocketAddrs}; @@ -216,7 +280,14 @@ async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result() { - return Ok(Async::::connect((addr, port)).await?); + let stream = Async::::connect((addr, port)).await?; + + if let Some(ka) = keepalive { + let sock = socket2::SockRef::from(stream.get_ref()); + sock.set_tcp_keepalive(&build_tcp_keepalive(ka))?; + } + + return Ok(stream); } let host = host.to_string(); @@ -232,7 +303,14 @@ async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result::connect(socket_addr).await { - Ok(stream) => return Ok(stream), + Ok(stream) => { + if let Some(ka) = keepalive { + let sock = socket2::SockRef::from(stream.get_ref()); + sock.set_tcp_keepalive(&build_tcp_keepalive(ka))?; + } + + return Ok(stream); + } Err(e) => last_err = Some(e), } } diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index f61654d876..5b6032834d 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -17,7 +17,9 @@ impl MySqlConnection { let handshake = match &options.socket { Some(path) => crate::net::connect_uds(path, do_handshake).await?, - None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?, + None => { + crate::net::connect_tcp(&options.host, options.port, None, do_handshake).await? + } }; let stream = handshake?; diff --git a/sqlx-postgres/src/connection/stream.rs b/sqlx-postgres/src/connection/stream.rs index e8a1aedc47..df7c09367f 100644 --- a/sqlx-postgres/src/connection/stream.rs +++ b/sqlx-postgres/src/connection/stream.rs @@ -44,7 +44,20 @@ impl PgStream { pub(super) async fn connect(options: &PgConnectOptions) -> Result { let socket_result = match options.fetch_socket() { Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?, - None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?, + None => { + let keepalive = if options.keepalives { + Some(&options.keepalive_config) + } else { + None + }; + net::connect_tcp( + &options.host, + options.port, + keepalive, + MaybeUpgradeTls(options), + ) + .await? + } }; let socket = socket_result?; diff --git a/sqlx-postgres/src/options/mod.rs b/sqlx-postgres/src/options/mod.rs index 21e6628cae..0ba234de49 100644 --- a/sqlx-postgres/src/options/mod.rs +++ b/sqlx-postgres/src/options/mod.rs @@ -2,9 +2,11 @@ use std::borrow::Cow; use std::env::var; use std::fmt::{self, Display, Write}; use std::path::{Path, PathBuf}; +use std::time::Duration; pub use ssl_mode::PgSslMode; +use crate::net::KeepaliveConfig; use crate::{connection::LogSettings, net::tls::CertificateInput}; mod connect; @@ -30,6 +32,8 @@ pub struct PgConnectOptions { pub(crate) log_settings: LogSettings, pub(crate) extra_float_digits: Option>, pub(crate) options: Option, + pub(crate) keepalives: bool, + pub(crate) keepalive_config: KeepaliveConfig, } impl Default for PgConnectOptions { @@ -97,6 +101,9 @@ impl PgConnectOptions { extra_float_digits: Some("2".into()), log_settings: Default::default(), options: var("PGOPTIONS").ok(), + // Matches libpq default: keepalives=1 with OS defaults for timers. + keepalives: true, + keepalive_config: KeepaliveConfig::default(), } } @@ -452,6 +459,85 @@ impl PgConnectOptions { self } + /// Enables or disables TCP keepalive on the connection. + /// + /// This option is ignored for Unix domain sockets. + /// + /// Keepalive is enabled by default. + /// + /// When enabled, OS defaults are used for all timer parameters unless + /// overridden by [`keepalives_idle`][Self::keepalives_idle], + /// [`keepalives_interval`][Self::keepalives_interval], or + /// [`keepalives_retries`][Self::keepalives_retries]. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .keepalives(false); + /// ``` + pub fn keepalives(mut self, enable: bool) -> Self { + self.keepalives = enable; + self + } + + /// Sets the idle time before TCP keepalive probes begin. + /// + /// This is ignored for Unix domain sockets, or if the `keepalives` + /// option is disabled. + /// + /// # Example + /// + /// ```rust + /// # use std::time::Duration; + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .keepalives_idle(Duration::from_secs(60)); + /// ``` + pub fn keepalives_idle(mut self, idle: Duration) -> Self { + self.keepalive_config.idle = Some(idle); + self + } + + /// Sets the interval between TCP keepalive probes. + /// + /// This is ignored for Unix domain sockets, or if the `keepalives` + /// option is disabled. + /// + /// # Example + /// + /// ```rust + /// # use std::time::Duration; + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .keepalives_interval(Duration::from_secs(5)); + /// ``` + pub fn keepalives_interval(mut self, interval: Duration) -> Self { + self.keepalive_config.interval = Some(interval); + self + } + + /// Sets the maximum number of TCP keepalive probes before the connection is dropped. + /// + /// This is ignored for Unix domain sockets, or if the `keepalives` + /// option is disabled. + /// + /// Only supported on Unix platforms; ignored on other platforms. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new() + /// .keepalives_retries(3); + /// ``` + #[cfg(unix)] + pub fn keepalives_retries(mut self, retries: u32) -> Self { + self.keepalive_config.retries = Some(retries); + self + } + /// We try using a socket if hostname starts with `/` or if socket parameter /// is specified. pub(crate) fn fetch_socket(&self) -> Option { @@ -580,6 +666,34 @@ impl PgConnectOptions { pub fn get_options(&self) -> Option<&str> { self.options.as_deref() } + + /// Get whether TCP keepalives are enabled. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_postgres::PgConnectOptions; + /// let options = PgConnectOptions::new(); + /// assert!(options.get_keepalives()); + /// ``` + pub fn get_keepalives(&self) -> bool { + self.keepalives + } + + /// Get the idle time before TCP keepalive probes begin. + pub fn get_keepalives_idle(&self) -> Option { + self.keepalive_config.idle + } + + /// Get the interval between TCP keepalive probes. + pub fn get_keepalives_interval(&self) -> Option { + self.keepalive_config.interval + } + + /// Get the maximum number of TCP keepalive probes. + pub fn get_keepalives_retries(&self) -> Option { + self.keepalive_config.retries + } } fn default_host(port: u16) -> String { diff --git a/sqlx-postgres/src/options/parse.rs b/sqlx-postgres/src/options/parse.rs index e911305698..741d14f79e 100644 --- a/sqlx-postgres/src/options/parse.rs +++ b/sqlx-postgres/src/options/parse.rs @@ -1,9 +1,12 @@ -use crate::error::Error; -use crate::{PgConnectOptions, PgSslMode}; -use sqlx_core::percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}; -use sqlx_core::Url; use std::net::IpAddr; use std::str::FromStr; +use std::time::Duration; + +use sqlx_core::percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}; +use sqlx_core::Url; + +use crate::error::Error; +use crate::{PgConnectOptions, PgSslMode}; impl PgConnectOptions { pub(crate) fn parse_from_url(url: &Url) -> Result { @@ -104,6 +107,41 @@ impl PgConnectOptions { } } + "keepalives" => match value.as_ref() { + "0" | "1" => { + options = options.keepalives(value.as_ref() == "1"); + } + _ => { + return Err(Error::Configuration( + format!("keepalives must be 0 or 1, got: {value}").into(), + )); + } + }, + + "keepalives_idle" => { + let secs: u64 = value.parse().map_err(Error::config)?; + if secs > 0 { + options = options.keepalives_idle(Duration::from_secs(secs)); + } + } + + "keepalives_interval" => { + let secs: u64 = value.parse().map_err(Error::config)?; + if secs > 0 { + options = options.keepalives_interval(Duration::from_secs(secs)); + } + } + + "keepalives_count" => { + let _count: u32 = value.parse().map_err(Error::config)?; + // On non-Unix, TCP_KEEPCNT is not supported; the value is + // parsed for validation only (see `keepalives_retries` docs). + #[cfg(unix)] + { + options = options.keepalives_retries(_count); + } + } + _ => tracing::warn!(%key, %value, "ignoring unrecognized connect parameter"), } } @@ -166,6 +204,26 @@ impl PgConnectOptions { &self.statement_cache_capacity.to_string(), ); + url.query_pairs_mut() + .append_pair("keepalives", if self.keepalives { "1" } else { "0" }); + + if self.keepalives { + if let Some(idle) = self.keepalive_config.idle { + url.query_pairs_mut() + .append_pair("keepalives_idle", &idle.as_secs().to_string()); + } + + if let Some(interval) = self.keepalive_config.interval { + url.query_pairs_mut() + .append_pair("keepalives_interval", &interval.as_secs().to_string()); + } + + if let Some(retries) = self.keepalive_config.retries { + url.query_pairs_mut() + .append_pair("keepalives_count", &retries.to_string()); + } + } + url } } @@ -309,8 +367,8 @@ fn it_returns_the_parsed_url_when_socket() { let opts = PgConnectOptions::from_str(url).unwrap(); let mut expected_url = Url::parse(url).unwrap(); - // PgConnectOptions defaults - let query_string = "sslmode=prefer&statement-cache-capacity=100"; + // PgConnectOptions defaults (keepalives=1 is enabled by default) + let query_string = "sslmode=prefer&statement-cache-capacity=100&keepalives=1"; let port = 5432; expected_url.set_query(Some(query_string)); let _ = expected_url.set_port(Some(port)); @@ -324,8 +382,8 @@ fn it_returns_the_parsed_url_when_host() { let opts = PgConnectOptions::from_str(url).unwrap(); let mut expected_url = Url::parse(url).unwrap(); - // PgConnectOptions defaults - let query_string = "sslmode=prefer&statement-cache-capacity=100"; + // PgConnectOptions defaults (keepalives=1 is enabled by default) + let query_string = "sslmode=prefer&statement-cache-capacity=100&keepalives=1"; expected_url.set_query(Some(query_string)); assert_eq!(expected_url, opts.build_url()); @@ -340,3 +398,85 @@ fn built_url_can_be_parsed() { assert!(parsed.is_ok()); } + +#[test] +fn it_parses_keepalives_enabled() { + let url = "postgres://localhost/db?keepalives=1"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert!(opts.keepalives); + assert_eq!(opts.keepalive_config.idle, None); + assert_eq!(opts.keepalive_config.interval, None); + assert_eq!(opts.keepalive_config.retries, None); +} + +#[test] +fn it_parses_keepalives_disabled() { + let url = "postgres://localhost/db?keepalives_idle=60&keepalives=0"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert!(!opts.keepalives); + // timer values are preserved even when keepalives is disabled + assert_eq!(opts.keepalive_config.idle, Some(Duration::from_secs(60))); +} + +#[test] +fn it_parses_keepalives_idle() { + let url = "postgres://localhost/db?keepalives_idle=60"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!(opts.keepalive_config.idle, Some(Duration::from_secs(60))); +} + +#[test] +fn it_parses_keepalives_interval_before_idle() { + // interval appears before idle — must not be silently lost + let url = "postgres://localhost/db?keepalives_interval=5&keepalives_idle=60"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!(opts.keepalive_config.idle, Some(Duration::from_secs(60))); + assert_eq!(opts.keepalive_config.interval, Some(Duration::from_secs(5))); +} + +#[test] +fn it_parses_keepalives_all_params() { + let url = "postgres://localhost/db?keepalives_count=3&keepalives_interval=5&keepalives_idle=60"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + assert_eq!(opts.keepalive_config.idle, Some(Duration::from_secs(60))); + assert_eq!(opts.keepalive_config.interval, Some(Duration::from_secs(5))); + #[cfg(unix)] + assert_eq!(opts.keepalive_config.retries, Some(3)); +} + +#[test] +fn it_treats_zero_keepalive_timers_as_os_default() { + let url = "postgres://localhost/db?keepalives_idle=0&keepalives_interval=0"; + let opts = PgConnectOptions::from_str(url).unwrap(); + + // 0 means "use OS default" — the value should not be stored + assert_eq!(opts.keepalive_config.idle, None); + assert_eq!(opts.keepalive_config.interval, None); +} + +#[test] +fn it_rejects_invalid_keepalives_value() { + let url = "postgres://localhost/db?keepalives=2"; + assert!(PgConnectOptions::from_str(url).is_err()); +} + +#[test] +fn it_roundtrips_keepalive_through_build_url() { + let url = "postgres://localhost/db?keepalives_idle=60&keepalives_interval=5&keepalives_count=3"; + let opts = PgConnectOptions::from_str(url).unwrap(); + let rebuilt = PgConnectOptions::from_str(&opts.build_url().to_string()).unwrap(); + + assert!(rebuilt.keepalives); + assert_eq!(rebuilt.keepalive_config.idle, Some(Duration::from_secs(60))); + assert_eq!( + rebuilt.keepalive_config.interval, + Some(Duration::from_secs(5)) + ); + #[cfg(unix)] + assert_eq!(rebuilt.keepalive_config.retries, Some(3)); +}