Skip to content

Commit

Permalink
Avoid incomplete buffer for complete messages
Browse files Browse the repository at this point in the history
payload: Use BytesMut as Owned
  • Loading branch information
alexheretic authored and daniel-abramov committed Dec 14, 2024
1 parent 6c5aeb2 commit 7c5c3d2
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 70 deletions.
21 changes: 15 additions & 6 deletions benches/write.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! Benchmarks for write performance.
use bytes::{BufMut, BytesMut};
use criterion::Criterion;
use std::{
hint,
io::{self, Read, Write},
fmt::Write as _,
hint, io,
time::{Duration, Instant},
};
use tungstenite::{protocol::Role, Message, WebSocket};
Expand All @@ -16,12 +17,12 @@ const MOCK_WRITE_LEN: usize = 8 * 1024 * 1024;
/// Each `flush` takes **~8µs** to simulate flush io.
struct MockWrite(Vec<u8>);

impl Read for MockWrite {
impl io::Read for MockWrite {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "reads not supported"))
}
}
impl Write for MockWrite {
impl io::Write for MockWrite {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.0.len() + buf.len() > MOCK_WRITE_LEN {
self.flush()?;
Expand Down Expand Up @@ -54,11 +55,19 @@ fn benchmark(c: &mut Criterion) {
let mut ws =
WebSocket::from_raw_socket(MockWrite(Vec::with_capacity(MOCK_WRITE_LEN)), role, None);

let mut buf = BytesMut::with_capacity(128 * 1024);

b.iter(|| {
for i in 0_u64..100_000 {
let msg = match i {
_ if i % 3 == 0 => Message::binary(i.to_le_bytes().to_vec()),
_ => Message::text(format!("{{\"id\":{i}}}")),
_ if i % 3 == 0 => {
buf.put_slice(&i.to_le_bytes());
Message::binary(buf.split())
}
_ => {
buf.write_fmt(format_args!("{{\"id\":{i}}}")).unwrap();
Message::Text(buf.split().try_into().unwrap())
}
};
ws.write(msg).unwrap();
}
Expand Down
30 changes: 14 additions & 16 deletions src/protocol/frame/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
io::{Cursor, ErrorKind, Read, Write},
result::Result as StdResult,
str::Utf8Error,
string::{FromUtf8Error, String},
string::String,
};

use byteorder::{NetworkEndian, ReadBytesExt};
Expand All @@ -16,7 +16,11 @@ use super::{
mask::{apply_mask, generate_mask},
Payload,
};
use crate::error::{Error, ProtocolError, Result};
use crate::{
error::{Error, ProtocolError, Result},
protocol::frame::Utf8Payload,
};
use bytes::{Buf, BytesMut};

/// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)]
Expand Down Expand Up @@ -209,7 +213,7 @@ impl FrameHeader {
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Frame {
header: FrameHeader,
pub(crate) payload: Payload,
payload: Payload,
}

impl Frame {
Expand Down Expand Up @@ -275,15 +279,9 @@ impl Frame {
}
}

/// Consume the frame into its payload as binary.
#[inline]
pub fn into_data(self) -> Vec<u8> {
self.payload.into_data()
}

/// Consume the frame into its payload as string.
#[inline]
pub fn into_text(self) -> StdResult<String, FromUtf8Error> {
pub fn into_text(self) -> StdResult<Utf8Payload, Utf8Error> {
self.payload.into_text()
}

Expand All @@ -308,8 +306,8 @@ impl Frame {
_ => {
let mut data = self.payload.into_data();
let code = u16::from_be_bytes([data[0], data[1]]).into();
data.drain(0..2);
let text = String::from_utf8(data)?;
data.advance(2);
let text = String::from_utf8(data.to_vec())?;
Ok(Some(CloseFrame { code, reason: text.into() }))
}
}
Expand Down Expand Up @@ -353,12 +351,12 @@ impl Frame {
#[inline]
pub fn close(msg: Option<CloseFrame>) -> Frame {
let payload = if let Some(CloseFrame { code, reason }) = msg {
let mut p = Vec::with_capacity(reason.len() + 2);
let mut p = BytesMut::with_capacity(reason.len() + 2);
p.extend(u16::from(code).to_be_bytes());
p.extend_from_slice(reason.as_bytes());
p
} else {
Vec::new()
<_>::default()
};

Frame { header: FrameHeader::default(), payload: Payload::Owned(payload) }
Expand Down Expand Up @@ -476,7 +474,7 @@ mod tests {
let mut payload = Vec::new();
raw.read_to_end(&mut payload).unwrap();
let frame = Frame::from_payload(header, payload.into());
assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
assert_eq!(frame.into_payload().as_slice(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
}

#[test]
Expand All @@ -489,7 +487,7 @@ mod tests {

#[test]
fn display() {
let f = Frame::message("hi there", OpCode::Data(Data::Text), true);
let f = Frame::message(Payload::from_static(b"hi there"), OpCode::Data(Data::Text), true);
let view = format!("{f}");
assert!(view.contains("payload:"));
}
Expand Down
17 changes: 10 additions & 7 deletions src/protocol/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl FrameCodec {
}

if len <= self.in_buffer.len() {
break self.in_buffer.split_to(len).freeze();
break self.in_buffer.split_to(len);
}
}
}
Expand All @@ -206,7 +206,7 @@ impl FrameCodec {

let (header, length) = self.header.take().expect("Bug: no frame header");
debug_assert_eq!(payload.len() as u64, length);
let frame = Frame::from_payload(header, payload.into());
let frame = Frame::from_payload(header, Payload::Owned(payload));
trace!("received frame {frame}");
Ok(Some(frame))
}
Expand Down Expand Up @@ -282,10 +282,13 @@ mod tests {
let mut sock = FrameSocket::new(raw);

assert_eq!(
sock.read(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
sock.read(None).unwrap().unwrap().into_payload().as_slice(),
&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
assert_eq!(
sock.read(None).unwrap().unwrap().into_payload().as_slice(),
&[0x03, 0x02, 0x01]
);
assert_eq!(sock.read(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]);
assert!(sock.read(None).unwrap().is_none());

let (_, rest) = sock.into_inner();
Expand All @@ -297,8 +300,8 @@ mod tests {
let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
assert_eq!(
sock.read(None).unwrap().unwrap().into_data(),
vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
sock.read(None).unwrap().unwrap().into_payload().as_slice(),
&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
);
}

Expand Down
63 changes: 43 additions & 20 deletions src/protocol/frame/payload.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::{fmt::Display, mem, string::FromUtf8Error};

use bytes::Bytes;
use bytes::{Bytes, BytesMut};
use core::str;
use std::{fmt::Display, mem};

/// Utf8 payload.
#[derive(Debug, Default, Clone, Eq, PartialEq)]
Expand Down Expand Up @@ -46,6 +45,24 @@ impl TryFrom<Payload> for Utf8Payload {
}
}

impl TryFrom<Bytes> for Utf8Payload {
type Error = str::Utf8Error;

#[inline]
fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
Payload::from(bytes).try_into()
}
}

impl TryFrom<BytesMut> for Utf8Payload {
type Error = str::Utf8Error;

#[inline]
fn try_from(bytes: BytesMut) -> Result<Self, Self::Error> {
Payload::from(bytes).try_into()
}
}

impl From<String> for Utf8Payload {
#[inline]
fn from(s: String) -> Self {
Expand All @@ -56,7 +73,7 @@ impl From<String> for Utf8Payload {
impl From<&str> for Utf8Payload {
#[inline]
fn from(s: &str) -> Self {
Self(Payload::Owned(s.as_bytes().to_vec()))
Self(Payload::Owned(s.as_bytes().into()))
}
}

Expand Down Expand Up @@ -84,7 +101,7 @@ impl Display for Utf8Payload {
#[derive(Debug, Clone)]
pub enum Payload {
/// Owned data with unique ownership.
Owned(Vec<u8>),
Owned(BytesMut),
/// Shared data with shared ownership.
Shared(Bytes),
}
Expand All @@ -103,6 +120,15 @@ impl Payload {
Self::Shared(Bytes::from_owner(owner))
}

/// If owned converts into shared & then clones (cheaply).
#[inline]
pub fn share(&mut self) -> Self {
if let Self::Owned(bytes) = self {
*self = Self::Shared(mem::take(bytes).freeze());
}
self.clone()
}

/// Returns a slice of the payload.
#[inline]
pub fn as_slice(&self) -> &[u8] {
Expand Down Expand Up @@ -144,7 +170,7 @@ impl Payload {

/// Consumes the payload and returns the underlying data as a vector.
#[inline]
pub fn into_data(self) -> Vec<u8> {
pub fn into_data(self) -> BytesMut {
match self {
Payload::Owned(v) => v,
Payload::Shared(v) => v.into(),
Expand All @@ -153,32 +179,29 @@ impl Payload {

/// Consumes the payload and returns the underlying data as a string.
#[inline]
pub fn into_text(self) -> Result<String, FromUtf8Error> {
match self {
Payload::Owned(v) => Ok(String::from_utf8(v)?),
Payload::Shared(v) => Ok(String::from_utf8(v.into())?),
}
pub fn into_text(self) -> Result<Utf8Payload, str::Utf8Error> {
self.try_into()
}
}

impl Default for Payload {
#[inline]
fn default() -> Self {
Self::Owned(Vec::new())
Self::Owned(<_>::default())
}
}

impl From<Vec<u8>> for Payload {
#[inline]
fn from(v: Vec<u8>) -> Self {
Payload::Owned(v)
Payload::Owned(BytesMut::from_iter(v))
}
}

impl From<String> for Payload {
#[inline]
fn from(v: String) -> Self {
Payload::Owned(v.into())
Vec::from(v).into()
}
}

Expand All @@ -189,17 +212,17 @@ impl From<Bytes> for Payload {
}
}

impl From<&'static [u8]> for Payload {
impl From<BytesMut> for Payload {
#[inline]
fn from(v: &'static [u8]) -> Self {
Self::from_static(v)
fn from(v: BytesMut) -> Self {
Payload::Owned(v)
}
}

impl From<&'static str> for Payload {
impl From<&[u8]> for Payload {
#[inline]
fn from(v: &'static str) -> Self {
Self::from_static(v.as_bytes())
fn from(v: &[u8]) -> Self {
Self::Owned(v.into())
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/protocol/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ impl Message {
Cow::Borrowed(s) => Payload::from_static(s.as_bytes()),
Cow::Owned(s) => s.into(),
},
Message::Frame(frame) => frame.payload,
Message::Frame(frame) => frame.into_payload(),
}
}

Expand All @@ -257,7 +257,7 @@ impl Message {
Cow::Borrowed(s) => Utf8Payload::from_static(s),
Cow::Owned(s) => s.into(),
}),
Message::Frame(frame) => Ok(frame.payload.try_into()?),
Message::Frame(frame) => Ok(frame.into_text()?),
}
}

Expand Down
Loading

0 comments on commit 7c5c3d2

Please sign in to comment.