Skip to content

Commit

Permalink
Add support for Netlink socket addresses
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinmehall committed Feb 8, 2025
1 parent 880bbd3 commit 53b4738
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/backend/libc/net/read_sockaddr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use crate::backend::c;
#[cfg(not(windows))]
use crate::ffi::CStr;
use crate::io::Errno;
#[cfg(linux_kernel)]
use crate::net::netlink::SocketAddrNetlink;
#[cfg(target_os = "linux")]
use crate::net::xdp::{SockaddrXdpFlags, SocketAddrXdp};
use crate::net::{AddressFamily, Ipv4Addr, Ipv6Addr, SocketAddrAny, SocketAddrV4, SocketAddrV6};
Expand Down Expand Up @@ -239,3 +241,14 @@ pub(crate) fn read_sockaddr_xdp(addr: &SocketAddrAny) -> Result<SocketAddrXdp, E
u32::from_be(decode.sxdp_shared_umem_fd),
))
}

#[cfg(linux_kernel)]
#[inline]
pub(crate) fn read_sockaddr_netlink(addr: &SocketAddrAny) -> Result<SocketAddrNetlink, Errno> {
if addr.address_family() != AddressFamily::NETLINK {
return Err(Errno::AFNOSUPPORT);
}
assert!(addr.len() >= size_of::<c::sockaddr_nl>());
let decode = unsafe { &*addr.as_ptr().cast::<c::sockaddr_nl>() };
Ok(SocketAddrNetlink::new(decode.nl_pid, decode.nl_groups))
}
11 changes: 11 additions & 0 deletions src/backend/linux_raw/net/read_sockaddr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

use crate::backend::c;
use crate::io::Errno;
use crate::net::netlink::SocketAddrNetlink;
#[cfg(target_os = "linux")]
use crate::net::xdp::{SockaddrXdpFlags, SocketAddrXdp};
use crate::net::{AddressFamily, SocketAddrAny};
Expand Down Expand Up @@ -133,3 +134,13 @@ pub(crate) fn read_sockaddr_xdp(addr: &SocketAddrAny) -> Result<SocketAddrXdp, E
u32::from_be(decode.sxdp_shared_umem_fd),
))
}

#[inline]
pub(crate) fn read_sockaddr_netlink(addr: &SocketAddrAny) -> Result<SocketAddrNetlink, Errno> {
if addr.address_family() != AddressFamily::NETLINK {
return Err(Errno::AFNOSUPPORT);
}
assert!(addr.len() >= size_of::<c::sockaddr_nl>());
let decode = unsafe { &*addr.as_ptr().cast::<c::sockaddr_nl>() };
Ok(SocketAddrNetlink::new(decode.nl_pid, decode.nl_groups))
}
6 changes: 6 additions & 0 deletions src/net/socket_addr_any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ impl fmt::Debug for SocketAddrAny {
return addr.fmt(f);
}
}
#[cfg(linux_kernel)]
AddressFamily::NETLINK => {
if let Ok(addr) = crate::net::netlink::SocketAddrNetlink::try_from(self.clone()) {
return addr.fmt(f);
}
}
_ => {}
}

Expand Down
83 changes: 83 additions & 0 deletions src/net/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,12 @@ pub mod netlink {
use {
super::{new_raw_protocol, Protocol},
crate::backend::c,
crate::backend::net::read_sockaddr::read_sockaddr_netlink,
crate::net::{
addr::{call_with_sockaddr, SocketAddrArg, SocketAddrOpaque},
SocketAddrAny,
},
core::mem,
};

/// `NETLINK_UNUSED`
Expand Down Expand Up @@ -1112,6 +1118,83 @@ pub mod netlink {
/// `NETLINK_GET_STRICT_CHK`
#[cfg(linux_kernel)]
pub const GET_STRICT_CHK: Protocol = Protocol(new_raw_protocol(c::NETLINK_GET_STRICT_CHK as _));

/// A Netlink socket address.
///
/// Used to bind to a Netlink socket.
///
/// Not ABI compatible with `struct sockaddr_nl`
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Debug)]
#[cfg(linux_kernel)]
pub struct SocketAddrNetlink {
/// Port ID
pid: u32,

/// Multicast groups mask
groups: u32,
}

#[cfg(linux_kernel)]
impl SocketAddrNetlink {
/// Construct a netlink address
#[inline]
pub fn new(pid: u32, groups: u32) -> Self {
Self { pid, groups }
}

/// Return port id.
#[inline]
pub fn pid(&self) -> u32 {
self.pid
}

/// Set port id.
#[inline]
pub fn set_pid(&mut self, pid: u32) {
self.pid = pid;
}

/// Return multicast groups mask.
#[inline]
pub fn groups(&self) -> u32 {
self.groups
}

/// Set multicast groups mask.
#[inline]
pub fn set_groups(&mut self, groups: u32) {
self.groups = groups;
}
}

#[cfg(linux_kernel)]
#[allow(unsafe_code)]
unsafe impl SocketAddrArg for SocketAddrNetlink {
fn with_sockaddr<R>(&self, f: impl FnOnce(*const SocketAddrOpaque, usize) -> R) -> R {
let mut addr: c::sockaddr_nl = unsafe { mem::zeroed() };
addr.nl_family = c::AF_NETLINK as _;
addr.nl_pid = self.pid;
addr.nl_groups = self.groups;
call_with_sockaddr(&addr, f)
}
}

#[cfg(linux_kernel)]
impl From<SocketAddrNetlink> for SocketAddrAny {
#[inline]
fn from(from: SocketAddrNetlink) -> Self {
from.as_any()
}
}

#[cfg(linux_kernel)]
impl TryFrom<SocketAddrAny> for SocketAddrNetlink {
type Error = crate::io::Errno;

fn try_from(addr: SocketAddrAny) -> Result<Self, Self::Error> {
read_sockaddr_netlink(&addr)
}
}
}

/// `ETH_P_*` constants.
Expand Down
2 changes: 2 additions & 0 deletions tests/net/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ mod addr;
mod cmsg;
mod connect_bind_send;
mod dgram;
#[cfg(linux_kernel)]
mod netlink;
#[cfg(feature = "event")]
mod poll;
#[cfg(unix)]
Expand Down
64 changes: 64 additions & 0 deletions tests/net/netlink.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use rustix::net::netlink::{self, SocketAddrNetlink};
use rustix::net::{
bind, getsockname, recvfrom, sendto, socket_with, AddressFamily, RecvFlags, SendFlags,
SocketAddrAny, SocketFlags, SocketType,
};

#[test]
fn encode_decode() {
let orig = SocketAddrNetlink::new(0x12345678, 0x9abcdef0);
let encoded = SocketAddrAny::from(orig);
let decoded = SocketAddrNetlink::try_from(encoded).unwrap();
assert_eq!(decoded, orig);
}

#[test]
fn test_bind_kobject_uevent() {
let server = socket_with(
AddressFamily::NETLINK,
SocketType::RAW,
SocketFlags::CLOEXEC,
Some(netlink::KOBJECT_UEVENT),
)
.unwrap();

bind(&server, &SocketAddrNetlink::new(0, 1)).unwrap();
}

#[test]
#[cfg_attr(
not(any(target_arch = "x86", target_arch = "x86_64")),
ignore = "qemu used in CI does not support NETLINK_USERSOCK"
)]
fn test_usersock() {
let server = socket_with(
AddressFamily::NETLINK,
SocketType::RAW,
SocketFlags::CLOEXEC,
Some(netlink::USERSOCK),
)
.unwrap();

bind(&server, &SocketAddrNetlink::new(0, 0)).unwrap();
let addr = getsockname(&server).unwrap();
let addr = SocketAddrNetlink::try_from(addr).unwrap();

let client = socket_with(
AddressFamily::NETLINK,
SocketType::RAW,
SocketFlags::CLOEXEC,
Some(netlink::USERSOCK),
)
.unwrap();

let data = b"ABCDEF";

sendto(client, data, SendFlags::empty(), &addr).unwrap();

let mut buffer = [0u8; 4096];
let (len, src) = recvfrom(&server, &mut buffer, RecvFlags::empty()).unwrap();

assert_eq!(&buffer[..len], data);
let src = SocketAddrNetlink::try_from(src.unwrap()).unwrap();
assert_eq!(src.groups(), 0);
}

0 comments on commit 53b4738

Please sign in to comment.