diff --git a/Cargo.lock b/Cargo.lock index 5c33dd3..7009ffd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -400,6 +400,7 @@ dependencies = [ "base64", "bytes", "criterion", + "futures", "http", "http-body-util", "hyper", @@ -412,6 +413,7 @@ dependencies = [ "thiserror", "tokio", "tokio-rustls", + "tokio-util", "trybuild", "utf-8", "webpki-roots", @@ -432,6 +434,20 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -439,6 +455,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -447,6 +464,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + [[package]] name = "futures-sink" version = "0.3.30" @@ -465,10 +488,15 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -1386,16 +1414,15 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.10" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" +checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" dependencies = [ "bytes", "futures-core", "futures-sink", "pin-project-lite", "tokio", - "tracing", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index d9d427e..b286716 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,6 +59,8 @@ bytes = "1.5.0" axum-core = { version = "0.4.3", optional = true } http = { version = "1", optional = true } async-trait = { version = "0.1", optional = true } +tokio-util = { version = "0.7", features = ["codec", "io"] } +futures = { version = "0.3", default-features = false, features = ["std"] } [features] default = ["simd"] diff --git a/src/fragment.rs b/src/fragment.rs index 937eeaa..10e36b9 100644 --- a/src/fragment.rs +++ b/src/fragment.rs @@ -105,7 +105,7 @@ impl<'f, S> FragmentCollector { { loop { let (res, obligated_send) = - self.read_half.read_frame_inner(&mut self.stream).await; + self.read_half.read_frame(&mut self.stream).await; let is_closed = self.write_half.closed; if let Some(obligated_send) = obligated_send { if !is_closed { @@ -179,7 +179,7 @@ impl<'f, S> FragmentCollectorRead { { loop { let (res, obligated_send) = - self.read_half.read_frame_inner(&mut self.stream).await; + self.read_half.read_frame(&mut self.stream).await; if let Some(frame) = obligated_send { let res = send_fn(frame).await; res.map_err(|e| WebSocketError::SendError(e.into()))?; diff --git a/src/frame.rs b/src/frame.rs index 7178e8b..7126fbb 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -138,7 +138,7 @@ pub struct Frame<'f> { pub payload: Payload<'f>, } -const MAX_HEAD_SIZE: usize = 16; +pub(crate) const MAX_HEAD_SIZE: usize = 16; impl<'f> Frame<'f> { /// Creates a new WebSocket `Frame`. @@ -321,6 +321,9 @@ impl<'f> Frame<'f> { } /// Writes the frame to the buffer and returns a slice of the buffer containing the frame. + /// + /// This function will NOT append the frame to the Vec, but rather replace the current bytes + /// with the frame's serialized bytes. pub fn write<'a>(&mut self, buf: &'a mut Vec) -> &'a [u8] { fn reserve_enough(buf: &mut Vec, len: usize) { if buf.len() < len { @@ -330,7 +333,7 @@ impl<'f> Frame<'f> { let len = self.payload.len(); reserve_enough(buf, len + MAX_HEAD_SIZE); - let size = self.fmt_head(buf); + let size = self.fmt_head(&mut *buf); buf[size..size + len].copy_from_slice(&self.payload); &buf[..size + len] } diff --git a/src/lib.rs b/src/lib.rs index ca1d5b6..46c2086 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -167,13 +167,25 @@ pub mod upgrade; use bytes::Buf; use bytes::BytesMut; +use frame::MAX_HEAD_SIZE; +use futures::task::AtomicWaker; +use std::collections::VecDeque; +use std::future::poll_fn; +use std::io::IoSlice; +use std::mem; +use std::ops::Deref; +use std::ops::DerefMut; +use std::pin::pin; +use std::sync::Arc; +use std::task::ready; +use std::task::Context; +use std::task::Poll; + #[cfg(feature = "unstable-split")] use std::future::Future; use tokio::io::AsyncRead; -use tokio::io::AsyncReadExt; use tokio::io::AsyncWrite; -use tokio::io::AsyncWriteExt; pub use crate::close::CloseCode; pub use crate::error::WebSocketError; @@ -185,6 +197,62 @@ pub use crate::frame::OpCode; pub use crate::frame::Payload; pub use crate::mask::unmask; +enum ContextKind { + /// Read is used when the cx is called from WebSocketRead. + Read, + /// Write is used when the cx is called from WebSocketWrite. + Write, +} + +// WakerDemux keeps track of whether the waker was called from a reader or a writer. +// +// This is important because the reader can also write, in order to reply to Ping or Close messages. +// If we didn't implement the WakerDemux the reader could hijack the writer's Waker and the writer's task +// would never get notified. +// +// Waking up the WakerDemux will wake the read and write tasks. +#[derive(Default)] +struct WakerDemux { + read_waker: AtomicWaker, + write_waker: AtomicWaker, +} + +impl futures::task::ArcWake for WakerDemux { + fn wake_by_ref(this: &Arc) { + this.read_waker.wake(); + this.write_waker.wake(); + } +} + +impl WakerDemux { + /// Set the Waker to the corresponding slot. + #[inline] + fn set_waker(&self, kind: ContextKind, waker: &futures::task::Waker) { + match kind { + ContextKind::Read => { + self.read_waker.register(waker); + } + ContextKind::Write => { + self.write_waker.register(waker); + } + } + } + + #[inline] + fn with_context(self: &Arc, f: F) -> R + where + F: FnOnce(&mut Context<'_>) -> R, + { + let waker = futures::task::waker_ref(&self); + let mut cx = Context::from_waker(&waker); + f(&mut cx) + } +} + +/// The role the connection is taking. +/// +/// When a server role is taken the frames will not be masked, unlike +/// the client role, in which frames are masked. #[derive(Copy, Clone, PartialEq)] pub enum Role { Server, @@ -198,6 +266,18 @@ pub(crate) struct WriteHalf { auto_apply_mask: bool, writev_threshold: usize, write_buffer: Vec, + // where in the write_buffer we should read from when writing to the stream + read_head: usize, + // only used with vectored writes. stores the frame payloads + payloads: VecDeque, +} + +struct WriteBuffer { + // where in the write_buffer this payload should be inserted + position: usize, + read_head: usize, + // TODO(dgrr): add a lifetime instead of using 'static? + payload: Payload<'static>, } pub(crate) struct ReadHalf { @@ -207,16 +287,39 @@ pub(crate) struct ReadHalf { auto_pong: bool, writev_threshold: usize, max_message_size: usize, + read_state: Option, buffer: BytesMut, } +struct Header { + fin: bool, + masked: bool, + opcode: OpCode, + extra: usize, + length_code: u8, + header_size: usize, +} + +struct HeaderAndMask { + header: Header, + mask: Option<[u8; 4]>, + payload_len: usize, +} + +enum ReadState { + Header(Header), + Payload(HeaderAndMask), +} + #[cfg(feature = "unstable-split")] +/// Read end of a WebSocket connection. pub struct WebSocketRead { stream: S, read_half: ReadHalf, } #[cfg(feature = "unstable-split")] +/// Write end of a WebSocket connection. pub struct WebSocketWrite { stream: S, write_half: WriteHalf, @@ -297,7 +400,7 @@ impl<'f, S> WebSocketRead { { loop { let (res, obligated_send) = - self.read_half.read_frame_inner(&mut self.stream).await; + self.read_half.read_frame(&mut self.stream).await; if let Some(frame) = obligated_send { let res = send_fn(frame).await; res.map_err(|e| WebSocketError::SendError(e.into()))?; @@ -307,6 +410,18 @@ impl<'f, S> WebSocketRead { } } } + + /// Reads a frame from the stream. + #[inline(always)] + pub fn poll_read_frame( + &mut self, + cx: &mut Context<'_>, + ) -> Poll<(Result, WebSocketError>, Option)> + where + S: AsyncRead + Unpin, + { + self.read_half.poll_read_frame(&mut self.stream, cx) + } } #[cfg(feature = "unstable-split")] @@ -329,10 +444,12 @@ impl<'f, S> WebSocketWrite { self.write_half.auto_apply_mask = auto_apply_mask; } + /// Returns whether the connection was closed or not. pub fn is_closed(&self) -> bool { self.write_half.closed } + /// Sends a frame. pub async fn write_frame( &mut self, frame: Frame<'f>, @@ -342,13 +459,79 @@ impl<'f, S> WebSocketWrite { { self.write_half.write_frame(&mut self.stream, frame).await } + + /// Serializes the frame into the internal buffer and tries to flush the contents. + /// + /// If the function returns Poll::Pending, the user needs to call poll_flush. + pub fn poll_write_frame( + &mut self, + cx: &mut Context<'_>, + frame: Frame<'f>, + ) -> Poll> + where + S: AsyncWrite + Unpin, + { + self.write_half.start_send_frame(frame)?; + self.write_half.poll_flush(&mut self.stream, cx) + } +} + +/// Keep track of the state of the Stream +enum StreamState { + // reading from Stream + Reading(S), + // flushing obligated send + Flushing(S), + // keep the stream here just in case the user wants to access to it + Closed(S), + // used temporarily + None, +} + +impl Deref for StreamState { + type Target = S; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + match self { + StreamState::Reading(stream) => stream, + StreamState::Flushing(stream) => stream, + StreamState::Closed(stream) => stream, + StreamState::None => unreachable!(), + } + } +} + +impl DerefMut for StreamState { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + StreamState::Reading(stream) => stream, + StreamState::Flushing(stream) => stream, + StreamState::Closed(stream) => stream, + StreamState::None => unreachable!(), + } + } +} + +impl StreamState { + #[inline(always)] + fn into_inner(self) -> S { + match self { + StreamState::Reading(stream) => stream, + StreamState::Flushing(stream) => stream, + StreamState::Closed(stream) => stream, + StreamState::None => unreachable!(), + } + } } /// WebSocket protocol implementation over an async stream. pub struct WebSocket { - stream: S, + stream: StreamState, write_half: WriteHalf, read_half: ReadHalf, + waker: Arc, } impl<'f, S> WebSocket { @@ -375,17 +558,19 @@ impl<'f, S> WebSocket { where S: AsyncRead + AsyncWrite + Unpin, { + let waker = Arc::new(WakerDemux::default()); Self { - stream, + waker, + stream: StreamState::Reading(stream), write_half: WriteHalf::after_handshake(role), read_half: ReadHalf::after_handshake(role), } } + #[cfg(feature = "unstable-split")] /// Split a [`WebSocket`] into a [`WebSocketRead`] and [`WebSocketWrite`] half. Note that the split version does not /// handle fragmented packets and you may wish to create a [`FragmentCollectorRead`] over top of the read half that /// is returned. - #[cfg(feature = "unstable-split")] pub fn split( self, split_fn: impl Fn(S) -> (R, W), @@ -413,13 +598,13 @@ impl<'f, S> WebSocket { #[inline] pub fn into_inner(self) -> S { // self.write_half.into_inner().stream - self.stream + self.stream.into_inner() } /// Consumes the `WebSocket` and returns the underlying stream. #[inline] pub(crate) fn into_parts_internal(self) -> (S, ReadHalf, WriteHalf) { - (self.stream, self.read_half, self.write_half) + (self.stream.into_inner(), self.read_half, self.write_half) } /// Sets whether to use vectored writes. This option does not guarantee that vectored writes will be always used. @@ -463,6 +648,7 @@ impl<'f, S> WebSocket { self.write_half.auto_apply_mask = auto_apply_mask; } + /// Returns whether the connection is closed or not. pub fn is_closed(&self) -> bool { self.write_half.closed } @@ -491,10 +677,61 @@ impl<'f, S> WebSocket { where S: AsyncRead + AsyncWrite + Unpin, { - self.write_half.write_frame(&mut self.stream, frame).await?; + self + .write_half + .write_frame(self.stream.deref_mut(), frame) + .await?; Ok(()) } + /// Serializes a frame into the internal buffer. + /// + /// This method is similar to [Sink::start_send](https://docs.rs/futures/0.3.30/futures/sink/trait.Sink.html#tymethod.start_send). + pub fn start_send_frame( + &mut self, + frame: Frame<'f>, + ) -> Result<(), WebSocketError> + where + S: AsyncWrite + Unpin, + { + self.write_half.start_send_frame(frame) + } + + /// Serializes a frame into the internal buffer. + /// + /// Beware of the internal buffer. If the other end of the connection is not consuming fast enough it might fill fast. + /// + /// This method is similar to [Sink::start_send](https://docs.rs/futures/0.3.30/futures/sink/trait.Sink.html#tymethod.start_send). + #[inline(always)] + pub fn poll_write_frame( + &mut self, + cx: &mut Context<'_>, + frame: Frame<'f>, + ) -> Poll> + where + S: AsyncWrite + Unpin, + { + self.write_half.start_send_frame(frame)?; + self.poll_flush(cx) + } + + /// Flushes the internal buffer into the Stream. + /// + /// Returns Poll::Ready(Ok(())) when no more bytes are left. + #[inline(always)] + pub fn poll_flush( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> + where + S: AsyncWrite + Unpin, + { + self.waker.set_waker(ContextKind::Write, cx.waker()); + self.waker.with_context(|cx| { + self.write_half.poll_flush(self.stream.deref_mut(), cx) + }) + } + /// Reads a frame from the stream. /// /// This method will unmask the frame payload. For fragmented frames, use `FragmentCollector::read_frame`. @@ -521,37 +758,101 @@ impl<'f, S> WebSocket { /// Ok(()) /// } /// ``` + #[inline(always)] pub async fn read_frame(&mut self) -> Result, WebSocketError> + where + S: AsyncRead + AsyncWrite + Unpin, + { + poll_fn(|cx| self.poll_read_frame(cx)).await + } + + /// Polls the next frame from the Stream. + pub fn poll_read_frame( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, WebSocketError>> where S: AsyncRead + AsyncWrite + Unpin, { loop { - let (res, obligated_send) = - self.read_half.read_frame_inner(&mut self.stream).await; - let is_closed = self.write_half.closed; - if let Some(frame) = obligated_send { - if !is_closed { - self.write_half.write_frame(&mut self.stream, frame).await?; + match mem::replace(&mut self.stream, StreamState::None) { + StreamState::None => unreachable!(), + StreamState::Reading(mut stream) => { + let (res, obligated_send) = + match self.read_half.poll_read_frame(&mut stream, cx) { + Poll::Ready(res) => res, + Poll::Pending => { + self.stream = StreamState::Reading(stream); + break Poll::Pending; + } + }; + + let is_closed = self.write_half.closed; + + macro_rules! try_send_obligated { + () => { + if let Some(frame) = obligated_send { + // if the write half didn't emit the close frame + if !is_closed { + self.write_half.start_send_frame(frame)?; + self.stream = StreamState::Flushing(stream); + } else { + self.stream = StreamState::Reading(stream); + } + } else { + self.stream = StreamState::Reading(stream); + } + }; + } + + if let Some(frame) = res? { + if is_closed && frame.opcode != OpCode::Close { + self.stream = StreamState::Closed(stream); + break Poll::Ready(Err(WebSocketError::ConnectionClosed)); + } + + try_send_obligated!(); + break Poll::Ready(Ok(frame)); + } + + try_send_obligated!(); } - } - if let Some(frame) = res? { - if is_closed && frame.opcode != OpCode::Close { - return Err(WebSocketError::ConnectionClosed); + StreamState::Flushing(mut stream) => { + self.waker.set_waker(ContextKind::Read, cx.waker()); + + let res = self + .waker + .with_context(|cx| self.write_half.poll_flush(&mut stream, cx)); + match res { + Poll::Ready(ok) => { + self.stream = if self.is_closed() { + StreamState::Closed(stream) + } else { + StreamState::Reading(stream) + }; + ok?; + } + Poll::Pending => { + self.stream = StreamState::Flushing(stream); + } + } + } + StreamState::Closed(stream) => { + self.stream = StreamState::Closed(stream); + break Poll::Ready(Err(WebSocketError::ConnectionClosed)); } - break Ok(frame); } } } } -const MAX_HEADER_SIZE: usize = 14; - impl ReadHalf { pub fn after_handshake(role: Role) -> Self { let buffer = BytesMut::with_capacity(8192); Self { role, + read_state: None, auto_apply_mask: true, auto_close: true, auto_pong: true, @@ -567,16 +868,29 @@ impl ReadHalf { /// has been closed. /// /// XXX: Do not expose this method to the public API. - pub(crate) async fn read_frame_inner<'f, S>( + #[inline(always)] + pub(crate) async fn read_frame<'f, S>( &mut self, stream: &mut S, ) -> (Result>, WebSocketError>, Option>) where S: AsyncRead + Unpin, { - let mut frame = match self.parse_frame_header(stream).await { + poll_fn(|cx| self.poll_read_frame(stream, cx)).await + } + + /// Reads a frame from the Stream. + pub(crate) fn poll_read_frame<'f, S>( + &mut self, + stream: &mut S, + cx: &mut Context<'_>, + ) -> Poll<(Result>, WebSocketError>, Option>)> + where + S: AsyncRead + Unpin, + { + let mut frame = match ready!(self.poll_parse_frame_header(stream, cx)) { Ok(frame) => frame, - Err(e) => return (Err(e), None), + Err(e) => return Poll::Ready((Err(e), None)), }; if self.role == Role::Server && self.auto_apply_mask { @@ -587,7 +901,9 @@ impl ReadHalf { OpCode::Close if self.auto_close => { match frame.payload.len() { 0 => {} - 1 => return (Err(WebSocketError::InvalidCloseFrame), None), + 1 => { + return Poll::Ready((Err(WebSocketError::InvalidCloseFrame), None)) + } _ => { let code = close::CloseCode::from(u16::from_be_bytes( frame.payload[0..2].try_into().unwrap(), @@ -595,130 +911,187 @@ impl ReadHalf { #[cfg(feature = "simd")] if simdutf8::basic::from_utf8(&frame.payload[2..]).is_err() { - return (Err(WebSocketError::InvalidUTF8), None); + return Poll::Ready((Err(WebSocketError::InvalidUTF8), None)); }; #[cfg(not(feature = "simd"))] if std::str::from_utf8(&frame.payload[2..]).is_err() { - return (Err(WebSocketError::InvalidUTF8), None); + return Poll::Ready((Err(WebSocketError::InvalidUTF8), None)); }; if !code.is_allowed() { - return ( + return Poll::Ready(( Err(WebSocketError::InvalidCloseCode), Some(Frame::close(1002, &frame.payload[2..])), - ); + )); } } }; let obligated_send = Frame::close_raw(frame.payload.to_owned().into()); - (Ok(Some(frame)), Some(obligated_send)) + Poll::Ready((Ok(Some(frame)), Some(obligated_send))) } OpCode::Ping if self.auto_pong => { - (Ok(None), Some(Frame::pong(frame.payload))) + Poll::Ready((Ok(None), Some(Frame::pong(frame.payload)))) } OpCode::Text => { if frame.fin && !frame.is_utf8() { - (Err(WebSocketError::InvalidUTF8), None) + Poll::Ready((Err(WebSocketError::InvalidUTF8), None)) } else { - (Ok(Some(frame)), None) + Poll::Ready((Ok(Some(frame)), None)) } } - _ => (Ok(Some(frame)), None), + _ => Poll::Ready((Ok(Some(frame)), None)), } } - async fn parse_frame_header<'a, S>( + /// Reads a frame from the Stream parsing the headers. + fn poll_parse_frame_header<'a, S>( &mut self, stream: &mut S, - ) -> Result, WebSocketError> + cx: &mut Context<'_>, + ) -> Poll, WebSocketError>> where S: AsyncRead + Unpin, { - macro_rules! eof { - ($n:expr) => {{ - if $n == 0 { - return Err(WebSocketError::UnexpectedEOF); + macro_rules! read_next { + ($variant:expr,$value:expr) => {{ + let bytes_read = match tokio_util::io::poll_read_buf( + pin!(&mut *stream), + cx, + &mut self.buffer, + ) { + Poll::Ready(ready) => ready, + Poll::Pending => { + self.read_state = Some($variant($value)); + return Poll::Pending; + } + }?; + if bytes_read == 0 { + return Poll::Ready(Err(WebSocketError::UnexpectedEOF)); } }}; } - // Read the first two bytes - while self.buffer.remaining() < 2 { - eof!(stream.read_buf(&mut self.buffer).await?); - } - - let fin = self.buffer[0] & 0b10000000 != 0; - let rsv1 = self.buffer[0] & 0b01000000 != 0; - let rsv2 = self.buffer[0] & 0b00100000 != 0; - let rsv3 = self.buffer[0] & 0b00010000 != 0; - - if rsv1 || rsv2 || rsv3 { - return Err(WebSocketError::ReservedBitsNotZero); - } - - let opcode = frame::OpCode::try_from(self.buffer[0] & 0b00001111)?; - let masked = self.buffer[1] & 0b10000000 != 0; + loop { + match self.read_state.take() { + None => { + // Read the first two bytes + while self.buffer.remaining() < 2 { + let bytes_read = ready!(tokio_util::io::poll_read_buf( + pin!(&mut *stream), + cx, + &mut self.buffer + ))?; + if bytes_read == 0 { + return Poll::Ready(Err(WebSocketError::UnexpectedEOF)); + } + } - let length_code = self.buffer[1] & 0x7F; - let extra = match length_code { - 126 => 2, - 127 => 8, - _ => 0, - }; + let fin = self.buffer[0] & 0b10000000 != 0; + let rsv1 = self.buffer[0] & 0b01000000 != 0; + let rsv2 = self.buffer[0] & 0b00100000 != 0; + let rsv3 = self.buffer[0] & 0b00010000 != 0; - self.buffer.advance(2); - while self.buffer.remaining() < extra + masked as usize * 4 { - eof!(stream.read_buf(&mut self.buffer).await?); - } + if rsv1 || rsv2 || rsv3 { + return Poll::Ready(Err(WebSocketError::ReservedBitsNotZero)); + } - let payload_len: usize = match extra { - 0 => usize::from(length_code), - 2 => self.buffer.get_u16() as usize, - #[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))] - 8 => self.buffer.get_u64() as usize, - // On 32bit systems, usize is only 4bytes wide so we must check for usize overflowing - #[cfg(any( - target_pointer_width = "8", - target_pointer_width = "16", - target_pointer_width = "32" - ))] - 8 => match usize::try_from(self.buffer.get_u64()) { - Ok(length) => length, - Err(_) => return Err(WebSocketError::FrameTooLarge), - }, - _ => unreachable!(), - }; + let opcode = frame::OpCode::try_from(self.buffer[0] & 0b00001111)?; + let masked = self.buffer[1] & 0b10000000 != 0; + + let length_code = self.buffer[1] & 0x7F; + let extra = match length_code { + 126 => 2, + 127 => 8, + _ => 0, + }; + + let header_size = extra + masked as usize * 4; + self.buffer.advance(2); + + self.read_state = Some(ReadState::Header(Header { + fin, + masked, + opcode, + length_code, + extra, + header_size, + })); + } + Some(ReadState::Header(header)) => { + // total header size + while self.buffer.remaining() < header.header_size { + read_next!(ReadState::Header, header); + } - let mask = if masked { - Some(self.buffer.get_u32().to_be_bytes()) - } else { - None - }; + let payload_len: usize = match header.extra { + 0 => usize::from(header.length_code), + 2 => self.buffer.get_u16() as usize, + #[cfg(any( + target_pointer_width = "64", + target_pointer_width = "128" + ))] + 8 => self.buffer.get_u64() as usize, + // On 32bit systems, usize is only 4bytes wide so we must check for usize overflowing + #[cfg(any( + target_pointer_width = "8", + target_pointer_width = "16", + target_pointer_width = "32" + ))] + 8 => match usize::try_from(self.buffer.get_u64()) { + Ok(length) => length, + Err(_) => return Err(WebSocketError::FrameTooLarge), + }, + _ => unreachable!(), + }; + + let mask = if header.masked { + Some(self.buffer.get_u32().to_be_bytes()) + } else { + None + }; + + if frame::is_control(header.opcode) && !header.fin { + return Poll::Ready(Err(WebSocketError::ControlFrameFragmented)); + } - if frame::is_control(opcode) && !fin { - return Err(WebSocketError::ControlFrameFragmented); - } + if header.opcode == OpCode::Ping && payload_len > 125 { + return Poll::Ready(Err(WebSocketError::PingFrameTooLarge)); + } - if opcode == OpCode::Ping && payload_len > 125 { - return Err(WebSocketError::PingFrameTooLarge); - } + if payload_len >= self.max_message_size { + return Poll::Ready(Err(WebSocketError::FrameTooLarge)); + } - if payload_len >= self.max_message_size { - return Err(WebSocketError::FrameTooLarge); - } + self.read_state = Some(ReadState::Payload(HeaderAndMask { + header, + mask, + payload_len, + })); + } + Some(ReadState::Payload(header_and_mask)) => { + // Reserve a bit more to try to get next frame header and avoid a syscall to read it next time + self.buffer.reserve(header_and_mask.payload_len + 14); + while self.buffer.remaining() < header_and_mask.payload_len { + read_next!(ReadState::Payload, header_and_mask); + } - // Reserve a bit more to try to get next frame header and avoid a syscall to read it next time - self.buffer.reserve(payload_len + MAX_HEADER_SIZE); - while payload_len > self.buffer.remaining() { - eof!(stream.read_buf(&mut self.buffer).await?); + let header = header_and_mask.header; + let mask = header_and_mask.mask; + let payload_len = header_and_mask.payload_len; + + let payload = self.buffer.split_to(payload_len); + let frame = Frame::new( + header.fin, + header.opcode, + mask, + Payload::Bytes(payload), + ); + break Poll::Ready(Ok(frame)); + } + } } - - // if we read too much it will stay in the buffer, for the next call to this method - let payload = self.buffer.split_to(payload_len); - let frame = Frame::new(fin, opcode, mask, Payload::Bytes(payload)); - Ok(frame) } } @@ -730,7 +1103,9 @@ impl WriteHalf { auto_apply_mask: true, vectored: true, writev_threshold: 1024, - write_buffer: Vec::with_capacity(2), + read_head: 0, + write_buffer: Vec::with_capacity(1024), + payloads: VecDeque::with_capacity(1), } } @@ -738,11 +1113,60 @@ impl WriteHalf { pub async fn write_frame<'a, S>( &'a mut self, stream: &mut S, - mut frame: Frame<'a>, + frame: Frame<'a>, ) -> Result<(), WebSocketError> where S: AsyncWrite + Unpin, { + // maybe_frame determines the state. + // If a frame is present we need to poll_ready, else flush it. + let mut maybe_frame = Some(frame); + poll_fn(|cx| loop { + match maybe_frame.take() { + Some(frame) => match self.poll_ready(stream, cx) { + Poll::Ready(res) => { + res?; + self.start_send_frame(frame)?; + } + Poll::Pending => { + maybe_frame = Some(frame); + return Poll::Pending; + } + }, + None => { + return self.poll_flush(stream, cx); + } + } + }) + .await + } + + /// Ensures that the underlying connection is ready. It will try to flush the contents if any. + /// + /// If you prefer to buffer requests as much as possible you can skip this step, generally and + /// call start_send_frame. + pub fn poll_ready( + &mut self, + stream: &mut S, + cx: &mut Context<'_>, + ) -> Poll> + where + S: AsyncWrite + Unpin, + { + while self.read_head < self.write_buffer.len() || !self.payloads.is_empty() + { + ready!(self.write(stream, cx))?; + } + + Poll::Ready(Ok(())) + } + + pub fn start_send_frame<'a>( + &'a mut self, + mut frame: Frame<'a>, + ) -> Result<(), WebSocketError> { + // TODO(dario): backpressure check? tokio codec does it + if self.role == Role::Client && self.auto_apply_mask { frame.mask(); } @@ -753,19 +1177,106 @@ impl WriteHalf { return Err(WebSocketError::ConnectionClosed); } - if self.vectored && frame.payload.len() > self.writev_threshold { - frame.writev(stream).await?; - } else { - let text = frame.write(&mut self.write_buffer); - stream.write_all(text).await?; + // TODO(dgrr): Cap max payload size with a user setting? + + let payload_len = frame.payload.len(); + let max_len = payload_len + MAX_HEAD_SIZE; + if self.write_buffer.len() + max_len > self.write_buffer.capacity() { + // if the len we need for this frame will require a realloc, let's clear the written head of the buffer + self.write_buffer.splice(0..self.read_head, [0u8; 0]); + self.read_head = 0; + self.write_buffer.reserve(max_len); + } + // resize the buffer so we have room to write the head + let current_len = self.write_buffer.len(); + self.write_buffer.resize(current_len + MAX_HEAD_SIZE, 0); + + let buf = &mut self.write_buffer[current_len..]; + let size = frame.fmt_head(buf); + self.write_buffer.truncate(current_len + size); + + let vectored = self.vectored && frame.payload.len() > self.writev_threshold; + match frame.payload { + Payload::Owned(b) if vectored => self.payloads.push_back(WriteBuffer { + position: self.write_buffer.len(), + read_head: 0, + payload: Payload::Owned(b), + }), + Payload::Bytes(b) if vectored => self.payloads.push_back(WriteBuffer { + position: self.write_buffer.len(), + read_head: 0, + payload: Payload::Bytes(b), + }), + _ => { + self.write_buffer.extend_from_slice(&frame.payload); + } } Ok(()) } + + pub fn poll_flush<'a, S>( + &'a mut self, + stream: &mut S, + cx: &mut Context<'_>, + ) -> Poll> + where + S: AsyncWrite + Unpin, + { + ready!(self.poll_ready(stream, cx))?; + + // flush the stream + pin!(&mut *stream).poll_flush(cx).map_err(Into::into) + } + + fn write( + &mut self, + stream: &mut S, + cx: &mut Context<'_>, + ) -> Poll> + where + S: AsyncWrite + Unpin, + { + let written = if let Some(front) = self.payloads.front_mut() { + let b = [ + IoSlice::new(&self.write_buffer[self.read_head..front.position]), + IoSlice::new(&front.payload), + ]; + + let written = ready!(pin!(&mut *stream).poll_write_vectored(cx, &b))?; + + if written < b[0].len() { + self.read_head += written; + } else { + let written = written - b[0].len(); + self.read_head = front.position; + front.read_head += written; + if front.read_head == front.payload.len() { + self.payloads.pop_front(); + } + } + + written + } else { + let written = + ready!(pin!(&mut *stream) + .poll_write(cx, &self.write_buffer[self.read_head..]))?; + self.read_head += written; + written + }; + + if written == 0 { + return Poll::Ready(Err(WebSocketError::ConnectionClosed)); + } + + Poll::Ready(Ok(())) + } } #[cfg(test)] mod tests { + use std::ops::Deref; + use super::*; const _: () = { @@ -792,4 +1303,93 @@ mod tests { } assert_unsync::>(); }; + + #[tokio::test] + async fn test_contiguous_simple_and_vectored_writes() { + struct MockStream(Vec); + + impl AsyncRead for MockStream { + fn poll_read( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + if this.0.is_empty() { + return Poll::Ready(Ok(())); + } + + let size_before = buf.filled().len(); + buf.put_slice(&this.0); + let diff = buf.filled().len() - size_before; + + this.0.drain(..diff); + + Poll::Ready(Ok(())) + } + } + + impl AsyncWrite for MockStream { + fn poll_write( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.get_mut().0.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + let simple_string = b"1234".to_vec(); + // copy this string more than 1024 times to trigger the vector writes + let long_string = b"A".repeat(1025); + + let mut stream = MockStream(vec![]); + let mut write_half = super::WriteHalf::after_handshake(Role::Server); + let mut read_half = super::ReadHalf::after_handshake(Role::Client); + + poll_fn(|cx| { + // write + assert!(write_half.poll_ready(&mut stream, cx).is_ready()); + // serialize both frames at the same time + assert!(write_half + .start_send_frame(Frame::text(Payload::Owned(simple_string.clone()))) + .is_ok()); + assert!(write_half + .start_send_frame(Frame::text(Payload::Owned(long_string.clone()))) + .is_ok()); + assert!(write_half.poll_flush(&mut stream, cx).is_ready()); + + // read + for body in [&simple_string, &long_string] { + let Poll::Ready((res, mandatory_send)) = + read_half.poll_read_frame(&mut stream, cx) + else { + unreachable!() + }; + + assert!(mandatory_send.is_none()); + + let frame = res.unwrap().unwrap(); + assert_eq!(frame.payload.deref(), body); + } + + Poll::Ready(()) + }) + .await; + } }