diff --git a/benches/write.rs b/benches/write.rs index ae1512f2..573fe944 100644 --- a/benches/write.rs +++ b/benches/write.rs @@ -1,8 +1,9 @@ //! Benchmarks for write performance. +use bytes::{BufMut, BytesMut}; use criterion::Criterion; use std::{ - hint, - io::{self, Read, Write}, + fmt::Write as _, + hint, io, time::{Duration, Instant}, }; use tungstenite::{protocol::Role, Message, WebSocket}; @@ -16,12 +17,12 @@ const MOCK_WRITE_LEN: usize = 8 * 1024 * 1024; /// Each `flush` takes **~8µs** to simulate flush io. struct MockWrite(Vec); -impl Read for MockWrite { +impl io::Read for MockWrite { fn read(&mut self, _: &mut [u8]) -> io::Result { Err(io::Error::new(io::ErrorKind::WouldBlock, "reads not supported")) } } -impl Write for MockWrite { +impl io::Write for MockWrite { fn write(&mut self, buf: &[u8]) -> io::Result { if self.0.len() + buf.len() > MOCK_WRITE_LEN { self.flush()?; @@ -54,11 +55,19 @@ fn benchmark(c: &mut Criterion) { let mut ws = WebSocket::from_raw_socket(MockWrite(Vec::with_capacity(MOCK_WRITE_LEN)), role, None); + let mut buf = BytesMut::with_capacity(128 * 1024); + b.iter(|| { for i in 0_u64..100_000 { let msg = match i { - _ if i % 3 == 0 => Message::binary(i.to_le_bytes().to_vec()), - _ => Message::text(format!("{{\"id\":{i}}}")), + _ if i % 3 == 0 => { + buf.put_slice(&i.to_le_bytes()); + Message::binary(buf.split()) + } + _ => { + buf.write_fmt(format_args!("{{\"id\":{i}}}")).unwrap(); + Message::Text(buf.split().try_into().unwrap()) + } }; ws.write(msg).unwrap(); } diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index 5f1d4698..3a9ff515 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -5,7 +5,7 @@ use std::{ io::{Cursor, ErrorKind, Read, Write}, result::Result as StdResult, str::Utf8Error, - string::{FromUtf8Error, String}, + string::String, }; use byteorder::{NetworkEndian, ReadBytesExt}; @@ -16,7 +16,11 @@ use super::{ mask::{apply_mask, generate_mask}, Payload, }; -use crate::error::{Error, ProtocolError, Result}; +use crate::{ + error::{Error, ProtocolError, Result}, + protocol::frame::Utf8Payload, +}; +use bytes::{Buf, BytesMut}; /// A struct representing the close command. #[derive(Debug, Clone, Eq, PartialEq)] @@ -209,7 +213,7 @@ impl FrameHeader { #[derive(Debug, Clone, Eq, PartialEq)] pub struct Frame { header: FrameHeader, - pub(crate) payload: Payload, + payload: Payload, } impl Frame { @@ -275,15 +279,9 @@ impl Frame { } } - /// Consume the frame into its payload as binary. - #[inline] - pub fn into_data(self) -> Vec { - self.payload.into_data() - } - /// Consume the frame into its payload as string. #[inline] - pub fn into_text(self) -> StdResult { + pub fn into_text(self) -> StdResult { self.payload.into_text() } @@ -308,8 +306,8 @@ impl Frame { _ => { let mut data = self.payload.into_data(); let code = u16::from_be_bytes([data[0], data[1]]).into(); - data.drain(0..2); - let text = String::from_utf8(data)?; + data.advance(2); + let text = String::from_utf8(data.to_vec())?; Ok(Some(CloseFrame { code, reason: text.into() })) } } @@ -353,12 +351,12 @@ impl Frame { #[inline] pub fn close(msg: Option) -> Frame { let payload = if let Some(CloseFrame { code, reason }) = msg { - let mut p = Vec::with_capacity(reason.len() + 2); + let mut p = BytesMut::with_capacity(reason.len() + 2); p.extend(u16::from(code).to_be_bytes()); p.extend_from_slice(reason.as_bytes()); p } else { - Vec::new() + <_>::default() }; Frame { header: FrameHeader::default(), payload: Payload::Owned(payload) } @@ -476,7 +474,7 @@ mod tests { let mut payload = Vec::new(); raw.read_to_end(&mut payload).unwrap(); let frame = Frame::from_payload(header, payload.into()); - assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); + assert_eq!(frame.into_payload().as_slice(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); } #[test] @@ -489,7 +487,7 @@ mod tests { #[test] fn display() { - let f = Frame::message("hi there", OpCode::Data(Data::Text), true); + let f = Frame::message(Payload::from_static(b"hi there"), OpCode::Data(Data::Text), true); let view = format!("{f}"); assert!(view.contains("payload:")); } diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index e2d8a64e..76baeae7 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -185,7 +185,7 @@ impl FrameCodec { } if len <= self.in_buffer.len() { - break self.in_buffer.split_to(len).freeze(); + break self.in_buffer.split_to(len); } } } @@ -206,7 +206,7 @@ impl FrameCodec { let (header, length) = self.header.take().expect("Bug: no frame header"); debug_assert_eq!(payload.len() as u64, length); - let frame = Frame::from_payload(header, payload.into()); + let frame = Frame::from_payload(header, Payload::Owned(payload)); trace!("received frame {frame}"); Ok(Some(frame)) } @@ -282,10 +282,13 @@ mod tests { let mut sock = FrameSocket::new(raw); assert_eq!( - sock.read(None).unwrap().unwrap().into_data(), - vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + sock.read(None).unwrap().unwrap().into_payload().as_slice(), + &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + ); + assert_eq!( + sock.read(None).unwrap().unwrap().into_payload().as_slice(), + &[0x03, 0x02, 0x01] ); - assert_eq!(sock.read(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]); assert!(sock.read(None).unwrap().is_none()); let (_, rest) = sock.into_inner(); @@ -297,8 +300,8 @@ mod tests { let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]); assert_eq!( - sock.read(None).unwrap().unwrap().into_data(), - vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] + sock.read(None).unwrap().unwrap().into_payload().as_slice(), + &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07] ); } diff --git a/src/protocol/frame/payload.rs b/src/protocol/frame/payload.rs index b92f209c..eb7a5a29 100644 --- a/src/protocol/frame/payload.rs +++ b/src/protocol/frame/payload.rs @@ -1,7 +1,6 @@ -use std::{fmt::Display, mem, string::FromUtf8Error}; - -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use core::str; +use std::{fmt::Display, mem}; /// Utf8 payload. #[derive(Debug, Default, Clone, Eq, PartialEq)] @@ -46,6 +45,24 @@ impl TryFrom for Utf8Payload { } } +impl TryFrom for Utf8Payload { + type Error = str::Utf8Error; + + #[inline] + fn try_from(bytes: Bytes) -> Result { + Payload::from(bytes).try_into() + } +} + +impl TryFrom for Utf8Payload { + type Error = str::Utf8Error; + + #[inline] + fn try_from(bytes: BytesMut) -> Result { + Payload::from(bytes).try_into() + } +} + impl From for Utf8Payload { #[inline] fn from(s: String) -> Self { @@ -56,7 +73,7 @@ impl From for Utf8Payload { impl From<&str> for Utf8Payload { #[inline] fn from(s: &str) -> Self { - Self(Payload::Owned(s.as_bytes().to_vec())) + Self(Payload::Owned(s.as_bytes().into())) } } @@ -84,7 +101,7 @@ impl Display for Utf8Payload { #[derive(Debug, Clone)] pub enum Payload { /// Owned data with unique ownership. - Owned(Vec), + Owned(BytesMut), /// Shared data with shared ownership. Shared(Bytes), } @@ -103,6 +120,15 @@ impl Payload { Self::Shared(Bytes::from_owner(owner)) } + /// If owned converts into shared & then clones (cheaply). + #[inline] + pub fn share(&mut self) -> Self { + if let Self::Owned(bytes) = self { + *self = Self::Shared(mem::take(bytes).freeze()); + } + self.clone() + } + /// Returns a slice of the payload. #[inline] pub fn as_slice(&self) -> &[u8] { @@ -144,7 +170,7 @@ impl Payload { /// Consumes the payload and returns the underlying data as a vector. #[inline] - pub fn into_data(self) -> Vec { + pub fn into_data(self) -> BytesMut { match self { Payload::Owned(v) => v, Payload::Shared(v) => v.into(), @@ -153,32 +179,29 @@ impl Payload { /// Consumes the payload and returns the underlying data as a string. #[inline] - pub fn into_text(self) -> Result { - match self { - Payload::Owned(v) => Ok(String::from_utf8(v)?), - Payload::Shared(v) => Ok(String::from_utf8(v.into())?), - } + pub fn into_text(self) -> Result { + self.try_into() } } impl Default for Payload { #[inline] fn default() -> Self { - Self::Owned(Vec::new()) + Self::Owned(<_>::default()) } } impl From> for Payload { #[inline] fn from(v: Vec) -> Self { - Payload::Owned(v) + Payload::Owned(BytesMut::from_iter(v)) } } impl From for Payload { #[inline] fn from(v: String) -> Self { - Payload::Owned(v.into()) + Vec::from(v).into() } } @@ -189,17 +212,17 @@ impl From for Payload { } } -impl From<&'static [u8]> for Payload { +impl From for Payload { #[inline] - fn from(v: &'static [u8]) -> Self { - Self::from_static(v) + fn from(v: BytesMut) -> Self { + Payload::Owned(v) } } -impl From<&'static str> for Payload { +impl From<&[u8]> for Payload { #[inline] - fn from(v: &'static str) -> Self { - Self::from_static(v.as_bytes()) + fn from(v: &[u8]) -> Self { + Self::Owned(v.into()) } } diff --git a/src/protocol/message.rs b/src/protocol/message.rs index a39a89cd..c8011c15 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -241,7 +241,7 @@ impl Message { Cow::Borrowed(s) => Payload::from_static(s.as_bytes()), Cow::Owned(s) => s.into(), }, - Message::Frame(frame) => frame.payload, + Message::Frame(frame) => frame.into_payload(), } } @@ -257,7 +257,7 @@ impl Message { Cow::Borrowed(s) => Utf8Payload::from_static(s), Cow::Owned(s) => s.into(), }), - Message::Frame(frame) => Ok(frame.payload.try_into()?), + Message::Frame(frame) => Ok(frame.into_text()?), } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 2dee5a12..faa0415a 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -603,12 +603,12 @@ impl WebSocketContext { Err(Error::Protocol(ProtocolError::UnknownControlFrameType(i))) } OpCtl::Ping => { - let data = frame.into_data(); + let mut data = frame.into_payload(); // No ping processing after we sent a close frame. if self.state.is_active() { - self.set_additional(Frame::pong(data.clone())); + self.set_additional(Frame::pong(data.share())); } - Ok(Some(Message::Ping(data.into()))) + Ok(Some(Message::Ping(data))) } OpCtl::Pong => Ok(Some(Message::Pong(frame.into_payload()))), } @@ -619,7 +619,10 @@ impl WebSocketContext { match data { OpData::Continue => { if let Some(ref mut msg) = self.incomplete { - msg.extend(frame.into_data(), self.config.max_message_size)?; + msg.extend( + frame.into_payload().as_slice(), + self.config.max_message_size, + )?; } else { return Err(Error::Protocol( ProtocolError::UnexpectedContinueFrame, @@ -634,23 +637,21 @@ impl WebSocketContext { c if self.incomplete.is_some() => { Err(Error::Protocol(ProtocolError::ExpectedFragment(c))) } + OpData::Text if fin => Ok(Some(Message::Text(frame.into_text()?))), + OpData::Binary if fin => Ok(Some(Message::Binary(frame.into_payload()))), OpData::Text | OpData::Binary => { - let msg = { - let message_type = match data { - OpData::Text => IncompleteMessageType::Text, - OpData::Binary => IncompleteMessageType::Binary, - _ => panic!("Bug: message is not text nor binary"), - }; - let mut m = IncompleteMessage::new(message_type); - m.extend(frame.into_data(), self.config.max_message_size)?; - m + let message_type = match data { + OpData::Text => IncompleteMessageType::Text, + OpData::Binary => IncompleteMessageType::Binary, + _ => panic!("Bug: message is not text nor binary"), }; - if fin { - Ok(Some(msg.complete()?)) - } else { - self.incomplete = Some(msg); - Ok(None) - } + let mut incomplete = IncompleteMessage::new(message_type); + incomplete.extend( + frame.into_payload().as_slice(), + self.config.max_message_size, + )?; + self.incomplete = Some(incomplete); + Ok(None) } OpData::Reserved(i) => { Err(Error::Protocol(ProtocolError::UnknownDataFrameType(i)))