diff --git a/rt/src/net/udp.rs b/rt/src/net/udp.rs index 838e3c12b..9708f0f79 100644 --- a/rt/src/net/udp.rs +++ b/rt/src/net/udp.rs @@ -141,6 +141,24 @@ impl UdpSocket { }) } + /// Converts a [`std::net::UdpSocket`] to a [`heph_rt::net::UdpSocket`]. + /// + /// [`heph_rt::net::UdpSocket`]: UdpSocket + /// + /// # Notes + /// + /// It's up to the caller to ensure that the socket's mode is correctly set + /// to [`Connected`] or [`Unconnected`]. + pub fn from_std(rt: &RT, socket: std::net::UdpSocket) -> UdpSocket + where + RT: Access, + { + UdpSocket { + fd: AsyncFd::new(socket.into(), rt.submission_queue()), + mode: PhantomData, + } + } + /// Returns the sockets peer address. pub fn peer_addr(&self) -> io::Result { self.with_ref(|socket| socket.peer_addr().and_then(convert_address)) diff --git a/rt/tests/functional/udp.rs b/rt/tests/functional/udp.rs index 676116773..c9ea3eaef 100644 --- a/rt/tests/functional/udp.rs +++ b/rt/tests/functional/udp.rs @@ -5,9 +5,9 @@ use std::net::SocketAddr; use std::time::Duration; use heph::actor::{self, actor_fn, Actor, NewActor}; -use heph_rt::net::udp::UdpSocket; +use heph_rt::net::udp::{UdpSocket, Unconnected}; use heph_rt::spawn::ActorOptions; -use heph_rt::test::{join, try_spawn_local, PanicSupervisor}; +use heph_rt::test::{block_on_local_actor, join, try_spawn_local, PanicSupervisor}; use heph_rt::ThreadLocal; use crate::util::{any_local_address, any_local_ipv6_address}; @@ -299,3 +299,28 @@ fn assert_read(mut got: &[u8], expected: &[&[u8]]) { got = g; } } + +#[test] +fn socket_from_std() { + async fn actor(ctx: actor::Context) -> io::Result<()> { + let socket = std::net::UdpSocket::bind(any_local_address())?; + let socket = UdpSocket::::from_std(ctx.runtime_ref(), socket); + let local_address = socket.local_addr()?; + + let peer = std::net::UdpSocket::bind(any_local_address())?; + let peer_address = peer.local_addr()?; + + let (_, bytes_written) = socket.send_to(DATA, peer_address).await?; + assert_eq!(bytes_written, DATA.len()); + + let mut buf = vec![0; DATA.len() + 2]; + let (n, address) = peer.recv_from(&mut buf)?; + assert_eq!(n, DATA.len()); + assert_eq!(&buf[..n], DATA); + assert_eq!(address, local_address); + + Ok(()) + } + + block_on_local_actor(actor_fn(actor), ()).unwrap(); +}