diff --git a/Cargo.lock b/Cargo.lock index f2c3c20bec8..cd78185085b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -445,6 +445,19 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "asynchronous-codec" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a860072022177f903e59730004fb5dc13db9275b79bb2aef7ba8ce831956c233" +dependencies = [ + "bytes", + "futures-sink", + "futures-util", + "memchr", + "pin-project-lite", +] + [[package]] name = "atomic-waker" version = "1.1.1" @@ -2480,7 +2493,7 @@ name = "libp2p-dcutr" version = "0.11.0" dependencies = [ "async-std", - "asynchronous-codec", + "asynchronous-codec 0.6.2", "clap", "either", "env_logger 0.10.0", @@ -2530,7 +2543,7 @@ dependencies = [ name = "libp2p-floodsub" version = "0.44.0" dependencies = [ - "asynchronous-codec", + "asynchronous-codec 0.6.2", "cuckoofilter", "fnv", "futures", @@ -2550,7 +2563,7 @@ name = "libp2p-gossipsub" version = "0.46.0" dependencies = [ "async-std", - "asynchronous-codec", + "asynchronous-codec 0.6.2", "base64 0.21.4", "byteorder", "bytes", @@ -2588,7 +2601,7 @@ name = "libp2p-identify" version = "0.44.0" dependencies = [ "async-std", - "asynchronous-codec", + "asynchronous-codec 0.6.2", "either", "env_logger 0.10.0", "futures", @@ -2642,7 +2655,7 @@ version = "0.45.0" dependencies = [ "arrayvec", "async-std", - "asynchronous-codec", + "asynchronous-codec 0.6.2", "bytes", "either", "env_logger 0.10.0", @@ -2737,7 +2750,7 @@ name = "libp2p-mplex" version = "0.41.0" dependencies = [ "async-std", - "asynchronous-codec", + "asynchronous-codec 0.6.2", "bytes", "criterion", "env_logger 0.10.0", @@ -2771,6 +2784,7 @@ dependencies = [ name = "libp2p-noise" version = "0.44.0" dependencies = [ + "asynchronous-codec 0.7.0", "bytes", "curve25519-dalek", "env_logger 0.10.0", @@ -2847,7 +2861,7 @@ dependencies = [ name = "libp2p-plaintext" version = "0.41.0" dependencies = [ - "asynchronous-codec", + "asynchronous-codec 0.6.2", "bytes", "env_logger 0.10.0", "futures", @@ -2915,7 +2929,7 @@ dependencies = [ name = "libp2p-relay" version = "0.17.0" dependencies = [ - "asynchronous-codec", + "asynchronous-codec 0.6.2", "bytes", "either", "env_logger 0.10.0", @@ -2944,7 +2958,7 @@ name = "libp2p-rendezvous" version = "0.14.0" dependencies = [ "async-trait", - "asynchronous-codec", + "asynchronous-codec 0.6.2", "bimap", "env_logger 0.10.0", "futures", @@ -3178,7 +3192,7 @@ dependencies = [ name = "libp2p-webrtc-utils" version = "0.1.0" dependencies = [ - "asynchronous-codec", + "asynchronous-codec 0.6.2", "bytes", "futures", "hex", @@ -4198,7 +4212,7 @@ dependencies = [ name = "quick-protobuf-codec" version = "0.2.0" dependencies = [ - "asynchronous-codec", + "asynchronous-codec 0.6.2", "bytes", "quick-protobuf", "thiserror", @@ -5866,7 +5880,7 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6889a77d49f1f013504cec6bf97a2c730394adedaeb1deb5ea08949a50541105" dependencies = [ - "asynchronous-codec", + "asynchronous-codec 0.6.2", "bytes", ] diff --git a/transports/noise/Cargo.toml b/transports/noise/Cargo.toml index 82ba697edcf..cea4743a72a 100644 --- a/transports/noise/Cargo.toml +++ b/transports/noise/Cargo.toml @@ -9,6 +9,7 @@ license = "MIT" repository = "https://github.com/libp2p/rust-libp2p" [dependencies] +asynchronous-codec = "0.7" bytes = "1" curve25519-dalek = "4.1.1" futures = "0.3.28" diff --git a/transports/noise/src/io.rs b/transports/noise/src/io.rs index ee184695696..c43e1dd67a1 100644 --- a/transports/noise/src/io.rs +++ b/transports/noise/src/io.rs @@ -22,8 +22,9 @@ mod framed; pub(crate) mod handshake; +use asynchronous_codec::Framed; use bytes::Bytes; -use framed::{NoiseFramed, MAX_FRAME_LEN}; +use framed::{Codec, MAX_FRAME_LEN}; use futures::prelude::*; use futures::ready; use log::trace; @@ -38,7 +39,7 @@ use std::{ /// /// `T` is the type of the underlying I/O resource. pub struct Output { - io: NoiseFramed, + io: Framed>, recv_buffer: Bytes, recv_offset: usize, send_buffer: Vec, @@ -47,12 +48,12 @@ pub struct Output { impl fmt::Debug for Output { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("NoiseOutput").field("io", &self.io).finish() + f.debug_struct("NoiseOutput").finish() } } impl Output { - fn new(io: NoiseFramed) -> Self { + fn new(io: Framed>) -> Self { Output { io, recv_buffer: Bytes::new(), diff --git a/transports/noise/src/io/framed.rs b/transports/noise/src/io/framed.rs index d7fa79fc815..739b0eea426 100644 --- a/transports/noise/src/io/framed.rs +++ b/transports/noise/src/io/framed.rs @@ -18,20 +18,18 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -//! This module provides a `Sink` and `Stream` for length-delimited -//! Noise protocol messages in form of [`NoiseFramed`]. +//! Provides a [`Codec`] type implementing the [`Encoder`] and [`Decoder`] traits. +//! +//! Alongside a [`asynchronous_codec::Framed`] this provides a [Sink](futures::Sink) +//! and [Stream](futures::Stream) for length-delimited Noise protocol messages. -use crate::io::Output; +use super::handshake::proto; use crate::{protocol::PublicKey, Error}; -use bytes::{Bytes, BytesMut}; -use futures::prelude::*; -use futures::ready; -use log::{debug, trace}; -use std::{ - fmt, io, - pin::Pin, - task::{Context, Poll}, -}; +use asynchronous_codec::{Decoder, Encoder}; +use bytes::{Buf, Bytes, BytesMut}; +use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer}; +use std::io; +use std::mem::size_of; /// Max. size of a noise message. const MAX_NOISE_MSG_LEN: usize = 65535; @@ -43,61 +41,49 @@ static_assertions::const_assert! { MAX_FRAME_LEN + EXTRA_ENCRYPT_SPACE <= MAX_NOISE_MSG_LEN } -/// A `NoiseFramed` is a `Sink` and `Stream` for length-delimited -/// Noise protocol messages. -/// -/// `T` is the type of the underlying I/O resource and `S` the -/// type of the Noise session state. -pub(crate) struct NoiseFramed { - io: T, +/// Codec holds the noise session state `S` and acts as a medium for +/// encoding and decoding length-delimited session messages. +pub(crate) struct Codec { session: S, - read_state: ReadState, - write_state: WriteState, - read_buffer: Vec, - write_buffer: Vec, - decrypt_buffer: BytesMut, -} -impl fmt::Debug for NoiseFramed { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("NoiseFramed") - .field("read_state", &self.read_state) - .field("write_state", &self.write_state) - .finish() - } + // We reuse write and encryption buffers across multiple messages to avoid reallocations. + // We cannot reuse read and decryption buffers because we cannot return borrowed data. + write_buffer: BytesMut, + encrypt_buffer: BytesMut, } -impl NoiseFramed { - /// Creates a nwe `NoiseFramed` for beginning a Noise protocol handshake. - pub(crate) fn new(io: T, state: snow::HandshakeState) -> Self { - NoiseFramed { - io, - session: state, - read_state: ReadState::Ready, - write_state: WriteState::Ready, - read_buffer: Vec::new(), - write_buffer: Vec::new(), - decrypt_buffer: BytesMut::new(), +impl Codec { + pub(crate) fn new(session: S) -> Self { + Codec { + session, + write_buffer: BytesMut::default(), + encrypt_buffer: BytesMut::default(), } } +} +impl Codec { + /// Checks if the session was started in the `initiator` role. pub(crate) fn is_initiator(&self) -> bool { self.session.is_initiator() } + /// Checks if the session was started in the `responder` role. pub(crate) fn is_responder(&self) -> bool { !self.session.is_initiator() } - /// Converts the `NoiseFramed` into a `NoiseOutput` encrypted data stream - /// once the handshake is complete, including the static DH [`PublicKey`] - /// of the remote, if received. + /// Converts the underlying Noise session from the [`snow::HandshakeState`] to a + /// [`snow::TransportState`] once the handshake is complete, including the static + /// DH [`PublicKey`] of the remote if received. /// - /// If the underlying Noise protocol session state does not permit - /// transitioning to transport mode because the handshake is incomplete, - /// an error is returned. Similarly if the remote's static DH key, if - /// present, cannot be parsed. - pub(crate) fn into_transport(self) -> Result<(PublicKey, Output), Error> { + /// If the Noise protocol session state does not permit transitioning to + /// transport mode because the handshake is incomplete, an error is returned. + /// + /// An error is also returned if the remote's static DH key is not present or + /// cannot be parsed, as that indicates a fatal handshake error for the noise + /// `XX` pattern, which is the only handshake protocol libp2p currently supports. + pub(crate) fn into_transport(self) -> Result<(PublicKey, Codec), Error> { let dh_remote_pubkey = self.session.get_remote_static().ok_or_else(|| { Error::Io(io::Error::new( io::ErrorKind::Other, @@ -106,355 +92,152 @@ impl NoiseFramed { })?; let dh_remote_pubkey = PublicKey::from_slice(dh_remote_pubkey)?; + let codec = Codec::new(self.session.into_transport_mode()?); - let io = NoiseFramed { - session: self.session.into_transport_mode()?, - io: self.io, - read_state: ReadState::Ready, - write_state: WriteState::Ready, - read_buffer: self.read_buffer, - write_buffer: self.write_buffer, - decrypt_buffer: self.decrypt_buffer, - }; - - Ok((dh_remote_pubkey, Output::new(io))) + Ok((dh_remote_pubkey, codec)) } } -/// The states for reading Noise protocol frames. -#[derive(Debug)] -enum ReadState { - /// Ready to read another frame. - Ready, - /// Reading frame length. - ReadLen { buf: [u8; 2], off: usize }, - /// Reading frame data. - ReadData { len: usize, off: usize }, - /// EOF has been reached (terminal state). - /// - /// The associated result signals if the EOF was unexpected or not. - Eof(Result<(), ()>), - /// A decryption error occurred (terminal state). - DecErr, -} +impl Encoder for Codec { + type Error = io::Error; + type Item<'a> = &'a proto::NoiseHandshakePayload; -/// The states for writing Noise protocol frames. -#[derive(Debug)] -enum WriteState { - /// Ready to write another frame. - Ready, - /// Writing the frame length. - WriteLen { - len: usize, - buf: [u8; 2], - off: usize, - }, - /// Writing the frame data. - WriteData { len: usize, off: usize }, - /// EOF has been reached unexpectedly (terminal state). - Eof, - /// An encryption error occurred (terminal state). - EncErr, -} + fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> { + let item_size = item.get_size(); -impl WriteState { - fn is_ready(&self) -> bool { - if let WriteState::Ready = self { - return true; - } - false - } -} + self.write_buffer.resize(item_size, 0); + let mut writer = Writer::new(&mut self.write_buffer[..item_size]); + item.write_message(&mut writer) + .expect("Protobuf encoding to succeed"); -impl futures::stream::Stream for NoiseFramed -where - T: AsyncRead + Unpin, - S: SessionState + Unpin, -{ - type Item = io::Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = Pin::into_inner(self); - loop { - trace!("read state: {:?}", this.read_state); - match this.read_state { - ReadState::Ready => { - this.read_state = ReadState::ReadLen { - buf: [0, 0], - off: 0, - }; - } - ReadState::ReadLen { mut buf, mut off } => { - let n = match read_frame_len(&mut this.io, cx, &mut buf, &mut off) { - Poll::Ready(Ok(Some(n))) => n, - Poll::Ready(Ok(None)) => { - trace!("read: eof"); - this.read_state = ReadState::Eof(Ok(())); - return Poll::Ready(None); - } - Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), - Poll::Pending => { - this.read_state = ReadState::ReadLen { buf, off }; - return Poll::Pending; - } - }; - trace!("read: frame len = {}", n); - if n == 0 { - trace!("read: empty frame"); - this.read_state = ReadState::Ready; - continue; - } - this.read_buffer.resize(usize::from(n), 0u8); - this.read_state = ReadState::ReadData { - len: usize::from(n), - off: 0, - } - } - ReadState::ReadData { len, ref mut off } => { - let n = { - let f = - Pin::new(&mut this.io).poll_read(cx, &mut this.read_buffer[*off..len]); - match ready!(f) { - Ok(n) => n, - Err(e) => return Poll::Ready(Some(Err(e))), - } - }; - trace!("read: {}/{} bytes", *off + n, len); - if n == 0 { - trace!("read: eof"); - this.read_state = ReadState::Eof(Err(())); - return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))); - } - *off += n; - if len == *off { - trace!("read: decrypting {} bytes", len); - this.decrypt_buffer.resize(len, 0); - if let Ok(n) = this - .session - .read_message(&this.read_buffer, &mut this.decrypt_buffer) - { - this.decrypt_buffer.truncate(n); - trace!("read: payload len = {} bytes", n); - this.read_state = ReadState::Ready; - // Return an immutable view into the current buffer. - // If the view is dropped before the next frame is - // read, the `BytesMut` will reuse the same buffer - // for the next frame. - let view = this.decrypt_buffer.split().freeze(); - return Poll::Ready(Some(Ok(view))); - } else { - debug!("read: decryption error"); - this.read_state = ReadState::DecErr; - return Poll::Ready(Some(Err(io::ErrorKind::InvalidData.into()))); - } - } - } - ReadState::Eof(Ok(())) => { - trace!("read: eof"); - return Poll::Ready(None); - } - ReadState::Eof(Err(())) => { - trace!("read: eof (unexpected)"); - return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))); - } - ReadState::DecErr => { - return Poll::Ready(Some(Err(io::ErrorKind::InvalidData.into()))) - } - } - } + encrypt( + &self.write_buffer[..item_size], + dst, + &mut self.encrypt_buffer, + |item, buffer| self.session.write_message(item, buffer), + )?; + + Ok(()) } } -impl futures::sink::Sink<&Vec> for NoiseFramed -where - T: AsyncWrite + Unpin, - S: SessionState + Unpin, -{ +impl Decoder for Codec { type Error = io::Error; + type Item = proto::NoiseHandshakePayload; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + let cleartext = match decrypt(src, |ciphertext, decrypt_buffer| { + self.session.read_message(ciphertext, decrypt_buffer) + })? { + None => return Ok(None), + Some(cleartext) => cleartext, + }; - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = Pin::into_inner(self); - loop { - trace!("write state {:?}", this.write_state); - match this.write_state { - WriteState::Ready => { - return Poll::Ready(Ok(())); - } - WriteState::WriteLen { len, buf, mut off } => { - trace!("write: frame len ({}, {:?}, {}/2)", len, buf, off); - match write_frame_len(&mut this.io, cx, &buf, &mut off) { - Poll::Ready(Ok(true)) => (), - Poll::Ready(Ok(false)) => { - trace!("write: eof"); - this.write_state = WriteState::Eof; - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); - } - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Pending => { - this.write_state = WriteState::WriteLen { len, buf, off }; - return Poll::Pending; - } - } - this.write_state = WriteState::WriteData { len, off: 0 } - } - WriteState::WriteData { len, ref mut off } => { - let n = { - let f = - Pin::new(&mut this.io).poll_write(cx, &this.write_buffer[*off..len]); - match ready!(f) { - Ok(n) => n, - Err(e) => return Poll::Ready(Err(e)), - } - }; - if n == 0 { - trace!("write: eof"); - this.write_state = WriteState::Eof; - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); - } - *off += n; - trace!("write: {}/{} bytes written", *off, len); - if len == *off { - trace!("write: finished with {} bytes", len); - this.write_state = WriteState::Ready; - } - } - WriteState::Eof => { - trace!("write: eof"); - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); - } - WriteState::EncErr => return Poll::Ready(Err(io::ErrorKind::InvalidData.into())), - } - } - } + let mut reader = BytesReader::from_bytes(&cleartext[..]); + let pb = + proto::NoiseHandshakePayload::from_reader(&mut reader, &cleartext).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidData, + "Failed decoding handshake payload", + ) + })?; - fn start_send(self: Pin<&mut Self>, frame: &Vec) -> Result<(), Self::Error> { - assert!(frame.len() <= MAX_FRAME_LEN); - let this = Pin::into_inner(self); - assert!(this.write_state.is_ready()); - - this.write_buffer - .resize(frame.len() + EXTRA_ENCRYPT_SPACE, 0u8); - match this - .session - .write_message(frame, &mut this.write_buffer[..]) - { - Ok(n) => { - trace!("write: cipher text len = {} bytes", n); - this.write_buffer.truncate(n); - this.write_state = WriteState::WriteLen { - len: n, - buf: u16::to_be_bytes(n as u16), - off: 0, - }; - Ok(()) - } - Err(e) => { - log::error!("encryption error: {:?}", e); - this.write_state = WriteState::EncErr; - Err(io::ErrorKind::InvalidData.into()) - } - } + Ok(Some(pb)) } +} + +impl Encoder for Codec { + type Error = io::Error; + type Item<'a> = &'a [u8]; - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().poll_ready(cx))?; - Pin::new(&mut self.io).poll_flush(cx) + fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> { + encrypt(item, dst, &mut self.encrypt_buffer, |item, buffer| { + self.session.write_message(item, buffer) + }) } +} + +impl Decoder for Codec { + type Error = io::Error; + type Item = Bytes; - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().poll_flush(cx))?; - Pin::new(&mut self.io).poll_close(cx) + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + decrypt(src, |ciphertext, decrypt_buffer| { + self.session.read_message(ciphertext, decrypt_buffer) + }) } } -/// A stateful context in which Noise protocol messages can be read and written. -pub(crate) trait SessionState { - fn read_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result; - fn write_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result; +/// Encrypts the given cleartext to `dst`. +/// +/// This is a standalone function to allow us reusing the `encrypt_buffer` and to use to across different session states of the noise protocol. +fn encrypt( + cleartext: &[u8], + dst: &mut BytesMut, + encrypt_buffer: &mut BytesMut, + encrypt_fn: impl FnOnce(&[u8], &mut [u8]) -> Result, +) -> io::Result<()> { + log::trace!("Encrypting {} bytes", cleartext.len()); + + encrypt_buffer.resize(cleartext.len() + EXTRA_ENCRYPT_SPACE, 0); + let n = encrypt_fn(cleartext, encrypt_buffer).map_err(into_io_error)?; + + log::trace!("Outgoing ciphertext has {n} bytes"); + + encode_length_prefixed(&encrypt_buffer[..n], dst); + + Ok(()) } -impl SessionState for snow::HandshakeState { - fn read_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result { - self.read_message(msg, buf) - } +/// Encrypts the given ciphertext. +/// +/// This is a standalone function so we can use it across different session states of the noise protocol. +/// In case `ciphertext` does not contain enough bytes to decrypt the entire frame, `Ok(None)` is returned. +fn decrypt( + ciphertext: &mut BytesMut, + decrypt_fn: impl FnOnce(&[u8], &mut [u8]) -> Result, +) -> io::Result> { + let ciphertext = match decode_length_prefixed(ciphertext)? { + Some(b) => b, + None => return Ok(None), + }; - fn write_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result { - self.write_message(msg, buf) - } + log::trace!("Incoming ciphertext has {} bytes", ciphertext.len()); + + let mut decrypt_buffer = BytesMut::zeroed(ciphertext.len()); + let n = decrypt_fn(&ciphertext, &mut decrypt_buffer).map_err(into_io_error)?; + + log::trace!("Decrypted cleartext has {n} bytes"); + + Ok(Some(decrypt_buffer.split_to(n).freeze())) } -impl SessionState for snow::TransportState { - fn read_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result { - self.read_message(msg, buf) - } +fn into_io_error(err: snow::Error) -> io::Error { + io::Error::new(io::ErrorKind::InvalidData, err) +} - fn write_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result { - self.write_message(msg, buf) - } +const U16_LENGTH: usize = size_of::(); + +fn encode_length_prefixed(src: &[u8], dst: &mut BytesMut) { + dst.reserve(U16_LENGTH + src.len()); + dst.extend_from_slice(&(src.len() as u16).to_be_bytes()); + dst.extend_from_slice(src); } -/// Read 2 bytes as frame length from the given source into the given buffer. -/// -/// Panics if `off >= 2`. -/// -/// When [`Poll::Pending`] is returned, the given buffer and offset -/// may have been updated (i.e. a byte may have been read) and must be preserved -/// for the next invocation. -/// -/// Returns `None` if EOF has been encountered. -fn read_frame_len( - mut io: &mut R, - cx: &mut Context<'_>, - buf: &mut [u8; 2], - off: &mut usize, -) -> Poll>> { - loop { - match ready!(Pin::new(&mut io).poll_read(cx, &mut buf[*off..])) { - Ok(n) => { - if n == 0 { - return Poll::Ready(Ok(None)); - } - *off += n; - if *off == 2 { - return Poll::Ready(Ok(Some(u16::from_be_bytes(*buf)))); - } - } - Err(e) => { - return Poll::Ready(Err(e)); - } - } +fn decode_length_prefixed(src: &mut BytesMut) -> Result, io::Error> { + if src.len() < size_of::() { + return Ok(None); } -} -/// Write 2 bytes as frame length from the given buffer into the given sink. -/// -/// Panics if `off >= 2`. -/// -/// When [`Poll::Pending`] is returned, the given offset -/// may have been updated (i.e. a byte may have been written) and must -/// be preserved for the next invocation. -/// -/// Returns `false` if EOF has been encountered. -fn write_frame_len( - mut io: &mut W, - cx: &mut Context<'_>, - buf: &[u8; 2], - off: &mut usize, -) -> Poll> { - loop { - match ready!(Pin::new(&mut io).poll_write(cx, &buf[*off..])) { - Ok(n) => { - if n == 0 { - return Poll::Ready(Ok(false)); - } - *off += n; - if *off == 2 { - return Poll::Ready(Ok(true)); - } - } - Err(e) => { - return Poll::Ready(Err(e)); - } - } + let mut len_bytes = [0u8; U16_LENGTH]; + len_bytes.copy_from_slice(&src[..U16_LENGTH]); + let len = u16::from_be_bytes(len_bytes) as usize; + + if src.len() - U16_LENGTH >= len { + // Skip the length header we already read. + src.advance(U16_LENGTH); + Ok(Some(src.split_to(len).freeze())) + } else { + Ok(None) } } diff --git a/transports/noise/src/io/handshake.rs b/transports/noise/src/io/handshake.rs index c853af7b189..7cc0f859e6e 100644 --- a/transports/noise/src/io/handshake.rs +++ b/transports/noise/src/io/handshake.rs @@ -20,23 +20,24 @@ //! Noise protocol handshake I/O. -mod proto { +pub(super) mod proto { #![allow(unreachable_pub)] include!("../generated/mod.rs"); pub use self::payload::proto::NoiseExtensions; pub use self::payload::proto::NoiseHandshakePayload; } -use crate::io::{framed::NoiseFramed, Output}; -use crate::protocol::{KeypairIdentity, STATIC_KEY_DOMAIN}; -use crate::{DecodeError, Error}; -use bytes::Bytes; +use super::framed::Codec; +use crate::io::Output; +use crate::protocol::{KeypairIdentity, PublicKey, STATIC_KEY_DOMAIN}; +use crate::Error; +use asynchronous_codec::Framed; use futures::prelude::*; use libp2p_identity as identity; use multihash::Multihash; -use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer}; +use quick_protobuf::MessageWrite; use std::collections::HashSet; -use std::io; +use std::{io, mem}; ////////////////////////////////////////////////////////////////////////////// // Internal @@ -44,7 +45,7 @@ use std::io; /// Handshake state. pub(crate) struct State { /// The underlying I/O resource. - io: NoiseFramed, + io: Framed>, /// The associated public identity of the local node's static DH keypair, /// which can be sent to the remote as part of an authenticated handshake. identity: KeypairIdentity, @@ -63,7 +64,10 @@ struct Extensions { webtransport_certhashes: HashSet>, } -impl State { +impl State +where + T: AsyncRead + AsyncWrite, +{ /// Initializes the state for a new Noise handshake, using the given local /// identity keypair and local DH static public key. The handshake messages /// will be sent and received on the given I/O resource and using the @@ -79,7 +83,7 @@ impl State { ) -> Self { Self { identity, - io: NoiseFramed::new(io, session), + io: Framed::new(io, Codec::new(session)), dh_remote_pubkey_sig: None, id_remote_pubkey: expected_remote_key, responder_webtransport_certhashes, @@ -88,12 +92,16 @@ impl State { } } -impl State { +impl State +where + T: AsyncRead + AsyncWrite, +{ /// Finish a handshake, yielding the established remote identity and the /// [`Output`] for communicating on the encrypted channel. pub(crate) fn finish(self) -> Result<(identity::PublicKey, Output), Error> { - let is_initiator = self.io.is_initiator(); - let (pubkey, io) = self.io.into_transport()?; + let is_initiator = self.io.codec().is_initiator(); + + let (pubkey, framed) = map_into_transport(self.io)?; let id_pk = self .id_remote_pubkey @@ -131,10 +139,34 @@ impl State { } } - Ok((id_pk, io)) + Ok((id_pk, Output::new(framed))) } } +/// Maps the provided [`Framed`] from the [`snow::HandshakeState`] into the [`snow::TransportState`]. +/// +/// This is a bit tricky because [`Framed`] cannot just be de-composed but only into its [`FramedParts`](asynchronous_codec::FramedParts). +/// However, we need to retain the original [`FramedParts`](asynchronous_codec::FramedParts) because they contain the active read & write buffers. +/// +/// Those are likely **not** empty because the remote may directly write to the stream again after the noise handshake finishes. +fn map_into_transport( + framed: Framed>, +) -> Result<(PublicKey, Framed>), Error> +where + T: AsyncRead + AsyncWrite, +{ + let mut parts = framed.into_parts().map_codec(Some); + + let (pubkey, codec) = mem::take(&mut parts.codec) + .expect("We just set it to `Some`") + .into_transport()?; + + let parts = parts.map_codec(|_| codec); + let framed = Framed::from_parts(parts); + + Ok((pubkey, framed)) +} + impl From for Extensions { fn from(value: proto::NoiseExtensions) -> Self { Extensions { @@ -151,14 +183,14 @@ impl From for Extensions { // Handshake Message Futures /// A future for receiving a Noise handshake message. -async fn recv(state: &mut State) -> Result +async fn recv(state: &mut State) -> Result where T: AsyncRead + Unpin, { match state.io.next().await { None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "eof").into()), Some(Err(e)) => Err(e.into()), - Some(Ok(m)) => Ok(m), + Some(Ok(p)) => Ok(p), } } @@ -167,12 +199,11 @@ pub(crate) async fn recv_empty(state: &mut State) -> Result<(), Error> where T: AsyncRead + Unpin, { - let msg = recv(state).await?; - if !msg.is_empty() { - return Err( - io::Error::new(io::ErrorKind::InvalidData, "Unexpected handshake payload.").into(), - ); + let payload = recv(state).await?; + if payload.get_size() != 0 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "Expected empty payload.").into()); } + Ok(()) } @@ -181,7 +212,10 @@ pub(crate) async fn send_empty(state: &mut State) -> Result<(), Error> where T: AsyncWrite + Unpin, { - state.io.send(&Vec::new()).await?; + state + .io + .send(&proto::NoiseHandshakePayload::default()) + .await?; Ok(()) } @@ -190,11 +224,7 @@ pub(crate) async fn recv_identity(state: &mut State) -> Result<(), Error> where T: AsyncRead + Unpin, { - let msg = recv(state).await?; - let mut reader = BytesReader::from_bytes(&msg[..]); - let pb = - proto::NoiseHandshakePayload::from_reader(&mut reader, &msg[..]).map_err(DecodeError)?; - + let pb = recv(state).await?; state.id_remote_pubkey = Some(identity::PublicKey::try_decode_protobuf(&pb.identity_key)?); if !pb.identity_sig.is_empty() { @@ -211,7 +241,7 @@ where /// Send a Noise handshake message with a payload identifying the local node to the remote. pub(crate) async fn send_identity(state: &mut State) -> Result<(), Error> where - T: AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, { let mut pb = proto::NoiseHandshakePayload { identity_key: state.identity.public.encode_protobuf(), @@ -221,7 +251,7 @@ where pb.identity_sig = state.identity.signature.clone(); // If this is the responder then send WebTransport certhashes to initiator, if any. - if state.io.is_responder() { + if state.io.codec().is_responder() { if let Some(ref certhashes) = state.responder_webtransport_certhashes { let ext = pb .extensions @@ -231,11 +261,7 @@ where } } - let mut msg = Vec::with_capacity(pb.get_size()); - - let mut writer = Writer::new(&mut msg); - pb.write_message(&mut writer).expect("Encoding to succeed"); - state.io.send(&msg).await?; + state.io.send(&pb).await?; Ok(()) } diff --git a/transports/noise/src/lib.rs b/transports/noise/src/lib.rs index 485f5d68155..7b1c8e27e8d 100644 --- a/transports/noise/src/lib.rs +++ b/transports/noise/src/lib.rs @@ -122,7 +122,7 @@ impl Config { self } - fn into_responder(self, socket: S) -> Result, Error> { + fn into_responder(self, socket: S) -> Result, Error> { let session = noise_params_into_builder( self.params, &self.prologue, @@ -142,7 +142,7 @@ impl Config { Ok(state) } - fn into_initiator(self, socket: S) -> Result, Error> { + fn into_initiator(self, socket: S) -> Result, Error> { let session = noise_params_into_builder( self.params, &self.prologue, @@ -188,7 +188,7 @@ where handshake::send_identity(&mut state).await?; handshake::recv_identity(&mut state).await?; - let (pk, io) = state.finish()?; + let (pk, io) = dbg!(state.finish())?; Ok((pk.to_peer_id(), io)) }