diff --git a/Cargo.lock b/Cargo.lock index 42378f90fe0..98f2503fa2f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3234,6 +3234,22 @@ dependencies = [ "zeroize", ] +[[package]] +name = "libp2p-stream" +version = "0.1.0-alpha" +dependencies = [ + "futures", + "libp2p-core", + "libp2p-identity", + "libp2p-swarm", + "libp2p-swarm-test", + "rand 0.8.5", + "tokio", + "tracing", + "tracing-subscriber", + "void", +] + [[package]] name = "libp2p-swarm" version = "0.44.1" @@ -5565,6 +5581,20 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "stream-example" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures", + "libp2p", + "libp2p-stream", + "rand 0.8.5", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "stringmatch" version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index d10ed7e3bbf..9215d68dd9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ members = [ "examples/ping", "examples/relay-server", "examples/rendezvous", + "examples/stream", "examples/upnp", "hole-punching-tests", "identity", @@ -45,10 +46,11 @@ members = [ "protocols/relay", "protocols/rendezvous", "protocols/request-response", + "protocols/stream", "protocols/upnp", - "swarm", "swarm-derive", "swarm-test", + "swarm", "transports/dns", "transports/noise", "transports/plaintext", @@ -57,11 +59,11 @@ members = [ "transports/tcp", "transports/tls", "transports/uds", - "transports/webrtc", "transports/webrtc-websys", + "transports/webrtc", + "transports/websocket-websys", "transports/websocket", "transports/webtransport-websys", - "transports/websocket-websys", "wasm-tests/webtransport-tests", ] resolver = "2" @@ -99,6 +101,7 @@ libp2p-relay = { version = "0.17.1", path = "protocols/relay" } libp2p-rendezvous = { version = "0.14.0", path = "protocols/rendezvous" } libp2p-request-response = { version = "0.26.1", path = "protocols/request-response" } libp2p-server = { version = "0.12.5", path = "misc/server" } +libp2p-stream = { version = "0.1.0-alpha", path = "protocols/stream" } libp2p-swarm = { version = "0.44.1", path = "swarm" } libp2p-swarm-derive = { version = "=0.34.2", path = "swarm-derive" } # `libp2p-swarm-derive` may not be compatible with different `libp2p-swarm` non-breaking releases. E.g. `libp2p-swarm` might introduce a new enum variant `FromSwarm` (which is `#[non-exhaustive]`) in a non-breaking release. Older versions of `libp2p-swarm-derive` would not forward this enum variant within the `NetworkBehaviour` hierarchy. Thus the version pinning is required. libp2p-swarm-test = { version = "0.3.0", path = "swarm-test" } diff --git a/examples/stream/Cargo.toml b/examples/stream/Cargo.toml new file mode 100644 index 00000000000..5aab488358b --- /dev/null +++ b/examples/stream/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "stream-example" +version = "0.1.0" +edition = "2021" +publish = false +license = "MIT" + +[package.metadata.release] +release = false + +[dependencies] +anyhow = "1" +futures = "0.3.29" +libp2p = { path = "../../libp2p", features = [ "tokio", "quic"] } +libp2p-stream = { path = "../../protocols/stream", version = "0.1.0-alpha" } +rand = "0.8" +tokio = { version = "1.35", features = ["full"] } +tracing = "0.1.37" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +[lints] +workspace = true diff --git a/examples/stream/README.md b/examples/stream/README.md new file mode 100644 index 00000000000..8437a5ea21e --- /dev/null +++ b/examples/stream/README.md @@ -0,0 +1,35 @@ +## Description + +This example shows the usage of the `stream::Behaviour`. +As a counter-part to the `request_response::Behaviour`, the `stream::Behaviour` allows users to write stream-oriented protocols whilst having minimal interaction with the `Swarm`. + +In this showcase, we implement an echo protocol: All incoming data is echoed back to the dialer, until the stream is closed. + +## Usage + +To run the example, follow these steps: + +1. Start an instance of the example in one terminal: + + ```sh + cargo run --bin stream-example + ``` + + Observe printed listen address. + +2. Start another instance in a new terminal, providing the listen address of the first one. + + ```sh + cargo run --bin stream-example --
+ ``` + +3. Both terminals should now continuosly print messages. + +## Conclusion + +The `stream::Behaviour` is an "escape-hatch" from the way typical rust-libp2p protocols are written. +It is suitable for several scenarios including: + +- prototyping of new protocols +- experimentation with rust-libp2p +- integration in `async/await`-heavy applications \ No newline at end of file diff --git a/examples/stream/src/main.rs b/examples/stream/src/main.rs new file mode 100644 index 00000000000..872ab8c3b98 --- /dev/null +++ b/examples/stream/src/main.rs @@ -0,0 +1,154 @@ +use std::{io, time::Duration}; + +use anyhow::{Context, Result}; +use futures::{AsyncReadExt, AsyncWriteExt, StreamExt}; +use libp2p::{multiaddr::Protocol, Multiaddr, PeerId, Stream, StreamProtocol}; +use libp2p_stream as stream; +use rand::RngCore; +use tracing::level_filters::LevelFilter; +use tracing_subscriber::EnvFilter; + +const ECHO_PROTOCOL: StreamProtocol = StreamProtocol::new("/echo"); + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env()?, + ) + .init(); + + let maybe_address = std::env::args() + .nth(1) + .map(|arg| arg.parse::()) + .transpose() + .context("Failed to parse argument as `Multiaddr`")?; + + let mut swarm = libp2p::SwarmBuilder::with_new_identity() + .with_tokio() + .with_quic() + .with_behaviour(|_| stream::Behaviour::new())? + .with_swarm_config(|c| c.with_idle_connection_timeout(Duration::from_secs(10))) + .build(); + + swarm.listen_on("/ip4/127.0.0.1/udp/0/quic-v1".parse()?)?; + + let mut incoming_streams = swarm + .behaviour() + .new_control() + .accept(ECHO_PROTOCOL) + .unwrap(); + + // Deal with incoming streams. + // Spawning a dedicated task is just one way of doing this. + // libp2p doesn't care how you handle incoming streams but you _must_ handle them somehow. + // To mitigate DoS attacks, libp2p will internally drop incoming streams if your application cannot keep up processing them. + tokio::spawn(async move { + // This loop handles incoming streams _sequentially_ but that doesn't have to be the case. + // You can also spawn a dedicated task per stream if you want to. + // Be aware that this breaks backpressure though as spawning new tasks is equivalent to an unbounded buffer. + // Each task needs memory meaning an aggressive remote peer may force you OOM this way. + + while let Some((peer, stream)) = incoming_streams.next().await { + match echo(stream).await { + Ok(n) => { + tracing::info!(%peer, "Echoed {n} bytes!"); + } + Err(e) => { + tracing::warn!(%peer, "Echo failed: {e}"); + continue; + } + }; + } + }); + + // In this demo application, the dialing peer initiates the protocol. + if let Some(address) = maybe_address { + let Some(Protocol::P2p(peer_id)) = address.iter().last() else { + anyhow::bail!("Provided address does not end in `/p2p`"); + }; + + swarm.dial(address)?; + + tokio::spawn(connection_handler(peer_id, swarm.behaviour().new_control())); + } + + // Poll the swarm to make progress. + loop { + let event = swarm.next().await.expect("never terminates"); + + match event { + libp2p::swarm::SwarmEvent::NewListenAddr { address, .. } => { + let listen_address = address.with_p2p(*swarm.local_peer_id()).unwrap(); + tracing::info!(%listen_address); + } + event => tracing::trace!(?event), + } + } +} + +/// A very simple, `async fn`-based connection handler for our custom echo protocol. +async fn connection_handler(peer: PeerId, mut control: stream::Control) { + loop { + tokio::time::sleep(Duration::from_secs(1)).await; // Wait a second between echos. + + let stream = match control.open_stream(peer, ECHO_PROTOCOL).await { + Ok(stream) => stream, + Err(error @ stream::OpenStreamError::UnsupportedProtocol(_)) => { + tracing::info!(%peer, %error); + return; + } + Err(error) => { + // Other errors may be temporary. + // In production, something like an exponential backoff / circuit-breaker may be more appropriate. + tracing::debug!(%peer, %error); + continue; + } + }; + + if let Err(e) = send(stream).await { + tracing::warn!(%peer, "Echo protocol failed: {e}"); + continue; + } + + tracing::info!(%peer, "Echo complete!") + } +} + +async fn echo(mut stream: Stream) -> io::Result { + let mut total = 0; + + let mut buf = [0u8; 100]; + + loop { + let read = stream.read(&mut buf).await?; + if read == 0 { + return Ok(total); + } + + total += read; + stream.write_all(&buf[..read]).await?; + } +} + +async fn send(mut stream: Stream) -> io::Result<()> { + let num_bytes = rand::random::() % 1000; + + let mut bytes = vec![0; num_bytes]; + rand::thread_rng().fill_bytes(&mut bytes); + + stream.write_all(&bytes).await?; + + let mut buf = vec![0; num_bytes]; + stream.read_exact(&mut buf).await?; + + if bytes != buf { + return Err(io::Error::new(io::ErrorKind::Other, "incorrect echo")); + } + + stream.close().await?; + + Ok(()) +} diff --git a/protocols/stream/CHANGELOG.md b/protocols/stream/CHANGELOG.md new file mode 100644 index 00000000000..2e177e2f1bc --- /dev/null +++ b/protocols/stream/CHANGELOG.md @@ -0,0 +1,3 @@ +## 0.1.0-alpha + +Initial release. diff --git a/protocols/stream/Cargo.toml b/protocols/stream/Cargo.toml new file mode 100644 index 00000000000..be340939720 --- /dev/null +++ b/protocols/stream/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "libp2p-stream" +version = "0.1.0-alpha" +edition = "2021" +rust-version.workspace = true +description = "Generic stream protocols for libp2p" +license = "MIT" +repository = "https://github.com/libp2p/rust-libp2p" +keywords = ["peer-to-peer", "libp2p", "networking"] +categories = ["network-programming", "asynchronous"] + +[dependencies] +futures = "0.3.29" +libp2p-core = { workspace = true } +libp2p-identity = { workspace = true, features = ["peerid"] } +libp2p-swarm = { workspace = true } +tracing = "0.1.37" +void = "1" +rand = "0.8" + +[dev-dependencies] +libp2p-swarm-test = { workspace = true } +tokio = { version = "1", features = ["full"] } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +[lints] +workspace = true diff --git a/protocols/stream/README.md b/protocols/stream/README.md new file mode 100644 index 00000000000..c8a56e119ca --- /dev/null +++ b/protocols/stream/README.md @@ -0,0 +1,69 @@ +# Generic (stream) protocols + +This module provides a generic [`NetworkBehaviour`](libp2p_swarm::NetworkBehaviour) for stream-oriented protocols. +Streams are the fundamental primitive of libp2p and all other protocols are implemented using streams. +In contrast to other [`NetworkBehaviour`](libp2p_swarm::NetworkBehaviour)s, this module takes a different design approach. +All interaction happens through a [`Control`] that can be obtained via [`Behaviour::new_control`]. +[`Control`]s can be cloned and thus shared across your application. + +## Inbound + +To accept streams for a particular [`StreamProtocol`](libp2p_swarm::StreamProtocol) using this module, use [`Control::accept`]: + +### Example + +```rust,no_run +# fn main() { +# use libp2p_swarm::{Swarm, StreamProtocol}; +# use libp2p_stream as stream; +# use futures::StreamExt as _; +let mut swarm: Swarm = todo!(); + +let mut control = swarm.behaviour().new_control(); +let mut incoming = control.accept(StreamProtocol::new("/my-protocol")).unwrap(); + +let handler_future = async move { + while let Some((peer, stream)) = incoming.next().await { + // Execute your protocol using `stream`. + } +}; +# } +``` + +### Resource management + +[`Control::accept`] returns you an instance of [`IncomingStreams`]. +This struct implements [`Stream`](futures::Stream) and like other streams, is lazy. +You must continuously poll it to make progress. +In the example above, this taken care of by using the [`StreamExt::next`](futures::StreamExt::next) helper. + +Internally, we will drop streams if your application falls behind in processing these incoming streams, i.e. if whatever loop calls `.next()` is not fast enough. + +### Drop + +As soon as you drop [`IncomingStreams`], the protocol will be de-registered. +Any further attempt by remote peers to open a stream using the provided protocol will result in a negotiation error. + +## Outbound + +To open a new outbound stream for a particular protocol, use [`Control::open_stream`]. + +### Example + +```rust,no_run +# fn main() { +# use libp2p_swarm::{Swarm, StreamProtocol}; +# use libp2p_stream as stream; +# use libp2p_identity::PeerId; +let mut swarm: Swarm = todo!(); +let peer_id: PeerId = todo!(); + +let mut control = swarm.behaviour().new_control(); + +let protocol_future = async move { + let stream = control.open_stream(peer_id, StreamProtocol::new("/my-protocol")).await.unwrap(); + + // Execute your protocol here using `stream`. +}; +# } +``` \ No newline at end of file diff --git a/protocols/stream/src/behaviour.rs b/protocols/stream/src/behaviour.rs new file mode 100644 index 00000000000..e02aca884b7 --- /dev/null +++ b/protocols/stream/src/behaviour.rs @@ -0,0 +1,143 @@ +use core::fmt; +use std::{ + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use futures::{channel::mpsc, StreamExt}; +use libp2p_core::{Endpoint, Multiaddr}; +use libp2p_identity::PeerId; +use libp2p_swarm::{ + self as swarm, dial_opts::DialOpts, ConnectionDenied, ConnectionId, FromSwarm, + NetworkBehaviour, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm, +}; +use swarm::{ + behaviour::ConnectionEstablished, dial_opts::PeerCondition, ConnectionClosed, DialError, + DialFailure, +}; + +use crate::{handler::Handler, shared::Shared, Control}; + +/// A generic behaviour for stream-oriented protocols. +pub struct Behaviour { + shared: Arc>, + dial_receiver: mpsc::Receiver, +} + +impl Default for Behaviour { + fn default() -> Self { + Self::new() + } +} + +impl Behaviour { + pub fn new() -> Self { + let (dial_sender, dial_receiver) = mpsc::channel(0); + + Self { + shared: Arc::new(Mutex::new(Shared::new(dial_sender))), + dial_receiver, + } + } + + /// Obtain a new [`Control`]. + pub fn new_control(&self) -> Control { + Control::new(self.shared.clone()) + } +} + +/// The protocol is already registered. +#[derive(Debug)] +pub struct AlreadyRegistered; + +impl fmt::Display for AlreadyRegistered { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "The protocol is already registered") + } +} + +impl std::error::Error for AlreadyRegistered {} + +impl NetworkBehaviour for Behaviour { + type ConnectionHandler = Handler; + type ToSwarm = (); + + fn handle_established_inbound_connection( + &mut self, + connection_id: ConnectionId, + peer: PeerId, + _: &Multiaddr, + _: &Multiaddr, + ) -> Result, ConnectionDenied> { + Ok(Handler::new( + peer, + self.shared.clone(), + Shared::lock(&self.shared).receiver(peer, connection_id), + )) + } + + fn handle_established_outbound_connection( + &mut self, + connection_id: ConnectionId, + peer: PeerId, + _: &Multiaddr, + _: Endpoint, + ) -> Result, ConnectionDenied> { + Ok(Handler::new( + peer, + self.shared.clone(), + Shared::lock(&self.shared).receiver(peer, connection_id), + )) + } + + fn on_swarm_event(&mut self, event: FromSwarm) { + match event { + FromSwarm::ConnectionEstablished(ConnectionEstablished { + peer_id, + connection_id, + .. + }) => Shared::lock(&self.shared).on_connection_established(connection_id, peer_id), + FromSwarm::ConnectionClosed(ConnectionClosed { connection_id, .. }) => { + Shared::lock(&self.shared).on_connection_closed(connection_id) + } + FromSwarm::DialFailure(DialFailure { + peer_id: Some(peer_id), + error: + error @ (DialError::Transport(_) + | DialError::Denied { .. } + | DialError::NoAddresses + | DialError::WrongPeerId { .. }), + .. + }) => { + let reason = error.to_string(); // We can only forward the string repr but it is better than nothing. + + Shared::lock(&self.shared).on_dial_failure(peer_id, reason) + } + _ => {} + } + } + + fn on_connection_handler_event( + &mut self, + _peer_id: PeerId, + _connection_id: ConnectionId, + event: THandlerOutEvent, + ) { + void::unreachable(event); + } + + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + if let Poll::Ready(Some(peer)) = self.dial_receiver.poll_next_unpin(cx) { + return Poll::Ready(ToSwarm::Dial { + opts: DialOpts::peer_id(peer) + .condition(PeerCondition::DisconnectedAndNotDialing) + .build(), + }); + } + + Poll::Pending + } +} diff --git a/protocols/stream/src/control.rs b/protocols/stream/src/control.rs new file mode 100644 index 00000000000..6aabaaff30e --- /dev/null +++ b/protocols/stream/src/control.rs @@ -0,0 +1,124 @@ +use core::fmt; +use std::{ + io, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use crate::AlreadyRegistered; +use crate::{handler::NewStream, shared::Shared}; + +use futures::{ + channel::{mpsc, oneshot}, + SinkExt as _, StreamExt as _, +}; +use libp2p_identity::PeerId; +use libp2p_swarm::{Stream, StreamProtocol}; + +/// A (remote) control for opening new streams and registration of inbound protocols. +/// +/// A [`Control`] can be cloned and thus allows for concurrent access. +#[derive(Clone)] +pub struct Control { + shared: Arc>, +} + +impl Control { + pub(crate) fn new(shared: Arc>) -> Self { + Self { shared } + } + + /// Attempt to open a new stream for the given protocol and peer. + /// + /// In case we are currently not connected to the peer, we will attempt to make a new connection. + /// + /// ## Backpressure + /// + /// [`Control`]s support backpressure similarly to bounded channels: + /// Each [`Control`] has a guaranteed slot for internal messages. + /// A single control will always open one stream at a time which is enforced by requiring `&mut self`. + /// + /// This backpressure mechanism breaks if you clone [`Control`]s excessively. + pub async fn open_stream( + &mut self, + peer: PeerId, + protocol: StreamProtocol, + ) -> Result { + tracing::debug!(%peer, "Requesting new stream"); + + let mut new_stream_sender = Shared::lock(&self.shared).sender(peer); + + let (sender, receiver) = oneshot::channel(); + + new_stream_sender + .send(NewStream { protocol, sender }) + .await + .map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e))?; + + let stream = receiver + .await + .map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e))??; + + Ok(stream) + } + + /// Accept inbound streams for the provided protocol. + /// + /// To stop accepting streams, simply drop the returned [`IncomingStreams`] handle. + pub fn accept( + &mut self, + protocol: StreamProtocol, + ) -> Result { + Shared::lock(&self.shared).accept(protocol) + } +} + +/// Errors while opening a new stream. +#[derive(Debug)] +#[non_exhaustive] +pub enum OpenStreamError { + /// The remote does not support the requested protocol. + UnsupportedProtocol(StreamProtocol), + /// IO Error that occurred during the protocol handshake. + Io(std::io::Error), +} + +impl From for OpenStreamError { + fn from(v: std::io::Error) -> Self { + Self::Io(v) + } +} + +impl fmt::Display for OpenStreamError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + OpenStreamError::UnsupportedProtocol(p) => { + write!(f, "failed to open stream: remote peer does not support {p}") + } + OpenStreamError::Io(e) => { + write!(f, "failed to open stream: io error: {e}") + } + } + } +} + +/// A handle to inbound streams for a particular protocol. +#[must_use = "Streams do nothing unless polled."] +pub struct IncomingStreams { + receiver: mpsc::Receiver<(PeerId, Stream)>, +} + +impl IncomingStreams { + pub(crate) fn new(receiver: mpsc::Receiver<(PeerId, Stream)>) -> Self { + Self { receiver } + } +} + +impl futures::Stream for IncomingStreams { + type Item = (PeerId, Stream); + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.receiver.poll_next_unpin(cx) + } +} diff --git a/protocols/stream/src/handler.rs b/protocols/stream/src/handler.rs new file mode 100644 index 00000000000..f63b93c1761 --- /dev/null +++ b/protocols/stream/src/handler.rs @@ -0,0 +1,165 @@ +use std::{ + io, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use futures::{ + channel::{mpsc, oneshot}, + StreamExt as _, +}; +use libp2p_identity::PeerId; +use libp2p_swarm::{ + self as swarm, + handler::{ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound}, + ConnectionHandler, Stream, StreamProtocol, +}; + +use crate::{shared::Shared, upgrade::Upgrade, OpenStreamError}; + +pub struct Handler { + remote: PeerId, + shared: Arc>, + + receiver: mpsc::Receiver, + pending_upgrade: Option<( + StreamProtocol, + oneshot::Sender>, + )>, +} + +impl Handler { + pub(crate) fn new( + remote: PeerId, + shared: Arc>, + receiver: mpsc::Receiver, + ) -> Self { + Self { + shared, + receiver, + pending_upgrade: None, + remote, + } + } +} + +impl ConnectionHandler for Handler { + type FromBehaviour = void::Void; + type ToBehaviour = void::Void; + type InboundProtocol = Upgrade; + type OutboundProtocol = Upgrade; + type InboundOpenInfo = (); + type OutboundOpenInfo = (); + + fn listen_protocol( + &self, + ) -> swarm::SubstreamProtocol { + swarm::SubstreamProtocol::new( + Upgrade { + supported_protocols: Shared::lock(&self.shared).supported_inbound_protocols(), + }, + (), + ) + } + + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll< + swarm::ConnectionHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::ToBehaviour, + >, + > { + if self.pending_upgrade.is_some() { + return Poll::Pending; + } + + match self.receiver.poll_next_unpin(cx) { + Poll::Ready(Some(new_stream)) => { + self.pending_upgrade = Some((new_stream.protocol.clone(), new_stream.sender)); + return Poll::Ready(swarm::ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: swarm::SubstreamProtocol::new( + Upgrade { + supported_protocols: vec![new_stream.protocol], + }, + (), + ), + }); + } + Poll::Ready(None) => {} // Sender is gone, no more work to do. + Poll::Pending => {} + } + + Poll::Pending + } + + fn on_behaviour_event(&mut self, event: Self::FromBehaviour) { + void::unreachable(event) + } + + fn on_connection_event( + &mut self, + event: ConnectionEvent< + Self::InboundProtocol, + Self::OutboundProtocol, + Self::InboundOpenInfo, + Self::OutboundOpenInfo, + >, + ) { + match event { + ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound { + protocol: (stream, protocol), + info: (), + }) => { + Shared::lock(&self.shared).on_inbound_stream(self.remote, stream, protocol); + } + ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound { + protocol: (stream, actual_protocol), + info: (), + }) => { + let Some((expected_protocol, sender)) = self.pending_upgrade.take() else { + debug_assert!( + false, + "Negotiated an outbound stream without a back channel" + ); + return; + }; + debug_assert_eq!(expected_protocol, actual_protocol); + + let _ = sender.send(Ok(stream)); + } + ConnectionEvent::DialUpgradeError(DialUpgradeError { error, info: () }) => { + let Some((p, sender)) = self.pending_upgrade.take() else { + debug_assert!( + false, + "Received a `DialUpgradeError` without a back channel" + ); + return; + }; + + let error = match error { + swarm::StreamUpgradeError::Timeout => { + OpenStreamError::Io(io::Error::from(io::ErrorKind::TimedOut)) + } + swarm::StreamUpgradeError::Apply(v) => void::unreachable(v), + swarm::StreamUpgradeError::NegotiationFailed => { + OpenStreamError::UnsupportedProtocol(p) + } + swarm::StreamUpgradeError::Io(io) => OpenStreamError::Io(io), + }; + + let _ = sender.send(Err(error)); + } + _ => {} + } + } +} + +/// Message from a [`Control`](crate::Control) to a [`ConnectionHandler`] to negotiate a new outbound stream. +#[derive(Debug)] +pub(crate) struct NewStream { + pub(crate) protocol: StreamProtocol, + pub(crate) sender: oneshot::Sender>, +} diff --git a/protocols/stream/src/lib.rs b/protocols/stream/src/lib.rs new file mode 100644 index 00000000000..d498a1b71e5 --- /dev/null +++ b/protocols/stream/src/lib.rs @@ -0,0 +1,10 @@ +#![doc = include_str!("../README.md")] + +mod behaviour; +mod control; +mod handler; +mod shared; +mod upgrade; + +pub use behaviour::{AlreadyRegistered, Behaviour}; +pub use control::{Control, IncomingStreams, OpenStreamError}; diff --git a/protocols/stream/src/shared.rs b/protocols/stream/src/shared.rs new file mode 100644 index 00000000000..48aa6613d83 --- /dev/null +++ b/protocols/stream/src/shared.rs @@ -0,0 +1,167 @@ +use std::{ + collections::{hash_map::Entry, HashMap}, + io, + sync::{Arc, Mutex, MutexGuard}, +}; + +use futures::channel::mpsc; +use libp2p_identity::PeerId; +use libp2p_swarm::{ConnectionId, Stream, StreamProtocol}; +use rand::seq::IteratorRandom as _; + +use crate::{handler::NewStream, AlreadyRegistered, IncomingStreams}; + +pub(crate) struct Shared { + /// Tracks the supported inbound protocols created via [`Control::accept`](crate::Control::accept). + /// + /// For each [`StreamProtocol`], we hold the [`mpsc::Sender`] corresponding to the [`mpsc::Receiver`] in [`IncomingStreams`]. + supported_inbound_protocols: HashMap>, + + connections: HashMap, + senders: HashMap>, + + /// Tracks channel pairs for a peer whilst we are dialing them. + pending_channels: HashMap, mpsc::Receiver)>, + + /// Sender for peers we want to dial. + /// + /// We manage this through a channel to avoid locks as part of [`NetworkBehaviour::poll`](libp2p_swarm::NetworkBehaviour::poll). + dial_sender: mpsc::Sender, +} + +impl Shared { + pub(crate) fn lock(shared: &Arc>) -> MutexGuard<'_, Shared> { + shared.lock().unwrap_or_else(|e| e.into_inner()) + } +} + +impl Shared { + pub(crate) fn new(dial_sender: mpsc::Sender) -> Self { + Self { + dial_sender, + connections: Default::default(), + senders: Default::default(), + pending_channels: Default::default(), + supported_inbound_protocols: Default::default(), + } + } + + pub(crate) fn accept( + &mut self, + protocol: StreamProtocol, + ) -> Result { + if self.supported_inbound_protocols.contains_key(&protocol) { + return Err(AlreadyRegistered); + } + + let (sender, receiver) = mpsc::channel(0); + self.supported_inbound_protocols + .insert(protocol.clone(), sender); + + Ok(IncomingStreams::new(receiver)) + } + + /// Lists the protocols for which we have an active [`IncomingStreams`] instance. + pub(crate) fn supported_inbound_protocols(&mut self) -> Vec { + self.supported_inbound_protocols + .retain(|_, sender| !sender.is_closed()); + + self.supported_inbound_protocols.keys().cloned().collect() + } + + pub(crate) fn on_inbound_stream( + &mut self, + remote: PeerId, + stream: Stream, + protocol: StreamProtocol, + ) { + match self.supported_inbound_protocols.entry(protocol.clone()) { + Entry::Occupied(mut entry) => match entry.get_mut().try_send((remote, stream)) { + Ok(()) => {} + Err(e) if e.is_full() => { + tracing::debug!(%protocol, "Channel is full, dropping inbound stream"); + } + Err(e) if e.is_disconnected() => { + tracing::debug!(%protocol, "Channel is gone, dropping inbound stream"); + entry.remove(); + } + _ => unreachable!(), + }, + Entry::Vacant(_) => { + tracing::debug!(%protocol, "channel is gone, dropping inbound stream"); + } + } + } + + pub(crate) fn on_connection_established(&mut self, conn: ConnectionId, peer: PeerId) { + self.connections.insert(conn, peer); + } + + pub(crate) fn on_connection_closed(&mut self, conn: ConnectionId) { + self.connections.remove(&conn); + } + + pub(crate) fn on_dial_failure(&mut self, peer: PeerId, reason: String) { + let Some((_, mut receiver)) = self.pending_channels.remove(&peer) else { + return; + }; + + while let Ok(Some(new_stream)) = receiver.try_next() { + let _ = new_stream + .sender + .send(Err(crate::OpenStreamError::Io(io::Error::new( + io::ErrorKind::NotConnected, + reason.clone(), + )))); + } + } + + pub(crate) fn sender(&mut self, peer: PeerId) -> mpsc::Sender { + let maybe_sender = self + .connections + .iter() + .filter_map(|(c, p)| (p == &peer).then_some(c)) + .choose(&mut rand::thread_rng()) + .and_then(|c| self.senders.get(c)); + + match maybe_sender { + Some(sender) => { + tracing::debug!("Returning sender to existing connection"); + + sender.clone() + } + None => { + tracing::debug!(%peer, "Not connected to peer, initiating dial"); + + let (sender, _) = self + .pending_channels + .entry(peer) + .or_insert_with(|| mpsc::channel(0)); + + let _ = self.dial_sender.try_send(peer); + + sender.clone() + } + } + } + + pub(crate) fn receiver( + &mut self, + peer: PeerId, + connection: ConnectionId, + ) -> mpsc::Receiver { + if let Some((sender, receiver)) = self.pending_channels.remove(&peer) { + tracing::debug!(%peer, %connection, "Returning existing pending receiver"); + + self.senders.insert(connection, sender); + return receiver; + } + + tracing::debug!(%peer, %connection, "Creating new channel pair"); + + let (sender, receiver) = mpsc::channel(0); + self.senders.insert(connection, sender); + + receiver + } +} diff --git a/protocols/stream/src/upgrade.rs b/protocols/stream/src/upgrade.rs new file mode 100644 index 00000000000..ac9fb3ed992 --- /dev/null +++ b/protocols/stream/src/upgrade.rs @@ -0,0 +1,42 @@ +use std::future::{ready, Ready}; + +use libp2p_core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}; +use libp2p_swarm::{Stream, StreamProtocol}; + +pub struct Upgrade { + pub(crate) supported_protocols: Vec, +} + +impl UpgradeInfo for Upgrade { + type Info = StreamProtocol; + + type InfoIter = std::vec::IntoIter; + + fn protocol_info(&self) -> Self::InfoIter { + self.supported_protocols.clone().into_iter() + } +} + +impl InboundUpgrade for Upgrade { + type Output = (Stream, StreamProtocol); + + type Error = void::Void; + + type Future = Ready>; + + fn upgrade_inbound(self, socket: Stream, info: Self::Info) -> Self::Future { + ready(Ok((socket, info))) + } +} + +impl OutboundUpgrade for Upgrade { + type Output = (Stream, StreamProtocol); + + type Error = void::Void; + + type Future = Ready>; + + fn upgrade_outbound(self, socket: Stream, info: Self::Info) -> Self::Future { + ready(Ok((socket, info))) + } +} diff --git a/protocols/stream/tests/lib.rs b/protocols/stream/tests/lib.rs new file mode 100644 index 00000000000..cd6caaced5e --- /dev/null +++ b/protocols/stream/tests/lib.rs @@ -0,0 +1,80 @@ +use std::io; + +use futures::{AsyncReadExt as _, AsyncWriteExt as _, StreamExt as _}; +use libp2p_identity::PeerId; +use libp2p_stream as stream; +use libp2p_swarm::{StreamProtocol, Swarm}; +use libp2p_swarm_test::SwarmExt as _; +use stream::OpenStreamError; +use tracing::level_filters::LevelFilter; +use tracing_subscriber::EnvFilter; + +const PROTOCOL: StreamProtocol = StreamProtocol::new("/test"); + +#[tokio::test] +async fn dropping_incoming_streams_deregisters() { + let _ = tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::builder() + .with_default_directive(LevelFilter::DEBUG.into()) + .from_env() + .unwrap(), + ) + .with_test_writer() + .try_init(); + + let mut swarm1 = Swarm::new_ephemeral(|_| stream::Behaviour::new()); + let mut swarm2 = Swarm::new_ephemeral(|_| stream::Behaviour::new()); + + let mut control = swarm1.behaviour().new_control(); + let mut incoming = swarm2.behaviour().new_control().accept(PROTOCOL).unwrap(); + + swarm2.listen().with_memory_addr_external().await; + swarm1.connect(&mut swarm2).await; + + let swarm2_peer_id = *swarm2.local_peer_id(); + + let handle = tokio::spawn(async move { + while let Some((_, mut stream)) = incoming.next().await { + stream.write_all(&[42]).await.unwrap(); + stream.close().await.unwrap(); + } + }); + tokio::spawn(swarm1.loop_on_next()); + tokio::spawn(swarm2.loop_on_next()); + + let mut stream = control.open_stream(swarm2_peer_id, PROTOCOL).await.unwrap(); + + let mut buf = [0u8; 1]; + stream.read_exact(&mut buf).await.unwrap(); + assert_eq!([42], buf); + + handle.abort(); + let _ = handle.await; + + let error = control + .open_stream(swarm2_peer_id, PROTOCOL) + .await + .unwrap_err(); + assert!(matches!(error, OpenStreamError::UnsupportedProtocol(_))); +} + +#[tokio::test] +async fn dial_errors_are_propagated() { + let swarm1 = Swarm::new_ephemeral(|_| stream::Behaviour::new()); + + let mut control = swarm1.behaviour().new_control(); + tokio::spawn(swarm1.loop_on_next()); + + let error = control + .open_stream(PeerId::random(), PROTOCOL) + .await + .unwrap_err(); + + let OpenStreamError::Io(e) = error else { + panic!("Unexpected error: {error}") + }; + + assert_eq!(e.kind(), io::ErrorKind::NotConnected); + assert_eq!("Dial error: no addresses for peer.", e.to_string()); +}