Skip to content

Commit

Permalink
UtfPayload
Browse files Browse the repository at this point in the history
  • Loading branch information
alexheretic authored and daniel-abramov committed Dec 14, 2024
1 parent b22f917 commit 6c5aeb2
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 63 deletions.
4 changes: 2 additions & 2 deletions benches/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn benchmark(c: &mut Criterion) {
writer
.send(match i {
_ if i % 3 == 0 => Message::binary(i.to_le_bytes().to_vec()),
_ => Message::Text(format!("{{\"id\":{i}}}")),
_ => Message::text(format!("{{\"id\":{i}}}")),
})
.unwrap();
sum += i;
Expand All @@ -68,7 +68,7 @@ fn benchmark(c: &mut Criterion) {
sum += u64::from_le_bytes(*a);
}
Message::Text(msg) => {
let i: u64 = msg[6..msg.len() - 1].parse().unwrap();
let i: u64 = msg.as_str()[6..msg.len() - 1].parse().unwrap();
sum += i;
}
m => panic!("Unexpected {m}"),
Expand Down
2 changes: 1 addition & 1 deletion benches/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ fn benchmark(c: &mut Criterion) {
for i in 0_u64..100_000 {
let msg = match i {
_ if i % 3 == 0 => Message::binary(i.to_le_bytes().to_vec()),
_ => Message::Text(format!("{{\"id\":{i}}}")),
_ => Message::text(format!("{{\"id\":{i}}}")),
};
ws.write(msg).unwrap();
}
Expand Down
2 changes: 1 addition & 1 deletion examples/autobahn-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect("ws://localhost:9001/getCaseCount")?;
let msg = socket.read()?;
socket.close(None)?;
Ok(msg.into_text()?.parse::<u32>().unwrap())
Ok(msg.into_text()?.as_str().parse::<u32>().unwrap())
}

fn update_reports() -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Lightweight, flexible WebSockets for Rust.
#![deny(
missing_docs,
// missing_docs,
missing_copy_implementations,
missing_debug_implementations,
trivial_casts,
Expand Down
2 changes: 1 addition & 1 deletion src/protocol/frame/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ impl FrameHeader {
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Frame {
header: FrameHeader,
payload: Payload,
pub(crate) payload: Payload,
}

impl Frame {
Expand Down
64 changes: 36 additions & 28 deletions src/protocol/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ mod frame;
mod mask;
mod payload;

pub use self::{
frame::{CloseFrame, Frame, FrameHeader},
payload::{Payload, Utf8Payload},
};

use crate::{
error::{CapacityError, Error, Result},
Message, ReadBuffer,
Message,
};
use bytes::BytesMut;
use log::*;
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};

pub use self::{
frame::{CloseFrame, Frame, FrameHeader},
payload::Payload,
};
use std::io::{Cursor, Error as IoError, ErrorKind as IoErrorKind, Read, Write};

/// A reader and writer for WebSocket frames.
#[derive(Debug)]
Expand All @@ -40,8 +41,8 @@ impl<Stream> FrameSocket<Stream> {
}

/// Extract a stream from the socket.
pub fn into_inner(self) -> (Stream, Vec<u8>) {
(self.stream, self.codec.in_buffer.into_vec())
pub fn into_inner(self) -> (Stream, BytesMut) {
(self.stream, self.codec.in_buffer)
}

/// Returns a shared reference to the inner stream.
Expand Down Expand Up @@ -98,7 +99,7 @@ where
#[derive(Debug)]
pub(super) struct FrameCodec {
/// Buffer to read data from the stream.
in_buffer: ReadBuffer,
in_buffer: BytesMut,
/// Buffer to send packets to the network.
out_buffer: Vec<u8>,
/// Capacity limit for `out_buffer`.
Expand All @@ -113,11 +114,13 @@ pub(super) struct FrameCodec {
header: Option<(FrameHeader, u64)>,
}

const READ_BUFFER_CAP: usize = 64 * 1024;

impl FrameCodec {
/// Create a new frame codec.
pub(super) fn new() -> Self {
Self {
in_buffer: ReadBuffer::new(),
in_buffer: BytesMut::with_capacity(READ_BUFFER_CAP),
out_buffer: Vec::new(),
max_out_buffer_len: usize::MAX,
out_buffer_write_len: 0,
Expand All @@ -127,8 +130,10 @@ impl FrameCodec {

/// Create a new frame codec from partially read data.
pub(super) fn from_partially_read(part: Vec<u8>) -> Self {
let mut in_buffer = BytesMut::from_iter(part);
in_buffer.reserve(READ_BUFFER_CAP.saturating_sub(in_buffer.len()));
Self {
in_buffer: ReadBuffer::from_partially_read(part),
in_buffer,
out_buffer: Vec::new(),
max_out_buffer_len: usize::MAX,
out_buffer_write_len: 0,
Expand Down Expand Up @@ -160,38 +165,39 @@ impl FrameCodec {

let payload = loop {
{
let cursor = self.in_buffer.as_cursor_mut();

if self.header.is_none() {
self.header = FrameHeader::parse(cursor)?;
let mut cursor = Cursor::new(&mut self.in_buffer);
self.header = FrameHeader::parse(&mut cursor)?;
let advanced = cursor.position();
bytes::Buf::advance(&mut self.in_buffer, advanced as _);
}

if let Some((_, ref length)) = self.header {
let length = *length;
if let Some((_, len)) = &self.header {
let len = *len as usize;

// Enforce frame size limit early and make sure `length`
// is not too big (fits into `usize`).
if length > max_size as u64 {
if len > max_size {
return Err(Error::Capacity(CapacityError::MessageTooLong {
size: length as usize,
size: len,
max_size,
}));
}

let input_size = cursor.get_ref().len() as u64 - cursor.position();
if length <= input_size {
// No truncation here since `length` is checked above
let mut payload = Vec::with_capacity(length as usize);
if length > 0 {
cursor.take(length).read_to_end(&mut payload)?;
}
break payload;
if len <= self.in_buffer.len() {
break self.in_buffer.split_to(len).freeze();
}
}
}

// Not enough data in buffer.
let size = self.in_buffer.read_from(stream)?;
let reserve_len = self.header.as_ref().map(|(_, l)| *l as usize).unwrap_or(6);
self.in_buffer.reserve(reserve_len);
let mut buf = self.in_buffer.split_off(self.in_buffer.len());
buf.resize(reserve_len.max(buf.capacity()), 0);
let size = stream.read(&mut buf)?;
buf.truncate(size);
self.in_buffer.unsplit(buf);
if size == 0 {
trace!("no frame received");
return Ok(None);
Expand Down Expand Up @@ -267,6 +273,8 @@ mod tests {

#[test]
fn read_frames() {
env_logger::init();

let raw = Cursor::new(vec![
0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x03, 0x03, 0x02, 0x01,
0x99,
Expand Down
134 changes: 130 additions & 4 deletions src/protocol/frame/payload.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,87 @@
use std::{mem, string::FromUtf8Error};
use std::{fmt::Display, mem, string::FromUtf8Error};

use bytes::Bytes;
use core::str;

/// Utf8 payload.
#[derive(Debug, Default, Clone, Eq, PartialEq)]
pub struct Utf8Payload(Payload);

impl Utf8Payload {
#[inline]
pub const fn from_static(str: &'static str) -> Self {
Self(Payload::Shared(Bytes::from_static(str.as_bytes())))
}

/// Returns a slice of the payload.
#[inline]
pub fn as_slice(&self) -> &[u8] {
self.0.as_slice()
}

#[inline]
pub fn as_str(&self) -> &str {
// safety: is valid uft8
unsafe { str::from_utf8_unchecked(self.as_slice()) }
}

#[inline]
pub fn len(&self) -> usize {
self.as_slice().len()
}

#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}

impl TryFrom<Payload> for Utf8Payload {
type Error = str::Utf8Error;

#[inline]
fn try_from(payload: Payload) -> Result<Self, Self::Error> {
str::from_utf8(payload.as_slice())?;
Ok(Self(payload))
}
}

impl From<String> for Utf8Payload {
#[inline]
fn from(s: String) -> Self {
Self(s.into())
}
}

impl From<&str> for Utf8Payload {
#[inline]
fn from(s: &str) -> Self {
Self(Payload::Owned(s.as_bytes().to_vec()))
}
}

impl From<&String> for Utf8Payload {
#[inline]
fn from(s: &String) -> Self {
s.as_str().into()
}
}

impl From<Utf8Payload> for Payload {
#[inline]
fn from(Utf8Payload(payload): Utf8Payload) -> Self {
payload
}
}

impl Display for Utf8Payload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}

/// A payload of a WebSocket frame.
#[derive(Debug, Clone, Eq, PartialEq)]
#[derive(Debug, Clone)]
pub enum Payload {
/// Owned data with unique ownership.
Owned(Vec<u8>),
Expand All @@ -12,6 +90,19 @@ pub enum Payload {
}

impl Payload {
#[inline]
pub const fn from_static(bytes: &'static [u8]) -> Self {
Self::Shared(Bytes::from_static(bytes))
}

#[inline]
pub fn from_owner<T>(owner: T) -> Self
where
T: AsRef<[u8]> + Send + 'static,
{
Self::Shared(Bytes::from_owner(owner))
}

/// Returns a slice of the payload.
#[inline]
pub fn as_slice(&self) -> &[u8] {
Expand Down Expand Up @@ -70,32 +161,67 @@ impl Payload {
}
}

impl Default for Payload {
#[inline]
fn default() -> Self {
Self::Owned(Vec::new())
}
}

impl From<Vec<u8>> for Payload {
#[inline]
fn from(v: Vec<u8>) -> Self {
Payload::Owned(v)
}
}

impl From<String> for Payload {
#[inline]
fn from(v: String) -> Self {
Payload::Owned(v.into())
}
}

impl From<Bytes> for Payload {
#[inline]
fn from(v: Bytes) -> Self {
Payload::Shared(v)
}
}

impl From<&'static [u8]> for Payload {
#[inline]
fn from(v: &'static [u8]) -> Self {
Payload::Shared(Bytes::from_static(v))
Self::from_static(v)
}
}

impl From<&'static str> for Payload {
#[inline]
fn from(v: &'static str) -> Self {
Payload::Shared(Bytes::from_static(v.as_bytes()))
Self::from_static(v.as_bytes())
}
}

impl PartialEq<Payload> for Payload {
#[inline]
fn eq(&self, other: &Payload) -> bool {
self.as_slice() == other.as_slice()
}
}

impl Eq for Payload {}

impl PartialEq<[u8]> for Payload {
#[inline]
fn eq(&self, other: &[u8]) -> bool {
self.as_slice() == other
}
}

impl<const N: usize> PartialEq<&[u8; N]> for Payload {
#[inline]
fn eq(&self, other: &&[u8; N]) -> bool {
self.as_slice() == *other
}
}
Loading

0 comments on commit 6c5aeb2

Please sign in to comment.