diff --git a/Cargo.lock b/Cargo.lock index 04fb22b2fc2..a7ba9e9d7f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2676,6 +2676,7 @@ dependencies = [ "libp2p-identify", "libp2p-identity", "libp2p-noise", + "libp2p-protocol-utils", "libp2p-swarm", "libp2p-swarm-test", "libp2p-yamux", diff --git a/misc/protocol-utils/src/ipd_queue.rs b/misc/protocol-utils/src/ipd_queue.rs index ade441e2d31..97f6330efa4 100644 --- a/misc/protocol-utils/src/ipd_queue.rs +++ b/misc/protocol-utils/src/ipd_queue.rs @@ -28,11 +28,6 @@ impl InflightProtocolDataQueue { pub fn enqueue_request(&mut self, request: Req, data: D) { self.pending_requests.push_back(request); self.data_of_inflight_requests.push_back(data); - - debug_assert_eq!( - self.pending_requests.len(), - self.data_of_inflight_requests.len() - ); } /// Submits a response to the queue. @@ -46,6 +41,11 @@ impl InflightProtocolDataQueue { self.received_responses.push_back(res); } + /// How many protocols are currently in-flight. + pub fn num_inflight(&self) -> usize { + self.data_of_inflight_requests.len() - self.received_responses.len() + } + pub fn next_completed(&mut self) -> Option<(Res, D)> { let res = self.received_responses.pop_front()?; let data = self.data_of_inflight_requests.pop_front()?; diff --git a/protocols/kad/Cargo.toml b/protocols/kad/Cargo.toml index 403189b1801..f16a2580766 100644 --- a/protocols/kad/Cargo.toml +++ b/protocols/kad/Cargo.toml @@ -19,6 +19,7 @@ asynchronous-codec = "0.6" futures = "0.3.29" libp2p-core = { workspace = true } libp2p-swarm = { workspace = true } +libp2p-protocol-utils = { workspace = true } quick-protobuf = "0.8" quick-protobuf-codec = { workspace = true } libp2p-identity = { workspace = true, features = ["rand"] } diff --git a/protocols/kad/src/handler.rs b/protocols/kad/src/handler.rs index adfb076541c..2ba5a543c5c 100644 --- a/protocols/kad/src/handler.rs +++ b/protocols/kad/src/handler.rs @@ -29,16 +29,15 @@ use futures::prelude::*; use futures::stream::SelectAll; use libp2p_core::{upgrade, ConnectedPoint}; use libp2p_identity::PeerId; -use libp2p_swarm::handler::{ - ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, -}; +use libp2p_protocol_utils::InflightProtocolDataQueue; +use libp2p_swarm::handler::{ConnectionEvent, FullyNegotiatedInbound}; use libp2p_swarm::{ ConnectionHandler, ConnectionHandlerEvent, Stream, StreamUpgradeError, SubstreamProtocol, SupportedProtocols, }; -use std::collections::VecDeque; use std::task::Waker; use std::{error, fmt, io, marker::PhantomData, pin::Pin, task::Context, task::Poll}; +use void::Void; const MAX_NUM_SUBSTREAMS: usize = 32; @@ -62,12 +61,13 @@ pub struct Handler { /// List of active outbound substreams with the state they are in. outbound_substreams: SelectAll, - /// Number of outbound streams being upgraded right now. - num_requested_outbound_streams: usize, - /// List of outbound substreams that are waiting to become active next. /// Contains the request we want to send, and the user data if we expect an answer. - pending_messages: VecDeque<(KadRequestMsg, Option)>, + pending_streams: InflightProtocolDataQueue< + (KadRequestMsg, Option), + ProtocolConfig, + Result, StreamUpgradeError>, + >, /// List of active inbound substreams with the state they are in. inbound_substreams: SelectAll, @@ -293,7 +293,7 @@ pub enum HandlerEvent { #[derive(Debug)] pub enum HandlerQueryErr { /// Error while trying to perform the query. - Upgrade(StreamUpgradeError), + Upgrade(StreamUpgradeError), /// Received an answer that doesn't correspond to the request. UnexpectedMessage, /// I/O error in the substream. @@ -329,8 +329,8 @@ impl error::Error for HandlerQueryErr { } } -impl From> for HandlerQueryErr { - fn from(err: StreamUpgradeError) -> Self { +impl From> for HandlerQueryErr { + fn from(err: StreamUpgradeError) -> Self { HandlerQueryErr::Upgrade(err) } } @@ -481,40 +481,12 @@ impl Handler { next_connec_unique_id: UniqueConnecId(0), inbound_substreams: Default::default(), outbound_substreams: Default::default(), - num_requested_outbound_streams: 0, - pending_messages: Default::default(), + pending_streams: Default::default(), protocol_status: None, remote_supported_protocols: Default::default(), } } - fn on_fully_negotiated_outbound( - &mut self, - FullyNegotiatedOutbound { protocol, info: () }: FullyNegotiatedOutbound< - ::OutboundProtocol, - ::OutboundOpenInfo, - >, - ) { - if let Some((msg, query_id)) = self.pending_messages.pop_front() { - self.outbound_substreams - .push(OutboundSubstreamState::PendingSend(protocol, msg, query_id)); - } else { - debug_assert!(false, "Requested outbound stream without message") - } - - self.num_requested_outbound_streams -= 1; - - if self.protocol_status.is_none() { - // Upon the first successfully negotiated substream, we know that the - // remote is configured with the same protocol name and we want - // the behaviour to add this peer to the routing table, if possible. - self.protocol_status = Some(ProtocolStatus { - supported: true, - reported: false, - }); - } - } - fn on_fully_negotiated_inbound( &mut self, FullyNegotiatedInbound { protocol, .. }: FullyNegotiatedInbound< @@ -572,26 +544,6 @@ impl Handler { substream: protocol, }); } - - fn on_dial_upgrade_error( - &mut self, - DialUpgradeError { - info: (), error, .. - }: DialUpgradeError< - ::OutboundOpenInfo, - ::OutboundProtocol, - >, - ) { - // TODO: cache the fact that the remote doesn't support kademlia at all, so that we don't - // continue trying - - if let Some((_, Some(query_id))) = self.pending_messages.pop_front() { - self.outbound_substreams - .push(OutboundSubstreamState::ReportError(error.into(), query_id)); - } - - self.num_requested_outbound_streams -= 1; - } } impl ConnectionHandler for Handler { @@ -626,16 +578,20 @@ impl ConnectionHandler for Handler { } } HandlerIn::FindNodeReq { key, query_id } => { - let msg = KadRequestMsg::FindNode { key }; - self.pending_messages.push_back((msg, Some(query_id))); + self.pending_streams.enqueue_request( + self.protocol_config.clone(), + (KadRequestMsg::FindNode { key }, Some(query_id)), + ); } HandlerIn::FindNodeRes { closer_peers, request_id, } => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }), HandlerIn::GetProvidersReq { key, query_id } => { - let msg = KadRequestMsg::GetProviders { key }; - self.pending_messages.push_back((msg, Some(query_id))); + self.pending_streams.enqueue_request( + self.protocol_config.clone(), + (KadRequestMsg::GetProviders { key }, Some(query_id)), + ); } HandlerIn::GetProvidersRes { closer_peers, @@ -649,16 +605,22 @@ impl ConnectionHandler for Handler { }, ), HandlerIn::AddProvider { key, provider } => { - let msg = KadRequestMsg::AddProvider { key, provider }; - self.pending_messages.push_back((msg, None)); + self.pending_streams.enqueue_request( + self.protocol_config.clone(), + (KadRequestMsg::AddProvider { key, provider }, None), + ); } HandlerIn::GetRecord { key, query_id } => { - let msg = KadRequestMsg::GetValue { key }; - self.pending_messages.push_back((msg, Some(query_id))); + self.pending_streams.enqueue_request( + self.protocol_config.clone(), + (KadRequestMsg::GetValue { key }, Some(query_id)), + ); } HandlerIn::PutRecord { record, query_id } => { - let msg = KadRequestMsg::PutValue { record }; - self.pending_messages.push_back((msg, Some(query_id))); + self.pending_streams.enqueue_request( + self.protocol_config.clone(), + (KadRequestMsg::PutValue { record }, Some(query_id)), + ); } HandlerIn::GetRecordRes { record, @@ -712,44 +674,67 @@ impl ConnectionHandler for Handler { ) -> Poll< ConnectionHandlerEvent, > { - match &mut self.protocol_status { - Some(status) if !status.reported => { - status.reported = true; - let event = if status.supported { - HandlerEvent::ProtocolConfirmed { - endpoint: self.endpoint.clone(), - } - } else { - HandlerEvent::ProtocolNotSupported { - endpoint: self.endpoint.clone(), - } - }; + loop { + match &mut self.protocol_status { + Some(status) if !status.reported => { + status.reported = true; + let event = if status.supported { + HandlerEvent::ProtocolConfirmed { + endpoint: self.endpoint.clone(), + } + } else { + HandlerEvent::ProtocolNotSupported { + endpoint: self.endpoint.clone(), + } + }; - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); + } + _ => {} } - _ => {} - } - if let Poll::Ready(Some(event)) = self.outbound_substreams.poll_next_unpin(cx) { - return Poll::Ready(event); - } + if let Poll::Ready(Some(event)) = self.outbound_substreams.poll_next_unpin(cx) { + return Poll::Ready(event); + } - if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) { - return Poll::Ready(event); - } + if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) { + return Poll::Ready(event); + } - let num_in_progress_outbound_substreams = - self.outbound_substreams.len() + self.num_requested_outbound_streams; - if num_in_progress_outbound_substreams < MAX_NUM_SUBSTREAMS - && self.num_requested_outbound_streams < self.pending_messages.len() - { - self.num_requested_outbound_streams += 1; - return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(self.protocol_config.clone(), ()), - }); - } + match self.pending_streams.next_completed() { + Some((Ok(stream), (message, query_id))) => { + self.outbound_substreams + .push(OutboundSubstreamState::PendingSend( + stream, message, query_id, + )); + continue; + } + // TODO: Check if the remote doesn't support kademlia and stop trying if it doesn't + Some((Err(error), (_, Some(query_id)))) => { + self.outbound_substreams + .push(OutboundSubstreamState::ReportError(error.into(), query_id)); + continue; + } + Some((Err(error), (message, None))) => { + tracing::debug!(?message, "Failed to establish stream: {error}"); + continue; + } + None => {} + } + + let num_in_progress_outbound_substreams = + self.outbound_substreams.len() + self.pending_streams.num_inflight(); - Poll::Pending + if num_in_progress_outbound_substreams < MAX_NUM_SUBSTREAMS { + if let Some(next) = self.pending_streams.next_request() { + return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(next, ()), + }); + } + } + + return Poll::Pending; + } } fn on_connection_event( @@ -762,14 +747,24 @@ impl ConnectionHandler for Handler { >, ) { match event { - ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => { - self.on_fully_negotiated_outbound(fully_negotiated_outbound) + ConnectionEvent::FullyNegotiatedOutbound(ev) => { + self.pending_streams.submit_response(Ok(ev.protocol)); + + if self.protocol_status.is_none() { + // Upon the first successfully negotiated substream, we know that the + // remote is configured with the same protocol name and we want + // the behaviour to add this peer to the routing table, if possible. + self.protocol_status = Some(ProtocolStatus { + supported: true, + reported: false, + }); + } } ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => { self.on_fully_negotiated_inbound(fully_negotiated_inbound) } - ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { - self.on_dial_upgrade_error(dial_upgrade_error) + ConnectionEvent::DialUpgradeError(ev) => { + self.pending_streams.submit_response(Err(ev.error)); } ConnectionEvent::RemoteProtocolsChange(change) => { let dirty = self.remote_supported_protocols.on_protocols_change(change); diff --git a/protocols/kad/src/protocol.rs b/protocols/kad/src/protocol.rs index e3e2c09e42a..5240b60276a 100644 --- a/protocols/kad/src/protocol.rs +++ b/protocols/kad/src/protocol.rs @@ -40,6 +40,7 @@ use std::marker::PhantomData; use std::{convert::TryFrom, time::Duration}; use std::{io, iter}; use tracing::debug; +use void::Void; /// The protocol name used for negotiating with multistream-select. pub(crate) const DEFAULT_PROTO_NAME: StreamProtocol = StreamProtocol::new("/ipfs/kad/1.0.0"); @@ -220,8 +221,8 @@ where C: AsyncRead + AsyncWrite + Unpin, { type Output = KadInStreamSink; - type Future = future::Ready>; - type Error = io::Error; + type Future = future::Ready>; + type Error = Void; fn upgrade_inbound(self, incoming: C, _: Self::Info) -> Self::Future { let codec = Codec::new(self.max_packet_size); @@ -235,8 +236,8 @@ where C: AsyncRead + AsyncWrite + Unpin, { type Output = KadOutStreamSink; - type Future = future::Ready>; - type Error = io::Error; + type Future = future::Ready>; + type Error = Void; fn upgrade_outbound(self, incoming: C, _: Self::Info) -> Self::Future { let codec = Codec::new(self.max_packet_size);