Skip to content

Commit

Permalink
Replace SocketAddrAny enum with a safe wrapper for SocketAddrStorage
Browse files Browse the repository at this point in the history
To support extensibility over address types, use sockaddr_storage and
Into / TryInto conversions.
  • Loading branch information
kevinmehall committed Feb 2, 2025
1 parent 28a1f52 commit 1b326b4
Show file tree
Hide file tree
Showing 19 changed files with 542 additions and 750 deletions.
1 change: 1 addition & 0 deletions src/backend/libc/net/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ impl fmt::Debug for SocketAddrUnix {

/// `struct sockaddr_storage`.
#[repr(transparent)]
#[derive(Copy, Clone)]
pub struct SocketAddrStorage(c::sockaddr_storage);

impl SocketAddrStorage {
Expand Down
334 changes: 111 additions & 223 deletions src/backend/libc/net/read_sockaddr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ use super::ext::{in6_addr_s6_addr, in_addr_s_addr, sockaddr_in6_sin6_scope_id};
use crate::backend::c;
#[cfg(not(windows))]
use crate::ffi::CStr;
use crate::io;
use crate::io::Errno;
#[cfg(target_os = "linux")]
use crate::net::xdp::{SockaddrXdpFlags, SocketAddrXdp};
use crate::net::{Ipv4Addr, Ipv6Addr, SocketAddrAny, SocketAddrV4, SocketAddrV6};
use crate::net::{AddressFamily, Ipv4Addr, Ipv6Addr, SocketAddrAny, SocketAddrV4, SocketAddrV6};
use core::mem::size_of;

// This must match the header of `sockaddr`.
Expand Down Expand Up @@ -97,260 +97,148 @@ pub(crate) unsafe fn read_sa_family(storage: *const c::sockaddr) -> u16 {
/// socket address.
#[cfg(apple)]
#[inline]
unsafe fn read_sun_path0(storage: *const c::sockaddr_storage) -> u8 {
unsafe fn read_sun_path0(storage: *const c::sockaddr) -> u8 {
// In `read_ss_family` we assert that we know the layout of `sockaddr`.
storage
.cast::<u8>()
.add(super::addr::offsetof_sun_path())
.read()
}

/// Set the `sa_family` field of a socket address to `AF_UNSPEC`, so that we
/// can test for `AF_UNSPEC` to test whether it was stored to.
pub(crate) unsafe fn initialize_family_to_unspec(storage: *mut c::sockaddr) {
(*storage.cast::<sockaddr_header>()).sa_family = c::AF_UNSPEC as _;
}

/// Read a socket address encoded in a platform-specific format.
///
/// # Safety
///
/// `storage` must point to valid socket address storage.
pub(crate) unsafe fn read_sockaddr(
storage: *const c::sockaddr,
len: usize,
) -> io::Result<SocketAddrAny> {
#[cfg(unix)]
let offsetof_sun_path = super::addr::offsetof_sun_path();

if len < size_of::<c::sa_family_t>() {
return Err(io::Errno::INVAL);
}
match read_sa_family(storage).into() {
c::AF_INET => {
if len < size_of::<c::sockaddr_in>() {
return Err(io::Errno::INVAL);
}
let decode = &*storage.cast::<c::sockaddr_in>();
Ok(SocketAddrAny::V4(SocketAddrV4::new(
Ipv4Addr::from(u32::from_be(in_addr_s_addr(decode.sin_addr))),
u16::from_be(decode.sin_port),
)))
}
c::AF_INET6 => {
if len < size_of::<c::sockaddr_in6>() {
return Err(io::Errno::INVAL);
}
let decode = &*storage.cast::<c::sockaddr_in6>();
#[cfg(not(windows))]
let s6_addr = decode.sin6_addr.s6_addr;
#[cfg(windows)]
let s6_addr = decode.sin6_addr.u.Byte;
#[cfg(not(windows))]
let sin6_scope_id = decode.sin6_scope_id;
#[cfg(windows)]
let sin6_scope_id = decode.Anonymous.sin6_scope_id;
Ok(SocketAddrAny::V6(SocketAddrV6::new(
Ipv6Addr::from(s6_addr),
u16::from_be(decode.sin6_port),
u32::from_be(decode.sin6_flowinfo),
sin6_scope_id,
)))
}
#[cfg(unix)]
c::AF_UNIX => {
if len < offsetof_sun_path {
return Err(io::Errno::INVAL);
}
if len == offsetof_sun_path {
SocketAddrUnix::new(&[][..]).map(SocketAddrAny::Unix)
} else {
let decode = &*storage.cast::<c::sockaddr_un>();

// On Linux check for Linux's [abstract namespace].
//
// [abstract namespace]: https://man7.org/linux/man-pages/man7/unix.7.html
#[cfg(linux_kernel)]
if decode.sun_path[0] == 0 {
return SocketAddrUnix::new_abstract_name(core::mem::transmute::<
&[c::c_char],
&[u8],
>(
&decode.sun_path[1..len - offsetof_sun_path],
))
.map(SocketAddrAny::Unix);
}

// Otherwise we expect a NUL-terminated filesystem path.

// Trim off unused bytes from the end of `path_bytes`.
let path_bytes = if cfg!(any(solarish, target_os = "freebsd")) {
// FreeBSD and illumos sometimes set the length to longer
// than the length of the NUL-terminated string. Find the
// NUL and truncate the string accordingly.
&decode.sun_path[..decode
.sun_path
.iter()
.position(|b| *b == 0)
.ok_or(io::Errno::INVAL)?]
} else {
// Otherwise, use the provided length.
let provided_len = len - 1 - offsetof_sun_path;
if decode.sun_path[provided_len] != 0 {
return Err(io::Errno::INVAL);
}
debug_assert_eq!(
CStr::from_ptr(decode.sun_path.as_ptr().cast())
.to_bytes()
.len(),
provided_len
);
&decode.sun_path[..provided_len]
};

SocketAddrUnix::new(core::mem::transmute::<&[c::c_char], &[u8]>(path_bytes))
.map(SocketAddrAny::Unix)
}
}
#[cfg(target_os = "linux")]
c::AF_XDP => {
if len < size_of::<c::sockaddr_xdp>() {
return Err(io::Errno::INVAL);
}
let decode = &*storage.cast::<c::sockaddr_xdp>();
Ok(SocketAddrAny::Xdp(SocketAddrXdp::new(
SockaddrXdpFlags::from_bits_retain(decode.sxdp_flags),
u32::from_be(decode.sxdp_ifindex),
u32::from_be(decode.sxdp_queue_id),
u32::from_be(decode.sxdp_shared_umem_fd),
)))
}
_ => Err(io::Errno::INVAL),
}
}

/// Read an optional socket address returned from the OS.
/// Check if a socket address returned from the OS is considered non-empty.
///
/// # Safety
///
/// `storage` must point to a valid socket address returned from the OS.
pub(crate) unsafe fn maybe_read_sockaddr_os(
storage: *const c::sockaddr_storage,
len: usize,
) -> Option<SocketAddrAny> {
#[inline]
pub(crate) unsafe fn sockaddr_nonempty(storage: *const c::sockaddr, len: usize) -> bool {
if len == 0 {
return None;
return false;
}

assert!(len >= size_of::<c::sa_family_t>());
let family = read_sa_family(storage.cast::<c::sockaddr>()).into();
let family: c::c_int = read_sa_family(storage.cast::<c::sockaddr>()).into();
if family == c::AF_UNSPEC {
return None;
return false;
}

// On macOS, if we get an `AF_UNIX` with an empty path, treat it as
// an absent address.
#[cfg(apple)]
if family == c::AF_UNIX && read_sun_path0(storage) == 0 {
return None;
return false;
}

Some(inner_read_sockaddr_os(family, storage, len))
true
}

/// Read a socket address returned from the OS.
///
/// # Safety
///
/// `storage` must point to a valid socket address returned from the OS.
pub(crate) unsafe fn read_sockaddr_os(
storage: *const c::sockaddr_storage,
len: usize,
) -> SocketAddrAny {
assert!(len >= size_of::<c::sa_family_t>());
let family = read_sa_family(storage.cast::<c::sockaddr>()).into();
inner_read_sockaddr_os(family, storage, len)
/// Set the `sa_family` field of a socket address to `AF_UNSPEC`, so that we
/// can test for `AF_UNSPEC` to test whether it was stored to.
pub(crate) unsafe fn initialize_family_to_unspec(storage: *mut c::sockaddr) {
(*storage.cast::<sockaddr_header>()).sa_family = c::AF_UNSPEC as _;
}

unsafe fn inner_read_sockaddr_os(
family: c::c_int,
storage: *const c::sockaddr_storage,
len: usize,
) -> SocketAddrAny {
#[cfg(unix)]
let offsetof_sun_path = super::addr::offsetof_sun_path();
pub(crate) fn read_sockaddr_v4(addr: &SocketAddrAny) -> Result<SocketAddrV4, Errno> {
if addr.address_family() != AddressFamily::INET {
return Err(Errno::AFNOSUPPORT);
}
if addr.len() < size_of::<c::sockaddr_in>() {
return Err(Errno::INVAL);
}
let decode = unsafe { &*addr.as_ptr().cast::<c::sockaddr_in>() };
Ok(SocketAddrV4::new(
Ipv4Addr::from(u32::from_be(in_addr_s_addr(decode.sin_addr))),
u16::from_be(decode.sin_port),
))
}

assert!(len >= size_of::<c::sa_family_t>());
match family {
c::AF_INET => {
assert!(len >= size_of::<c::sockaddr_in>());
let decode = &*storage.cast::<c::sockaddr_in>();
SocketAddrAny::V4(SocketAddrV4::new(
Ipv4Addr::from(u32::from_be(in_addr_s_addr(decode.sin_addr))),
u16::from_be(decode.sin_port),
))
}
c::AF_INET6 => {
assert!(len >= size_of::<c::sockaddr_in6>());
let decode = &*storage.cast::<c::sockaddr_in6>();
SocketAddrAny::V6(SocketAddrV6::new(
Ipv6Addr::from(in6_addr_s6_addr(decode.sin6_addr)),
u16::from_be(decode.sin6_port),
u32::from_be(decode.sin6_flowinfo),
sockaddr_in6_sin6_scope_id(decode),
))
}
#[cfg(unix)]
c::AF_UNIX => {
assert!(len >= offsetof_sun_path);
if len == offsetof_sun_path {
SocketAddrAny::Unix(SocketAddrUnix::new(&[][..]).unwrap())
} else {
let decode = &*storage.cast::<c::sockaddr_un>();
pub(crate) fn read_sockaddr_v6(addr: &SocketAddrAny) -> Result<SocketAddrV6, Errno> {
if addr.address_family() != AddressFamily::INET6 {
return Err(Errno::AFNOSUPPORT);
}
if addr.len() < size_of::<c::sockaddr_in6>() {
return Err(Errno::INVAL);
}
let decode = unsafe { &*addr.as_ptr().cast::<c::sockaddr_in6>() };
Ok(SocketAddrV6::new(
Ipv6Addr::from(in6_addr_s6_addr(decode.sin6_addr)),
u16::from_be(decode.sin6_port),
u32::from_be(decode.sin6_flowinfo),
sockaddr_in6_sin6_scope_id(decode),
))
}

// On Linux check for Linux's [abstract namespace].
//
// [abstract namespace]: https://man7.org/linux/man-pages/man7/unix.7.html
#[cfg(linux_kernel)]
if decode.sun_path[0] == 0 {
return SocketAddrAny::Unix(
SocketAddrUnix::new_abstract_name(core::mem::transmute::<
&[c::c_char],
&[u8],
>(
&decode.sun_path[1..len - offsetof_sun_path],
))
.unwrap(),
);
}
#[cfg(unix)]
pub(crate) fn read_sockaddr_unix(addr: &SocketAddrAny) -> Result<SocketAddrUnix, Errno> {
if addr.address_family() != AddressFamily::UNIX {
return Err(Errno::AFNOSUPPORT);
}

// Otherwise we expect a NUL-terminated filesystem path.
assert_eq!(decode.sun_path[len - 1 - offsetof_sun_path], 0);
let path_bytes = &decode.sun_path[..len - 1 - offsetof_sun_path];
let offsetof_sun_path = super::addr::offsetof_sun_path();
let len = addr.len();

// FreeBSD and illumos sometimes set the length to longer than
// the length of the NUL-terminated string. Find the NUL and
// truncate the string accordingly.
#[cfg(any(solarish, target_os = "freebsd"))]
let path_bytes = &path_bytes[..path_bytes.iter().position(|b| *b == 0).unwrap()];
if len < offsetof_sun_path {
return Err(Errno::INVAL);
}
if len == offsetof_sun_path {
SocketAddrUnix::new(&[][..])
} else {
let decode = unsafe { &*addr.as_ptr().cast::<c::sockaddr_un>() };

// On Linux check for Linux's [abstract namespace].
//
// [abstract namespace]: https://man7.org/linux/man-pages/man7/unix.7.html
#[cfg(linux_kernel)]
if decode.sun_path[0] == 0 {
let name = &decode.sun_path[1..len - offsetof_sun_path];
let name = unsafe { core::mem::transmute::<&[c::c_char], &[u8]>(name) };
return SocketAddrUnix::new_abstract_name(name);
}

SocketAddrAny::Unix(
SocketAddrUnix::new(core::mem::transmute::<&[c::c_char], &[u8]>(path_bytes))
.unwrap(),
)
// Otherwise we expect a NUL-terminated filesystem path.

// Trim off unused bytes from the end of `path_bytes`.
let path_bytes = if cfg!(any(solarish, target_os = "freebsd")) {
// FreeBSD and illumos sometimes set the length to longer
// than the length of the NUL-terminated string. Find the
// NUL and truncate the string accordingly.
&decode.sun_path[..decode
.sun_path
.iter()
.position(|b| *b == 0)
.ok_or(Errno::INVAL)?]
} else {
// Otherwise, use the provided length.
let provided_len = len - 1 - offsetof_sun_path;
if decode.sun_path[provided_len] != 0 {
return Err(Errno::INVAL);
}
}
#[cfg(target_os = "linux")]
c::AF_XDP => {
assert!(len >= size_of::<c::sockaddr_xdp>());
let decode = &*storage.cast::<c::sockaddr_xdp>();
SocketAddrAny::Xdp(SocketAddrXdp::new(
SockaddrXdpFlags::from_bits_retain(decode.sxdp_flags),
u32::from_be(decode.sxdp_ifindex),
u32::from_be(decode.sxdp_queue_id),
u32::from_be(decode.sxdp_shared_umem_fd),
))
}
other => unimplemented!("{:?}", other),
debug_assert_eq!(
unsafe { CStr::from_ptr(decode.sun_path.as_ptr().cast()) }
.to_bytes()
.len(),
provided_len
);
&decode.sun_path[..provided_len]
};

SocketAddrUnix::new(unsafe { core::mem::transmute::<&[c::c_char], &[u8]>(path_bytes) })
}
}

#[cfg(target_os = "linux")]
pub(crate) fn read_sockaddr_xdp(addr: &SocketAddrAny) -> Result<SocketAddrXdp, Errno> {
if addr.address_family() != AddressFamily::XDP {
return Err(Errno::AFNOSUPPORT);
}
if addr.len() < size_of::<c::sockaddr_xdp>() {
return Err(Errno::INVAL);
}
let decode = unsafe { &*addr.as_ptr().cast::<c::sockaddr_xdp>() };
Ok(SocketAddrXdp::new(
SockaddrXdpFlags::from_bits_retain(decode.sxdp_flags),
u32::from_be(decode.sxdp_ifindex),
u32::from_be(decode.sxdp_queue_id),
u32::from_be(decode.sxdp_shared_umem_fd),
))
}
Loading

0 comments on commit 1b326b4

Please sign in to comment.