Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@ json = ["serde", "serde_json"]

# for conditional compilation
_rt-async-global-executor = ["async-global-executor", "_rt-async-io", "_rt-async-task"]
_rt-async-io = ["async-io", "async-fs"] # see note at async-fs declaration
_rt-async-io = ["async-io", "async-fs", "socket2"] # see note at async-fs declaration
_rt-async-std = ["async-std", "_rt-async-io"]
_rt-async-task = ["async-task"]
_rt-smol = ["smol", "_rt-async-io", "_rt-async-task"]
_rt-tokio = ["tokio", "tokio-stream"]

_rt-tokio = ["tokio", "tokio-stream", "socket2"]
_tls-native-tls = ["native-tls"]
_tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"]
_tls-rustls-ring-webpki = ["_tls-rustls", "rustls/ring", "webpki-roots"]
Expand Down Expand Up @@ -102,6 +101,7 @@ hashlink = "0.11.0"
indexmap = "2.0"
event-listener = "5.2.0"
hashbrown = "0.16.0"
socket2 = { version = "0.5", features = ["all"], optional = true }

thiserror.workspace = true

Expand Down
3 changes: 2 additions & 1 deletion sqlx-core/src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ mod socket;
pub mod tls;

pub use socket::{
connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer,
connect_tcp, connect_uds, BufferedSocket, KeepaliveConfig, Socket, SocketIntoBox, WithSocket,
WriteBuffer,
};
94 changes: 86 additions & 8 deletions sqlx-core/src/net/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::io;
use std::path::Path;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use std::time::Duration;

pub use buffered::{BufferedSocket, WriteBuffer};
use bytes::BufMut;
Expand All @@ -12,6 +13,25 @@ use crate::io::ReadBuf;

mod buffered;

/// Configuration for TCP keepalive probes on a connection.
///
/// All fields default to `None`, meaning the OS default is used.
/// Constructing a `KeepaliveConfig::default()` and passing it enables keepalive
/// with OS defaults for all parameters.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct KeepaliveConfig {
/// Time the connection must be idle before keepalive probes begin.
/// `None` means the OS default.
pub idle: Option<Duration>,
/// Interval between keepalive probes.
/// `None` means the OS default.
pub interval: Option<Duration>,
/// Maximum number of failed probes before the connection is dropped.
/// Only supported on Unix; ignored on other platforms.
/// `None` means the OS default.
pub retries: Option<u32>,
}

pub trait Socket: Send + Sync + Unpin + 'static {
fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize>;

Expand Down Expand Up @@ -181,23 +201,63 @@ impl<S: Socket + ?Sized> Socket for Box<S> {
}
}

#[cfg(any(feature = "_rt-tokio", feature = "_rt-async-io"))]
fn build_tcp_keepalive(config: &KeepaliveConfig) -> socket2::TcpKeepalive {
let mut ka = socket2::TcpKeepalive::new();

if let Some(idle) = config.idle {
ka = ka.with_time(idle);
}

// socket2's `with_interval` is unavailable on these platforms.
#[cfg(not(any(
target_os = "haiku",
target_os = "openbsd",
target_os = "redox",
target_os = "solaris",
)))]
if let Some(interval) = config.interval {
ka = ka.with_interval(interval);
}

// socket2's `with_retries` is unavailable on these platforms.
#[cfg(not(any(
target_os = "haiku",
target_os = "openbsd",
target_os = "redox",
target_os = "solaris",
target_os = "windows",
)))]
if let Some(retries) = config.retries {
ka = ka.with_retries(retries);
}

ka
}

pub async fn connect_tcp<Ws: WithSocket>(
host: &str,
port: u16,
keepalive: Option<&KeepaliveConfig>,
with_socket: Ws,
) -> crate::Result<Ws::Output> {
#[cfg(feature = "_rt-tokio")]
if crate::rt::rt_tokio::available() {
return Ok(with_socket
.with_socket(tokio::net::TcpStream::connect((host, port)).await?)
.await);
let stream = tokio::net::TcpStream::connect((host, port)).await?;

if let Some(ka) = keepalive {
let sock = socket2::SockRef::from(&stream);
sock.set_tcp_keepalive(&build_tcp_keepalive(ka))?;
}

return Ok(with_socket.with_socket(stream).await);
}

cfg_if! {
if #[cfg(feature = "_rt-async-io")] {
Ok(with_socket.with_socket(connect_tcp_async_io(host, port).await?).await)
Ok(with_socket.with_socket(connect_tcp_async_io(host, port, keepalive).await?).await)
} else {
crate::rt::missing_rt((host, port, with_socket))
crate::rt::missing_rt((host, port, keepalive, with_socket))
}
}
}
Expand All @@ -208,15 +268,26 @@ pub async fn connect_tcp<Ws: WithSocket>(
///
/// This implements the same behavior as [`tokio::net::TcpStream::connect()`].
#[cfg(feature = "_rt-async-io")]
async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result<impl Socket> {
async fn connect_tcp_async_io(
host: &str,
port: u16,
keepalive: Option<&KeepaliveConfig>,
) -> crate::Result<impl Socket> {
use async_io::Async;
use std::net::{IpAddr, TcpStream, ToSocketAddrs};

// IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
let host = host.trim_matches(&['[', ']'][..]);

if let Ok(addr) = host.parse::<IpAddr>() {
return Ok(Async::<TcpStream>::connect((addr, port)).await?);
let stream = Async::<TcpStream>::connect((addr, port)).await?;

if let Some(ka) = keepalive {
let sock = socket2::SockRef::from(stream.get_ref());
sock.set_tcp_keepalive(&build_tcp_keepalive(ka))?;
}

return Ok(stream);
}

let host = host.to_string();
Expand All @@ -232,7 +303,14 @@ async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result<impl Socke
// Loop through all the Socket Addresses that the hostname resolves to
for socket_addr in addresses {
match Async::<TcpStream>::connect(socket_addr).await {
Ok(stream) => return Ok(stream),
Ok(stream) => {
if let Some(ka) = keepalive {
let sock = socket2::SockRef::from(stream.get_ref());
sock.set_tcp_keepalive(&build_tcp_keepalive(ka))?;
}

return Ok(stream);
}
Err(e) => last_err = Some(e),
}
}
Expand Down
4 changes: 3 additions & 1 deletion sqlx-mysql/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ impl MySqlConnection {

let handshake = match &options.socket {
Some(path) => crate::net::connect_uds(path, do_handshake).await?,
None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?,
None => {
crate::net::connect_tcp(&options.host, options.port, None, do_handshake).await?
}
};

let stream = handshake?;
Expand Down
15 changes: 14 additions & 1 deletion sqlx-postgres/src/connection/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,20 @@ impl PgStream {
pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
let socket_result = match options.fetch_socket() {
Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?,
None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?,
None => {
let keepalive = if options.keepalives {
Some(&options.keepalive_config)
} else {
None
};
net::connect_tcp(
&options.host,
options.port,
keepalive,
MaybeUpgradeTls(options),
)
.await?
}
};

let socket = socket_result?;
Expand Down
114 changes: 114 additions & 0 deletions sqlx-postgres/src/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ use std::borrow::Cow;
use std::env::var;
use std::fmt::{self, Display, Write};
use std::path::{Path, PathBuf};
use std::time::Duration;

pub use ssl_mode::PgSslMode;

use crate::net::KeepaliveConfig;
use crate::{connection::LogSettings, net::tls::CertificateInput};

mod connect;
Expand All @@ -30,6 +32,8 @@ pub struct PgConnectOptions {
pub(crate) log_settings: LogSettings,
pub(crate) extra_float_digits: Option<Cow<'static, str>>,
pub(crate) options: Option<String>,
pub(crate) keepalives: bool,
pub(crate) keepalive_config: KeepaliveConfig,
}

impl Default for PgConnectOptions {
Expand Down Expand Up @@ -97,6 +101,9 @@ impl PgConnectOptions {
extra_float_digits: Some("2".into()),
log_settings: Default::default(),
options: var("PGOPTIONS").ok(),
// Matches libpq default: keepalives=1 with OS defaults for timers.
keepalives: true,
keepalive_config: KeepaliveConfig::default(),
}
}

Expand Down Expand Up @@ -452,6 +459,85 @@ impl PgConnectOptions {
self
}

/// Enables or disables TCP keepalive on the connection.
///
/// This option is ignored for Unix domain sockets.
///
/// Keepalive is enabled by default.
///
/// When enabled, OS defaults are used for all timer parameters unless
/// overridden by [`keepalives_idle`][Self::keepalives_idle],
/// [`keepalives_interval`][Self::keepalives_interval], or
/// [`keepalives_retries`][Self::keepalives_retries].
///
/// # Example
///
/// ```rust
/// # use sqlx_postgres::PgConnectOptions;
/// let options = PgConnectOptions::new()
/// .keepalives(false);
/// ```
pub fn keepalives(mut self, enable: bool) -> Self {
self.keepalives = enable;
self
}

/// Sets the idle time before TCP keepalive probes begin.
///
/// This is ignored for Unix domain sockets, or if the `keepalives`
/// option is disabled.
///
/// # Example
///
/// ```rust
/// # use std::time::Duration;
/// # use sqlx_postgres::PgConnectOptions;
/// let options = PgConnectOptions::new()
/// .keepalives_idle(Duration::from_secs(60));
/// ```
pub fn keepalives_idle(mut self, idle: Duration) -> Self {
self.keepalive_config.idle = Some(idle);
self
}

/// Sets the interval between TCP keepalive probes.
///
/// This is ignored for Unix domain sockets, or if the `keepalives`
/// option is disabled.
///
/// # Example
///
/// ```rust
/// # use std::time::Duration;
/// # use sqlx_postgres::PgConnectOptions;
/// let options = PgConnectOptions::new()
/// .keepalives_interval(Duration::from_secs(5));
/// ```
pub fn keepalives_interval(mut self, interval: Duration) -> Self {
self.keepalive_config.interval = Some(interval);
self
}

/// Sets the maximum number of TCP keepalive probes before the connection is dropped.
///
/// This is ignored for Unix domain sockets, or if the `keepalives`
/// option is disabled.
///
/// Only supported on Unix platforms; ignored on other platforms.
///
/// # Example
///
/// ```rust
/// # use sqlx_postgres::PgConnectOptions;
/// let options = PgConnectOptions::new()
/// .keepalives_retries(3);
/// ```
#[cfg(unix)]
pub fn keepalives_retries(mut self, retries: u32) -> Self {
self.keepalive_config.retries = Some(retries);
self
}

/// We try using a socket if hostname starts with `/` or if socket parameter
/// is specified.
pub(crate) fn fetch_socket(&self) -> Option<String> {
Expand Down Expand Up @@ -580,6 +666,34 @@ impl PgConnectOptions {
pub fn get_options(&self) -> Option<&str> {
self.options.as_deref()
}

/// Get whether TCP keepalives are enabled.
///
/// # Example
///
/// ```rust
/// # use sqlx_postgres::PgConnectOptions;
/// let options = PgConnectOptions::new();
/// assert!(options.get_keepalives());
/// ```
pub fn get_keepalives(&self) -> bool {
self.keepalives
}

/// Get the idle time before TCP keepalive probes begin.
pub fn get_keepalives_idle(&self) -> Option<Duration> {
self.keepalive_config.idle
}

/// Get the interval between TCP keepalive probes.
pub fn get_keepalives_interval(&self) -> Option<Duration> {
self.keepalive_config.interval
}

/// Get the maximum number of TCP keepalive probes.
pub fn get_keepalives_retries(&self) -> Option<u32> {
self.keepalive_config.retries
}
}

fn default_host(port: u16) -> String {
Expand Down
Loading