diff --git a/crates/shadowsocks/src/net/option.rs b/crates/shadowsocks/src/net/option.rs index 4cca34e94c16..b53b6bace314 100644 --- a/crates/shadowsocks/src/net/option.rs +++ b/crates/shadowsocks/src/net/option.rs @@ -16,6 +16,9 @@ pub struct TcpSocketOpts { /// `TCP_FASTOPEN`, enables TFO pub fastopen: bool, + + /// `TCP_KEEPALIVE`, enables keep-alive messages on connection-oriented sockets + pub keepalive: bool, } impl Default for TcpSocketOpts { @@ -25,6 +28,7 @@ impl Default for TcpSocketOpts { recv_buffer_size: None, nodelay: false, fastopen: false, + keepalive: true, } } } diff --git a/crates/shadowsocks/src/net/sys/mod.rs b/crates/shadowsocks/src/net/sys/mod.rs index 7c9fcb79e491..864603008439 100644 --- a/crates/shadowsocks/src/net/sys/mod.rs +++ b/crates/shadowsocks/src/net/sys/mod.rs @@ -46,9 +46,66 @@ fn set_common_sockopt_for_connect(addr: SocketAddr, socket: &TcpSocket, opts: &C } fn set_common_sockopt_after_connect(stream: &tokio::net::TcpStream, opts: &ConnectOpts) -> io::Result<()> { - if opts.tcp.nodelay { - stream.set_nodelay(true)?; - } + stream.set_nodelay(opts.tcp.nodelay)?; + set_common_sockopt_after_connect_sys(stream, opts)?; + + Ok(()) +} + +#[cfg(unix)] +#[inline] +fn set_common_sockopt_after_connect_sys(stream: &tokio::net::TcpStream, opts: &ConnectOpts) -> io::Result<()> { + use socket2::Socket; + use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd}; + + let socket = unsafe { Socket::from_raw_fd(stream.as_raw_fd()) }; + + macro_rules! try_sockopt { + ($socket:ident . $func:ident ($($arg:expr),*)) => { + match $socket . $func ($($arg),*) { + Ok(e) => e, + Err(err) => { + let _ = socket.into_raw_fd(); + return Err(err); + } + } + }; + } + + try_sockopt!(socket.set_keepalive(opts.tcp.keepalive)); + + let _ = socket.into_raw_fd(); + Ok(()) +} + +#[cfg(windows)] +#[inline] +fn set_common_sockopt_after_connect_sys(stream: &tokio::net::TcpStream, opts: &ConnectOpts) -> io::Result<()> { + use socket2::Socket; + use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket}; + + let socket = unsafe { Socket::from_raw_socket(stream.as_raw_socket()) }; + + macro_rules! try_sockopt { + ($socket:ident . $func:ident ($($arg:expr),*)) => { + match $socket . $func ($($arg),*) { + Ok(e) => e, + Err(err) => { + let _ = socket.into_raw_socket(); + return Err(err); + } + } + }; + } + + try_sockopt!(socket.set_keepalive(opts.tcp.keepalive)); + + let _ = socket.into_raw_socket(); + Ok(()) +} +#[cfg(all(not(windows), not(unix)))] +#[inline] +fn set_common_sockopt_after_connect_sys(_: &tokio::net::TcpStream, _: &ConnectOpts) -> io::Result<()> { Ok(()) } diff --git a/crates/shadowsocks/src/net/tcp.rs b/crates/shadowsocks/src/net/tcp.rs index e210553743fb..c8c9dbb26ea9 100644 --- a/crates/shadowsocks/src/net/tcp.rs +++ b/crates/shadowsocks/src/net/tcp.rs @@ -241,41 +241,69 @@ impl From for TokioTcpListener { } #[cfg(unix)] -fn setsockopt_with_opt(f: &F, opts: &AcceptOpts) -> io::Result<()> { +fn setsockopt_with_opt(f: &tokio::net::TcpStream, opts: &AcceptOpts) -> io::Result<()> { let socket = unsafe { Socket::from_raw_fd(f.as_raw_fd()) }; + macro_rules! try_sockopt { + ($socket:ident . $func:ident ($($arg:expr),*)) => { + match $socket . $func ($($arg),*) { + Ok(e) => e, + Err(err) => { + let _ = socket.into_raw_fd(); + return Err(err); + } + } + }; + } + if let Some(buf_size) = opts.tcp.send_buffer_size { - socket.set_send_buffer_size(buf_size as usize)?; + try_sockopt!(socket.set_send_buffer_size(buf_size as usize)); } if let Some(buf_size) = opts.tcp.recv_buffer_size { - socket.set_recv_buffer_size(buf_size as usize)?; + try_sockopt!(socket.set_recv_buffer_size(buf_size as usize)); } - if opts.tcp.nodelay { - socket.set_nodelay(true)?; - } + try_sockopt!(socket.set_nodelay(opts.tcp.nodelay)); + try_sockopt!(socket.set_keepalive(opts.tcp.keepalive)); let _ = socket.into_raw_fd(); Ok(()) } #[cfg(windows)] -fn setsockopt_with_opt(f: &F, opts: &AcceptOpts) -> io::Result<()> { +fn setsockopt_with_opt(f: &tokio::net::TcpStream, opts: &AcceptOpts) -> io::Result<()> { let socket = unsafe { Socket::from_raw_socket(f.as_raw_socket()) }; + macro_rules! try_sockopt { + ($socket:ident . $func:ident ($($arg:expr),*)) => { + match $socket . $func ($($arg),*) { + Ok(e) => e, + Err(err) => { + let _ = socket.into_raw_socket(); + return Err(err); + } + } + }; + } + if let Some(buf_size) = opts.tcp.send_buffer_size { - socket.set_send_buffer_size(buf_size as usize)?; + try_sockopt!(socket.set_send_buffer_size(buf_size as usize)); } if let Some(buf_size) = opts.tcp.recv_buffer_size { - socket.set_recv_buffer_size(buf_size as usize)?; + try_sockopt!(socket.set_recv_buffer_size(buf_size as usize)); } - if opts.tcp.nodelay { - socket.set_nodelay(true)?; - } + try_sockopt!(socket.set_nodelay(opts.tcp.nodelay)); + try_sockopt!(socket.set_keepalive(opts.tcp.keepalive)); let _ = socket.into_raw_socket(); Ok(()) } + +#[cfg(all(not(windows), not(unix)))] +fn setsockopt_with_opt(f: &tokio::net::TcpStream, opts: &AcceptOpts) -> io::Result<()> { + f.set_nodelay(opts.tcp.nodelay)?; + Ok(()) +}