From 8d295ff34192d2e1e76f92d4a888fd3baaa6b407 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 16 Nov 2023 09:40:34 +1100 Subject: [PATCH] Use oneshot's instead for `libp2p-relay` --- protocols/relay/src/priv_client/handler.rs | 220 ++++++++----------- protocols/relay/src/protocol/outbound_hop.rs | 4 +- 2 files changed, 99 insertions(+), 125 deletions(-) diff --git a/protocols/relay/src/priv_client/handler.rs b/protocols/relay/src/priv_client/handler.rs index 1925d6f6ab4..662d63cc742 100644 --- a/protocols/relay/src/priv_client/handler.rs +++ b/protocols/relay/src/priv_client/handler.rs @@ -18,9 +18,12 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use crate::client::Connection; use crate::priv_client::transport; +use crate::priv_client::transport::ToListenerMsg; use crate::protocol::{self, inbound_stop, outbound_hop}; use crate::{priv_client, proto, HOP_PROTOCOL_NAME, STOP_PROTOCOL_NAME}; +use futures::channel::mpsc::Sender; use futures::channel::{mpsc, oneshot}; use futures::future::FutureExt; use futures_timer::Delay; @@ -28,17 +31,16 @@ use libp2p_core::multiaddr::Protocol; use libp2p_core::upgrade::ReadyUpgrade; use libp2p_core::Multiaddr; use libp2p_identity::PeerId; -use libp2p_swarm::handler::{ - ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, -}; +use libp2p_swarm::handler::{ConnectionEvent, FullyNegotiatedInbound}; use libp2p_swarm::{ - ConnectionHandler, ConnectionHandlerEvent, StreamProtocol, StreamUpgradeError, + ConnectionHandler, ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError, SubstreamProtocol, }; use std::collections::VecDeque; use std::task::{Context, Poll}; use std::time::Duration; use std::{fmt, io}; +use void::Void; /// The maximum number of circuits being denied concurrently. /// @@ -104,8 +106,7 @@ pub struct Handler { >, >, - /// We issue a stream upgrade for each pending request. - pending_requests: VecDeque, + pending_streams: VecDeque>>>, inflight_reserve_requests: futures_bounded::FuturesTupleSet< Result, @@ -133,7 +134,7 @@ impl Handler { remote_peer_id, remote_addr, queued_events: Default::default(), - pending_requests: Default::default(), + pending_streams: Default::default(), inflight_reserve_requests: futures_bounded::FuturesTupleSet::new( STREAM_TIMEOUT, MAX_CONCURRENT_STREAMS_PER_CONNECTION, @@ -154,57 +155,6 @@ impl Handler { } } - fn on_dial_upgrade_error( - &mut self, - DialUpgradeError { error, .. }: DialUpgradeError< - ::OutboundOpenInfo, - ::OutboundProtocol, - >, - ) { - let pending_request = self - .pending_requests - .pop_front() - .expect("got a stream error without a pending request"); - - match pending_request { - PendingRequest::Reserve { mut to_listener } => { - let error = match error { - StreamUpgradeError::Timeout => { - outbound_hop::ReserveError::Io(io::ErrorKind::TimedOut.into()) - } - StreamUpgradeError::Apply(never) => void::unreachable(never), - StreamUpgradeError::NegotiationFailed => { - outbound_hop::ReserveError::Unsupported - } - StreamUpgradeError::Io(e) => outbound_hop::ReserveError::Io(e), - }; - - if let Err(e) = - to_listener.try_send(transport::ToListenerMsg::Reservation(Err(error))) - { - tracing::debug!("Unable to send error to listener: {}", e.into_send_error()) - } - self.reservation.failed(); - } - PendingRequest::Connect { - to_dial: send_back, .. - } => { - let error = match error { - StreamUpgradeError::Timeout => { - outbound_hop::ConnectError::Io(io::ErrorKind::TimedOut.into()) - } - StreamUpgradeError::NegotiationFailed => { - outbound_hop::ConnectError::Unsupported - } - StreamUpgradeError::Io(e) => outbound_hop::ConnectError::Io(e), - StreamUpgradeError::Apply(v) => void::unreachable(v), - }; - - let _ = send_back.send(Err(error)); - } - } - } - fn insert_to_deny_futs(&mut self, circuit: inbound_stop::Circuit) { let src_peer_id = circuit.src_peer_id(); @@ -219,6 +169,62 @@ impl Handler { ) } } + + fn make_new_reservation(&mut self, to_listener: Sender) { + let (sender, receiver) = oneshot::channel(); + + self.pending_streams.push_back(sender); + self.queued_events + .push_back(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()), + }); + let result = self.inflight_reserve_requests.try_push( + async move { + let stream = receiver + .await + .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))? + .map_err(into_reserve_error)?; + + let reservation = outbound_hop::make_reservation(stream).await?; + + Ok(reservation) + }, + to_listener, + ); + + if result.is_err() { + tracing::warn!("Dropping in-flight reservation request because we are at capacity"); + } + } + + fn establish_new_circuit( + &mut self, + to_dial: oneshot::Sender>, + dst_peer_id: PeerId, + ) { + let (sender, receiver) = oneshot::channel(); + + self.pending_streams.push_back(sender); + self.queued_events + .push_back(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()), + }); + let result = self.inflight_outbound_connect_requests.try_push( + async move { + let stream = receiver + .await + .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))? + .map_err(into_connect_error)?; + + outbound_hop::open_circuit(stream, dst_peer_id).await + }, + to_dial, + ); + + if result.is_err() { + tracing::warn!("Dropping in-flight connect request because we are at capacity") + } + } } impl ConnectionHandler for Handler { @@ -236,25 +242,13 @@ impl ConnectionHandler for Handler { fn on_behaviour_event(&mut self, event: Self::FromBehaviour) { match event { In::Reserve { to_listener } => { - self.pending_requests - .push_back(PendingRequest::Reserve { to_listener }); - self.queued_events - .push_back(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()), - }); + self.make_new_reservation(to_listener); } In::EstablishCircuit { - to_dial: send_back, + to_dial, dst_peer_id, } => { - self.pending_requests.push_back(PendingRequest::Connect { - dst_peer_id, - to_dial: send_back, - }); - self.queued_events - .push_back(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()), - }); + self.establish_new_circuit(to_dial, dst_peer_id); } } } @@ -402,12 +396,8 @@ impl ConnectionHandler for Handler { } if let Poll::Ready(Some(to_listener)) = self.reservation.poll(cx) { - self.pending_requests - .push_back(PendingRequest::Reserve { to_listener }); - - return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()), - }); + self.make_new_reservation(to_listener); + continue; } // Deny incoming circuit requests. @@ -450,42 +440,16 @@ impl ConnectionHandler for Handler { tracing::warn!("Dropping inbound stream because we are at capacity") } } - ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound { - protocol: stream, - .. - }) => { - let pending_request = self.pending_requests.pop_front().expect( - "opened a stream without a pending connection command or a reserve listener", - ); - match pending_request { - PendingRequest::Reserve { to_listener } => { - if self - .inflight_reserve_requests - .try_push(outbound_hop::make_reservation(stream), to_listener) - .is_err() - { - tracing::warn!("Dropping outbound stream because we are at capacity"); - } - } - PendingRequest::Connect { - dst_peer_id, - to_dial: send_back, - } => { - if self - .inflight_outbound_connect_requests - .try_push(outbound_hop::open_circuit(stream, dst_peer_id), send_back) - .is_err() - { - tracing::warn!("Dropping outbound stream because we are at capacity"); - } - } + ConnectionEvent::FullyNegotiatedOutbound(ev) => { + if let Some(next) = self.pending_streams.pop_front() { + let _ = next.send(Ok(ev.protocol)); } } - ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => { - void::unreachable(listen_upgrade_error.error) - } - ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { - self.on_dial_upgrade_error(dial_upgrade_error) + ConnectionEvent::ListenUpgradeError(ev) => void::unreachable(ev.error), + ConnectionEvent::DialUpgradeError(ev) => { + if let Some(next) = self.pending_streams.pop_front() { + let _ = next.send(Err(ev.error)); + } } _ => {} } @@ -614,14 +578,24 @@ impl Reservation { } } -pub(crate) enum PendingRequest { - Reserve { - /// A channel into the [`Transport`](priv_client::Transport). - to_listener: mpsc::Sender, - }, - Connect { - dst_peer_id: PeerId, - /// A channel into the future returned by [`Transport::dial`](libp2p_core::Transport::dial). - to_dial: oneshot::Sender>, - }, +fn into_reserve_error(e: StreamUpgradeError) -> outbound_hop::ReserveError { + match e { + StreamUpgradeError::Timeout => { + outbound_hop::ReserveError::Io(io::ErrorKind::TimedOut.into()) + } + StreamUpgradeError::Apply(never) => void::unreachable(never), + StreamUpgradeError::NegotiationFailed => outbound_hop::ReserveError::Unsupported, + StreamUpgradeError::Io(e) => outbound_hop::ReserveError::Io(e), + } +} + +fn into_connect_error(e: StreamUpgradeError) -> outbound_hop::ConnectError { + match e { + StreamUpgradeError::Timeout => { + outbound_hop::ConnectError::Io(io::ErrorKind::TimedOut.into()) + } + StreamUpgradeError::Apply(never) => void::unreachable(never), + StreamUpgradeError::NegotiationFailed => outbound_hop::ConnectError::Unsupported, + StreamUpgradeError::Io(e) => outbound_hop::ConnectError::Io(e), + } } diff --git a/protocols/relay/src/protocol/outbound_hop.rs b/protocols/relay/src/protocol/outbound_hop.rs index e5f9a6a0a52..3ae824be167 100644 --- a/protocols/relay/src/protocol/outbound_hop.rs +++ b/protocols/relay/src/protocol/outbound_hop.rs @@ -47,7 +47,7 @@ pub enum ConnectError { #[error("Remote does not support the `{HOP_PROTOCOL_NAME}` protocol")] Unsupported, #[error("IO error")] - Io(#[source] io::Error), + Io(#[from] io::Error), #[error("Protocol error")] Protocol(#[from] ProtocolViolation), } @@ -61,7 +61,7 @@ pub enum ReserveError { #[error("Remote does not support the `{HOP_PROTOCOL_NAME}` protocol")] Unsupported, #[error("IO error")] - Io(#[source] io::Error), + Io(#[from] io::Error), #[error("Protocol error")] Protocol(#[from] ProtocolViolation), }