Skip to content

Commit

Permalink
typesafe request IDs
Browse files Browse the repository at this point in the history
  • Loading branch information
oblique committed Oct 23, 2023
1 parent 1dbb5bc commit 845b50e
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 82 deletions.
4 changes: 2 additions & 2 deletions examples/file-sharing/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use libp2p::{
identity, kad,
multiaddr::Protocol,
noise,
request_response::{self, ProtocolSupport, RequestId, ResponseChannel},
request_response::{self, OutboundRequestId, ProtocolSupport, ResponseChannel},
swarm::{NetworkBehaviour, Swarm, SwarmEvent},
tcp, yamux, PeerId,
};
Expand Down Expand Up @@ -175,7 +175,7 @@ pub(crate) struct EventLoop {
pending_start_providing: HashMap<kad::QueryId, oneshot::Sender<()>>,
pending_get_providers: HashMap<kad::QueryId, oneshot::Sender<HashSet<PeerId>>>,
pending_request_file:
HashMap<RequestId, oneshot::Sender<Result<Vec<u8>, Box<dyn Error + Send>>>>,
HashMap<OutboundRequestId, oneshot::Sender<Result<Vec<u8>, Box<dyn Error + Send>>>>,
}

impl EventLoop {
Expand Down
6 changes: 3 additions & 3 deletions protocols/autonat/src/behaviour.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use instant::Instant;
use libp2p_core::{multiaddr::Protocol, ConnectedPoint, Endpoint, Multiaddr};
use libp2p_identity::PeerId;
use libp2p_request_response::{
self as request_response, ProtocolSupport, RequestId, ResponseChannel,
self as request_response, InboundRequestId, OutboundRequestId, ProtocolSupport, ResponseChannel,
};
use libp2p_swarm::{
behaviour::{
Expand Down Expand Up @@ -187,14 +187,14 @@ pub struct Behaviour {
PeerId,
(
ProbeId,
RequestId,
InboundRequestId,
Vec<Multiaddr>,
ResponseChannel<DialResponse>,
),
>,

// Ongoing outbound probes and mapped to the inner request id.
ongoing_outbound: HashMap<RequestId, ProbeId>,
ongoing_outbound: HashMap<OutboundRequestId, ProbeId>,

// Connected peers with the observed address of each connection.
// If the endpoint of a connection is relayed or not global (in case of Config::only_global_ips),
Expand Down
6 changes: 3 additions & 3 deletions protocols/autonat/src/behaviour/as_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use futures_timer::Delay;
use instant::Instant;
use libp2p_core::Multiaddr;
use libp2p_identity::PeerId;
use libp2p_request_response::{self as request_response, OutboundFailure, RequestId};
use libp2p_request_response::{self as request_response, OutboundFailure, OutboundRequestId};
use libp2p_swarm::{ConnectionId, ListenAddresses, PollParameters, ToSwarm};
use rand::{seq::SliceRandom, thread_rng};
use std::{
Expand Down Expand Up @@ -91,7 +91,7 @@ pub(crate) struct AsClient<'a> {
pub(crate) throttled_servers: &'a mut Vec<(PeerId, Instant)>,
pub(crate) nat_status: &'a mut NatStatus,
pub(crate) confidence: &'a mut usize,
pub(crate) ongoing_outbound: &'a mut HashMap<RequestId, ProbeId>,
pub(crate) ongoing_outbound: &'a mut HashMap<OutboundRequestId, ProbeId>,
pub(crate) last_probe: &'a mut Option<Instant>,
pub(crate) schedule_probe: &'a mut Delay,
pub(crate) listen_addresses: &'a ListenAddresses,
Expand All @@ -118,7 +118,7 @@ impl<'a> HandleInnerEvent for AsClient<'a> {
let probe_id = self
.ongoing_outbound
.remove(&request_id)
.expect("RequestId exists.");
.expect("OutboundRequestId exists.");

let event = match response.result.clone() {
Ok(address) => OutboundProbeEvent::Response {
Expand Down
4 changes: 2 additions & 2 deletions protocols/autonat/src/behaviour/as_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use instant::Instant;
use libp2p_core::{multiaddr::Protocol, Multiaddr};
use libp2p_identity::PeerId;
use libp2p_request_response::{
self as request_response, InboundFailure, RequestId, ResponseChannel,
self as request_response, InboundFailure, InboundRequestId, ResponseChannel,
};
use libp2p_swarm::{
dial_opts::{DialOpts, PeerCondition},
Expand Down Expand Up @@ -85,7 +85,7 @@ pub(crate) struct AsServer<'a> {
PeerId,
(
ProbeId,
RequestId,
InboundRequestId,
Vec<Multiaddr>,
ResponseChannel<DialResponse>,
),
Expand Down
6 changes: 3 additions & 3 deletions protocols/perf/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ use crate::{protocol::Response, RunDuration, RunParams};

/// Connection identifier.
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub struct RunId(request_response::RequestId);
pub struct RunId(request_response::OutboundRequestId);

impl From<request_response::RequestId> for RunId {
fn from(value: request_response::RequestId) -> Self {
impl From<request_response::OutboundRequestId> for RunId {
fn from(value: request_response::OutboundRequestId) -> Self {
Self(value)
}
}
Expand Down
14 changes: 9 additions & 5 deletions protocols/rendezvous/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use futures::stream::FuturesUnordered;
use futures::stream::StreamExt;
use libp2p_core::{Endpoint, Multiaddr, PeerRecord};
use libp2p_identity::{Keypair, PeerId, SigningError};
use libp2p_request_response::{ProtocolSupport, RequestId};
use libp2p_request_response::{OutboundRequestId, ProtocolSupport};
use libp2p_swarm::{
ConnectionDenied, ConnectionId, ExternalAddresses, FromSwarm, NetworkBehaviour, PollParameters,
THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
Expand All @@ -41,8 +41,8 @@ pub struct Behaviour {

keypair: Keypair,

waiting_for_register: HashMap<RequestId, (PeerId, Namespace)>,
waiting_for_discovery: HashMap<RequestId, (PeerId, Option<Namespace>)>,
waiting_for_register: HashMap<OutboundRequestId, (PeerId, Namespace)>,
waiting_for_discovery: HashMap<OutboundRequestId, (PeerId, Option<Namespace>)>,

/// Hold addresses of all peers that we have discovered so far.
///
Expand Down Expand Up @@ -337,7 +337,7 @@ impl NetworkBehaviour for Behaviour {
}

impl Behaviour {
fn event_for_outbound_failure(&mut self, req_id: &RequestId) -> Option<Event> {
fn event_for_outbound_failure(&mut self, req_id: &OutboundRequestId) -> Option<Event> {
if let Some((rendezvous_node, namespace)) = self.waiting_for_register.remove(req_id) {
return Some(Event::RegisterFailed {
rendezvous_node,
Expand All @@ -357,7 +357,11 @@ impl Behaviour {
None
}

fn handle_response(&mut self, request_id: &RequestId, response: Message) -> Option<Event> {
fn handle_response(
&mut self,
request_id: &OutboundRequestId,
response: Message,
) -> Option<Event> {
match response {
RegisterResponse(Ok(ttl)) => {
if let Some((rendezvous_node, namespace)) =
Expand Down
70 changes: 37 additions & 33 deletions protocols/request-response/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub use protocol::ProtocolSupport;

use crate::codec::Codec;
use crate::handler::protocol::Protocol;
use crate::{RequestId, EMPTY_QUEUE_SHRINK_THRESHOLD};
use crate::{InboundRequestId, OutboundRequestId, EMPTY_QUEUE_SHRINK_THRESHOLD};

use futures::channel::mpsc;
use futures::{channel::oneshot, prelude::*};
Expand Down Expand Up @@ -67,27 +67,26 @@ where
requested_outbound: VecDeque<OutboundMessage<TCodec>>,
/// A channel for receiving inbound requests.
inbound_receiver: mpsc::Receiver<(
RequestId,
InboundRequestId,
TCodec::Request,
oneshot::Sender<TCodec::Response>,
)>,
/// The [`mpsc::Sender`] for the above receiver. Cloned for each inbound request.
inbound_sender: mpsc::Sender<(
RequestId,
InboundRequestId,
TCodec::Request,
oneshot::Sender<TCodec::Response>,
)>,

inbound_request_id: Arc<AtomicU64>,

worker_streams:
futures_bounded::FuturesMap<(RequestId, Direction), Result<Event<TCodec>, io::Error>>,
worker_streams: futures_bounded::FuturesMap<RequestId, Result<Event<TCodec>, io::Error>>,
}

#[derive(Clone, Copy, PartialEq, Eq, Hash)]
enum Direction {
Inbound,
Outbound,
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum RequestId {
Inbound(InboundRequestId),
Outbound(OutboundRequestId),
}

impl<TCodec> Handler<TCodec>
Expand Down Expand Up @@ -130,7 +129,7 @@ where
>,
) {
let mut codec = self.codec.clone();
let request_id = RequestId(self.inbound_request_id.fetch_add(1, Ordering::Relaxed));
let request_id = InboundRequestId(self.inbound_request_id.fetch_add(1, Ordering::Relaxed));
let mut sender = self.inbound_sender.clone();

let recv = async move {
Expand Down Expand Up @@ -160,7 +159,7 @@ where

if self
.worker_streams
.try_push((request_id, Direction::Inbound), recv.boxed())
.try_push(RequestId::Inbound(request_id), recv.boxed())
.is_ok()
{
self.pending_events
Expand Down Expand Up @@ -203,7 +202,7 @@ where

if self
.worker_streams
.try_push((request_id, Direction::Outbound), send.boxed())
.try_push(RequestId::Outbound(request_id), send.boxed())
.is_err()
{
log::warn!("Dropping outbound stream because we are at capacity")
Expand Down Expand Up @@ -263,34 +262,34 @@ where
TCodec: Codec,
{
/// A request is going to be received.
IncomingRequest { request_id: RequestId },
IncomingRequest { request_id: InboundRequestId },
/// A request has been received.
Request {
request_id: RequestId,
request_id: InboundRequestId,
request: TCodec::Request,
sender: oneshot::Sender<TCodec::Response>,
},
/// A response has been received.
Response {
request_id: RequestId,
request_id: OutboundRequestId,
response: TCodec::Response,
},
/// A response to an inbound request has been sent.
ResponseSent(RequestId),
ResponseSent(InboundRequestId),
/// A response to an inbound request was omitted as a result
/// of dropping the response `sender` of an inbound `Request`.
ResponseOmission(RequestId),
ResponseOmission(InboundRequestId),
/// An outbound request timed out while sending the request
/// or waiting for the response.
OutboundTimeout(RequestId),
OutboundTimeout(OutboundRequestId),
/// An outbound request failed to negotiate a mutually supported protocol.
OutboundUnsupportedProtocols(RequestId),
OutboundUnsupportedProtocols(OutboundRequestId),
OutboundStreamFailed {
request_id: RequestId,
request_id: OutboundRequestId,
error: io::Error,
},
InboundStreamFailed {
request_id: RequestId,
request_id: InboundRequestId,
error: io::Error,
},
}
Expand Down Expand Up @@ -348,7 +347,7 @@ impl<TCodec: Codec> fmt::Debug for Event<TCodec> {
}

pub struct OutboundMessage<TCodec: Codec> {
pub(crate) request_id: RequestId,
pub(crate) request_id: OutboundRequestId,
pub(crate) request: TCodec::Request,
pub(crate) protocols: SmallVec<[TCodec::Protocol; 2]>,
}
Expand Down Expand Up @@ -414,15 +413,15 @@ where
Poll::Ready((_, Ok(Ok(event)))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
}
Poll::Ready(((id, direction), Ok(Err(e)))) => {
log::debug!("Stream for request {id} failed: {e}");
Poll::Ready((id, Ok(Err(e)))) => {
log::debug!("Stream for request {id:?} failed: {e}");

let event = match direction {
Direction::Inbound => Event::InboundStreamFailed {
let event = match id {
RequestId::Inbound(id) => Event::InboundStreamFailed {
request_id: id,
error: e,
},
Direction::Outbound => Event::OutboundStreamFailed {
RequestId::Outbound(id) => Event::OutboundStreamFailed {
request_id: id,
error: e,
},
Expand All @@ -433,12 +432,17 @@ where
// should be forwarded to the upper layer.
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
}
Poll::Ready(((id, direction), Err(futures_bounded::Timeout { .. }))) => {
log::debug!("Stream for request {id} timed out");

if direction == Direction::Outbound {
let event = Event::OutboundTimeout(id);
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
Poll::Ready((id, Err(futures_bounded::Timeout { .. }))) => {
log::debug!("Stream for request {id:?} timed out");

match id {
RequestId::Inbound(_id) => {
// TODO
}
RequestId::Outbound(id) => {
let event = Event::OutboundTimeout(id);
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
}
}
}
Poll::Pending => break,
Expand Down
Loading

0 comments on commit 845b50e

Please sign in to comment.