From 6c5aeb2418bc0a08c856be5759eb0c2661c8eba0 Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Sat, 14 Dec 2024 03:57:05 +0000 Subject: [PATCH] UtfPayload --- benches/read.rs | 4 +- benches/write.rs | 2 +- examples/autobahn-client.rs | 2 +- src/lib.rs | 2 +- src/protocol/frame/frame.rs | 2 +- src/protocol/frame/mod.rs | 64 +++++++++------- src/protocol/frame/payload.rs | 134 +++++++++++++++++++++++++++++++++- src/protocol/message.rs | 56 +++++++------- 8 files changed, 203 insertions(+), 63 deletions(-) diff --git a/benches/read.rs b/benches/read.rs index a60c86e0..252ef3e3 100644 --- a/benches/read.rs +++ b/benches/read.rs @@ -52,7 +52,7 @@ fn benchmark(c: &mut Criterion) { writer .send(match i { _ if i % 3 == 0 => Message::binary(i.to_le_bytes().to_vec()), - _ => Message::Text(format!("{{\"id\":{i}}}")), + _ => Message::text(format!("{{\"id\":{i}}}")), }) .unwrap(); sum += i; @@ -68,7 +68,7 @@ fn benchmark(c: &mut Criterion) { sum += u64::from_le_bytes(*a); } Message::Text(msg) => { - let i: u64 = msg[6..msg.len() - 1].parse().unwrap(); + let i: u64 = msg.as_str()[6..msg.len() - 1].parse().unwrap(); sum += i; } m => panic!("Unexpected {m}"), diff --git a/benches/write.rs b/benches/write.rs index 435f9a3b..ae1512f2 100644 --- a/benches/write.rs +++ b/benches/write.rs @@ -58,7 +58,7 @@ fn benchmark(c: &mut Criterion) { for i in 0_u64..100_000 { let msg = match i { _ if i % 3 == 0 => Message::binary(i.to_le_bytes().to_vec()), - _ => Message::Text(format!("{{\"id\":{i}}}")), + _ => Message::text(format!("{{\"id\":{i}}}")), }; ws.write(msg).unwrap(); } diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index dcc3e75f..4ba5f6b9 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -8,7 +8,7 @@ fn get_case_count() -> Result { let (mut socket, _) = connect("ws://localhost:9001/getCaseCount")?; let msg = socket.read()?; socket.close(None)?; - Ok(msg.into_text()?.parse::().unwrap()) + Ok(msg.into_text()?.as_str().parse::().unwrap()) } fn update_reports() -> Result<()> { diff --git a/src/lib.rs b/src/lib.rs index 6b79d5be..d6189688 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ //! Lightweight, flexible WebSockets for Rust. #![deny( - missing_docs, + // missing_docs, missing_copy_implementations, missing_debug_implementations, trivial_casts, diff --git a/src/protocol/frame/frame.rs b/src/protocol/frame/frame.rs index c484c06b..5f1d4698 100644 --- a/src/protocol/frame/frame.rs +++ b/src/protocol/frame/frame.rs @@ -209,7 +209,7 @@ impl FrameHeader { #[derive(Debug, Clone, Eq, PartialEq)] pub struct Frame { header: FrameHeader, - payload: Payload, + pub(crate) payload: Payload, } impl Frame { diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs index eb714b80..e2d8a64e 100644 --- a/src/protocol/frame/mod.rs +++ b/src/protocol/frame/mod.rs @@ -7,17 +7,18 @@ mod frame; mod mask; mod payload; +pub use self::{ + frame::{CloseFrame, Frame, FrameHeader}, + payload::{Payload, Utf8Payload}, +}; + use crate::{ error::{CapacityError, Error, Result}, - Message, ReadBuffer, + Message, }; +use bytes::BytesMut; use log::*; -use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write}; - -pub use self::{ - frame::{CloseFrame, Frame, FrameHeader}, - payload::Payload, -}; +use std::io::{Cursor, Error as IoError, ErrorKind as IoErrorKind, Read, Write}; /// A reader and writer for WebSocket frames. #[derive(Debug)] @@ -40,8 +41,8 @@ impl FrameSocket { } /// Extract a stream from the socket. - pub fn into_inner(self) -> (Stream, Vec) { - (self.stream, self.codec.in_buffer.into_vec()) + pub fn into_inner(self) -> (Stream, BytesMut) { + (self.stream, self.codec.in_buffer) } /// Returns a shared reference to the inner stream. @@ -98,7 +99,7 @@ where #[derive(Debug)] pub(super) struct FrameCodec { /// Buffer to read data from the stream. - in_buffer: ReadBuffer, + in_buffer: BytesMut, /// Buffer to send packets to the network. out_buffer: Vec, /// Capacity limit for `out_buffer`. @@ -113,11 +114,13 @@ pub(super) struct FrameCodec { header: Option<(FrameHeader, u64)>, } +const READ_BUFFER_CAP: usize = 64 * 1024; + impl FrameCodec { /// Create a new frame codec. pub(super) fn new() -> Self { Self { - in_buffer: ReadBuffer::new(), + in_buffer: BytesMut::with_capacity(READ_BUFFER_CAP), out_buffer: Vec::new(), max_out_buffer_len: usize::MAX, out_buffer_write_len: 0, @@ -127,8 +130,10 @@ impl FrameCodec { /// Create a new frame codec from partially read data. pub(super) fn from_partially_read(part: Vec) -> Self { + let mut in_buffer = BytesMut::from_iter(part); + in_buffer.reserve(READ_BUFFER_CAP.saturating_sub(in_buffer.len())); Self { - in_buffer: ReadBuffer::from_partially_read(part), + in_buffer, out_buffer: Vec::new(), max_out_buffer_len: usize::MAX, out_buffer_write_len: 0, @@ -160,38 +165,39 @@ impl FrameCodec { let payload = loop { { - let cursor = self.in_buffer.as_cursor_mut(); - if self.header.is_none() { - self.header = FrameHeader::parse(cursor)?; + let mut cursor = Cursor::new(&mut self.in_buffer); + self.header = FrameHeader::parse(&mut cursor)?; + let advanced = cursor.position(); + bytes::Buf::advance(&mut self.in_buffer, advanced as _); } - if let Some((_, ref length)) = self.header { - let length = *length; + if let Some((_, len)) = &self.header { + let len = *len as usize; // Enforce frame size limit early and make sure `length` // is not too big (fits into `usize`). - if length > max_size as u64 { + if len > max_size { return Err(Error::Capacity(CapacityError::MessageTooLong { - size: length as usize, + size: len, max_size, })); } - let input_size = cursor.get_ref().len() as u64 - cursor.position(); - if length <= input_size { - // No truncation here since `length` is checked above - let mut payload = Vec::with_capacity(length as usize); - if length > 0 { - cursor.take(length).read_to_end(&mut payload)?; - } - break payload; + if len <= self.in_buffer.len() { + break self.in_buffer.split_to(len).freeze(); } } } // Not enough data in buffer. - let size = self.in_buffer.read_from(stream)?; + let reserve_len = self.header.as_ref().map(|(_, l)| *l as usize).unwrap_or(6); + self.in_buffer.reserve(reserve_len); + let mut buf = self.in_buffer.split_off(self.in_buffer.len()); + buf.resize(reserve_len.max(buf.capacity()), 0); + let size = stream.read(&mut buf)?; + buf.truncate(size); + self.in_buffer.unsplit(buf); if size == 0 { trace!("no frame received"); return Ok(None); @@ -267,6 +273,8 @@ mod tests { #[test] fn read_frames() { + env_logger::init(); + let raw = Cursor::new(vec![ 0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x03, 0x03, 0x02, 0x01, 0x99, diff --git a/src/protocol/frame/payload.rs b/src/protocol/frame/payload.rs index fac722c3..b92f209c 100644 --- a/src/protocol/frame/payload.rs +++ b/src/protocol/frame/payload.rs @@ -1,9 +1,87 @@ -use std::{mem, string::FromUtf8Error}; +use std::{fmt::Display, mem, string::FromUtf8Error}; use bytes::Bytes; +use core::str; + +/// Utf8 payload. +#[derive(Debug, Default, Clone, Eq, PartialEq)] +pub struct Utf8Payload(Payload); + +impl Utf8Payload { + #[inline] + pub const fn from_static(str: &'static str) -> Self { + Self(Payload::Shared(Bytes::from_static(str.as_bytes()))) + } + + /// Returns a slice of the payload. + #[inline] + pub fn as_slice(&self) -> &[u8] { + self.0.as_slice() + } + + #[inline] + pub fn as_str(&self) -> &str { + // safety: is valid uft8 + unsafe { str::from_utf8_unchecked(self.as_slice()) } + } + + #[inline] + pub fn len(&self) -> usize { + self.as_slice().len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl TryFrom for Utf8Payload { + type Error = str::Utf8Error; + + #[inline] + fn try_from(payload: Payload) -> Result { + str::from_utf8(payload.as_slice())?; + Ok(Self(payload)) + } +} + +impl From for Utf8Payload { + #[inline] + fn from(s: String) -> Self { + Self(s.into()) + } +} + +impl From<&str> for Utf8Payload { + #[inline] + fn from(s: &str) -> Self { + Self(Payload::Owned(s.as_bytes().to_vec())) + } +} + +impl From<&String> for Utf8Payload { + #[inline] + fn from(s: &String) -> Self { + s.as_str().into() + } +} + +impl From for Payload { + #[inline] + fn from(Utf8Payload(payload): Utf8Payload) -> Self { + payload + } +} + +impl Display for Utf8Payload { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} /// A payload of a WebSocket frame. -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone)] pub enum Payload { /// Owned data with unique ownership. Owned(Vec), @@ -12,6 +90,19 @@ pub enum Payload { } impl Payload { + #[inline] + pub const fn from_static(bytes: &'static [u8]) -> Self { + Self::Shared(Bytes::from_static(bytes)) + } + + #[inline] + pub fn from_owner(owner: T) -> Self + where + T: AsRef<[u8]> + Send + 'static, + { + Self::Shared(Bytes::from_owner(owner)) + } + /// Returns a slice of the payload. #[inline] pub fn as_slice(&self) -> &[u8] { @@ -70,32 +161,67 @@ impl Payload { } } +impl Default for Payload { + #[inline] + fn default() -> Self { + Self::Owned(Vec::new()) + } +} + impl From> for Payload { + #[inline] fn from(v: Vec) -> Self { Payload::Owned(v) } } impl From for Payload { + #[inline] fn from(v: String) -> Self { Payload::Owned(v.into()) } } impl From for Payload { + #[inline] fn from(v: Bytes) -> Self { Payload::Shared(v) } } impl From<&'static [u8]> for Payload { + #[inline] fn from(v: &'static [u8]) -> Self { - Payload::Shared(Bytes::from_static(v)) + Self::from_static(v) } } impl From<&'static str> for Payload { + #[inline] fn from(v: &'static str) -> Self { - Payload::Shared(Bytes::from_static(v.as_bytes())) + Self::from_static(v.as_bytes()) + } +} + +impl PartialEq for Payload { + #[inline] + fn eq(&self, other: &Payload) -> bool { + self.as_slice() == other.as_slice() + } +} + +impl Eq for Payload {} + +impl PartialEq<[u8]> for Payload { + #[inline] + fn eq(&self, other: &[u8]) -> bool { + self.as_slice() == other + } +} + +impl PartialEq<&[u8; N]> for Payload { + #[inline] + fn eq(&self, other: &&[u8; N]) -> bool { + self.as_slice() == *other } } diff --git a/src/protocol/message.rs b/src/protocol/message.rs index 502906ba..a39a89cd 100644 --- a/src/protocol/message.rs +++ b/src/protocol/message.rs @@ -1,6 +1,6 @@ -use std::{fmt, result::Result as StdResult, str}; +use std::{borrow::Cow, fmt, result::Result as StdResult, str}; -use super::frame::{CloseFrame, Frame, Payload}; +use super::frame::{CloseFrame, Frame, Payload, Utf8Payload}; use crate::error::{CapacityError, Error, Result}; mod string_collect { @@ -138,7 +138,7 @@ impl IncompleteMessage { IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v.into())), IncompleteMessageCollector::Text(t) => { let text = t.into_string()?; - Ok(Message::Text(text)) + Ok(Message::text(text)) } } } @@ -154,7 +154,7 @@ pub enum IncompleteMessageType { #[derive(Debug, Eq, PartialEq, Clone)] pub enum Message { /// A text WebSocket message - Text(String), + Text(Utf8Payload), /// A binary WebSocket message Binary(Payload), /// A ping message with the specified payload @@ -175,7 +175,7 @@ impl Message { /// Create a new text WebSocket message from a stringable. pub fn text(string: S) -> Message where - S: Into, + S: Into, { Message::Text(string.into()) } @@ -232,26 +232,32 @@ impl Message { } /// Consume the WebSocket and return it as binary data. - pub fn into_data(self) -> Vec { + pub fn into_data(self) -> Payload { match self { - Message::Text(string) => string.into_bytes(), - Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data.into_data(), - Message::Close(None) => Vec::new(), - Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), - Message::Frame(frame) => frame.into_data(), + Message::Text(string) => string.into(), + Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data, + Message::Close(None) => <_>::default(), + Message::Close(Some(frame)) => match frame.reason { + Cow::Borrowed(s) => Payload::from_static(s.as_bytes()), + Cow::Owned(s) => s.into(), + }, + Message::Frame(frame) => frame.payload, } } /// Attempt to consume the WebSocket message and convert it to a String. - pub fn into_text(self) -> Result { + pub fn into_text(self) -> Result { match self { - Message::Text(string) => Ok(string), + Message::Text(txt) => Ok(txt), Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => { - Ok(data.into_text()?) + Ok(data.try_into()?) } - Message::Close(None) => Ok(String::new()), - Message::Close(Some(frame)) => Ok(frame.reason.into_owned()), - Message::Frame(frame) => Ok(frame.into_text()?), + Message::Close(None) => Ok(<_>::default()), + Message::Close(Some(frame)) => Ok(match frame.reason { + Cow::Borrowed(s) => Utf8Payload::from_static(s), + Cow::Owned(s) => s.into(), + }), + Message::Frame(frame) => Ok(frame.payload.try_into()?), } } @@ -259,7 +265,7 @@ impl Message { /// this will try to convert binary data to utf8. pub fn to_text(&self) -> Result<&str> { match *self { - Message::Text(ref string) => Ok(string), + Message::Text(ref string) => Ok(string.as_str()), Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => { Ok(str::from_utf8(data.as_slice())?) } @@ -297,17 +303,17 @@ impl From> for Message { impl From for Vec { fn from(message: Message) -> Self { - message.into_data() + message.into_data().as_slice().into() } } -impl TryFrom for String { - type Error = Error; +// impl TryFrom for String { +// type Error = Error; - fn try_from(value: Message) -> StdResult { - value.into_text() - } -} +// fn try_from(value: Message) -> StdResult { +// Ok(value.into_text()?.as_str().into()) +// } +// } impl fmt::Display for Message { fn fmt(&self, f: &mut fmt::Formatter) -> StdResult<(), fmt::Error> {