Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(net): use exclusive &mut for socket operations #841

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/tracing/net/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use crate::tracing::{
MultipathStrategy, PrivilegeMode, Probe, TracerChannelConfig, TracerProtocol,
};
use arrayvec::ArrayVec;
use itertools::Itertools;
use std::net::IpAddr;
use std::time::{Duration, SystemTime};
use tracing::instrument;
Expand Down Expand Up @@ -215,17 +214,23 @@ impl<S: Socket> TracerChannel<S> {
.retain(|probe| probe.start.elapsed().unwrap_or_default() < self.tcp_connect_timeout);
let found_index = self
.tcp_probes
.iter()
.find_position(|&probe| probe.socket.is_writable().unwrap_or_default())
.map(|(i, _)| i);
.iter_mut()
.enumerate()
.find_map(|(index, probe)| {
if probe.socket.is_writable().unwrap_or_default() {
Some(index)
} else {
None
}
});
if let Some(i) = found_index {
let probe = self.tcp_probes.remove(i);
let mut probe = self.tcp_probes.remove(i);
match self.dest_addr {
IpAddr::V4(_) => {
ipv4::recv_tcp_socket(&probe.socket, probe.sequence, self.dest_addr)
ipv4::recv_tcp_socket(&mut probe.socket, probe.sequence, self.dest_addr)
}
IpAddr::V6(_) => {
ipv6::recv_tcp_socket(&probe.socket, probe.sequence, self.dest_addr)
ipv6::recv_tcp_socket(&mut probe.socket, probe.sequence, self.dest_addr)
}
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/tracing/net/ipv4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ pub fn recv_icmp_probe<S: Socket>(

#[instrument(skip(tcp_socket))]
pub fn recv_tcp_socket<S: Socket>(
tcp_socket: &S,
tcp_socket: &mut S,
sequence: Sequence,
dest_addr: IpAddr,
) -> TraceResult<Option<ProbeResponse>> {
Expand Down
2 changes: 1 addition & 1 deletion src/tracing/net/ipv6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ pub fn recv_icmp_probe<S: Socket>(

#[instrument(skip(tcp_socket))]
pub fn recv_tcp_socket<S: Socket>(
tcp_socket: &S,
tcp_socket: &mut S,
sequence: Sequence,
dest_addr: IpAddr,
) -> TraceResult<Option<ProbeResponse>> {
Expand Down
44 changes: 22 additions & 22 deletions src/tracing/net/platform/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ fn test_send_local_ip4_packet(src_addr: Ipv4Addr, total_length: u16) -> TraceRes
ipv4.set_destination(Ipv4Addr::LOCALHOST);
ipv4.set_total_length(total_length);
ipv4.set_payload(icmp.packet());
let probe_socket = SocketImpl::new_dgram_ipv4(Protocol::ICMPV4)
let mut probe_socket = SocketImpl::new_dgram_ipv4(Protocol::ICMPV4)
.or_else(|_| SocketImpl::new_raw_ipv4(Protocol::from(nix::libc::IPPROTO_RAW)))?;
probe_socket.set_header_included(true)?;
let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
Expand Down Expand Up @@ -141,7 +141,7 @@ pub fn is_host_unreachable_error(_code: i32) -> bool {
/// Note that no packets are transmitted by this method.
#[instrument(ret)]
pub fn discover_local_addr(target_addr: IpAddr, port: u16) -> TraceResult<IpAddr> {
let socket = match target_addr {
let mut socket = match target_addr {
IpAddr::V4(_) => SocketImpl::new_udp_dgram_socket_ipv4(),
IpAddr::V6(_) => SocketImpl::new_udp_dgram_socket_ipv6(),
}?;
Expand Down Expand Up @@ -209,12 +209,12 @@ impl Socket for SocketImpl {
#[instrument]
fn new_icmp_send_socket_ipv4(raw: bool) -> IoResult<Self> {
if raw {
let socket = Self::new_raw_ipv4(Protocol::from(nix::libc::IPPROTO_RAW))?;
let mut socket = Self::new_raw_ipv4(Protocol::from(nix::libc::IPPROTO_RAW))?;
socket.set_nonblocking(true)?;
socket.set_header_included(true)?;
Ok(socket)
} else {
let socket = Self::new(Domain::IPV4, Type::DGRAM, Protocol::ICMPV4)?;
let mut socket = Self::new(Domain::IPV4, Type::DGRAM, Protocol::ICMPV4)?;
socket.set_nonblocking(true)?;
socket.set_header_included(true)?;
Ok(socket)
Expand All @@ -235,7 +235,7 @@ impl Socket for SocketImpl {
#[instrument]
fn new_udp_send_socket_ipv4(raw: bool) -> IoResult<Self> {
if raw {
let socket = Self::new_raw_ipv4(Protocol::from(nix::libc::IPPROTO_RAW))?;
let mut socket = Self::new_raw_ipv4(Protocol::from(nix::libc::IPPROTO_RAW))?;
socket.set_nonblocking(true)?;
socket.set_header_included(true)?;
Ok(socket)
Expand All @@ -260,7 +260,7 @@ impl Socket for SocketImpl {
#[instrument]
fn new_recv_socket_ipv4(addr: Ipv4Addr, raw: bool) -> IoResult<Self> {
if raw {
let socket = Self::new_raw_ipv4(Protocol::ICMPV4)?;
let mut socket = Self::new_raw_ipv4(Protocol::ICMPV4)?;
socket.set_nonblocking(true)?;
socket.set_header_included(true)?;
Ok(socket)
Expand All @@ -284,14 +284,14 @@ impl Socket for SocketImpl {
}
#[instrument]
fn new_stream_socket_ipv4() -> IoResult<Self> {
let socket = Self::new(Domain::IPV4, Type::STREAM, Protocol::TCP)?;
let mut socket = Self::new(Domain::IPV4, Type::STREAM, Protocol::TCP)?;
socket.set_nonblocking(true)?;
socket.set_reuse_port(true)?;
Ok(socket)
}
#[instrument]
fn new_stream_socket_ipv6() -> IoResult<Self> {
let socket = Self::new(Domain::IPV6, Type::STREAM, Protocol::TCP)?;
let mut socket = Self::new(Domain::IPV6, Type::STREAM, Protocol::TCP)?;
socket.set_nonblocking(true)?;
socket.set_reuse_port(true)?;
Ok(socket)
Expand All @@ -311,52 +311,52 @@ impl Socket for SocketImpl {
.map_err(|err| IoError::Bind(err, address))
}
#[instrument(skip(self))]
fn set_tos(&self, tos: u32) -> IoResult<()> {
fn set_tos(&mut self, tos: u32) -> IoResult<()> {
self.inner
.set_tos(tos)
.map_err(|err| IoError::Other(err, IoOperation::SetTos))
}
#[instrument(skip(self))]
fn set_ttl(&self, ttl: u32) -> IoResult<()> {
fn set_ttl(&mut self, ttl: u32) -> IoResult<()> {
self.inner
.set_ttl(ttl)
.map_err(|err| IoError::Other(err, IoOperation::SetTtl))
}
#[instrument(skip(self))]
fn set_reuse_port(&self, reuse: bool) -> IoResult<()> {
fn set_reuse_port(&mut self, reuse: bool) -> IoResult<()> {
self.inner
.set_reuse_port(reuse)
.map_err(|err| IoError::Other(err, IoOperation::SetReusePort))
}
#[instrument(skip(self))]
fn set_header_included(&self, included: bool) -> IoResult<()> {
fn set_header_included(&mut self, included: bool) -> IoResult<()> {
self.inner
.set_header_included(included)
.map_err(|err| IoError::Other(err, IoOperation::SetHeaderIncluded))
}
#[instrument(skip(self))]
fn set_unicast_hops_v6(&self, hops: u8) -> IoResult<()> {
fn set_unicast_hops_v6(&mut self, hops: u8) -> IoResult<()> {
self.inner
.set_unicast_hops_v6(u32::from(hops))
.map_err(|err| IoError::Other(err, IoOperation::SetUnicastHopsV6))
}
#[instrument(skip(self))]
fn connect(&self, address: SocketAddr) -> IoResult<()> {
fn connect(&mut self, address: SocketAddr) -> IoResult<()> {
tracing::debug!(?address);
self.inner
.connect(&SockAddr::from(address))
.map_err(|err| IoError::Connect(err, address))
}
#[instrument(skip(self, buf))]
fn send_to(&self, buf: &[u8], addr: SocketAddr) -> IoResult<()> {
fn send_to(&mut self, buf: &[u8], addr: SocketAddr) -> IoResult<()> {
tracing::debug!(buf = format!("{:02x?}", buf.iter().format(" ")), ?addr);
self.inner
.send_to(buf, &SockAddr::from(addr))
.map_err(|err| IoError::SendTo(err, addr))?;
Ok(())
}
#[instrument(skip(self))]
fn is_readable(&self, timeout: Duration) -> IoResult<bool> {
fn is_readable(&mut self, timeout: Duration) -> IoResult<bool> {
let mut read = FdSet::new();
read.insert(&self.inner);
let readable = nix::sys::select::select(
Expand All @@ -376,7 +376,7 @@ impl Socket for SocketImpl {
}
}
#[instrument(skip(self))]
fn is_writable(&self) -> IoResult<bool> {
fn is_writable(&mut self) -> IoResult<bool> {
let mut write = FdSet::new();
write.insert(&self.inner);
let writable = nix::sys::select::select(
Expand Down Expand Up @@ -421,13 +421,13 @@ impl Socket for SocketImpl {
Ok(bytes_read)
}
#[instrument(skip(self))]
fn shutdown(&self) -> IoResult<()> {
fn shutdown(&mut self) -> IoResult<()> {
self.inner
.shutdown(Shutdown::Both)
.map_err(|err| IoError::Other(err, IoOperation::Shutdown))
}
#[instrument(skip(self), ret)]
fn peer_addr(&self) -> IoResult<Option<SocketAddr>> {
fn peer_addr(&mut self) -> IoResult<Option<SocketAddr>> {
let addr = self
.inner
.peer_addr()
Expand All @@ -437,19 +437,19 @@ impl Socket for SocketImpl {
Ok(addr)
}
#[instrument(skip(self), ret)]
fn take_error(&self) -> IoResult<Option<io::Error>> {
fn take_error(&mut self) -> IoResult<Option<io::Error>> {
self.inner
.take_error()
.map_err(|err| IoError::Other(err, IoOperation::TakeError))
}
#[allow(clippy::unused_self, clippy::unnecessary_wraps)]
#[instrument(skip(self), ret)]
fn icmp_error_info(&self) -> IoResult<IpAddr> {
fn icmp_error_info(&mut self) -> IoResult<IpAddr> {
Ok(IpAddr::V4(Ipv4Addr::UNSPECIFIED))
}
#[allow(clippy::unused_self, clippy::unnecessary_wraps)]
#[instrument(skip(self))]
fn close(&self) -> IoResult<()> {
fn close(&mut self) -> IoResult<()> {
Ok(())
}
}
Expand Down
38 changes: 19 additions & 19 deletions src/tracing/net/platform/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ impl SocketImpl {
}

#[instrument(skip(self))]
fn get_overlapped_result(&self) -> IoResult<()> {
fn get_overlapped_result(&mut self) -> IoResult<()> {
let mut bytes_read = 0;
let mut flags = 0;
let ol = *self.ol;
Expand Down Expand Up @@ -309,7 +309,7 @@ impl Socket for SocketImpl {
#[instrument]
fn new_icmp_send_socket_ipv4(raw: bool) -> IoResult<Self> {
if raw {
let sock = Self::new(Domain::IPV4, Type::RAW, Some(Protocol::from(IPPROTO_RAW)))?;
let mut sock = Self::new(Domain::IPV4, Type::RAW, Some(Protocol::from(IPPROTO_RAW)))?;
sock.set_non_blocking(true)?;
sock.set_header_included(true)?;
Ok(sock)
Expand All @@ -332,7 +332,7 @@ impl Socket for SocketImpl {
#[instrument]
fn new_udp_send_socket_ipv4(raw: bool) -> IoResult<Self> {
if raw {
let sock = Self::new(Domain::IPV4, Type::RAW, Some(Protocol::from(IPPROTO_RAW)))?;
let mut sock = Self::new(Domain::IPV4, Type::RAW, Some(Protocol::from(IPPROTO_RAW)))?;
sock.set_non_blocking(true)?;
sock.set_header_included(true)?;
Ok(sock)
Expand Down Expand Up @@ -381,15 +381,15 @@ impl Socket for SocketImpl {

#[instrument]
fn new_stream_socket_ipv4() -> IoResult<Self> {
let sock = Self::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?;
let mut sock = Self::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?;
sock.set_non_blocking(true)?;
sock.set_reuse_port(true)?;
Ok(sock)
}

#[instrument]
fn new_stream_socket_ipv6() -> IoResult<Self> {
let sock = Self::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP))?;
let mut sock = Self::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP))?;
sock.set_non_blocking(true)?;
sock.set_reuse_port(true)?;
Ok(sock)
Expand Down Expand Up @@ -422,21 +422,21 @@ impl Socket for SocketImpl {
}

#[instrument(skip(self))]
fn set_tos(&self, tos: u32) -> IoResult<()> {
fn set_tos(&mut self, tos: u32) -> IoResult<()> {
self.inner
.set_tos(tos)
.map_err(|err| IoError::Other(err, IoOperation::SetTos))
}

#[instrument(skip(self))]
fn set_ttl(&self, ttl: u32) -> IoResult<()> {
fn set_ttl(&mut self, ttl: u32) -> IoResult<()> {
self.inner
.set_ttl(ttl)
.map_err(|err| IoError::Other(err, IoOperation::SetTtl))
}

#[instrument(skip(self))]
fn set_reuse_port(&self, is_reuse_port: bool) -> IoResult<()> {
fn set_reuse_port(&mut self, is_reuse_port: bool) -> IoResult<()> {
self.setsockopt_bool(SOL_SOCKET as _, SO_REUSE_UNICASTPORT as _, is_reuse_port)
.or_else(|_| {
self.setsockopt_bool(SOL_SOCKET as _, SO_PORT_SCALABILITY as _, is_reuse_port)
Expand All @@ -445,21 +445,21 @@ impl Socket for SocketImpl {
}

#[instrument(skip(self))]
fn set_header_included(&self, is_header_included: bool) -> IoResult<()> {
fn set_header_included(&mut self, is_header_included: bool) -> IoResult<()> {
self.inner
.set_header_included(is_header_included)
.map_err(|err| IoError::Other(err, IoOperation::SetHeaderIncluded))
}

#[instrument(skip(self))]
fn set_unicast_hops_v6(&self, max_hops: u8) -> IoResult<()> {
fn set_unicast_hops_v6(&mut self, max_hops: u8) -> IoResult<()> {
self.inner
.set_unicast_hops_v6(max_hops.into())
.map_err(|err| IoError::Other(err, IoOperation::SetUnicastHopsV6))
}

#[instrument(skip(self))]
fn connect(&self, addr: SocketAddr) -> IoResult<()> {
fn connect(&mut self, addr: SocketAddr) -> IoResult<()> {
self.set_fail_connect_on_icmp_error(true)?;
syscall!(
WSAEventSelect(
Expand All @@ -479,7 +479,7 @@ impl Socket for SocketImpl {
}

#[instrument(skip(self, buf))]
fn send_to(&self, buf: &[u8], addr: SocketAddr) -> IoResult<()> {
fn send_to(&mut self, buf: &[u8], addr: SocketAddr) -> IoResult<()> {
tracing::debug!(buf = format!("{:02x?}", buf.iter().format(" ")), ?addr);
self.inner
.send_to(buf, &SockAddr::from(addr))
Expand All @@ -488,7 +488,7 @@ impl Socket for SocketImpl {
}

#[instrument(skip(self))]
fn is_readable(&self, timeout: Duration) -> IoResult<bool> {
fn is_readable(&mut self, timeout: Duration) -> IoResult<bool> {
if !self.wait_for_event(timeout)? {
return Ok(false);
};
Expand All @@ -502,7 +502,7 @@ impl Socket for SocketImpl {
}

#[instrument(skip(self))]
fn is_writable(&self) -> IoResult<bool> {
fn is_writable(&mut self) -> IoResult<bool> {
if !self.wait_for_event(Duration::ZERO)? {
return Ok(false);
};
Expand Down Expand Up @@ -538,14 +538,14 @@ impl Socket for SocketImpl {
}

#[instrument(skip(self))]
fn shutdown(&self) -> IoResult<()> {
fn shutdown(&mut self) -> IoResult<()> {
self.inner
.shutdown(std::net::Shutdown::Both)
.map_err(|err| IoError::Other(err, IoOperation::Shutdown))
}

#[instrument(skip(self), ret)]
fn peer_addr(&self) -> IoResult<Option<SocketAddr>> {
fn peer_addr(&mut self) -> IoResult<Option<SocketAddr>> {
Ok(self
.inner
.peer_addr()
Expand All @@ -554,7 +554,7 @@ impl Socket for SocketImpl {
}

#[instrument(skip(self), ret)]
fn take_error(&self) -> IoResult<Option<Error>> {
fn take_error(&mut self) -> IoResult<Option<Error>> {
match self.getsockopt(SOL_SOCKET as _, SO_ERROR as _, 0) {
Ok(0) => Ok(None),
Ok(errno) => Ok(Some(Error::from_raw_os_error(errno))),
Expand All @@ -565,7 +565,7 @@ impl Socket for SocketImpl {

#[instrument(skip(self), ret)]
#[allow(unsafe_code)]
fn icmp_error_info(&self) -> IoResult<IpAddr> {
fn icmp_error_info(&mut self) -> IoResult<IpAddr> {
let icmp_error_info = self
.getsockopt::<ICMP_ERROR_INFO>(
IPPROTO_TCP as _,
Expand All @@ -590,7 +590,7 @@ impl Socket for SocketImpl {

// Interestingly, Socket2 sockets don't seem to call closesocket on drop??
#[instrument(skip(self))]
fn close(&self) -> IoResult<()> {
fn close(&mut self) -> IoResult<()> {
syscall!(closesocket(self.inner.as_raw_socket() as _), |res| res
== SOCKET_ERROR)
.map_err(|err| IoError::Other(err, IoOperation::Close))
Expand Down
Loading
Loading