diff --git a/Cargo.lock b/Cargo.lock index 67e9d12531..a321adbbe7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1758,7 +1758,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.5.6", + "socket2 0.4.10", "tokio", "tower-service", "tracing", @@ -3274,9 +3274,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.6" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" +checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" dependencies = [ "libc", "windows-sys 0.52.0", @@ -3399,6 +3399,7 @@ dependencies = [ "serde_json", "sha2", "smallvec", + "socket2 0.5.7", "sqlx", "thiserror", "time", @@ -3935,7 +3936,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2 0.5.6", + "socket2 0.5.7", "tokio-macros", "windows-sys 0.48.0", ] diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 789d30fb1c..918abf002c 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -80,6 +80,7 @@ hashlink = "0.9.0" indexmap = "2.0" event-listener = "5.2.0" hashbrown = "0.14.5" +socket2 = { version = "0.5.7", features = ["all"] } [dev-dependencies] sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] } diff --git a/sqlx-core/src/net/mod.rs b/sqlx-core/src/net/mod.rs index f9c43668ab..cd2af8baec 100644 --- a/sqlx-core/src/net/mod.rs +++ b/sqlx-core/src/net/mod.rs @@ -2,5 +2,6 @@ mod socket; pub mod tls; pub use socket::{ - connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer, + connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, TcpKeepalive, WithSocket, + WriteBuffer, }; diff --git a/sqlx-core/src/net/socket/mod.rs b/sqlx-core/src/net/socket/mod.rs index 0470abb5ec..7e9a9ca463 100644 --- a/sqlx-core/src/net/socket/mod.rs +++ b/sqlx-core/src/net/socket/mod.rs @@ -8,10 +8,12 @@ use bytes::BufMut; use futures_core::ready; pub use buffered::{BufferedSocket, WriteBuffer}; +pub use tcp_keepalive::TcpKeepalive; use crate::io::ReadBuf; mod buffered; +mod tcp_keepalive; pub trait Socket: Send + Sync + Unpin + 'static { fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result; @@ -186,6 +188,7 @@ pub async fn connect_tcp( host: &str, port: u16, with_socket: Ws, + keepalive: Option<&TcpKeepalive>, ) -> crate::Result { // IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those. let host = host.trim_matches(&['[', ']'][..]); @@ -197,6 +200,13 @@ pub async fn connect_tcp( let stream = TcpStream::connect((host, port)).await?; stream.set_nodelay(true)?; + // set tcp keepalive + if let Some(keepalive) = keepalive { + let keepalive = keepalive.socket2(); + let sock_ref = socket2::SockRef::from(&stream); + sock_ref.set_tcp_keepalive(&keepalive)?; + } + return Ok(with_socket.with_socket(stream)); } @@ -216,9 +226,21 @@ pub async fn connect_tcp( s.get_ref().set_nodelay(true)?; Ok(s) }); - match stream { - Ok(stream) => return Ok(with_socket.with_socket(stream)), - Err(e) => last_err = Some(e), + let stream = match stream { + Ok(stream) => stream, + Err(e) => { + last_err = Some(e); + continue; + } + }; + // set tcp keepalive + if let Some(keepalive) = keepalive { + let keepalive = keepalive.socket2(); + let sock_ref = socket2::SockRef::from(&stream); + match sock_ref.set_tcp_keepalive(&keepalive) { + Ok(_) => return Ok(with_socket.with_socket(stream)), + Err(e) => last_err = Some(e), + } } } diff --git a/sqlx-core/src/net/socket/tcp_keepalive.rs b/sqlx-core/src/net/socket/tcp_keepalive.rs new file mode 100644 index 0000000000..9b381f21de --- /dev/null +++ b/sqlx-core/src/net/socket/tcp_keepalive.rs @@ -0,0 +1,244 @@ +use std::time::Duration; + +/// Configures a socket's TCP keepalive parameters. +#[derive(Debug, Clone, Copy)] +pub struct TcpKeepalive { + #[cfg_attr( + any(target_os = "openbsd", target_os = "haiku", target_os = "vita"), + allow(dead_code) + )] + time: Option, + #[cfg(not(any( + target_os = "openbsd", + target_os = "redox", + target_os = "solaris", + target_os = "nto", + target_os = "espidf", + target_os = "vita", + target_os = "haiku", + )))] + interval: Option, + #[cfg(not(any( + target_os = "openbsd", + target_os = "redox", + target_os = "solaris", + target_os = "windows", + target_os = "nto", + target_os = "espidf", + target_os = "vita", + target_os = "haiku", + )))] + retries: Option, +} + +impl TcpKeepalive { + /// Returns a new, empty set of TCP keepalive parameters. + /// The unset parameters will use OS-defined defaults. + pub const fn new() -> TcpKeepalive { + TcpKeepalive { + time: None, + #[cfg(not(any( + target_os = "openbsd", + target_os = "redox", + target_os = "solaris", + target_os = "nto", + target_os = "espidf", + target_os = "vita", + target_os = "haiku", + )))] + interval: None, + #[cfg(not(any( + target_os = "openbsd", + target_os = "redox", + target_os = "solaris", + target_os = "windows", + target_os = "nto", + target_os = "espidf", + target_os = "vita", + target_os = "haiku", + )))] + retries: None, + } + } + + /// Set the amount of time after which TCP keepalive probes will be sent on + /// idle connections. + /// + /// This will set `TCP_KEEPALIVE` on macOS and iOS, and + /// `TCP_KEEPIDLE` on all other Unix operating systems, except + /// OpenBSD and Haiku which don't support any way to set this + /// option. On Windows, this sets the value of the `tcp_keepalive` + /// struct's `keepalivetime` field. + /// + /// Some platforms specify this value in seconds, so sub-second + /// specifications may be omitted. + pub const fn with_time(self, time: Duration) -> Self { + Self { + time: Some(time), + ..self + } + } + + /// Set the value of the `TCP_KEEPINTVL` option. On Windows, this sets the + /// value of the `tcp_keepalive` struct's `keepaliveinterval` field. + /// + /// Sets the time interval between TCP keepalive probes. + /// + /// Some platforms specify this value in seconds, so sub-second + /// specifications may be omitted. + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + target_os = "windows", + ))] + #[cfg_attr( + docsrs, + doc(cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + target_os = "windows", + ))) + )] + pub const fn with_interval(self, interval: Duration) -> Self { + Self { + interval: Some(interval), + ..self + } + } + + /// Set the value of the `TCP_KEEPCNT` option. + /// + /// Set the maximum number of TCP keepalive probes that will be sent before + /// dropping a connection, if TCP keepalive is enabled on this socket. + /// + /// This setter has no effect on Windows. + #[cfg(all(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + )))] + #[cfg_attr( + docsrs, + doc(cfg(all(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + )))) + )] + pub const fn with_retries(self, retries: u32) -> Self { + Self { + retries: Some(retries), + ..self + } + } + + /// Convert `TcpKeepalive` to `socket2::TcpKeepalive`. + #[doc(hidden)] + pub(super) const fn socket2(self) -> socket2::TcpKeepalive { + let mut ka = socket2::TcpKeepalive::new(); + if let Some(time) = self.time { + ka = ka.with_time(time); + } + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + target_os = "windows", + ))] + #[cfg_attr( + docsrs, + doc(cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + target_os = "windows", + ))) + )] + if let Some(interval) = self.interval { + ka = ka.with_interval(interval); + } + #[cfg(all(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + )))] + #[cfg_attr( + docsrs, + doc(cfg(all(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + )))) + )] + if let Some(retries) = self.retries { + ka = ka.with_retries(retries); + } + ka + } +} diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index 468478e550..786c0f6aea 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -19,7 +19,15 @@ impl MySqlConnection { let handshake = match &options.socket { Some(path) => crate::net::connect_uds(path, do_handshake).await?, - None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?, + None => { + crate::net::connect_tcp( + &options.host, + options.port, + do_handshake, + options.tcp_keep_alive.as_ref(), + ) + .await? + } }; let stream = handshake.await?; diff --git a/sqlx-mysql/src/options/mod.rs b/sqlx-mysql/src/options/mod.rs index db2b20c19d..c37b47d58d 100644 --- a/sqlx-mysql/src/options/mod.rs +++ b/sqlx-mysql/src/options/mod.rs @@ -1,10 +1,11 @@ use std::path::{Path, PathBuf}; +use std::time::Duration; mod connect; mod parse; mod ssl_mode; -use crate::{connection::LogSettings, net::tls::CertificateInput}; +use crate::{connection::LogSettings, net::tls::CertificateInput, net::TcpKeepalive}; pub use ssl_mode::MySqlSslMode; /// Options and flags which can be used to configure a MySQL connection. @@ -80,6 +81,7 @@ pub struct MySqlConnectOptions { pub(crate) no_engine_substitution: bool, pub(crate) timezone: Option, pub(crate) set_names: bool, + pub(crate) tcp_keep_alive: Option, } impl Default for MySqlConnectOptions { @@ -111,6 +113,7 @@ impl MySqlConnectOptions { no_engine_substitution: true, timezone: Some(String::from("+00:00")), set_names: true, + tcp_keep_alive: None, } } @@ -403,6 +406,16 @@ impl MySqlConnectOptions { self.set_names = flag_val; self } + + /// Sets the TCP keepalive time for the connection. + pub fn tcp_keepalive_time(mut self, time: Duration) -> Self { + self.tcp_keep_alive = Some(if self.tcp_keep_alive.is_none() { + TcpKeepalive::new().with_time(time) + } else { + self.tcp_keep_alive.unwrap().with_time(time) + }); + self + } } impl MySqlConnectOptions { diff --git a/sqlx-postgres/src/connection/stream.rs b/sqlx-postgres/src/connection/stream.rs index f165899248..151b548e79 100644 --- a/sqlx-postgres/src/connection/stream.rs +++ b/sqlx-postgres/src/connection/stream.rs @@ -44,7 +44,15 @@ impl PgStream { pub(super) async fn connect(options: &PgConnectOptions) -> Result { let socket_future = match options.fetch_socket() { Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?, - None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?, + None => { + net::connect_tcp( + &options.host, + options.port, + MaybeUpgradeTls(options), + options.tcp_keep_alive.as_ref(), + ) + .await? + } }; let socket = socket_future.await?; diff --git a/sqlx-postgres/src/options/mod.rs b/sqlx-postgres/src/options/mod.rs index a0b222606a..14e8dbc247 100644 --- a/sqlx-postgres/src/options/mod.rs +++ b/sqlx-postgres/src/options/mod.rs @@ -2,10 +2,11 @@ use std::borrow::Cow; use std::env::var; use std::fmt::{Display, Write}; use std::path::{Path, PathBuf}; +use std::time::Duration; pub use ssl_mode::PgSslMode; -use crate::{connection::LogSettings, net::tls::CertificateInput}; +use crate::{connection::LogSettings, net::tls::CertificateInput, net::TcpKeepalive}; mod connect; mod parse; @@ -102,6 +103,7 @@ pub struct PgConnectOptions { pub(crate) application_name: Option, pub(crate) log_settings: LogSettings, pub(crate) extra_float_digits: Option>, + pub(crate) tcp_keep_alive: Option, pub(crate) options: Option, } @@ -168,6 +170,7 @@ impl PgConnectOptions { application_name: var("PGAPPNAME").ok(), extra_float_digits: Some("2".into()), log_settings: Default::default(), + tcp_keep_alive: None, options: var("PGOPTIONS").ok(), } } @@ -493,6 +496,16 @@ impl PgConnectOptions { self } + /// Sets the TCP keepalive time for the connection. + pub fn tcp_keepalive_time(mut self, time: Duration) -> Self { + self.tcp_keep_alive = Some(if self.tcp_keep_alive.is_none() { + TcpKeepalive::new().with_time(time) + } else { + self.tcp_keep_alive.unwrap().with_time(time) + }); + self + } + /// Set additional startup options for the connection as a list of key-value pairs. /// /// # Example