From 2ab2cc4212685c6e801c89a2c00a37968907844d Mon Sep 17 00:00:00 2001 From: Todd Eisenberger Date: Mon, 7 Sep 2015 13:36:03 -0700 Subject: [PATCH 1/5] Implement safe sendmsg/recvmsg abstractions Supports SCM_CREDENTIALS and SCM_RIGHTS control messages --- Cargo.toml | 5 + build.rs | 5 + examples/socket_send.rs | 76 +++++++++++ src/cmsg_manip/cmsg.c | 35 +++++ src/lib.rs | 107 +++++++++++++++ src/sendmsg_impl.rs | 279 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 507 insertions(+) create mode 100644 build.rs create mode 100644 examples/socket_send.rs create mode 100644 src/cmsg_manip/cmsg.c create mode 100644 src/sendmsg_impl.rs diff --git a/Cargo.toml b/Cargo.toml index 87be173..46da985 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,11 +8,16 @@ repository = "https://github.com/sfackler/rust-unix-socket" documentation = "https://sfackler.github.io/rust-unix-socket/doc/v0.4.5/unix_socket" readme = "README.md" keywords = ["posix", "unix", "socket", "domain"] +links = "cmsg_manip" +build = "build.rs" [dependencies] libc = "0.1" debug-builders = "0.1" +[build-dependencies] +gcc = "0.3" + [dev-dependencies] tempdir = "0.3" diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..0605860 --- /dev/null +++ b/build.rs @@ -0,0 +1,5 @@ +extern crate gcc; + +fn main() { + gcc::compile_library("libcmsg_manip.a", &["src/cmsg_manip/cmsg.c"]); +} diff --git a/examples/socket_send.rs b/examples/socket_send.rs new file mode 100644 index 0000000..ba8b739 --- /dev/null +++ b/examples/socket_send.rs @@ -0,0 +1,76 @@ +extern crate libc; +extern crate unix_socket; + +use std::os::unix::io::{AsRawFd, FromRawFd}; +use std::path::Path; + +use unix_socket::{ControlMsg, UCred, UnixDatagram}; + +fn handle_parent(sock: UnixDatagram) { + let (parent2, child2) = UnixDatagram::pair().unwrap(); + + let cmsg = ControlMsg::Rights(vec![child2.as_raw_fd()]); + let cmsg2 = unsafe { ControlMsg::Credentials(UCred{ + pid: libc::getpid(), + uid: libc::getuid(), + gid: libc::getgid(), + }) }; + println!("cmsg {:?}", cmsg2); + let sent_bytes = sock.sendmsg::<&Path>(None, &[&[]], &[cmsg, cmsg2], 0).unwrap(); + assert_eq!(sent_bytes, 0); + drop(child2); + println!("Parent sent child SCM_RIGHTS fd"); + + let mut buf = &mut [0u8; 4096]; + let read = parent2.recv(buf).unwrap(); + assert_eq!(&buf[..read], "Hello, world!".as_bytes()); + println!("Parent received message from child via SCM_RIGHTS fd"); +} + +fn handle_child(sock: UnixDatagram) { + sock.set_passcred(true).unwrap(); + let mut cmsg_buf = &mut [0u8; 4096]; + let result = sock.recvmsg(&[&mut[]], cmsg_buf, 0).unwrap(); + assert_eq!(result.control_msgs.len(), 2); + + let mut new_sock = None; + let mut creds = None; + for cmsg in result.control_msgs { + match cmsg.clone() { + ControlMsg::Rights(fds) => { + assert!(new_sock.is_none()); + assert_eq!(fds.len(), 1); + unsafe { + new_sock = Some(UnixDatagram::from_raw_fd(fds[0])); + } + println!("Child received SCM_RIGHTS fd"); + }, + ControlMsg::Credentials(ucred) => { + assert!(creds.is_none()); + creds = Some(ucred); + println!("Child received SCM_CREDENTIALS"); + }, + _ => unreachable!(), + } + } + + let creds = creds.unwrap(); + unsafe { + assert_eq!(creds.uid, libc::getuid()); + assert_eq!(creds.gid, libc::getgid()); + assert!(creds.pid != 0); + } + let sent = new_sock.unwrap().send("Hello, world!".as_bytes()).unwrap(); + println!("Child sent message to parent via SCM_RIGHTS fd"); + assert_eq!(sent, 13); +} + +fn main() { + let (parent_sock, child_sock) = UnixDatagram::pair().unwrap(); + let pid = unsafe { libc::fork() }; + if pid == 0 { + handle_child(child_sock); + } else { + handle_parent(parent_sock); + } +} diff --git a/src/cmsg_manip/cmsg.c b/src/cmsg_manip/cmsg.c new file mode 100644 index 0000000..b74b65f --- /dev/null +++ b/src/cmsg_manip/cmsg.c @@ -0,0 +1,35 @@ +#define _GNU_SOURCE +#include + +size_t cmsghdr_size = sizeof(struct cmsghdr); +size_t iovec_size = sizeof(struct iovec); +size_t msghdr_size = sizeof(struct msghdr); +size_t ucred_size = sizeof(struct ucred); + +int scm_credentials = SCM_CREDENTIALS; +int scm_rights = SCM_RIGHTS; +int so_passcred = SO_PASSCRED; + +struct cmsghdr * cmsg_firsthdr(struct msghdr *msgh) { + return CMSG_FIRSTHDR(msgh); +} + +struct cmsghdr * cmsg_nxthdr(struct msghdr *msgh, struct cmsghdr *cmsg) { + return CMSG_NXTHDR(msgh, cmsg); +} + +size_t cmsg_align(size_t length) { + return CMSG_ALIGN(length); +} + +size_t cmsg_space(size_t length) { + return CMSG_SPACE(length); +} + +size_t cmsg_len(size_t length) { + return CMSG_LEN(length); +} + +unsigned char * cmsg_data(struct cmsghdr *cmsg) { + return CMSG_DATA(cmsg); +} diff --git a/src/lib.rs b/src/lib.rs index 89c4ad8..b393d4a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,9 @@ use std::fmt; use std::path::Path; use std::mem::size_of; +mod sendmsg_impl; +pub use sendmsg_impl::{ControlMsg, UCred}; + extern "C" { fn socketpair(domain: libc::c_int, ty: libc::c_int, @@ -168,6 +171,18 @@ impl Inner { .map(|_| ()) } } + + fn set_passcred(&self, receive_creds: bool) -> io::Result<()> { + unsafe { + let v: libc::c_int = if receive_creds { 1 } else { 0 }; + cvt(libc::setsockopt(self.0, + libc::SOL_SOCKET, + sendmsg_impl::SO_PASSCRED, + &v as *const libc::c_int as *const libc::c_void, + mem::size_of::() as libc::socklen_t)) + .map(|_| ()) + } + } } unsafe fn sockaddr_un>(path: P) @@ -641,6 +656,18 @@ impl<'a> Iterator for Incoming<'a> { } } +/// The return value from a call to recvmsg +pub struct RecvMsgResult { + /// Number of bytes received + pub data_bytes: usize, + /// Address of the sender + pub sender: SocketAddr, + /// List of all control messages received during this call + pub control_msgs: Vec, + /// Flags returned by recvmsg, see the recv(2) man page for a list + pub flags: libc::c_int, +} + /// A Unix datagram socket. /// /// # Examples @@ -771,6 +798,47 @@ impl UnixDatagram { } } + /// Receives data on the socket. + /// + /// If path is None, the peer address set by the `connect` method will be used. If it has not + /// been set, then this method will return an error. + /// + /// This interface allows sending data from multiple buffers. This acts as if the buffers had been + /// concatenated in the order they were given. + /// + /// ctrl_msgs are special ancillary data that can be sent, such as file descriptors and Unix credentials + /// + /// flags is a pass-through of the flags specified in the sendmsg(2) man page + /// + /// On success, returns the number of bytes written. + pub fn recvmsg(&self, buffers: &[&mut[u8]], cmsg_buffer: &mut [u8], flags: libc::c_int) -> io::Result { + let mut result = Err(io::Error::new(io::ErrorKind::Other, "programming error")); + let addr = try!(SocketAddr::new(|addr, len| { + unsafe { + result = sendmsg_impl::recvmsg( + self.inner.0, + buffers, + cmsg_buffer, + flags, + addr, + len); + } + if let Err(ref e) = result { + -(e.raw_os_error().unwrap() as libc::c_int) + } else { + 0 + } + })); + + let result = try!(result); + Ok(RecvMsgResult { + data_bytes: result.data_bytes, + sender: addr, + control_msgs: result.control_msgs, + flags: result.flags, + }) + } + /// Sends data on the socket to the specified address. /// /// On success, returns the number of bytes written. @@ -804,6 +872,38 @@ impl UnixDatagram { } } + /// Sends data on the socket to the specified address. + /// + /// If path is None, the peer address set by the `connect` method will be used. If it has not + /// been set, then this method will return an error. + /// + /// This interface allows sending data from multiple buffers. This acts as if the buffers had been + /// concatenated in the order they were given. + /// + /// ctrl_msgs are special ancillary data that can be sent, such as file descriptors and Unix credentials + /// + /// flags is a pass-through of the flags specified in the sendmsg(2) man page + /// + /// On success, returns the number of bytes written. + pub fn sendmsg>(&self, path: Option

, buffers: &[&[u8]], ctrl_msgs: &[ControlMsg], flags: libc::c_int) -> io::Result { + unsafe { + let dst = match path { + None => None, + Some(p) => { + let v = try!(sockaddr_un(p)); + Some(v) + }, + }; + + sendmsg_impl::sendmsg( + self.inner.0, + dst, + buffers, + ctrl_msgs, + flags) + } + } + /// Sets the read timeout for the socket. /// /// If the provided value is `None`, then `recv` and `recv_from` calls will @@ -852,6 +952,11 @@ impl UnixDatagram { pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { self.inner.shutdown(how) } + + /// Enable or disable receiving SCM_CREDENTIALS messages + pub fn set_passcred(&self, receive_creds: bool) -> io::Result<()> { + self.inner.set_passcred(receive_creds) + } } impl AsRawFd for UnixDatagram { @@ -1215,4 +1320,6 @@ mod test { thread.join().unwrap(); } + // TODO: Add tests for sending credentials without calling set_passcred and with + // TODO: Add tests for sending fds } diff --git a/src/sendmsg_impl.rs b/src/sendmsg_impl.rs new file mode 100644 index 0000000..0e31334 --- /dev/null +++ b/src/sendmsg_impl.rs @@ -0,0 +1,279 @@ +use std::io; +use std::ptr; +use std::mem; +use std::os::unix::io::RawFd; +use std::slice; + +use libc; + +mod raw { + use libc; + extern "system" { + pub fn sendmsg(socket: libc::c_int, msg: *const libc::c_void, flags: libc::c_int) -> libc::ssize_t; + pub fn recvmsg(socket: libc::c_int, msg: *mut libc::c_void, flags: libc::c_int) -> libc::ssize_t; + } + + #[allow(dead_code)] + #[link(name = "cmsg_manip")] + extern { + pub static cmsghdr_size: libc::size_t; + pub static iovec_size: libc::size_t; + pub static msghdr_size: libc::size_t; + pub static ucred_size: libc::size_t; + + pub static scm_credentials: libc::c_int; + pub static scm_rights: libc::c_int; + + pub static so_passcred: libc::c_int; + + pub fn cmsg_firsthdr(msgh: *const libc::c_void) -> *const libc::c_void; + pub fn cmsg_nxthdr(msgh: *const libc::c_void, cmsg: *const libc::c_void) -> *const libc::c_void; + pub fn cmsg_align(len: libc::size_t) -> libc::size_t; + pub fn cmsg_space(len: libc::size_t) -> libc::size_t; + pub fn cmsg_len(len: libc::size_t) -> libc::size_t; + pub fn cmsg_data(cmsg: *const libc::c_void) -> *const libc::c_void; + } +} + +pub use self::raw::so_passcred as SO_PASSCRED; + +pub use self::raw::scm_credentials as SCM_CREDENTIALS; +pub use self::raw::scm_rights as SCM_RIGHTS; + +pub unsafe fn sendmsg( + socket: libc::c_int, + dst: Option<(libc::sockaddr_un, libc::socklen_t)>, + buffers: &[&[u8]], + ctrl_msgs: &[ControlMsg], + flags: libc::c_int) -> io::Result { + + let mut msg: MsgHdr = mem::zeroed(); + + // Initialize destination field + if let Some((addr, len)) = dst { + msg.msg_name = (&addr as *const libc::sockaddr_un) as *const libc::c_void; + msg.msg_namelen = len; + } + + // Initialize scatter/gather vector + let mut iovecs = Vec::with_capacity(buffers.len()); + for buf in buffers { + iovecs.push(IoVec::new(buf)); + } + msg.msg_iov = iovecs.as_mut_ptr() as *mut libc::c_void; + msg.msg_iovlen = (mem::size_of::() * iovecs.len()) as libc::size_t; + + // Initialize control message struct + + let mut total_space: usize = 0; + for ctrl_msg in ctrl_msgs.iter().cloned() { + let size = match ctrl_msg { + ControlMsg::Rights(fds) => (mem::size_of::() * fds.len()) as libc::size_t, + ControlMsg::Credentials(..) => mem::size_of::() as libc::size_t, + _ => unimplemented!(), + }; + total_space += raw::cmsg_space(size) as usize; + } + + let mut ctrl_buf = &mut Vec::::with_capacity(total_space)[..]; + msg.msg_control = ctrl_buf.as_mut_ptr() as *mut libc::c_void; + msg.msg_controllen = total_space as libc::size_t; + + let msg_addr = (&msg as *const MsgHdr) as *const libc::c_void; + let mut cur_cmsg = raw::cmsg_firsthdr(msg_addr); + for ctrl_msg in ctrl_msgs.iter().cloned() { + if cur_cmsg == ptr::null() { + panic!("programming error: buffer too small"); + } + + let cmsg = cur_cmsg as *mut CmsgHdr; + match ctrl_msg { + // NOTE: Add handlers for new messages here + ControlMsg::Rights(fds) => { + (*cmsg).cmsg_len = raw::cmsg_len((mem::size_of::() * fds.len()) as libc::size_t) as libc::size_t; + (*cmsg).cmsg_level = libc::SOL_SOCKET; + (*cmsg).cmsg_type = SCM_RIGHTS; + let data = raw::cmsg_data(cur_cmsg) as *mut libc::c_int; + ptr::copy_nonoverlapping(fds.as_ptr(), data, fds.len()); + }, + ControlMsg::Credentials(ucred) => { + (*cmsg).cmsg_len = raw::cmsg_len(mem::size_of::() as libc::size_t) as libc::size_t; + (*cmsg).cmsg_level = libc::SOL_SOCKET; + (*cmsg).cmsg_type = SCM_CREDENTIALS; + let data = raw::cmsg_data(cur_cmsg) as *mut UCred; + ptr::write(data, ucred); + } + _ => unreachable!(), + } + + cur_cmsg = raw::cmsg_nxthdr(msg_addr, cur_cmsg); + } + + let res = raw::sendmsg(socket, msg_addr, flags); + if res < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(res as usize) + } +} + +pub struct InternalRecvMsgResult { + pub data_bytes: usize, + pub control_msgs: Vec, + pub flags: libc::c_int, +} + +pub unsafe fn recvmsg( + socket: libc::c_int, + buffers: &[&mut [u8]], + cmsg_buffer: &mut [u8], + flags: libc::c_int, + sender_addr: *mut libc::sockaddr, + sender_len: *mut libc::socklen_t) -> io::Result { + + let mut msg: MsgHdr = mem::zeroed(); + + msg.msg_name = sender_addr as *const libc::c_void; + msg.msg_namelen = *sender_len; + + // Initialize scatter/gather vector + let mut iovecs = Vec::with_capacity(buffers.len()); + for buf in buffers { + iovecs.push(IoVec::new(buf)); + } + msg.msg_iov = iovecs.as_mut_ptr() as *mut libc::c_void; + msg.msg_iovlen = (mem::size_of::() * iovecs.len()) as libc::size_t; + + // Initialize control message struct + msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut libc::c_void; + msg.msg_controllen = cmsg_buffer.len() as libc::size_t; + + let msg_addr = (&mut msg as *mut MsgHdr) as *mut libc::c_void; + let recvmsg_res = raw::recvmsg(socket, msg_addr, flags); + if recvmsg_res < 0 { + return Err(io::Error::last_os_error()); + } + + let mut cmsgs = vec![]; + + let mut cur_cmsg = raw::cmsg_firsthdr(msg_addr); + while cur_cmsg != ptr::null() { + // NOTE: Add handlers for new messages here + let cmsg = cur_cmsg as *mut CmsgHdr; + if (*cmsg).cmsg_level == libc::SOL_SOCKET { + if (*cmsg).cmsg_type == SCM_CREDENTIALS { + let ucred = raw::cmsg_data(cur_cmsg) as *mut UCred; + assert_eq!((ucred as i64) + mem::size_of::() as i64 - cur_cmsg as i64, (*cmsg).cmsg_len as i64); + cmsgs.push(ControlMsg::Credentials((*ucred).clone())); + } else if (*cmsg).cmsg_type == SCM_RIGHTS { + let mut fds = vec![]; + let data = raw::cmsg_data(cur_cmsg) as *mut libc::c_int; + let length = ((*cmsg).cmsg_len as i64 - (data as i64 - cur_cmsg as i64)) as usize; + assert_eq!(length % mem::size_of::(), 0); + let passed_fds = slice::from_raw_parts(data, length / mem::size_of::()); + for &fd in passed_fds { + fds.push(fd); + } + cmsgs.push(ControlMsg::Rights(fds)); + } else { + cmsgs.push(ControlMsg::Unknown{ level: (*cmsg).cmsg_level, typ: (*cmsg).cmsg_type }); + } + } else { + cmsgs.push(ControlMsg::Unknown{ level: (*cmsg).cmsg_level, typ: (*cmsg).cmsg_type }); + } + + cur_cmsg = raw::cmsg_nxthdr(msg_addr, cur_cmsg); + } + + + *sender_len = msg.msg_namelen; + Ok(InternalRecvMsgResult { + data_bytes: recvmsg_res as usize, + control_msgs: cmsgs, + flags: msg.msg_flags, + }) +} + +#[repr(C)] +struct MsgHdr { + pub msg_name: *const libc::c_void, + pub msg_namelen: libc::socklen_t, + pub msg_iov: *mut libc::c_void, + pub msg_iovlen: libc::size_t, + pub msg_control: *mut libc::c_void, + pub msg_controllen: libc::size_t, + pub msg_flags: libc::c_int, +} + +#[test] +fn msghdr_size_correctness() { + assert_eq!(raw::msghdr_size as usize, mem::size_of::()); +} + +#[repr(C)] +struct IoVec { + base: *const libc::c_void, + len: libc::size_t, +} + +impl IoVec { + fn new(buf: &[u8]) -> IoVec { + IoVec { + base: buf.as_ptr() as *const libc::c_void, + len: buf.len() as libc::size_t, + } + } +} + +#[test] +fn iovec_size_correctness() { + assert_eq!(raw::iovec_size as usize, mem::size_of::()); +} + +#[repr(C)] +struct CmsgHdr { + cmsg_len: libc::size_t, + cmsg_level: libc::c_int, + cmsg_type: libc::c_int, +} + +#[test] +fn cmsghdr_size_correctness() { + assert_eq!(raw::cmsghdr_size as usize, mem::size_of::()); +} + +/// Unix credential that can be sent/received over Unix sockets using `ControlMsg::Credential` +/// +/// This is a Rust version of `struct ucred` from sys/socket.h +#[derive(Clone, Debug)] +pub struct UCred{ + /// The sender's process id + pub pid: libc::pid_t, + /// The sender's user id + pub uid: libc::uid_t, + /// The sender's group id + pub gid: libc::gid_t, +} + +#[test] +fn ucred_size_correctness() { + assert_eq!(raw::ucred_size as usize, mem::size_of::()); +} + +/// Ancillary messages that can be sent/received over Unix sockets using `sendmsg`/`recvmsg`. +#[derive(Clone, Debug)] +pub enum ControlMsg { + /// Message used to transfer file descriptors + Rights(Vec), + /// Message used to provide kernel-verified Unix credentials of the sender + Credentials(UCred), + /// Any unimplemented message + Unknown { + /// cmsg_level of the unimplemented message + level: libc::c_int, + /// cmsg_type of the unimplemented message + typ: libc::c_int, + }, + // To add support for more messages, define the message in ControlMsg, + // and near the relevant NOTE comments above. +} From 54eb9fa738506d0b50c7a04bdd55fef0296aba8d Mon Sep 17 00:00:00 2001 From: Todd Eisenberger Date: Mon, 7 Sep 2015 21:46:46 -0700 Subject: [PATCH 2/5] Add tests and expose MSG_* flags --- src/cmsg_manip/cmsg.c | 6 ++ src/lib.rs | 186 +++++++++++++++++++++++++++++++++++++++++- src/sendmsg_impl.rs | 14 +++- 3 files changed, 202 insertions(+), 4 deletions(-) diff --git a/src/cmsg_manip/cmsg.c b/src/cmsg_manip/cmsg.c index b74b65f..ae1b0f4 100644 --- a/src/cmsg_manip/cmsg.c +++ b/src/cmsg_manip/cmsg.c @@ -10,6 +10,12 @@ int scm_credentials = SCM_CREDENTIALS; int scm_rights = SCM_RIGHTS; int so_passcred = SO_PASSCRED; +int msg_eor = MSG_EOR; +int msg_trunc = MSG_TRUNC; +int msg_ctrunc = MSG_CTRUNC; +int msg_oob = MSG_OOB; +int msg_errqueue = MSG_ERRQUEUE; + struct cmsghdr * cmsg_firsthdr(struct msghdr *msgh) { return CMSG_FIRSTHDR(msgh); } diff --git a/src/lib.rs b/src/lib.rs index b393d4a..f496ae1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,6 +24,7 @@ use std::mem::size_of; mod sendmsg_impl; pub use sendmsg_impl::{ControlMsg, UCred}; +pub use sendmsg_impl::{MSG_EOR, MSG_TRUNC, MSG_CTRUNC, MSG_OOB, MSG_ERRQUEUE}; extern "C" { fn socketpair(domain: libc::c_int, @@ -977,14 +978,17 @@ impl std::os::unix::io::FromRawFd for UnixDatagram { #[cfg(test)] mod test { + extern crate libc; extern crate tempdir; use std::thread; use std::io; use std::io::prelude::*; + use std::os::unix::io::{AsRawFd, FromRawFd}; + use std::path::Path; use self::tempdir::TempDir; - use {UnixListener, UnixStream, UnixDatagram, AddressKind}; + use {UnixListener, UnixStream, UnixDatagram, AddressKind, ControlMsg, RecvMsgResult, UCred, MSG_CTRUNC}; macro_rules! or_panic { ($e:expr) => { @@ -1320,6 +1324,182 @@ mod test { thread.join().unwrap(); } - // TODO: Add tests for sending credentials without calling set_passcred and with - // TODO: Add tests for sending fds + + /// Sends "hello" on the data channel and the specified cmsgs on the control channel + fn sendmsg_helper>(s: &UnixDatagram, dst: Option

, cmsgs: &[ControlMsg]) { + let msg = b"he"; + let msg2 = b"llo"; + let sent_bytes = or_panic!(s.sendmsg(dst, &[&msg[..], &msg2[..]], cmsgs, 0)); + assert_eq!(sent_bytes, 5); + } + + /// Expects to receive "hello" on the data channel, and uses the given buf for cmsgs + fn recvmsg_helper(s: &UnixDatagram, cmsg_buf: &mut [u8]) -> RecvMsgResult { + let mut buf = [0; 3]; + let mut buf2 = [0; 3]; + let result = or_panic!(s.recvmsg(&[&mut buf[..], &mut buf2[..]], cmsg_buf, 0)); + assert_eq!(result.data_bytes, 5); + assert_eq!(&buf[..], b"hel"); + assert_eq!(&buf2[..2], b"lo"); + result + } + + #[test] + fn test_sendmsg_to() { + let dir = or_panic!(TempDir::new("unix_socket")); + let path1 = dir.path().join("sock1"); + + let sock1 = or_panic!(UnixDatagram::bind(&path1)); + let sock2 = or_panic!(UnixDatagram::unbound()); + + // Make sure the path-specified form of sendmsg works + sendmsg_helper(&sock2, Some(&path1), &[]); + let mut buf = [0; 6]; + let size = or_panic!(sock1.recv(&mut buf)); + assert_eq!(size, 5); + assert_eq!(&buf[..5], b"hello"); + } + + #[test] + fn test_recvmsg_sender() { + let dir = or_panic!(TempDir::new("unix_socket")); + let path1 = dir.path().join("sock1"); + let path2 = dir.path().join("sock2"); + + let sock1 = or_panic!(UnixDatagram::bind(&path1)); + let sock2 = or_panic!(UnixDatagram::bind(&path2)); + + assert_eq!(or_panic!(sock1.send_to(b"hello", &path2)), 5); + let result = recvmsg_helper(&sock2, &mut []); + match result.sender.address() { + AddressKind::Pathname(p) => assert_eq!(p, path1.as_path()), + _ => unreachable!(), + } + } + + #[cfg(feature = "from_raw_fd")] + #[test] + fn test_ctrunc() { + let (s1, s2) = or_panic!(UnixDatagram::pair()); + let thread = thread::spawn(move || { + let mut cmsg_buf = [0; 1]; + let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); + assert_eq!(result.control_msgs.len(), 0); + // Make sure the control messages were reported truncated + assert_eq!(result.flags & MSG_CTRUNC, MSG_CTRUNC); + }); + + let (_, theirs) = or_panic!(UnixDatagram::pair()); + let cmsg = ControlMsg::Rights(vec![theirs.as_raw_fd()]); + sendmsg_helper::<&Path>(&s2, None, &[cmsg]); + drop(s2); + + thread.join().unwrap(); + } + + + #[test] + fn test_send_credentials_without_passcred() { + // Without passcred, the ucred should be dropped + let (s1, s2) = or_panic!(UnixDatagram::pair()); + let thread = thread::spawn(move || { + let mut cmsg_buf = [0; 4096]; + let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); + drop(cmsg_buf); + assert_eq!(result.control_msgs.len(), 0); + // Make sure the control messages weren't truncated + assert_eq!(result.flags & MSG_CTRUNC, 0); + }); + + let cmsg = unsafe { ControlMsg::Credentials(UCred{ + pid: libc::getpid(), + uid: libc::getuid(), + gid: libc::getgid(), + }) }; + sendmsg_helper::<&Path>(&s2, None, &[cmsg]); + drop(s2); + + thread.join().unwrap(); + } + + #[test] + fn test_send_credentials_with_passcred() { + // With passcred, the ucred should be sent. + // Note: SO_PASSCRED will cause a credential to always be sent. Unfortunately, + // without additional capabilities, we cannot properly test the SCM_CREDENTIALS + // message. We pass one through to sendmsg below just to exercise that codepath, + // but it will not demonstrate that we are correctly sending it (other than not + // triggering an EINVAL). + + let (s1, s2) = or_panic!(UnixDatagram::pair()); + let thread = thread::spawn(move || { + or_panic!(s1.set_passcred(true)); + let mut cmsg_buf = [0; 4096]; + let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); + drop(cmsg_buf); + assert_eq!(result.control_msgs.len(), 1); + // Make sure the control messages weren't truncated + assert_eq!(result.flags & MSG_CTRUNC, 0); + + for cmsg in result.control_msgs { + match cmsg { + ControlMsg::Credentials(ucred) => { + unsafe { + assert_eq!(ucred.pid, libc::getpid()); + assert_eq!(ucred.uid, libc::getuid()); + assert_eq!(ucred.gid, libc::getgid()); + } + }, + _ => panic!("Unexpected control message"), + } + } + }); + + let cmsg = unsafe { ControlMsg::Credentials(UCred{ + pid: libc::getpid(), + uid: libc::getuid(), + gid: libc::getgid(), + }) }; + sendmsg_helper::<&Path>(&s2, None, &[cmsg]); + drop(s2); + + thread.join().unwrap(); + } + + #[cfg(feature = "from_raw_fd")] + #[test] + fn test_send_fds() { + let (s1, s2) = or_panic!(UnixDatagram::pair()); + let thread = thread::spawn(move || { + let mut cmsg_buf = [0; 4096]; + let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); + drop(cmsg_buf); + assert_eq!(result.control_msgs.len(), 1); + // Make sure the control messages weren't truncated + assert_eq!(result.flags & MSG_CTRUNC, 0); + + for cmsg in result.control_msgs { + match cmsg { + ControlMsg::Rights(fds) => { + assert_eq!(fds.len(), 1); + let new_s = unsafe { UnixDatagram::from_raw_fd(fds[0]) }; + let mut buf = [0; 4]; + assert_eq!(or_panic!(new_s.recv(&mut buf[..])), 4); + assert_eq!(&buf[..], b"Test"); + }, + _ => panic!("Unexpected control message"), + } + } + }); + + let (my, theirs) = or_panic!(UnixDatagram::pair()); + let cmsg = ControlMsg::Rights(vec![theirs.as_raw_fd()]); + sendmsg_helper::<&Path>(&s2, None, &[cmsg]); + drop(s2); + drop(theirs); + + assert_eq!(or_panic!(my.send(b"Test")), 4); + + thread.join().unwrap(); + } } diff --git a/src/sendmsg_impl.rs b/src/sendmsg_impl.rs index 0e31334..db3e183 100644 --- a/src/sendmsg_impl.rs +++ b/src/sendmsg_impl.rs @@ -26,6 +26,12 @@ mod raw { pub static so_passcred: libc::c_int; + pub static msg_eor: libc::c_int; + pub static msg_trunc: libc::c_int; + pub static msg_ctrunc: libc::c_int; + pub static msg_oob: libc::c_int; + pub static msg_errqueue: libc::c_int; + pub fn cmsg_firsthdr(msgh: *const libc::c_void) -> *const libc::c_void; pub fn cmsg_nxthdr(msgh: *const libc::c_void, cmsg: *const libc::c_void) -> *const libc::c_void; pub fn cmsg_align(len: libc::size_t) -> libc::size_t; @@ -40,6 +46,12 @@ pub use self::raw::so_passcred as SO_PASSCRED; pub use self::raw::scm_credentials as SCM_CREDENTIALS; pub use self::raw::scm_rights as SCM_RIGHTS; +pub use self::raw::msg_eor as MSG_EOR; +pub use self::raw::msg_trunc as MSG_TRUNC; +pub use self::raw::msg_ctrunc as MSG_CTRUNC; +pub use self::raw::msg_oob as MSG_OOB; +pub use self::raw::msg_errqueue as MSG_ERRQUEUE; + pub unsafe fn sendmsg( socket: libc::c_int, dst: Option<(libc::sockaddr_un, libc::socklen_t)>, @@ -61,7 +73,7 @@ pub unsafe fn sendmsg( iovecs.push(IoVec::new(buf)); } msg.msg_iov = iovecs.as_mut_ptr() as *mut libc::c_void; - msg.msg_iovlen = (mem::size_of::() * iovecs.len()) as libc::size_t; + msg.msg_iovlen = iovecs.len() as libc::size_t; // Initialize control message struct From c5d5715caf420b8728247cc52a9cf1c2dcb10af6 Mon Sep 17 00:00:00 2001 From: Todd Eisenberger Date: Mon, 21 Sep 2015 16:24:33 -0700 Subject: [PATCH 3/5] Redo flags and address requested changes --- Cargo.toml | 4 +- examples/socket_send.rs | 14 ++++-- src/cmsg_manip/cmsg.c | 7 ++- src/lib.rs | 69 ++++++++++++++++----------- src/sendmsg_impl.rs | 100 ++++++++++++++++++++++++++++++++++------ 5 files changed, 147 insertions(+), 47 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 46da985..9ab6acd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,6 @@ repository = "https://github.com/sfackler/rust-unix-socket" documentation = "https://sfackler.github.io/rust-unix-socket/doc/v0.4.5/unix_socket" readme = "README.md" keywords = ["posix", "unix", "socket", "domain"] -links = "cmsg_manip" build = "build.rs" [dependencies] @@ -22,7 +21,8 @@ gcc = "0.3" tempdir = "0.3" [features] -default = ["from_raw_fd"] +default = ["from_raw_fd", "sendmsg"] from_raw_fd = [] socket_timeout = [] +sendmsg = [] diff --git a/examples/socket_send.rs b/examples/socket_send.rs index ba8b739..57eb990 100644 --- a/examples/socket_send.rs +++ b/examples/socket_send.rs @@ -4,8 +4,10 @@ extern crate unix_socket; use std::os::unix::io::{AsRawFd, FromRawFd}; use std::path::Path; -use unix_socket::{ControlMsg, UCred, UnixDatagram}; +#[cfg(feature = "sendmsg")] +use unix_socket::{ControlMsg, UCred, UnixDatagram, RecvMsgFlags, SendMsgFlags}; +#[cfg(feature = "sendmsg")] fn handle_parent(sock: UnixDatagram) { let (parent2, child2) = UnixDatagram::pair().unwrap(); @@ -16,7 +18,7 @@ fn handle_parent(sock: UnixDatagram) { gid: libc::getgid(), }) }; println!("cmsg {:?}", cmsg2); - let sent_bytes = sock.sendmsg::<&Path>(None, &[&[]], &[cmsg, cmsg2], 0).unwrap(); + let sent_bytes = sock.sendmsg::<&Path>(None, &[&[]], &[cmsg, cmsg2], SendMsgFlags::default()).unwrap(); assert_eq!(sent_bytes, 0); drop(child2); println!("Parent sent child SCM_RIGHTS fd"); @@ -27,10 +29,11 @@ fn handle_parent(sock: UnixDatagram) { println!("Parent received message from child via SCM_RIGHTS fd"); } +#[cfg(feature = "sendmsg")] fn handle_child(sock: UnixDatagram) { sock.set_passcred(true).unwrap(); let mut cmsg_buf = &mut [0u8; 4096]; - let result = sock.recvmsg(&[&mut[]], cmsg_buf, 0).unwrap(); + let result = sock.recvmsg(&[&mut[]], cmsg_buf, RecvMsgFlags::default()).unwrap(); assert_eq!(result.control_msgs.len(), 2); let mut new_sock = None; @@ -65,6 +68,7 @@ fn handle_child(sock: UnixDatagram) { assert_eq!(sent, 13); } +#[cfg(feature = "sendmsg")] fn main() { let (parent_sock, child_sock) = UnixDatagram::pair().unwrap(); let pid = unsafe { libc::fork() }; @@ -74,3 +78,7 @@ fn main() { handle_parent(parent_sock); } } + +#[cfg(not(feature = "sendmsg"))] +fn main() { +} diff --git a/src/cmsg_manip/cmsg.c b/src/cmsg_manip/cmsg.c index ae1b0f4..3a0221c 100644 --- a/src/cmsg_manip/cmsg.c +++ b/src/cmsg_manip/cmsg.c @@ -1,3 +1,4 @@ +// Need to use GNU_SOURCE for ucred struct #define _GNU_SOURCE #include @@ -13,8 +14,12 @@ int so_passcred = SO_PASSCRED; int msg_eor = MSG_EOR; int msg_trunc = MSG_TRUNC; int msg_ctrunc = MSG_CTRUNC; -int msg_oob = MSG_OOB; int msg_errqueue = MSG_ERRQUEUE; +int msg_dontwait = MSG_DONTWAIT; +int msg_cmsg_cloexec = MSG_CMSG_CLOEXEC; +int msg_nosignal = MSG_NOSIGNAL; +int msg_peek = MSG_PEEK; +int msg_waitall = MSG_WAITALL; struct cmsghdr * cmsg_firsthdr(struct msghdr *msgh) { return CMSG_FIRSTHDR(msgh); diff --git a/src/lib.rs b/src/lib.rs index f496ae1..db22056 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,9 +22,10 @@ use std::fmt; use std::path::Path; use std::mem::size_of; +#[cfg(feature = "sendmsg")] mod sendmsg_impl; -pub use sendmsg_impl::{ControlMsg, UCred}; -pub use sendmsg_impl::{MSG_EOR, MSG_TRUNC, MSG_CTRUNC, MSG_OOB, MSG_ERRQUEUE}; +#[cfg(feature = "sendmsg")] +pub use sendmsg_impl::{ControlMsg, UCred, SendMsgFlags, RecvMsgFlags, RecvMsgResultFlags}; extern "C" { fn socketpair(domain: libc::c_int, @@ -173,9 +174,10 @@ impl Inner { } } + #[cfg(feature = "sendmsg")] fn set_passcred(&self, receive_creds: bool) -> io::Result<()> { unsafe { - let v: libc::c_int = if receive_creds { 1 } else { 0 }; + let v: libc::c_int = receive_creds as libc::c_int; cvt(libc::setsockopt(self.0, libc::SOL_SOCKET, sendmsg_impl::SO_PASSCRED, @@ -658,6 +660,7 @@ impl<'a> Iterator for Incoming<'a> { } /// The return value from a call to recvmsg +#[cfg(feature = "sendmsg")] pub struct RecvMsgResult { /// Number of bytes received pub data_bytes: usize, @@ -665,8 +668,8 @@ pub struct RecvMsgResult { pub sender: SocketAddr, /// List of all control messages received during this call pub control_msgs: Vec, - /// Flags returned by recvmsg, see the recv(2) man page for a list - pub flags: libc::c_int, + /// Flags returned by recvmsg, see the struct definition for more details + pub flags: RecvMsgResultFlags, } /// A Unix datagram socket. @@ -804,15 +807,14 @@ impl UnixDatagram { /// If path is None, the peer address set by the `connect` method will be used. If it has not /// been set, then this method will return an error. /// - /// This interface allows sending data from multiple buffers. This acts as if the buffers had been + /// This interface allows receiving data into multiple buffers. This acts as if the buffers had been /// concatenated in the order they were given. /// - /// ctrl_msgs are special ancillary data that can be sent, such as file descriptors and Unix credentials - /// - /// flags is a pass-through of the flags specified in the sendmsg(2) man page + /// cmsg_buffer is space to use for storing control messages. /// /// On success, returns the number of bytes written. - pub fn recvmsg(&self, buffers: &[&mut[u8]], cmsg_buffer: &mut [u8], flags: libc::c_int) -> io::Result { + #[cfg(feature = "sendmsg")] + pub fn recvmsg(&self, buffers: &[&mut[u8]], cmsg_buffer: &mut [u8], flags: RecvMsgFlags) -> io::Result { let mut result = Err(io::Error::new(io::ErrorKind::Other, "programming error")); let addr = try!(SocketAddr::new(|addr, len| { unsafe { @@ -883,10 +885,9 @@ impl UnixDatagram { /// /// ctrl_msgs are special ancillary data that can be sent, such as file descriptors and Unix credentials /// - /// flags is a pass-through of the flags specified in the sendmsg(2) man page - /// /// On success, returns the number of bytes written. - pub fn sendmsg>(&self, path: Option

, buffers: &[&[u8]], ctrl_msgs: &[ControlMsg], flags: libc::c_int) -> io::Result { + #[cfg(feature = "sendmsg")] + pub fn sendmsg>(&self, path: Option

, buffers: &[&[u8]], ctrl_msgs: &[ControlMsg], flags: SendMsgFlags) -> io::Result { unsafe { let dst = match path { None => None, @@ -955,6 +956,7 @@ impl UnixDatagram { } /// Enable or disable receiving SCM_CREDENTIALS messages + #[cfg(feature = "sendmsg")] pub fn set_passcred(&self, receive_creds: bool) -> io::Result<()> { self.inner.set_passcred(receive_creds) } @@ -988,7 +990,7 @@ mod test { use std::path::Path; use self::tempdir::TempDir; - use {UnixListener, UnixStream, UnixDatagram, AddressKind, ControlMsg, RecvMsgResult, UCred, MSG_CTRUNC}; + use {UnixListener, UnixStream, UnixDatagram, AddressKind}; macro_rules! or_panic { ($e:expr) => { @@ -1326,24 +1328,29 @@ mod test { } /// Sends "hello" on the data channel and the specified cmsgs on the control channel - fn sendmsg_helper>(s: &UnixDatagram, dst: Option

, cmsgs: &[ControlMsg]) { + #[cfg(feature = "sendmsg")] + fn sendmsg_helper>(s: &UnixDatagram, dst: Option

, cmsgs: &[super::ControlMsg]) { + use SendMsgFlags; let msg = b"he"; let msg2 = b"llo"; - let sent_bytes = or_panic!(s.sendmsg(dst, &[&msg[..], &msg2[..]], cmsgs, 0)); + let sent_bytes = or_panic!(s.sendmsg(dst, &[&msg[..], &msg2[..]], cmsgs, SendMsgFlags::default())); assert_eq!(sent_bytes, 5); } /// Expects to receive "hello" on the data channel, and uses the given buf for cmsgs - fn recvmsg_helper(s: &UnixDatagram, cmsg_buf: &mut [u8]) -> RecvMsgResult { + #[cfg(feature = "sendmsg")] + fn recvmsg_helper(s: &UnixDatagram, cmsg_buf: &mut [u8]) -> super::RecvMsgResult { + use RecvMsgFlags; let mut buf = [0; 3]; let mut buf2 = [0; 3]; - let result = or_panic!(s.recvmsg(&[&mut buf[..], &mut buf2[..]], cmsg_buf, 0)); + let result = or_panic!(s.recvmsg(&[&mut buf[..], &mut buf2[..]], cmsg_buf, RecvMsgFlags::default())); assert_eq!(result.data_bytes, 5); assert_eq!(&buf[..], b"hel"); assert_eq!(&buf2[..2], b"lo"); result } + #[cfg(feature = "sendmsg")] #[test] fn test_sendmsg_to() { let dir = or_panic!(TempDir::new("unix_socket")); @@ -1360,6 +1367,7 @@ mod test { assert_eq!(&buf[..5], b"hello"); } + #[cfg(feature = "sendmsg")] #[test] fn test_recvmsg_sender() { let dir = or_panic!(TempDir::new("unix_socket")); @@ -1377,16 +1385,17 @@ mod test { } } - #[cfg(feature = "from_raw_fd")] + #[cfg(all(feature = "from_raw_fd", feature = "sendmsg"))] #[test] fn test_ctrunc() { + use ControlMsg; + let (s1, s2) = or_panic!(UnixDatagram::pair()); let thread = thread::spawn(move || { let mut cmsg_buf = [0; 1]; let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); assert_eq!(result.control_msgs.len(), 0); - // Make sure the control messages were reported truncated - assert_eq!(result.flags & MSG_CTRUNC, MSG_CTRUNC); + assert!(result.flags.control_truncated); }); let (_, theirs) = or_panic!(UnixDatagram::pair()); @@ -1398,8 +1407,11 @@ mod test { } + #[cfg(feature = "sendmsg")] #[test] fn test_send_credentials_without_passcred() { + use {ControlMsg, UCred}; + // Without passcred, the ucred should be dropped let (s1, s2) = or_panic!(UnixDatagram::pair()); let thread = thread::spawn(move || { @@ -1407,8 +1419,7 @@ mod test { let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); drop(cmsg_buf); assert_eq!(result.control_msgs.len(), 0); - // Make sure the control messages weren't truncated - assert_eq!(result.flags & MSG_CTRUNC, 0); + assert!(!result.flags.control_truncated); }); let cmsg = unsafe { ControlMsg::Credentials(UCred{ @@ -1422,6 +1433,7 @@ mod test { thread.join().unwrap(); } + #[cfg(feature = "sendmsg")] #[test] fn test_send_credentials_with_passcred() { // With passcred, the ucred should be sent. @@ -1431,6 +1443,8 @@ mod test { // but it will not demonstrate that we are correctly sending it (other than not // triggering an EINVAL). + use {ControlMsg, UCred}; + let (s1, s2) = or_panic!(UnixDatagram::pair()); let thread = thread::spawn(move || { or_panic!(s1.set_passcred(true)); @@ -1438,8 +1452,7 @@ mod test { let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); drop(cmsg_buf); assert_eq!(result.control_msgs.len(), 1); - // Make sure the control messages weren't truncated - assert_eq!(result.flags & MSG_CTRUNC, 0); + assert!(!result.flags.control_truncated); for cmsg in result.control_msgs { match cmsg { @@ -1466,17 +1479,17 @@ mod test { thread.join().unwrap(); } - #[cfg(feature = "from_raw_fd")] + #[cfg(all(feature = "from_raw_fd", feature = "sendmsg"))] #[test] fn test_send_fds() { + use ControlMsg; let (s1, s2) = or_panic!(UnixDatagram::pair()); let thread = thread::spawn(move || { let mut cmsg_buf = [0; 4096]; let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); drop(cmsg_buf); assert_eq!(result.control_msgs.len(), 1); - // Make sure the control messages weren't truncated - assert_eq!(result.flags & MSG_CTRUNC, 0); + assert!(!result.flags.control_truncated); for cmsg in result.control_msgs { match cmsg { diff --git a/src/sendmsg_impl.rs b/src/sendmsg_impl.rs index db3e183..ba15208 100644 --- a/src/sendmsg_impl.rs +++ b/src/sendmsg_impl.rs @@ -14,7 +14,6 @@ mod raw { } #[allow(dead_code)] - #[link(name = "cmsg_manip")] extern { pub static cmsghdr_size: libc::size_t; pub static iovec_size: libc::size_t; @@ -29,8 +28,12 @@ mod raw { pub static msg_eor: libc::c_int; pub static msg_trunc: libc::c_int; pub static msg_ctrunc: libc::c_int; - pub static msg_oob: libc::c_int; pub static msg_errqueue: libc::c_int; + pub static msg_dontwait: libc::c_int; + pub static msg_cmsg_cloexec: libc::c_int; + pub static msg_nosignal: libc::c_int; + pub static msg_peek: libc::c_int; + pub static msg_waitall: libc::c_int; pub fn cmsg_firsthdr(msgh: *const libc::c_void) -> *const libc::c_void; pub fn cmsg_nxthdr(msgh: *const libc::c_void, cmsg: *const libc::c_void) -> *const libc::c_void; @@ -46,18 +49,22 @@ pub use self::raw::so_passcred as SO_PASSCRED; pub use self::raw::scm_credentials as SCM_CREDENTIALS; pub use self::raw::scm_rights as SCM_RIGHTS; -pub use self::raw::msg_eor as MSG_EOR; -pub use self::raw::msg_trunc as MSG_TRUNC; -pub use self::raw::msg_ctrunc as MSG_CTRUNC; -pub use self::raw::msg_oob as MSG_OOB; -pub use self::raw::msg_errqueue as MSG_ERRQUEUE; +use self::raw::msg_eor as MSG_EOR; +use self::raw::msg_trunc as MSG_TRUNC; +use self::raw::msg_ctrunc as MSG_CTRUNC; +use self::raw::msg_errqueue as MSG_ERRQUEUE; +use self::raw::msg_dontwait as MSG_DONTWAIT; +use self::raw::msg_cmsg_cloexec as MSG_CMSG_CLOEXEC; +use self::raw::msg_nosignal as MSG_NOSIGNAL; +use self::raw::msg_peek as MSG_PEEK; +use self::raw::msg_waitall as MSG_WAITALL; pub unsafe fn sendmsg( socket: libc::c_int, dst: Option<(libc::sockaddr_un, libc::socklen_t)>, buffers: &[&[u8]], ctrl_msgs: &[ControlMsg], - flags: libc::c_int) -> io::Result { + flags: SendMsgFlags) -> io::Result { let mut msg: MsgHdr = mem::zeroed(); @@ -121,7 +128,7 @@ pub unsafe fn sendmsg( cur_cmsg = raw::cmsg_nxthdr(msg_addr, cur_cmsg); } - let res = raw::sendmsg(socket, msg_addr, flags); + let res = raw::sendmsg(socket, msg_addr, flags.as_cint()); if res < 0 { Err(io::Error::last_os_error()) } else { @@ -132,14 +139,14 @@ pub unsafe fn sendmsg( pub struct InternalRecvMsgResult { pub data_bytes: usize, pub control_msgs: Vec, - pub flags: libc::c_int, + pub flags: RecvMsgResultFlags, } pub unsafe fn recvmsg( socket: libc::c_int, buffers: &[&mut [u8]], cmsg_buffer: &mut [u8], - flags: libc::c_int, + flags: RecvMsgFlags, sender_addr: *mut libc::sockaddr, sender_len: *mut libc::socklen_t) -> io::Result { @@ -161,7 +168,7 @@ pub unsafe fn recvmsg( msg.msg_controllen = cmsg_buffer.len() as libc::size_t; let msg_addr = (&mut msg as *mut MsgHdr) as *mut libc::c_void; - let recvmsg_res = raw::recvmsg(socket, msg_addr, flags); + let recvmsg_res = raw::recvmsg(socket, msg_addr, flags.as_cint()); if recvmsg_res < 0 { return Err(io::Error::last_os_error()); } @@ -202,10 +209,77 @@ pub unsafe fn recvmsg( Ok(InternalRecvMsgResult { data_bytes: recvmsg_res as usize, control_msgs: cmsgs, - flags: msg.msg_flags, + flags: RecvMsgResultFlags::from_cint(msg.msg_flags), }) } +#[derive(Clone, Copy, Debug, Default)] +/// Flags given to sendmsg. See sendmsg(2) for more details. +pub struct SendMsgFlags { + /// Do not block (MSG_DONTWAIT) + pub dont_wait: bool, + /// Mark this packet as the end of a record (used for SOCK_SEQPACKET connections) (MSG_EOR) + pub end_of_record: bool, + /// Do not receive SIGPIPE if the other end breaks the connection (MSG_NOSIGNAL) + pub no_signal: bool, +} + +#[derive(Clone, Copy, Debug, Default)] +/// Flags given to recvmsg. See recvmsg(2) for more details. +pub struct RecvMsgFlags { + /// Sets the close-on-exec flag for any file descriptors received via SCM_RIGHTS (MSG_CMSG_CLOEXEC) + pub cmsg_cloexec: bool, + /// Do not block (MSG_DONTWAIT) + pub dont_wait: bool, + /// Do not remove the retrieved data from the receive queue (the next call will return the same data) (MSG_PEEK) + pub peek: bool, + /// Wait for the buffers to be filled (may still be interrupted by a signal or the socket hanging up) (MSG_WAITALL) + pub wait_all: bool, + // TODO: Add support for MSG_ERRQUEUE (need to support more cmsgs) +} + +#[derive(Clone, Copy, Debug, Default)] +/// Flags returned by recvmsg. See recvmsg(2) for more details. +pub struct RecvMsgResultFlags { + /// The returned data marks the end of a record (used for SOCK_SEQPACKET) (MSG_EOR) + pub end_of_record: bool, + /// Some data was discarded due to the provided buffers being too short (MSG_TRUNC) + pub truncated: bool, + /// Some control data was discarded (MSG_CTRUNC) + pub control_truncated: bool, +} + +impl SendMsgFlags { + fn as_cint(&self) -> libc::c_int { + let mut result = 0; + if self.dont_wait { result |= MSG_DONTWAIT; } + if self.end_of_record { result |= MSG_EOR; } + if self.no_signal { result |= MSG_NOSIGNAL; } + result + } +} + +impl RecvMsgFlags { + fn as_cint(&self) -> libc::c_int { + let mut result = 0; + if self.cmsg_cloexec { result |= MSG_CMSG_CLOEXEC; } + if self.dont_wait { result |= MSG_DONTWAIT; } + if self.peek { result |= MSG_PEEK; } + if self.wait_all { result |= MSG_WAITALL; } + result + } +} + +impl RecvMsgResultFlags { + fn from_cint(flags: libc::c_int) -> RecvMsgResultFlags { + RecvMsgResultFlags { + end_of_record: (flags & MSG_EOR) != 0, + truncated: (flags & MSG_TRUNC) != 0, + control_truncated: (flags & MSG_CTRUNC) != 0, + } + } +} + #[repr(C)] struct MsgHdr { pub msg_name: *const libc::c_void, From eaebccd526ad425cc023ecd682121aeb8978007a Mon Sep 17 00:00:00 2001 From: Todd Eisenberger Date: Fri, 25 Sep 2015 00:46:14 -0700 Subject: [PATCH 4/5] Address comments --- examples/socket_send.rs | 5 +- src/lib.rs | 12 ++--- src/sendmsg_impl.rs | 112 +++++++++++++++++++++++++++++++--------- 3 files changed, 98 insertions(+), 31 deletions(-) diff --git a/examples/socket_send.rs b/examples/socket_send.rs index 57eb990..9cbc433 100644 --- a/examples/socket_send.rs +++ b/examples/socket_send.rs @@ -18,7 +18,7 @@ fn handle_parent(sock: UnixDatagram) { gid: libc::getgid(), }) }; println!("cmsg {:?}", cmsg2); - let sent_bytes = sock.sendmsg::<&Path>(None, &[&[]], &[cmsg, cmsg2], SendMsgFlags::default()).unwrap(); + let sent_bytes = sock.sendmsg::<&Path>(None, &[&[]], &[cmsg, cmsg2], SendMsgFlags::new()).unwrap(); assert_eq!(sent_bytes, 0); drop(child2); println!("Parent sent child SCM_RIGHTS fd"); @@ -33,7 +33,8 @@ fn handle_parent(sock: UnixDatagram) { fn handle_child(sock: UnixDatagram) { sock.set_passcred(true).unwrap(); let mut cmsg_buf = &mut [0u8; 4096]; - let result = sock.recvmsg(&[&mut[]], cmsg_buf, RecvMsgFlags::default()).unwrap(); + let flags = RecvMsgFlags::new().cmsg_cloexec(true); + let result = sock.recvmsg(&[&mut[]], cmsg_buf, flags).unwrap(); assert_eq!(result.control_msgs.len(), 2); let mut new_sock = None; diff --git a/src/lib.rs b/src/lib.rs index db22056..927697b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1333,7 +1333,7 @@ mod test { use SendMsgFlags; let msg = b"he"; let msg2 = b"llo"; - let sent_bytes = or_panic!(s.sendmsg(dst, &[&msg[..], &msg2[..]], cmsgs, SendMsgFlags::default())); + let sent_bytes = or_panic!(s.sendmsg(dst, &[&msg[..], &msg2[..]], cmsgs, SendMsgFlags::new())); assert_eq!(sent_bytes, 5); } @@ -1343,7 +1343,7 @@ mod test { use RecvMsgFlags; let mut buf = [0; 3]; let mut buf2 = [0; 3]; - let result = or_panic!(s.recvmsg(&[&mut buf[..], &mut buf2[..]], cmsg_buf, RecvMsgFlags::default())); + let result = or_panic!(s.recvmsg(&[&mut buf[..], &mut buf2[..]], cmsg_buf, RecvMsgFlags::new())); assert_eq!(result.data_bytes, 5); assert_eq!(&buf[..], b"hel"); assert_eq!(&buf2[..2], b"lo"); @@ -1395,7 +1395,7 @@ mod test { let mut cmsg_buf = [0; 1]; let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); assert_eq!(result.control_msgs.len(), 0); - assert!(result.flags.control_truncated); + assert!(result.flags.control_truncated()); }); let (_, theirs) = or_panic!(UnixDatagram::pair()); @@ -1419,7 +1419,7 @@ mod test { let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); drop(cmsg_buf); assert_eq!(result.control_msgs.len(), 0); - assert!(!result.flags.control_truncated); + assert!(!result.flags.control_truncated()); }); let cmsg = unsafe { ControlMsg::Credentials(UCred{ @@ -1452,7 +1452,7 @@ mod test { let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); drop(cmsg_buf); assert_eq!(result.control_msgs.len(), 1); - assert!(!result.flags.control_truncated); + assert!(!result.flags.control_truncated()); for cmsg in result.control_msgs { match cmsg { @@ -1489,7 +1489,7 @@ mod test { let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); drop(cmsg_buf); assert_eq!(result.control_msgs.len(), 1); - assert!(!result.flags.control_truncated); + assert!(!result.flags.control_truncated()); for cmsg in result.control_msgs { match cmsg { diff --git a/src/sendmsg_impl.rs b/src/sendmsg_impl.rs index ba15208..df2840f 100644 --- a/src/sendmsg_impl.rs +++ b/src/sendmsg_impl.rs @@ -213,43 +213,60 @@ pub unsafe fn recvmsg( }) } -#[derive(Clone, Copy, Debug, Default)] +#[derive(Clone, Copy, Debug)] /// Flags given to sendmsg. See sendmsg(2) for more details. pub struct SendMsgFlags { - /// Do not block (MSG_DONTWAIT) - pub dont_wait: bool, - /// Mark this packet as the end of a record (used for SOCK_SEQPACKET connections) (MSG_EOR) - pub end_of_record: bool, - /// Do not receive SIGPIPE if the other end breaks the connection (MSG_NOSIGNAL) - pub no_signal: bool, + dont_wait: bool, + end_of_record: bool, + no_signal: bool, } -#[derive(Clone, Copy, Debug, Default)] +#[derive(Clone, Copy, Debug)] /// Flags given to recvmsg. See recvmsg(2) for more details. pub struct RecvMsgFlags { - /// Sets the close-on-exec flag for any file descriptors received via SCM_RIGHTS (MSG_CMSG_CLOEXEC) - pub cmsg_cloexec: bool, - /// Do not block (MSG_DONTWAIT) - pub dont_wait: bool, - /// Do not remove the retrieved data from the receive queue (the next call will return the same data) (MSG_PEEK) - pub peek: bool, - /// Wait for the buffers to be filled (may still be interrupted by a signal or the socket hanging up) (MSG_WAITALL) - pub wait_all: bool, + cmsg_cloexec: bool, + dont_wait: bool, + peek: bool, + wait_all: bool, // TODO: Add support for MSG_ERRQUEUE (need to support more cmsgs) } -#[derive(Clone, Copy, Debug, Default)] +#[derive(Clone, Copy, Debug)] /// Flags returned by recvmsg. See recvmsg(2) for more details. pub struct RecvMsgResultFlags { - /// The returned data marks the end of a record (used for SOCK_SEQPACKET) (MSG_EOR) - pub end_of_record: bool, - /// Some data was discarded due to the provided buffers being too short (MSG_TRUNC) - pub truncated: bool, - /// Some control data was discarded (MSG_CTRUNC) - pub control_truncated: bool, + end_of_record: bool, + truncated: bool, + control_truncated: bool, } impl SendMsgFlags { + /// Create a default SendMsgFlags + pub fn new() -> SendMsgFlags { + SendMsgFlags { + dont_wait: false, + end_of_record: false, + no_signal: false, + } + } + + /// Do not block (MSG_DONTWAIT) + pub fn dont_wait(mut self, v: bool) -> SendMsgFlags { + self.dont_wait = v; + self + } + + /// Mark this packet as the end of a record (used for SOCK_SEQPACKET connections) (MSG_EOR) + pub fn end_of_record(mut self, v: bool) -> SendMsgFlags { + self.end_of_record = v; + self + } + + /// Do not receive SIGPIPE if the other end breaks the connection (MSG_NOSIGNAL) + pub fn no_signal(mut self, v: bool) -> SendMsgFlags { + self.no_signal = v; + self + } + fn as_cint(&self) -> libc::c_int { let mut result = 0; if self.dont_wait { result |= MSG_DONTWAIT; } @@ -260,6 +277,40 @@ impl SendMsgFlags { } impl RecvMsgFlags { + /// Create a default RecvMsgFlags + pub fn new() -> RecvMsgFlags { + RecvMsgFlags { + cmsg_cloexec: false, + dont_wait: false, + peek: false, + wait_all: false, + } + } + + /// Sets the close-on-exec flag for any file descriptors received via SCM_RIGHTS (MSG_CMSG_CLOEXEC) + pub fn cmsg_cloexec(mut self, v: bool) -> RecvMsgFlags { + self.cmsg_cloexec = v; + self + } + + /// Do not block (MSG_DONTWAIT) + pub fn dont_wait(mut self, v: bool) -> RecvMsgFlags { + self.dont_wait = v; + self + } + + /// Do not remove the retrieved data from the receive queue (the next call will return the same data) (MSG_PEEK) + pub fn peek(mut self, v: bool) -> RecvMsgFlags { + self.peek = v; + self + } + + /// Wait for the buffers to be filled (may still be interrupted by a signal or the socket hanging up) (MSG_WAITALL) + pub fn wait_all(mut self, v: bool) -> RecvMsgFlags { + self.wait_all = v; + self + } + fn as_cint(&self) -> libc::c_int { let mut result = 0; if self.cmsg_cloexec { result |= MSG_CMSG_CLOEXEC; } @@ -271,6 +322,21 @@ impl RecvMsgFlags { } impl RecvMsgResultFlags { + /// The returned data marks the end of a record (used for SOCK_SEQPACKET) (MSG_EOR) + pub fn end_of_record(&self) -> bool { + self.end_of_record + } + + /// Some data was discarded due to the provided buffers being too short (MSG_TRUNC) + pub fn truncated(&self) -> bool { + self.truncated + } + + /// Some control data was discarded (MSG_CTRUNC) + pub fn control_truncated(&self) -> bool { + self.control_truncated + } + fn from_cint(flags: libc::c_int) -> RecvMsgResultFlags { RecvMsgResultFlags { end_of_record: (flags & MSG_EOR) != 0, From 4d0fd136b447a1cfb36f6f20cfd1e0dd230b290a Mon Sep 17 00:00:00 2001 From: Todd Eisenberger Date: Wed, 14 Oct 2015 23:04:15 -0700 Subject: [PATCH 5/5] Remove cmsg_buffer and fix stale comments --- examples/socket_send.rs | 3 +-- src/lib.rs | 51 ++++++++--------------------------------- src/sendmsg_impl.rs | 1 - 3 files changed, 11 insertions(+), 44 deletions(-) diff --git a/examples/socket_send.rs b/examples/socket_send.rs index 9cbc433..5b3478a 100644 --- a/examples/socket_send.rs +++ b/examples/socket_send.rs @@ -32,9 +32,8 @@ fn handle_parent(sock: UnixDatagram) { #[cfg(feature = "sendmsg")] fn handle_child(sock: UnixDatagram) { sock.set_passcred(true).unwrap(); - let mut cmsg_buf = &mut [0u8; 4096]; let flags = RecvMsgFlags::new().cmsg_cloexec(true); - let result = sock.recvmsg(&[&mut[]], cmsg_buf, flags).unwrap(); + let result = sock.recvmsg(&[&mut[]], flags).unwrap(); assert_eq!(result.control_msgs.len(), 2); let mut new_sock = None; diff --git a/src/lib.rs b/src/lib.rs index 927697b..97b68fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -804,24 +804,21 @@ impl UnixDatagram { /// Receives data on the socket. /// - /// If path is None, the peer address set by the `connect` method will be used. If it has not - /// been set, then this method will return an error. - /// /// This interface allows receiving data into multiple buffers. This acts as if the buffers had been /// concatenated in the order they were given. /// - /// cmsg_buffer is space to use for storing control messages. - /// /// On success, returns the number of bytes written. #[cfg(feature = "sendmsg")] - pub fn recvmsg(&self, buffers: &[&mut[u8]], cmsg_buffer: &mut [u8], flags: RecvMsgFlags) -> io::Result { + pub fn recvmsg(&self, buffers: &[&mut[u8]], flags: RecvMsgFlags) -> io::Result { let mut result = Err(io::Error::new(io::ErrorKind::Other, "programming error")); let addr = try!(SocketAddr::new(|addr, len| { + const CMSG_BUFFER_SIZE: usize = 4096; + let mut cmsg_buffer = [0u8; CMSG_BUFFER_SIZE]; unsafe { result = sendmsg_impl::recvmsg( self.inner.0, buffers, - cmsg_buffer, + &mut cmsg_buffer, flags, addr, len); @@ -1339,11 +1336,11 @@ mod test { /// Expects to receive "hello" on the data channel, and uses the given buf for cmsgs #[cfg(feature = "sendmsg")] - fn recvmsg_helper(s: &UnixDatagram, cmsg_buf: &mut [u8]) -> super::RecvMsgResult { + fn recvmsg_helper(s: &UnixDatagram) -> super::RecvMsgResult { use RecvMsgFlags; let mut buf = [0; 3]; let mut buf2 = [0; 3]; - let result = or_panic!(s.recvmsg(&[&mut buf[..], &mut buf2[..]], cmsg_buf, RecvMsgFlags::new())); + let result = or_panic!(s.recvmsg(&[&mut buf[..], &mut buf2[..]], RecvMsgFlags::new())); assert_eq!(result.data_bytes, 5); assert_eq!(&buf[..], b"hel"); assert_eq!(&buf2[..2], b"lo"); @@ -1378,35 +1375,13 @@ mod test { let sock2 = or_panic!(UnixDatagram::bind(&path2)); assert_eq!(or_panic!(sock1.send_to(b"hello", &path2)), 5); - let result = recvmsg_helper(&sock2, &mut []); + let result = recvmsg_helper(&sock2); match result.sender.address() { AddressKind::Pathname(p) => assert_eq!(p, path1.as_path()), _ => unreachable!(), } } - #[cfg(all(feature = "from_raw_fd", feature = "sendmsg"))] - #[test] - fn test_ctrunc() { - use ControlMsg; - - let (s1, s2) = or_panic!(UnixDatagram::pair()); - let thread = thread::spawn(move || { - let mut cmsg_buf = [0; 1]; - let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); - assert_eq!(result.control_msgs.len(), 0); - assert!(result.flags.control_truncated()); - }); - - let (_, theirs) = or_panic!(UnixDatagram::pair()); - let cmsg = ControlMsg::Rights(vec![theirs.as_raw_fd()]); - sendmsg_helper::<&Path>(&s2, None, &[cmsg]); - drop(s2); - - thread.join().unwrap(); - } - - #[cfg(feature = "sendmsg")] #[test] fn test_send_credentials_without_passcred() { @@ -1415,9 +1390,7 @@ mod test { // Without passcred, the ucred should be dropped let (s1, s2) = or_panic!(UnixDatagram::pair()); let thread = thread::spawn(move || { - let mut cmsg_buf = [0; 4096]; - let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); - drop(cmsg_buf); + let result = recvmsg_helper(&s1); assert_eq!(result.control_msgs.len(), 0); assert!(!result.flags.control_truncated()); }); @@ -1448,9 +1421,7 @@ mod test { let (s1, s2) = or_panic!(UnixDatagram::pair()); let thread = thread::spawn(move || { or_panic!(s1.set_passcred(true)); - let mut cmsg_buf = [0; 4096]; - let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); - drop(cmsg_buf); + let result = recvmsg_helper(&s1); assert_eq!(result.control_msgs.len(), 1); assert!(!result.flags.control_truncated()); @@ -1485,9 +1456,7 @@ mod test { use ControlMsg; let (s1, s2) = or_panic!(UnixDatagram::pair()); let thread = thread::spawn(move || { - let mut cmsg_buf = [0; 4096]; - let result = recvmsg_helper(&s1, &mut cmsg_buf[..]); - drop(cmsg_buf); + let result = recvmsg_helper(&s1); assert_eq!(result.control_msgs.len(), 1); assert!(!result.flags.control_truncated()); diff --git a/src/sendmsg_impl.rs b/src/sendmsg_impl.rs index df2840f..f7bb188 100644 --- a/src/sendmsg_impl.rs +++ b/src/sendmsg_impl.rs @@ -52,7 +52,6 @@ pub use self::raw::scm_rights as SCM_RIGHTS; use self::raw::msg_eor as MSG_EOR; use self::raw::msg_trunc as MSG_TRUNC; use self::raw::msg_ctrunc as MSG_CTRUNC; -use self::raw::msg_errqueue as MSG_ERRQUEUE; use self::raw::msg_dontwait as MSG_DONTWAIT; use self::raw::msg_cmsg_cloexec as MSG_CMSG_CLOEXEC; use self::raw::msg_nosignal as MSG_NOSIGNAL;