From 0181e86fa7c890cb1ae328ae536380783e8801e9 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 20 Oct 2023 18:07:45 +1100 Subject: [PATCH] fix(mdns): move IO off main task Resolves: #2591. Pull-Request: #4623. --- protocols/mdns/CHANGELOG.md | 2 + protocols/mdns/Cargo.toml | 3 +- protocols/mdns/src/behaviour.rs | 136 +++++++++++++++++--------- protocols/mdns/src/behaviour/iface.rs | 123 ++++++++++++++--------- 4 files changed, 171 insertions(+), 93 deletions(-) diff --git a/protocols/mdns/CHANGELOG.md b/protocols/mdns/CHANGELOG.md index 0c06bb26b3d..060fac8c51c 100644 --- a/protocols/mdns/CHANGELOG.md +++ b/protocols/mdns/CHANGELOG.md @@ -1,5 +1,7 @@ ## 0.45.0 - unreleased +- Don't perform IO in `Behaviour::poll`. + See [PR 4623](https://github.com/libp2p/rust-libp2p/pull/4623). ## 0.44.0 diff --git a/protocols/mdns/Cargo.toml b/protocols/mdns/Cargo.toml index b4a285448da..a8b6a7da1b9 100644 --- a/protocols/mdns/Cargo.toml +++ b/protocols/mdns/Cargo.toml @@ -11,6 +11,7 @@ keywords = ["peer-to-peer", "libp2p", "networking"] categories = ["network-programming", "asynchronous"] [dependencies] +async-std = { version = "1.12.0", optional = true } async-io = { version = "1.13.0", optional = true } data-encoding = "2.4.0" futures = "0.3.28" @@ -28,7 +29,7 @@ void = "1.0.2" [features] tokio = ["dep:tokio", "if-watch/tokio"] -async-io = ["dep:async-io", "if-watch/smol"] +async-io = ["dep:async-io", "dep:async-std", "if-watch/smol"] [dev-dependencies] async-std = { version = "1.9.0", features = ["attributes"] } diff --git a/protocols/mdns/src/behaviour.rs b/protocols/mdns/src/behaviour.rs index bc102f832df..9e937272e8c 100644 --- a/protocols/mdns/src/behaviour.rs +++ b/protocols/mdns/src/behaviour.rs @@ -25,7 +25,8 @@ mod timer; use self::iface::InterfaceState; use crate::behaviour::{socket::AsyncSocket, timer::Builder}; use crate::Config; -use futures::Stream; +use futures::channel::mpsc; +use futures::{Stream, StreamExt}; use if_watch::IfEvent; use libp2p_core::{Endpoint, Multiaddr}; use libp2p_identity::PeerId; @@ -36,6 +37,8 @@ use libp2p_swarm::{ }; use smallvec::SmallVec; use std::collections::hash_map::{Entry, HashMap}; +use std::future::Future; +use std::sync::{Arc, RwLock}; use std::{cmp, fmt, io, net::IpAddr, pin::Pin, task::Context, task::Poll, time::Instant}; /// An abstraction to allow for compatibility with various async runtimes. @@ -47,16 +50,27 @@ pub trait Provider: 'static { /// The IfWatcher type. type Watcher: Stream> + fmt::Debug + Unpin; + type TaskHandle: Abort; + /// Create a new instance of the `IfWatcher` type. fn new_watcher() -> Result; + + fn spawn(task: impl Future + Send + 'static) -> Self::TaskHandle; +} + +#[allow(unreachable_pub)] // Not re-exported. +pub trait Abort { + fn abort(self); } /// The type of a [`Behaviour`] using the `async-io` implementation. #[cfg(feature = "async-io")] pub mod async_io { use super::Provider; - use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer}; + use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer, Abort}; + use async_std::task::JoinHandle; use if_watch::smol::IfWatcher; + use std::future::Future; #[doc(hidden)] pub enum AsyncIo {} @@ -65,10 +79,21 @@ pub mod async_io { type Socket = AsyncUdpSocket; type Timer = AsyncTimer; type Watcher = IfWatcher; + type TaskHandle = JoinHandle<()>; fn new_watcher() -> Result { IfWatcher::new() } + + fn spawn(task: impl Future + Send + 'static) -> JoinHandle<()> { + async_std::task::spawn(task) + } + } + + impl Abort for JoinHandle<()> { + fn abort(self) { + async_std::task::spawn(self.cancel()); + } } pub type Behaviour = super::Behaviour; @@ -78,8 +103,10 @@ pub mod async_io { #[cfg(feature = "tokio")] pub mod tokio { use super::Provider; - use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer}; + use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer, Abort}; use if_watch::tokio::IfWatcher; + use std::future::Future; + use tokio::task::JoinHandle; #[doc(hidden)] pub enum Tokio {} @@ -88,10 +115,21 @@ pub mod tokio { type Socket = TokioUdpSocket; type Timer = TokioTimer; type Watcher = IfWatcher; + type TaskHandle = JoinHandle<()>; fn new_watcher() -> Result { IfWatcher::new() } + + fn spawn(task: impl Future + Send + 'static) -> Self::TaskHandle { + tokio::spawn(task) + } + } + + impl Abort for JoinHandle<()> { + fn abort(self) { + JoinHandle::abort(&self) + } } pub type Behaviour = super::Behaviour; @@ -110,8 +148,11 @@ where /// Iface watcher. if_watch: P::Watcher, - /// Mdns interface states. - iface_states: HashMap>, + /// Handles to tasks running the mDNS queries. + if_tasks: HashMap, + + query_response_receiver: mpsc::Receiver<(PeerId, Multiaddr, Instant)>, + query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>, /// List of nodes that we have discovered, the address, and when their TTL expires. /// @@ -124,7 +165,11 @@ where /// `None` if `discovered_nodes` is empty. closest_expiration: Option, - listen_addresses: ListenAddresses, + /// The current set of listen addresses. + /// + /// This is shared across all interface tasks using an [`RwLock`]. + /// The [`Behaviour`] updates this upon new [`FromSwarm`] events where as [`InterfaceState`]s read from it to answer inbound mDNS queries. + listen_addresses: Arc>, local_peer_id: PeerId, } @@ -135,10 +180,14 @@ where { /// Builds a new `Mdns` behaviour. pub fn new(config: Config, local_peer_id: PeerId) -> io::Result { + let (tx, rx) = mpsc::channel(10); // Chosen arbitrarily. + Ok(Self { config, if_watch: P::new_watcher()?, - iface_states: Default::default(), + if_tasks: Default::default(), + query_response_receiver: rx, + query_response_sender: tx, discovered_nodes: Default::default(), closest_expiration: Default::default(), listen_addresses: Default::default(), @@ -147,6 +196,7 @@ where } /// Returns true if the given `PeerId` is in the list of nodes discovered through mDNS. + #[deprecated(note = "Use `discovered_nodes` iterator instead.")] pub fn has_node(&self, peer_id: &PeerId) -> bool { self.discovered_nodes().any(|p| p == peer_id) } @@ -157,6 +207,7 @@ where } /// Expires a node before the ttl. + #[deprecated(note = "Unused API. Will be removed in the next release.")] pub fn expire_node(&mut self, peer_id: &PeerId) { let now = Instant::now(); for (peer, _addr, expires) in &mut self.discovered_nodes { @@ -225,28 +276,10 @@ where } fn on_swarm_event(&mut self, event: FromSwarm) { - self.listen_addresses.on_swarm_event(&event); - - match event { - FromSwarm::NewListener(_) => { - log::trace!("waking interface state because listening address changed"); - for iface in self.iface_states.values_mut() { - iface.fire_timer(); - } - } - FromSwarm::ConnectionClosed(_) - | FromSwarm::ConnectionEstablished(_) - | FromSwarm::DialFailure(_) - | FromSwarm::AddressChange(_) - | FromSwarm::ListenFailure(_) - | FromSwarm::NewListenAddr(_) - | FromSwarm::ExpiredListenAddr(_) - | FromSwarm::ListenerError(_) - | FromSwarm::ListenerClosed(_) - | FromSwarm::NewExternalAddrCandidate(_) - | FromSwarm::ExternalAddrExpired(_) - | FromSwarm::ExternalAddrConfirmed(_) => {} - } + self.listen_addresses + .write() + .unwrap_or_else(|e| e.into_inner()) + .on_swarm_event(&event); } fn poll( @@ -267,19 +300,26 @@ where { continue; } - if let Entry::Vacant(e) = self.iface_states.entry(addr) { - match InterfaceState::new(addr, self.config.clone(), self.local_peer_id) { + if let Entry::Vacant(e) = self.if_tasks.entry(addr) { + match InterfaceState::::new( + addr, + self.config.clone(), + self.local_peer_id, + self.listen_addresses.clone(), + self.query_response_sender.clone(), + ) { Ok(iface_state) => { - e.insert(iface_state); + e.insert(P::spawn(iface_state)); } Err(err) => log::error!("failed to create `InterfaceState`: {}", err), } } } Ok(IfEvent::Down(inet)) => { - if self.iface_states.contains_key(&inet.addr()) { + if let Some(handle) = self.if_tasks.remove(&inet.addr()) { log::info!("dropping instance {}", inet.addr()); - self.iface_states.remove(&inet.addr()); + + handle.abort(); } } Err(err) => log::error!("if watch returned an error: {}", err), @@ -287,23 +327,23 @@ where } // Emit discovered event. let mut discovered = Vec::new(); - for iface_state in self.iface_states.values_mut() { - while let Poll::Ready((peer, addr, expiration)) = - iface_state.poll(cx, &self.listen_addresses) + + while let Poll::Ready(Some((peer, addr, expiration))) = + self.query_response_receiver.poll_next_unpin(cx) + { + if let Some((_, _, cur_expires)) = self + .discovered_nodes + .iter_mut() + .find(|(p, a, _)| *p == peer && *a == addr) { - if let Some((_, _, cur_expires)) = self - .discovered_nodes - .iter_mut() - .find(|(p, a, _)| *p == peer && *a == addr) - { - *cur_expires = cmp::max(*cur_expires, expiration); - } else { - log::info!("discovered: {} {}", peer, addr); - self.discovered_nodes.push((peer, addr.clone(), expiration)); - discovered.push((peer, addr)); - } + *cur_expires = cmp::max(*cur_expires, expiration); + } else { + log::info!("discovered: {} {}", peer, addr); + self.discovered_nodes.push((peer, addr.clone(), expiration)); + discovered.push((peer, addr)); } } + if !discovered.is_empty() { let event = Event::Discovered(discovered); return Poll::Ready(ToSwarm::GenerateEvent(event)); diff --git a/protocols/mdns/src/behaviour/iface.rs b/protocols/mdns/src/behaviour/iface.rs index 54d6c657380..47601088fdc 100644 --- a/protocols/mdns/src/behaviour/iface.rs +++ b/protocols/mdns/src/behaviour/iface.rs @@ -25,10 +25,14 @@ use self::dns::{build_query, build_query_response, build_service_discovery_respo use self::query::MdnsPacket; use crate::behaviour::{socket::AsyncSocket, timer::Builder}; use crate::Config; +use futures::channel::mpsc; +use futures::{SinkExt, StreamExt}; use libp2p_core::Multiaddr; use libp2p_identity::PeerId; use libp2p_swarm::ListenAddresses; use socket2::{Domain, Socket, Type}; +use std::future::Future; +use std::sync::{Arc, RwLock}; use std::{ collections::VecDeque, io, @@ -72,6 +76,11 @@ pub(crate) struct InterfaceState { recv_socket: U, /// Send socket. send_socket: U, + + listen_addresses: Arc>, + + query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>, + /// Buffer used for receiving data from the main socket. /// RFC6762 discourages packets larger than the interface MTU, but allows sizes of up to 9000 /// bytes, if it can be ensured that all participating devices can handle such large packets. @@ -101,7 +110,13 @@ where T: Builder + futures::Stream, { /// Builds a new [`InterfaceState`]. - pub(crate) fn new(addr: IpAddr, config: Config, local_peer_id: PeerId) -> io::Result { + pub(crate) fn new( + addr: IpAddr, + config: Config, + local_peer_id: PeerId, + listen_addresses: Arc>, + query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>, + ) -> io::Result { log::info!("creating instance on iface {}", addr); let recv_socket = match addr { IpAddr::V4(addr) => { @@ -154,6 +169,8 @@ where addr, recv_socket, send_socket, + listen_addresses, + query_response_sender, recv_buffer: [0; 4096], send_buffer: Default::default(), discovered: Default::default(), @@ -172,78 +189,96 @@ where self.timeout = T::interval(interval); } - pub(crate) fn fire_timer(&mut self) { - self.timeout = T::interval_at(Instant::now(), INITIAL_TIMEOUT_INTERVAL); + fn mdns_socket(&self) -> SocketAddr { + SocketAddr::new(self.multicast_addr, 5353) } +} + +impl Future for InterfaceState +where + U: AsyncSocket, + T: Builder + futures::Stream, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); - pub(crate) fn poll( - &mut self, - cx: &mut Context, - listen_addresses: &ListenAddresses, - ) -> Poll<(PeerId, Multiaddr, Instant)> { loop { // 1st priority: Low latency: Create packet ASAP after timeout. - if Pin::new(&mut self.timeout).poll_next(cx).is_ready() { - log::trace!("sending query on iface {}", self.addr); - self.send_buffer.push_back(build_query()); - log::trace!("tick on {:#?} {:#?}", self.addr, self.probe_state); + if this.timeout.poll_next_unpin(cx).is_ready() { + log::trace!("sending query on iface {}", this.addr); + this.send_buffer.push_back(build_query()); + log::trace!("tick on {:#?} {:#?}", this.addr, this.probe_state); // Stop to probe when the initial interval reach the query interval - if let ProbeState::Probing(interval) = self.probe_state { + if let ProbeState::Probing(interval) = this.probe_state { let interval = interval * 2; - self.probe_state = if interval >= self.query_interval { - ProbeState::Finished(self.query_interval) + this.probe_state = if interval >= this.query_interval { + ProbeState::Finished(this.query_interval) } else { ProbeState::Probing(interval) }; } - self.reset_timer(); + this.reset_timer(); } // 2nd priority: Keep local buffers small: Send packets to remote. - if let Some(packet) = self.send_buffer.pop_front() { - match Pin::new(&mut self.send_socket).poll_write( - cx, - &packet, - SocketAddr::new(self.multicast_addr, 5353), - ) { + if let Some(packet) = this.send_buffer.pop_front() { + match this.send_socket.poll_write(cx, &packet, this.mdns_socket()) { Poll::Ready(Ok(_)) => { - log::trace!("sent packet on iface {}", self.addr); + log::trace!("sent packet on iface {}", this.addr); continue; } Poll::Ready(Err(err)) => { - log::error!("error sending packet on iface {} {}", self.addr, err); + log::error!("error sending packet on iface {} {}", this.addr, err); continue; } Poll::Pending => { - self.send_buffer.push_front(packet); + this.send_buffer.push_front(packet); } } } // 3rd priority: Keep local buffers small: Return discovered addresses. - if let Some(discovered) = self.discovered.pop_front() { - return Poll::Ready(discovered); + if this.query_response_sender.poll_ready_unpin(cx).is_ready() { + if let Some(discovered) = this.discovered.pop_front() { + match this.query_response_sender.try_send(discovered) { + Ok(()) => {} + Err(e) if e.is_disconnected() => { + return Poll::Ready(()); + } + Err(e) => { + this.discovered.push_front(e.into_inner()); + } + } + + continue; + } } // 4th priority: Remote work: Answer incoming requests. - match Pin::new(&mut self.recv_socket) - .poll_read(cx, &mut self.recv_buffer) - .map_ok(|(len, from)| MdnsPacket::new_from_bytes(&self.recv_buffer[..len], from)) + match this + .recv_socket + .poll_read(cx, &mut this.recv_buffer) + .map_ok(|(len, from)| MdnsPacket::new_from_bytes(&this.recv_buffer[..len], from)) { Poll::Ready(Ok(Ok(Some(MdnsPacket::Query(query))))) => { log::trace!( "received query from {} on {}", query.remote_addr(), - self.addr + this.addr ); - self.send_buffer.extend(build_query_response( + this.send_buffer.extend(build_query_response( query.query_id(), - self.local_peer_id, - listen_addresses.iter(), - self.ttl, + this.local_peer_id, + this.listen_addresses + .read() + .unwrap_or_else(|e| e.into_inner()) + .iter(), + this.ttl, )); continue; } @@ -251,16 +286,16 @@ where log::trace!( "received response from {} on {}", response.remote_addr(), - self.addr + this.addr ); - self.discovered - .extend(response.extract_discovered(Instant::now(), self.local_peer_id)); + this.discovered + .extend(response.extract_discovered(Instant::now(), this.local_peer_id)); // Stop probing when we have a valid response - if !self.discovered.is_empty() { - self.probe_state = ProbeState::Finished(self.query_interval); - self.reset_timer(); + if !this.discovered.is_empty() { + this.probe_state = ProbeState::Finished(this.query_interval); + this.reset_timer(); } continue; } @@ -268,11 +303,11 @@ where log::trace!( "received service discovery from {} on {}", disc.remote_addr(), - self.addr + this.addr ); - self.send_buffer - .push_back(build_service_discovery_response(disc.query_id(), self.ttl)); + this.send_buffer + .push_back(build_service_discovery_response(disc.query_id(), this.ttl)); continue; } Poll::Ready(Err(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {