From ecb4616d2a3cb4ec283b27c491ff56dc5f6d6e71 Mon Sep 17 00:00:00 2001 From: Park Joon-Kyu Date: Wed, 6 Nov 2024 21:57:48 +0900 Subject: [PATCH 1/8] Implement RSV bits --- actix-http/src/ws/codec.rs | 40 ++++++++++++++++++++++++++++----- actix-http/src/ws/frame.rs | 45 +++++++++++++++++++++++--------------- actix-http/src/ws/mod.rs | 2 +- actix-http/src/ws/proto.rs | 19 ++++++++++++++++ 4 files changed, 82 insertions(+), 24 deletions(-) diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs index ad487e400fb..fe2ca43bef5 100644 --- a/actix-http/src/ws/codec.rs +++ b/actix-http/src/ws/codec.rs @@ -6,7 +6,7 @@ use tracing::error; use super::{ frame::Parser, - proto::{CloseReason, OpCode}, + proto::{CloseReason, OpCode, RsvBits}, ProtocolError, }; @@ -71,6 +71,9 @@ pub enum Item { pub struct Codec { flags: Flags, max_size: usize, + + inbound_rsv_bits: Option, + outbound_rsv_bits: RsvBits, } bitflags! { @@ -88,6 +91,9 @@ impl Codec { Codec { max_size: 65_536, flags: Flags::SERVER, + + inbound_rsv_bits: None, + outbound_rsv_bits: RsvBits::empty(), } } @@ -108,6 +114,18 @@ impl Codec { self.flags.remove(Flags::SERVER); self } + + /// Get inbound RSV bits. + /// + /// Returns None if there's no received frame yet. + pub fn get_inbound_rsv_bits(&self) -> Option { + self.inbound_rsv_bits + } + + /// Set outbound RSV bits. + pub fn set_outbound_rsv_bits(&mut self, rsv_bits: RsvBits) { + self.outbound_rsv_bits = rsv_bits; + } } impl Default for Codec { @@ -125,6 +143,7 @@ impl Encoder for Codec { dst, txt, OpCode::Text, + self.outbound_rsv_bits, true, !self.flags.contains(Flags::SERVER), ), @@ -132,6 +151,7 @@ impl Encoder for Codec { dst, bin, OpCode::Binary, + self.outbound_rsv_bits, true, !self.flags.contains(Flags::SERVER), ), @@ -139,6 +159,7 @@ impl Encoder for Codec { dst, txt, OpCode::Ping, + self.outbound_rsv_bits, true, !self.flags.contains(Flags::SERVER), ), @@ -146,12 +167,16 @@ impl Encoder for Codec { dst, txt, OpCode::Pong, + self.outbound_rsv_bits, true, !self.flags.contains(Flags::SERVER), ), - Message::Close(reason) => { - Parser::write_close(dst, reason, !self.flags.contains(Flags::SERVER)) - } + Message::Close(reason) => Parser::write_close( + dst, + reason, + self.outbound_rsv_bits, + !self.flags.contains(Flags::SERVER), + ), Message::Continuation(cont) => match cont { Item::FirstText(data) => { if self.flags.contains(Flags::W_CONTINUATION) { @@ -162,6 +187,7 @@ impl Encoder for Codec { dst, &data[..], OpCode::Text, + self.outbound_rsv_bits, false, !self.flags.contains(Flags::SERVER), ) @@ -176,6 +202,7 @@ impl Encoder for Codec { dst, &data[..], OpCode::Binary, + self.outbound_rsv_bits, false, !self.flags.contains(Flags::SERVER), ) @@ -187,6 +214,7 @@ impl Encoder for Codec { dst, &data[..], OpCode::Continue, + self.outbound_rsv_bits, false, !self.flags.contains(Flags::SERVER), ) @@ -201,6 +229,7 @@ impl Encoder for Codec { dst, &data[..], OpCode::Continue, + self.outbound_rsv_bits, true, !self.flags.contains(Flags::SERVER), ) @@ -221,7 +250,8 @@ impl Decoder for Codec { fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) { - Ok(Some((finished, opcode, payload))) => { + Ok(Some((finished, opcode, rsv_bits, payload))) => { + self.inbound_rsv_bits = Some(rsv_bits); // continuation is not supported if !finished { return match opcode { diff --git a/actix-http/src/ws/frame.rs b/actix-http/src/ws/frame.rs index 35b3f8e668e..e166f1cf516 100644 --- a/actix-http/src/ws/frame.rs +++ b/actix-http/src/ws/frame.rs @@ -5,7 +5,7 @@ use tracing::debug; use super::{ mask::apply_mask, - proto::{CloseCode, CloseReason, OpCode}, + proto::{CloseCode, CloseReason, OpCode, RsvBits}, ProtocolError, }; @@ -17,7 +17,7 @@ impl Parser { fn parse_metadata( src: &[u8], server: bool, - ) -> Result)>, ProtocolError> { + ) -> Result)>, ProtocolError> { let chunk_len = src.len(); let mut idx = 2; @@ -37,6 +37,9 @@ impl Parser { return Err(ProtocolError::MaskedFrame); } + // RSV bits + let rsv_bits = RsvBits::from_bits((first & 0x70) >> 4).unwrap_or(RsvBits::empty()); + // Op code let opcode = OpCode::from(first & 0x0F); @@ -79,7 +82,7 @@ impl Parser { None }; - Ok(Some((idx, finished, opcode, length, mask))) + Ok(Some((idx, finished, opcode, rsv_bits, length, mask))) } /// Parse the input stream into a frame. @@ -87,12 +90,13 @@ impl Parser { src: &mut BytesMut, server: bool, max_size: usize, - ) -> Result)>, ProtocolError> { + ) -> Result)>, ProtocolError> { // try to parse ws frame metadata - let (idx, finished, opcode, length, mask) = match Parser::parse_metadata(src, server)? { - None => return Ok(None), - Some(res) => res, - }; + let (idx, finished, opcode, rsv_bits, length, mask) = + match Parser::parse_metadata(src, server)? { + None => return Ok(None), + Some(res) => res, + }; // not enough data if src.len() < idx + length { @@ -115,7 +119,7 @@ impl Parser { // no need for body if length == 0 { - return Ok(Some((finished, opcode, None))); + return Ok(Some((finished, opcode, rsv_bits, None))); } let mut data = src.split_to(length); @@ -127,7 +131,7 @@ impl Parser { } OpCode::Close if length > 125 => { debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); - return Ok(Some((true, OpCode::Close, None))); + return Ok(Some((true, OpCode::Close, rsv_bits, None))); } _ => {} } @@ -137,7 +141,7 @@ impl Parser { apply_mask(&mut data, mask); } - Ok(Some((finished, opcode, Some(data)))) + Ok(Some((finished, opcode, rsv_bits, Some(data)))) } /// Parse the payload of a close frame. @@ -161,15 +165,15 @@ impl Parser { dst: &mut BytesMut, pl: B, op: OpCode, + rsv_bits: RsvBits, fin: bool, mask: bool, ) { let payload = pl.as_ref(); - let one: u8 = if fin { - 0x80 | Into::::into(op) - } else { - op.into() - }; + let fin_bits = if fin { 0x80 } else { 0x00 }; + let rsv_bits = rsv_bits.bits() << 4; + + let one: u8 = fin_bits | rsv_bits | Into::::into(op); let payload_len = payload.len(); let (two, p_len) = if mask { (0x80, payload_len + 4) @@ -203,7 +207,12 @@ impl Parser { /// Create a new Close control frame. #[inline] - pub fn write_close(dst: &mut BytesMut, reason: Option, mask: bool) { + pub fn write_close( + dst: &mut BytesMut, + reason: Option, + rsv_bits: RsvBits, + mask: bool, + ) { let payload = match reason { None => Vec::new(), Some(reason) => { @@ -215,7 +224,7 @@ impl Parser { } }; - Parser::write_message(dst, payload, OpCode::Close, true, mask) + Parser::write_message(dst, payload, OpCode::Close, rsv_bits, true, mask) } } diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index 88053b254d5..811e634747c 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -20,7 +20,7 @@ pub use self::{ codec::{Codec, Frame, Item, Message}, dispatcher::Dispatcher, frame::Parser, - proto::{hash_key, CloseCode, CloseReason, OpCode}, + proto::{hash_key, CloseCode, CloseReason, OpCode, RsvBits}, }; /// WebSocket protocol errors. diff --git a/actix-http/src/ws/proto.rs b/actix-http/src/ws/proto.rs index 27815eaf248..6941f5828cd 100644 --- a/actix-http/src/ws/proto.rs +++ b/actix-http/src/ws/proto.rs @@ -222,6 +222,25 @@ impl> From<(CloseCode, T)> for CloseReason { } } +bitflags::bitflags! { + /// RSV bits defined in [RFC 6455 §5.2]. + /// Reserved for extensions and should be set to zero if no extensions are applicable. + /// + /// [RFC 6455]: https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 + #[derive(Debug, Eq, PartialEq, Clone, Copy)] + pub struct RsvBits: u8 { + const RSV1 = 0b0000_0100; + const RSV2 = 0b0000_0010; + const RSV3 = 0b0000_0001; + } +} + +impl Default for RsvBits { + fn default() -> Self { + Self::empty() + } +} + /// The WebSocket GUID as stated in the spec. /// See . static WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; From c97b6ac0d48ac4f5ebe94689ef904b2cc0783bf5 Mon Sep 17 00:00:00 2001 From: Park Joon-Kyu Date: Wed, 6 Nov 2024 21:57:48 +0900 Subject: [PATCH 2/8] implement DEFLATE compression --- actix-http/Cargo.toml | 31 ++- actix-http/src/ws/codec.rs | 391 ++++++++++++++++++++------ actix-http/src/ws/deflate.rs | 516 +++++++++++++++++++++++++++++++++++ actix-http/src/ws/frame.rs | 43 ++- actix-http/src/ws/mod.rs | 71 +++++ awc/Cargo.toml | 24 +- awc/src/ws.rs | 73 ++++- 7 files changed, 1028 insertions(+), 121 deletions(-) create mode 100644 actix-http/src/ws/deflate.rs diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index 3f81ea9f000..e94c6745ebd 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -1,10 +1,7 @@ [package] name = "actix-http" version = "3.9.0" -authors = [ - "Nikolay Kim ", - "Rob Ede ", -] +authors = ["Nikolay Kim ", "Rob Ede "] description = "HTTP types and services for the Actix ecosystem" keywords = ["actix", "http", "framework", "async", "futures"] homepage = "https://actix.rs" @@ -32,6 +29,7 @@ features = [ "compress-brotli", "compress-gzip", "compress-zstd", + "compress-ws-deflate", ] [package.metadata.cargo_check_external_types] @@ -62,12 +60,7 @@ default = [] http2 = ["dep:h2"] # WebSocket protocol implementation -ws = [ - "dep:local-channel", - "dep:base64", - "dep:rand", - "dep:sha1", -] +ws = ["dep:local-channel", "dep:base64", "dep:rand", "dep:sha1"] # TLS via OpenSSL openssl = ["__tls", "actix-tls/accept", "actix-tls/openssl"] @@ -89,8 +82,9 @@ rustls-0_23 = ["__tls", "actix-tls/accept", "actix-tls/rustls-0_23"] # Compression codecs compress-brotli = ["__compress", "dep:brotli"] -compress-gzip = ["__compress", "dep:flate2"] -compress-zstd = ["__compress", "dep:zstd"] +compress-gzip = ["__compress", "dep:flate2"] +compress-zstd = ["__compress", "dep:zstd"] +compress-ws-deflate = ["dep:flate2", "flate2/zlib-default"] # Internal (PRIVATE!) features used to aid testing and checking feature status. # Don't rely on these whatsoever. They are semver-exempt and may disappear at anytime. @@ -112,7 +106,9 @@ bytes = "1" bytestring = "1" derive_more = { version = "1", features = ["as_ref", "deref", "deref_mut", "display", "error", "from"] } encoding_rs = "0.8" -futures-core = { version = "0.3.17", default-features = false, features = ["alloc"] } +futures-core = { version = "0.3.17", default-features = false, features = [ + "alloc", +] } http = "0.2.7" httparse = "1.5.1" httpdate = "1.0.1" @@ -146,14 +142,19 @@ zstd = { version = "0.13", optional = true } [dev-dependencies] actix-http-test = { version = "3", features = ["openssl"] } actix-server = "2" -actix-tls = { version = "3.4", features = ["openssl", "rustls-0_23-webpki-roots"] } +actix-tls = { version = "3.4", features = [ + "openssl", + "rustls-0_23-webpki-roots", +] } actix-web = "4" async-stream = "0.3" criterion = { version = "0.5", features = ["html_reports"] } divan = "0.1.8" env_logger = "0.11" -futures-util = { version = "0.3.17", default-features = false, features = ["alloc"] } +futures-util = { version = "0.3.17", default-features = false, features = [ + "alloc", +] } memchr = "2.4" once_cell = "1.9" rcgen = "0.13" diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs index fe2ca43bef5..6c096eb4b68 100644 --- a/actix-http/src/ws/codec.rs +++ b/actix-http/src/ws/codec.rs @@ -1,9 +1,13 @@ use bitflags::bitflags; use bytes::{Bytes, BytesMut}; use bytestring::ByteString; -use tokio_util::codec::{Decoder, Encoder}; +use tokio_util::codec; use tracing::error; +#[cfg(feature = "compress-ws-deflate")] +use super::deflate::{ + DeflateCompressionContext, DeflateContext, DeflateDecompressionContext, RSV_BIT_DEFLATE_FLAG, +}; use super::{ frame::Parser, proto::{CloseReason, OpCode, RsvBits}, @@ -66,16 +70,6 @@ pub enum Item { Last(Bytes), } -/// WebSocket protocol codec. -#[derive(Debug, Clone)] -pub struct Codec { - flags: Flags, - max_size: usize, - - inbound_rsv_bits: Option, - outbound_rsv_bits: RsvBits, -} - bitflags! { #[derive(Debug, Clone, Copy)] struct Flags: u8 { @@ -85,81 +79,116 @@ bitflags! { } } -impl Codec { - /// Create new WebSocket frames decoder. - pub const fn new() -> Codec { - Codec { - max_size: 65_536, +/// WebSocket message encoder. +#[derive(Debug, Clone)] +pub struct Encoder { + flags: Flags, + + #[cfg(feature = "compress-ws-deflate")] + deflate_compress: Option, +} + +impl Encoder { + /// Create new WebSocket frames encoder. + pub const fn new() -> Encoder { + Encoder { flags: Flags::SERVER, - inbound_rsv_bits: None, - outbound_rsv_bits: RsvBits::empty(), + #[cfg(feature = "compress-ws-deflate")] + deflate_compress: None, } } - /// Set max frame size. - /// - /// By default max size is set to 64KiB. - #[must_use = "This returns the a new Codec, without modifying the original."] - pub fn max_size(mut self, size: usize) -> Self { - self.max_size = size; + /// Create new WebSocket frames encoder with `permessage-deflate` extension support. + #[cfg(feature = "compress-ws-deflate")] + pub fn new_deflate(compress: DeflateCompressionContext) -> Encoder { + Encoder { + flags: Flags::SERVER, + + deflate_compress: Some(compress), + } + } + + fn set_client_mode(mut self) -> Self { + self.flags = Flags::empty(); self } - /// Set decoder to client mode. - /// - /// By default decoder works in server mode. - #[must_use = "This returns the a new Codec, without modifying the original."] - pub fn client_mode(mut self) -> Self { - self.flags.remove(Flags::SERVER); + #[cfg(feature = "compress-ws-deflate")] + fn set_client_mode_deflate( + mut self, + remote_no_context_takeover: bool, + remote_max_window_bits: u8, + ) -> Self { + self.deflate_compress = self + .deflate_compress + .map(|c| c.reset_with(remote_no_context_takeover, remote_max_window_bits)); self } - /// Get inbound RSV bits. - /// - /// Returns None if there's no received frame yet. - pub fn get_inbound_rsv_bits(&self) -> Option { - self.inbound_rsv_bits + #[cfg(feature = "compress-ws-deflate")] + fn process_payload( + &mut self, + fin: bool, + bytes: Bytes, + ) -> Result<(Bytes, RsvBits), ProtocolError> { + if let Some(compress) = &mut self.deflate_compress { + Ok((compress.compress(fin, bytes)?, RSV_BIT_DEFLATE_FLAG)) + } else { + Ok((bytes, RsvBits::empty())) + } } - /// Set outbound RSV bits. - pub fn set_outbound_rsv_bits(&mut self, rsv_bits: RsvBits) { - self.outbound_rsv_bits = rsv_bits; + #[cfg(not(feature = "compress-ws-deflate"))] + fn process_payload( + &mut self, + _fin: bool, + bytes: Bytes, + ) -> Result<(Bytes, RsvBits), ProtocolError> { + Ok((bytes, RsvBits::empty())) } } -impl Default for Codec { +impl Default for Encoder { fn default() -> Self { Self::new() } } -impl Encoder for Codec { +impl codec::Encoder for Encoder { type Error = ProtocolError; fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { match item { - Message::Text(txt) => Parser::write_message( - dst, - txt, - OpCode::Text, - self.outbound_rsv_bits, - true, - !self.flags.contains(Flags::SERVER), - ), - Message::Binary(bin) => Parser::write_message( - dst, - bin, - OpCode::Binary, - self.outbound_rsv_bits, - true, - !self.flags.contains(Flags::SERVER), - ), + Message::Text(txt) => { + let (bytes, rsv_bits) = self.process_payload(true, txt.into_bytes())?; + + Parser::write_message( + dst, + bytes, + OpCode::Text, + rsv_bits, + true, + !self.flags.contains(Flags::SERVER), + ) + } + Message::Binary(bin) => { + let (bin, rsv_bits) = self.process_payload(true, bin)?; + + Parser::write_message( + dst, + bin, + OpCode::Binary, + rsv_bits, + true, + !self.flags.contains(Flags::SERVER), + ) + } Message::Ping(txt) => Parser::write_message( dst, txt, OpCode::Ping, - self.outbound_rsv_bits, + RsvBits::empty(), true, !self.flags.contains(Flags::SERVER), ), @@ -167,14 +196,14 @@ impl Encoder for Codec { dst, txt, OpCode::Pong, - self.outbound_rsv_bits, + RsvBits::empty(), true, !self.flags.contains(Flags::SERVER), ), Message::Close(reason) => Parser::write_close( dst, reason, - self.outbound_rsv_bits, + RsvBits::empty(), !self.flags.contains(Flags::SERVER), ), Message::Continuation(cont) => match cont { @@ -182,12 +211,14 @@ impl Encoder for Codec { if self.flags.contains(Flags::W_CONTINUATION) { return Err(ProtocolError::ContinuationStarted); } else { + let (data, rsv_bits) = self.process_payload(false, data)?; + self.flags.insert(Flags::W_CONTINUATION); Parser::write_message( dst, - &data[..], + data, OpCode::Text, - self.outbound_rsv_bits, + rsv_bits, false, !self.flags.contains(Flags::SERVER), ) @@ -197,12 +228,14 @@ impl Encoder for Codec { if self.flags.contains(Flags::W_CONTINUATION) { return Err(ProtocolError::ContinuationStarted); } else { + let (data, rsv_bits) = self.process_payload(false, data)?; + self.flags.insert(Flags::W_CONTINUATION); Parser::write_message( dst, - &data[..], + data, OpCode::Binary, - self.outbound_rsv_bits, + rsv_bits, false, !self.flags.contains(Flags::SERVER), ) @@ -210,11 +243,13 @@ impl Encoder for Codec { } Item::Continue(data) => { if self.flags.contains(Flags::W_CONTINUATION) { + let (data, rsv_bits) = self.process_payload(false, data)?; + Parser::write_message( dst, - &data[..], + data, OpCode::Continue, - self.outbound_rsv_bits, + rsv_bits, false, !self.flags.contains(Flags::SERVER), ) @@ -225,11 +260,14 @@ impl Encoder for Codec { Item::Last(data) => { if self.flags.contains(Flags::W_CONTINUATION) { self.flags.remove(Flags::W_CONTINUATION); + + let (data, rsv_bits) = self.process_payload(true, data)?; + Parser::write_message( dst, - &data[..], + data, OpCode::Continue, - self.outbound_rsv_bits, + rsv_bits, true, !self.flags.contains(Flags::SERVER), ) @@ -244,21 +282,120 @@ impl Encoder for Codec { } } -impl Decoder for Codec { +/// WebSocket message decoder. +#[derive(Debug, Clone)] +pub struct Decoder { + flags: Flags, + max_size: usize, + + #[cfg(feature = "compress-ws-deflate")] + deflate_decompress: Option, +} + +impl Decoder { + /// Create new WebSocket frames decoder. + pub const fn new() -> Decoder { + Decoder { + flags: Flags::SERVER, + max_size: 65_536, + + #[cfg(feature = "compress-ws-deflate")] + deflate_decompress: None, + } + } + + /// Create new WebSocket frames decoder with `permessage-deflate` extension support. + #[cfg(feature = "compress-ws-deflate")] + pub fn new_deflate(decompress: DeflateDecompressionContext) -> Decoder { + Decoder { + flags: Flags::SERVER, + max_size: 65_536, + + deflate_decompress: Some(decompress), + } + } + + fn set_client_mode(mut self) -> Self { + self.flags = Flags::empty(); + self + } + + #[cfg(feature = "compress-ws-deflate")] + fn set_client_mode_deflate( + mut self, + local_no_context_takeover: bool, + local_max_window_bits: u8, + ) -> Self { + if let Some(decompress) = &mut self.deflate_decompress { + decompress.reset_with(local_no_context_takeover, local_max_window_bits); + } + + self + } + + fn set_max_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + + #[cfg(feature = "compress-ws-deflate")] + fn process_payload( + &mut self, + fin: bool, + opcode: OpCode, + rsv_bits: RsvBits, + bytes: Option, + ) -> Result, ProtocolError> { + if let Some(bytes) = bytes { + if let Some(decompress) = &mut self.deflate_decompress { + Ok(Some(decompress.decompress(fin, opcode, rsv_bits, bytes)?)) + } else { + Ok(Some(bytes)) + } + } else { + Ok(None) + } + } + + #[cfg(not(feature = "compress-ws-deflate"))] + fn process_payload( + &mut self, + _fin: bool, + _opcode: OpCode, + _rsv_bits: RsvBits, + bytes: Option, + ) -> Result, ProtocolError> { + Ok(bytes) + } +} + +impl Default for Decoder { + fn default() -> Self { + Self::new() + } +} + +impl codec::Decoder for Decoder { type Item = Frame; type Error = ProtocolError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) { Ok(Some((finished, opcode, rsv_bits, payload))) => { - self.inbound_rsv_bits = Some(rsv_bits); + let payload = self.process_payload( + finished, + opcode, + rsv_bits, + payload.map(BytesMut::freeze), + )?; + // continuation is not supported if !finished { return match opcode { OpCode::Continue => { if self.flags.contains(Flags::CONTINUATION) { Ok(Some(Frame::Continuation(Item::Continue( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + payload.unwrap_or_else(Bytes::new), )))) } else { Err(ProtocolError::ContinuationNotStarted) @@ -268,7 +405,7 @@ impl Decoder for Codec { if !self.flags.contains(Flags::CONTINUATION) { self.flags.insert(Flags::CONTINUATION); Ok(Some(Frame::Continuation(Item::FirstBinary( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + payload.unwrap_or_else(Bytes::new), )))) } else { Err(ProtocolError::ContinuationStarted) @@ -278,7 +415,7 @@ impl Decoder for Codec { if !self.flags.contains(Flags::CONTINUATION) { self.flags.insert(Flags::CONTINUATION); Ok(Some(Frame::Continuation(Item::FirstText( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + payload.unwrap_or_else(Bytes::new), )))) } else { Err(ProtocolError::ContinuationStarted) @@ -296,7 +433,7 @@ impl Decoder for Codec { if self.flags.contains(Flags::CONTINUATION) { self.flags.remove(Flags::CONTINUATION); Ok(Some(Frame::Continuation(Item::Last( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + payload.unwrap_or_else(Bytes::new), )))) } else { Err(ProtocolError::ContinuationNotStarted) @@ -311,18 +448,10 @@ impl Decoder for Codec { Ok(Some(Frame::Close(None))) } } - OpCode::Ping => Ok(Some(Frame::Ping( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), - ))), - OpCode::Pong => Ok(Some(Frame::Pong( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), - ))), - OpCode::Binary => Ok(Some(Frame::Binary( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), - ))), - OpCode::Text => Ok(Some(Frame::Text( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), - ))), + OpCode::Ping => Ok(Some(Frame::Ping(payload.unwrap_or_else(Bytes::new)))), + OpCode::Pong => Ok(Some(Frame::Pong(payload.unwrap_or_else(Bytes::new)))), + OpCode::Binary => Ok(Some(Frame::Binary(payload.unwrap_or_else(Bytes::new)))), + OpCode::Text => Ok(Some(Frame::Text(payload.unwrap_or_else(Bytes::new)))), } } Ok(None) => Ok(None), @@ -330,3 +459,95 @@ impl Decoder for Codec { } } } + +/// WebSocket protocol codec. +#[derive(Debug, Default, Clone)] +pub struct Codec { + encoder: Encoder, + decoder: Decoder, +} + +impl Codec { + /// Create new WebSocket frames codec. + pub fn new() -> Codec { + Codec { + encoder: Encoder::new(), + decoder: Decoder::new(), + } + } + + /// Create new WebSocket frames codec with DEFLATE compression. + #[cfg(feature = "compress-ws-deflate")] + pub fn new_deflate(context: DeflateContext) -> Codec { + let DeflateContext { + compress, + decompress, + } = context; + + Codec { + encoder: Encoder::new_deflate(compress), + decoder: Decoder::new_deflate(decompress), + } + } + + /// Set max frame size. + /// + /// By default max size is set to 64KiB. + #[must_use = "This returns the a new Codec, without modifying the original."] + pub fn max_size(self, size: usize) -> Self { + let Self { encoder, decoder } = self; + + Codec { + encoder, + decoder: decoder.set_max_size(size), + } + } + + /// Set decoder to client mode. + /// + /// By default decoder works in server mode. + #[must_use = "This returns the a new Codec, without modifying the original."] + pub fn client_mode(self) -> Self { + let Self { + mut encoder, + mut decoder, + } = self; + + encoder = encoder.set_client_mode(); + decoder = decoder.set_client_mode(); + #[cfg(feature = "compress-ws-deflate")] + { + if let Some(decoder) = &decoder.deflate_decompress { + encoder = encoder.set_client_mode_deflate( + decoder.local_no_context_takeover, + decoder.local_max_window_bits, + ); + } + if let Some(encoder) = &encoder.deflate_compress { + decoder = decoder.set_client_mode_deflate( + encoder.remote_no_context_takeover, + encoder.remote_max_window_bits, + ); + } + } + + Self { encoder, decoder } + } +} + +impl codec::Decoder for Codec { + type Item = Frame; + type Error = ProtocolError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + self.decoder.decode(src) + } +} + +impl codec::Encoder for Codec { + type Error = ProtocolError; + + fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { + self.encoder.encode(item, dst) + } +} diff --git a/actix-http/src/ws/deflate.rs b/actix-http/src/ws/deflate.rs new file mode 100644 index 00000000000..bb256c18f2c --- /dev/null +++ b/actix-http/src/ws/deflate.rs @@ -0,0 +1,516 @@ +use std::convert::Infallible; + +use bytes::Bytes; +pub use flate2::Compression as DeflateCompressionLevel; + +use super::{OpCode, ProtocolError, RsvBits}; +use crate::header::{HeaderName, HeaderValue, TryIntoHeaderPair, SEC_WEBSOCKET_EXTENSIONS}; + +const MAX_WINDOW_BITS_RANGE: std::ops::RangeInclusive = 9..=15; +const DEFAULT_WINDOW_BITS: u8 = 15; +const BUF_SIZE: usize = 2048; + +pub(super) const RSV_BIT_DEFLATE_FLAG: RsvBits = RsvBits::RSV1; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum DeflateHandshakeError { + UnknownWebSocketParameters, + DuplicateParameter(&'static str), + MaxWindowBitsOutOfRange, + NoSuitableConfigurationFound, +} + +impl std::fmt::Display for DeflateHandshakeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::UnknownWebSocketParameters => { + write!(f, "Unknown WebSocket `permessage-deflate` parameters.") + } + Self::DuplicateParameter(p) => { + write!(f, "Duplicate WebSocket `permessage-deflate` parameter: {p}") + } + Self::MaxWindowBitsOutOfRange => write!( + f, + "Max window bits out of range. ({} to {} expected)", + MAX_WINDOW_BITS_RANGE.start(), + MAX_WINDOW_BITS_RANGE.end() + ), + Self::NoSuitableConfigurationFound => write!( + f, + "No suitable WebSocket `permedia-deflate` parameter configurations found." + ), + } + } +} + +impl std::error::Error for DeflateHandshakeError {} + +#[derive(Copy, Clone, Debug)] +pub enum ClientMaxWindowBits { + NotSpecified, + Specified(u8), +} + +#[derive(Debug, Clone, Default)] +pub struct DeflateSessionParameters { + pub server_no_context_takeover: bool, + pub client_no_context_takeover: bool, + pub server_max_window_bits: Option, + pub client_max_window_bits: Option, +} + +impl TryIntoHeaderPair for DeflateSessionParameters { + type Error = Infallible; + + fn try_into_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { + let mut response_extension = vec!["permessage-deflate".to_owned()]; + + if self.server_no_context_takeover { + response_extension.push("server_no_context_takeover".to_owned()); + } + if self.client_no_context_takeover { + response_extension.push("client_no_context_takeover".to_owned()); + } + if let Some(server_max_window_bits) = self.server_max_window_bits { + response_extension.push(format!("server_max_window_bits={server_max_window_bits}")); + } + if let Some(client_max_window_bits) = self.client_max_window_bits { + match client_max_window_bits { + ClientMaxWindowBits::NotSpecified => { + response_extension.push("client_max_window_bits".to_string()); + } + ClientMaxWindowBits::Specified(bits) => { + response_extension.push(format!("client_max_window_bits={bits}")); + } + } + } + + Ok(( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_str(&response_extension.join("; ")).unwrap(), + )) + } +} + +impl DeflateSessionParameters { + fn parse<'a>( + extension_frags: impl Iterator, + ) -> Result { + let mut client_max_window_bits = None; + let mut server_max_window_bits = None; + let mut client_no_context_takeover = None; + let mut server_no_context_takeover = None; + + let mut unknown_parameters = vec![]; + + for fragment in extension_frags { + if fragment == "client_max_window_bits" { + if client_max_window_bits.is_some() { + return Err(DeflateHandshakeError::DuplicateParameter( + "client_max_window_bits", + )); + } + client_max_window_bits = Some(ClientMaxWindowBits::NotSpecified); + } else if let Some(value) = fragment.strip_prefix("client_max_window_bits=") { + if client_max_window_bits.is_some() { + return Err(DeflateHandshakeError::DuplicateParameter( + "client_max_window_bits", + )); + } + let bits = value + .parse::() + .map_err(|_| DeflateHandshakeError::MaxWindowBitsOutOfRange)?; + if !MAX_WINDOW_BITS_RANGE.contains(&bits) { + return Err(DeflateHandshakeError::MaxWindowBitsOutOfRange); + } + client_max_window_bits = Some(ClientMaxWindowBits::Specified(bits)); + } else if let Some(value) = fragment.strip_prefix("server_max_window_bits=") { + if server_max_window_bits.is_some() { + return Err(DeflateHandshakeError::DuplicateParameter( + "server_max_window_bits", + )); + } + let bits = value + .parse::() + .map_err(|_| DeflateHandshakeError::MaxWindowBitsOutOfRange)?; + if !MAX_WINDOW_BITS_RANGE.contains(&bits) { + return Err(DeflateHandshakeError::MaxWindowBitsOutOfRange); + } + server_max_window_bits = Some(bits); + } else if fragment == "server_no_context_takeover" { + if server_no_context_takeover.is_some() { + return Err(DeflateHandshakeError::DuplicateParameter( + "server_no_context_takeover", + )); + } + server_no_context_takeover = Some(true); + } else if fragment == "client_no_context_takeover" { + if client_no_context_takeover.is_some() { + return Err(DeflateHandshakeError::DuplicateParameter( + "client_no_context_takeover", + )); + } + client_no_context_takeover = Some(true); + } else { + unknown_parameters.push(fragment.to_owned()); + } + } + + if !unknown_parameters.is_empty() { + Err(DeflateHandshakeError::UnknownWebSocketParameters) + } else { + Ok(DeflateSessionParameters { + server_no_context_takeover: server_no_context_takeover.unwrap_or(false), + client_no_context_takeover: client_no_context_takeover.unwrap_or(false), + server_max_window_bits, + client_max_window_bits, + }) + } + } + + pub fn from_extension_header(header_value: &str) -> Vec> { + let mut results = vec![]; + for extension in header_value.split(',').map(str::trim) { + let mut fragments = extension.split(';').map(str::trim); + if fragments.next() == Some("permessage-deflate") { + results.push(Self::parse(fragments)); + } + } + + results + } + + pub fn create_context( + &self, + compression_level: Option, + is_client_mode: bool, + ) -> DeflateContext { + let client_max_window_bits = + if let Some(ClientMaxWindowBits::Specified(value)) = self.client_max_window_bits { + value + } else { + DEFAULT_WINDOW_BITS + }; + let server_max_window_bits = self.server_max_window_bits.unwrap_or(DEFAULT_WINDOW_BITS); + + let (remote_no_context_takeover, remote_max_window_bits) = if is_client_mode { + (self.server_no_context_takeover, server_max_window_bits) + } else { + (self.client_no_context_takeover, client_max_window_bits) + }; + + let (local_no_context_takeover, local_max_window_bits) = if is_client_mode { + (self.client_no_context_takeover, client_max_window_bits) + } else { + (self.server_no_context_takeover, server_max_window_bits) + }; + + DeflateContext { + compress: DeflateCompressionContext::new( + compression_level, + remote_no_context_takeover, + remote_max_window_bits, + ), + decompress: DeflateDecompressionContext::new( + local_no_context_takeover, + local_max_window_bits, + ), + } + } +} + +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct DeflateServerConfig { + pub compression_level: Option, + + pub server_no_context_takeover: bool, + pub client_no_context_takeover: bool, + pub server_max_window_bits: Option, + pub client_max_window_bits: Option, +} + +impl DeflateServerConfig { + pub fn negotiate(&self, params: DeflateSessionParameters) -> DeflateSessionParameters { + let server_no_context_takeover = + if self.server_no_context_takeover && !params.server_no_context_takeover { + true + } else { + params.server_no_context_takeover + }; + + let client_no_context_takeover = + if self.client_no_context_takeover && !params.client_no_context_takeover { + true + } else { + params.client_no_context_takeover + }; + + let server_max_window_bits = + match (self.server_max_window_bits, params.server_max_window_bits) { + (None, value) => value, + (Some(config_value), None) => Some(config_value), + (Some(config_value), Some(value)) => { + if value > config_value { + Some(config_value) + } else { + Some(value) + } + } + }; + + let client_max_window_bits = + match (self.client_max_window_bits, params.client_max_window_bits) { + (None, None | Some(ClientMaxWindowBits::NotSpecified)) => None, + (None, Some(ClientMaxWindowBits::Specified(value))) => Some(value), + (Some(_), None) => None, + (Some(config_value), Some(ClientMaxWindowBits::NotSpecified)) => Some(config_value), + (Some(config_value), Some(ClientMaxWindowBits::Specified(value))) => { + if value > config_value { + Some(config_value) + } else { + Some(value) + } + } + }; + + DeflateSessionParameters { + server_no_context_takeover, + client_no_context_takeover, + server_max_window_bits, + client_max_window_bits: client_max_window_bits.map(ClientMaxWindowBits::Specified), + } + } +} + +#[derive(Debug)] +pub struct DeflateDecompressionContext { + pub(super) local_no_context_takeover: bool, + pub(super) local_max_window_bits: u8, + + decompress: flate2::Decompress, + + decode_continuation: bool, + total_bytes_written: u64, + total_bytes_read: u64, +} + +impl Clone for DeflateDecompressionContext { + fn clone(&self) -> Self { + // Create with empty context because the context is not meant to be cloned. + Self::new(self.local_no_context_takeover, self.local_max_window_bits) + } +} + +impl DeflateDecompressionContext { + fn new(local_no_context_takeover: bool, local_max_window_bits: u8) -> Self { + Self { + local_no_context_takeover, + local_max_window_bits, + + decompress: flate2::Decompress::new_with_window_bits(false, local_max_window_bits), + + decode_continuation: false, + total_bytes_written: 0, + total_bytes_read: 0, + } + } + + pub fn reset_with(&mut self, local_no_context_takeover: bool, local_max_window_bits: u8) { + *self = Self::new(local_no_context_takeover, local_max_window_bits); + } + + pub fn decompress( + &mut self, + fin: bool, + opcode: OpCode, + rsv: RsvBits, + payload: Bytes, + ) -> Result { + if !matches!(opcode, OpCode::Text | OpCode::Binary | OpCode::Continue) + || !rsv.contains(RSV_BIT_DEFLATE_FLAG) + { + return Ok(payload); + } + + if opcode == OpCode::Continue { + if !self.decode_continuation { + return Ok(payload); + } + } else { + self.decode_continuation = true; + } + + let mut output: Vec = vec![]; + let mut buf = [0u8; BUF_SIZE]; + + let mut offset: usize = 0; + loop { + let res = if offset >= payload.len() { + self.decompress + .decompress( + &[0x00, 0x00, 0xff, 0xff], + &mut buf, + flate2::FlushDecompress::Finish, + ) + .map_err(|e| { + self.reset(); + ProtocolError::Io(e.into()) + })? + } else { + self.decompress + .decompress(&payload[offset..], &mut buf, flate2::FlushDecompress::None) + .map_err(|e| { + self.reset(); + ProtocolError::Io(e.into()) + })? + }; + + let read = self.decompress.total_in() - self.total_bytes_read; + let written = self.decompress.total_out() - self.total_bytes_written; + + offset += read as usize; + self.total_bytes_read += read; + if written > 0 { + output.extend(buf.iter().take(written as usize)); + self.total_bytes_written += written; + } + + if res != flate2::Status::Ok { + break; + } + } + + if fin { + self.decode_continuation = false; + if self.local_no_context_takeover { + self.reset(); + } + } + + Ok(output.into()) + } + + pub(super) fn reset(&mut self) { + self.decompress.reset(false); + self.total_bytes_read = 0; + self.total_bytes_written = 0; + } +} + +#[derive(Debug)] +pub struct DeflateCompressionContext { + compression_level: flate2::Compression, + pub(super) remote_no_context_takeover: bool, + pub(super) remote_max_window_bits: u8, + + compress: flate2::Compress, + total_bytes_written: u64, + total_bytes_read: u64, +} + +impl Clone for DeflateCompressionContext { + fn clone(&self) -> Self { + // Create with empty context because the context is not meant to be cloned. + Self::new( + Some(self.compression_level), + self.remote_no_context_takeover, + self.remote_max_window_bits, + ) + } +} + +impl DeflateCompressionContext { + fn new( + compression_level: Option, + remote_no_context_takeover: bool, + remote_max_window_bits: u8, + ) -> Self { + let compression_level = compression_level.unwrap_or_default(); + + Self { + compression_level, + remote_no_context_takeover, + remote_max_window_bits, + + compress: flate2::Compress::new_with_window_bits( + compression_level, + false, + remote_max_window_bits, + ), + + total_bytes_written: 0, + total_bytes_read: 0, + } + } + + pub fn reset_with( + mut self, + remote_no_context_takeover: bool, + remote_max_window_bits: u8, + ) -> Self { + self = Self::new( + Some(self.compression_level), + remote_no_context_takeover, + remote_max_window_bits, + ); + + self + } + + pub fn compress(&mut self, fin: bool, payload: Bytes) -> Result { + let mut output = vec![]; + let mut buf = [0u8; BUF_SIZE]; + + loop { + let total_in = self.compress.total_in() - self.total_bytes_read; + let res = if total_in >= payload.len() as u64 { + self.compress + .compress(&[], &mut buf, flate2::FlushCompress::Sync) + .map_err(|e| { + self.reset(); + ProtocolError::Io(e.into()) + })? + } else { + self.compress + .compress(&payload, &mut buf, flate2::FlushCompress::None) + .map_err(|e| { + self.reset(); + ProtocolError::Io(e.into()) + })? + }; + + let written = self.compress.total_out() - self.total_bytes_written; + if written > 0 { + output.extend(buf.iter().take(written as usize)); + self.total_bytes_written += written; + } + + if res != flate2::Status::Ok { + break; + } + } + self.total_bytes_read = self.compress.total_in(); + + if output.iter().rev().take(4).eq(&[0xff, 0xff, 0x00, 0x00]) { + output.drain(output.len() - 4..); + } + + if fin && self.remote_no_context_takeover { + self.reset(); + } + + Ok(output.into()) + } + + fn reset(&mut self) { + self.compress.reset(); + self.total_bytes_read = 0; + self.total_bytes_written = 0; + } +} + +#[derive(Debug)] +pub struct DeflateContext { + pub compress: DeflateCompressionContext, + pub decompress: DeflateDecompressionContext, +} diff --git a/actix-http/src/ws/frame.rs b/actix-http/src/ws/frame.rs index e166f1cf516..0bd64a46522 100644 --- a/actix-http/src/ws/frame.rs +++ b/actix-http/src/ws/frame.rs @@ -237,18 +237,22 @@ mod tests { struct F { finished: bool, opcode: OpCode, + rsv_bits: RsvBits, payload: Bytes, } - fn is_none(frm: &Result)>, ProtocolError>) -> bool { + fn is_none( + frm: &Result)>, ProtocolError>, + ) -> bool { matches!(*frm, Ok(None)) } - fn extract(frm: Result)>, ProtocolError>) -> F { + fn extract(frm: Result)>, ProtocolError>) -> F { match frm { - Ok(Some((finished, opcode, payload))) => F { + Ok(Some((finished, opcode, rsv_bits, payload))) => F { finished, opcode, + rsv_bits, payload: payload .map(|b| b.freeze()) .unwrap_or_else(|| Bytes::from("")), @@ -269,6 +273,17 @@ mod tests { assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload.as_ref(), &b"1"[..]); + + let mut buf = BytesMut::from(&[0b1111_0001u8, 0b0000_0001u8][..]); + buf.extend(b"2"); + + let frame = extract(Parser::parse(&mut buf, false, 1024)); + assert!(frame.finished); + assert_eq!(frame.opcode, OpCode::Text); + assert_eq!(frame.payload.as_ref(), &b"2"[..]); + assert!(frame.rsv_bits.contains(RsvBits::RSV1)); + assert!(frame.rsv_bits.contains(RsvBits::RSV2)); + assert!(frame.rsv_bits.contains(RsvBits::RSV3)); } #[test] @@ -377,7 +392,14 @@ mod tests { #[test] fn test_ping_frame() { let mut buf = BytesMut::new(); - Parser::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false); + Parser::write_message( + &mut buf, + Vec::from("data"), + OpCode::Ping, + RsvBits::empty(), + true, + false, + ); let mut v = vec![137u8, 4u8]; v.extend(b"data"); @@ -387,7 +409,14 @@ mod tests { #[test] fn test_pong_frame() { let mut buf = BytesMut::new(); - Parser::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false); + Parser::write_message( + &mut buf, + Vec::from("data"), + OpCode::Pong, + RsvBits::empty(), + true, + false, + ); let mut v = vec![138u8, 4u8]; v.extend(b"data"); @@ -398,7 +427,7 @@ mod tests { fn test_close_frame() { let mut buf = BytesMut::new(); let reason = (CloseCode::Normal, "data"); - Parser::write_close(&mut buf, Some(reason.into()), false); + Parser::write_close(&mut buf, Some(reason.into()), RsvBits::empty(), false); let mut v = vec![136u8, 6u8, 3u8, 232u8]; v.extend(b"data"); @@ -408,7 +437,7 @@ mod tests { #[test] fn test_empty_close_frame() { let mut buf = BytesMut::new(); - Parser::write_close(&mut buf, None, false); + Parser::write_close(&mut buf, None, RsvBits::empty(), false); assert_eq!(&buf[..], &vec![0x88, 0x00][..]); } } diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index 811e634747c..90ef4a932c7 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -11,11 +11,15 @@ use http::{header, Method, StatusCode}; use crate::{body::BoxBody, header::HeaderValue, RequestHead, Response, ResponseBuilder}; mod codec; +#[cfg(feature = "compress-ws-deflate")] +mod deflate; mod dispatcher; mod frame; mod mask; mod proto; +#[cfg(feature = "compress-ws-deflate")] +pub use self::deflate::{DeflateCompressionLevel, DeflateServerConfig, DeflateSessionParameters}; pub use self::{ codec::{Codec, Frame, Item, Message}, dispatcher::Dispatcher, @@ -93,6 +97,11 @@ pub enum HandshakeError { /// WebSocket key is not set or wrong. #[display("unknown WebSocket key")] BadWebsocketKey, + + /// Invalid `permessage-deflate` request. + #[cfg(feature = "compress-ws-deflate")] + #[display(fmt = "invalid WebSocket `permessage-deflate` extension request")] + BadDeflateRequest(deflate::DeflateHandshakeError), } impl From for Response { @@ -135,6 +144,13 @@ impl From for Response { res.head_mut().reason = Some("Handshake error"); res } + + #[cfg(feature = "compress-ws-deflate")] + HandshakeError::BadDeflateRequest(_) => { + let mut res = Response::bad_request(); + res.head_mut().reason = Some("Invalid permessage-deflate request"); + res + } } } } @@ -151,6 +167,60 @@ pub fn handshake(req: &RequestHead) -> Result { Ok(handshake_response(req)) } +/// Verify WebSocket handshake request with DEFLATE compression configurations. +#[cfg(feature = "compress-ws-deflate")] +pub fn handshake_with_deflate( + req: &RequestHead, + config: &deflate::DeflateServerConfig, +) -> Result<(ResponseBuilder, Option), HandshakeError> { + verify_handshake(req)?; + + let mut available_configurations = vec![]; + for header in req.headers().get_all(header::SEC_WEBSOCKET_EXTENSIONS) { + let Ok(header_str) = header.to_str() else { + continue; + }; + + available_configurations.extend(deflate::DeflateSessionParameters::from_extension_header( + header_str, + )); + } + + let mut selected_config = None; + let mut selected_error = None; + for config in available_configurations { + match config { + Ok(v) => { + selected_config = Some(v); + break; + } + Err(e) => { + if selected_error.is_none() { + selected_error = Some(e); + } else { + selected_error = + Some(deflate::DeflateHandshakeError::NoSuitableConfigurationFound); + } + } + } + } + + if let Some(selected_error) = selected_error { + Err(HandshakeError::BadDeflateRequest(selected_error)) + } else { + let mut response = handshake_response(req); + + if let Some(selected_config) = selected_config { + let param = config.negotiate(selected_config); + let context = param.create_context(config.compression_level, false); + response.insert_header(param); + Ok((response, Some(context))) + } else { + Ok((response, None)) + } + } +} + /// Verify WebSocket handshake request. pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> { // WebSocket accepts only GET @@ -196,6 +266,7 @@ pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> { if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) { return Err(HandshakeError::BadWebsocketKey); } + Ok(()) } diff --git a/awc/Cargo.toml b/awc/Cargo.toml index c09f32ac862..95db5e761b1 100644 --- a/awc/Cargo.toml +++ b/awc/Cargo.toml @@ -63,9 +63,15 @@ rustls-0_20 = ["tls-rustls-0_20", "actix-tls/rustls-0_20"] # TLS via Rustls v0.21 rustls-0_21 = ["tls-rustls-0_21", "actix-tls/rustls-0_21"] # TLS via Rustls v0.22 (WebPKI roots) -rustls-0_22-webpki-roots = ["tls-rustls-0_22", "actix-tls/rustls-0_22-webpki-roots"] +rustls-0_22-webpki-roots = [ + "tls-rustls-0_22", + "actix-tls/rustls-0_22-webpki-roots", +] # TLS via Rustls v0.22 (Native roots) -rustls-0_22-native-roots = ["tls-rustls-0_22", "actix-tls/rustls-0_22-native-roots"] +rustls-0_22-native-roots = [ + "tls-rustls-0_22", + "actix-tls/rustls-0_22-native-roots", +] # TLS via Rustls v0.23 rustls-0_23 = ["tls-rustls-0_23", "actix-tls/rustls-0_23"] # TLS via Rustls v0.23 (WebPKI roots) @@ -79,6 +85,8 @@ compress-brotli = ["actix-http/compress-brotli", "__compress"] compress-gzip = ["actix-http/compress-gzip", "__compress"] # Zstd algorithm content-encoding support compress-zstd = ["actix-http/compress-zstd", "__compress"] +# Deflate compression for WebSocket +compress-ws-deflate = ["actix-http/compress-ws-deflate"] # Cookie parsing and cookie jar cookies = ["dep:cookie"] @@ -112,7 +120,7 @@ futures-util = { version = "0.3.17", default-features = false, features = ["allo h2 = "0.3.26" http = "0.2.7" itoa = "1" -log =" 0.4" +log = " 0.4" mime = "0.3" percent-encoding = "2.1" pin-project-lite = "0.2" @@ -125,8 +133,12 @@ tokio = { version = "1.24.2", features = ["sync"] } cookie = { version = "0.16", features = ["percent-encode"], optional = true } tls-openssl = { package = "openssl", version = "0.10.55", optional = true } -tls-rustls-0_20 = { package = "rustls", version = "0.20", optional = true, features = ["dangerous_configuration"] } -tls-rustls-0_21 = { package = "rustls", version = "0.21", optional = true, features = ["dangerous_configuration"] } +tls-rustls-0_20 = { package = "rustls", version = "0.20", optional = true, features = [ + "dangerous_configuration", +] } +tls-rustls-0_21 = { package = "rustls", version = "0.21", optional = true, features = [ + "dangerous_configuration", +] } tls-rustls-0_22 = { package = "rustls", version = "0.22", optional = true } tls-rustls-0_23 = { package = "rustls", version = "0.23", optional = true, default-features = false } @@ -151,7 +163,7 @@ rcgen = "0.13" rustls-pemfile = "2" tokio = { version = "1.24.2", features = ["rt-multi-thread", "macros"] } zstd = "0.13" -tls-rustls-0_23 = { package = "rustls", version = "0.23" } # add rustls 0.23 with default features to make aws_lc_rs work in tests +tls-rustls-0_23 = { package = "rustls", version = "0.23" } # add rustls 0.23 with default features to make aws_lc_rs work in tests [lints] workspace = true diff --git a/awc/src/ws.rs b/awc/src/ws.rs index 760331e9d6a..376dccaa9bb 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -30,6 +30,8 @@ use std::{fmt, net::SocketAddr, str}; use actix_codec::Framed; pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message}; +#[cfg(feature = "compress-ws-deflate")] +pub use actix_http::ws::{DeflateCompressionLevel, DeflateSessionParameters}; use actix_http::{ws, Payload, RequestHead}; use actix_rt::time::timeout; use actix_service::Service as _; @@ -59,6 +61,9 @@ pub struct WebsocketsRequest { server_mode: bool, config: ClientConfig, + #[cfg(feature = "compress-ws-deflate")] + deflate_compression_level: Option, + #[cfg(feature = "cookies")] cookies: Option, } @@ -94,6 +99,8 @@ impl WebsocketsRequest { protocols: None, max_size: 65_536, server_mode: false, + #[cfg(feature = "compress-ws-deflate")] + deflate_compression_level: None, #[cfg(feature = "cookies")] cookies: None, } @@ -249,6 +256,22 @@ impl WebsocketsRequest { self.header(AUTHORIZATION, format!("Bearer {}", token)) } + /// Enable DEFLATE compression + #[cfg(feature = "compress-ws-deflate")] + pub fn deflate( + mut self, + compression_level: Option, + params: DeflateSessionParameters, + ) -> Self { + use actix_http::header::TryIntoHeaderPair; + // Assume session parameters are always valid. + let (key, value) = params.try_into_pair().unwrap(); + + self.deflate_compression_level = compression_level; + + self.header(key, value) + } + /// Complete request construction and connect to a WebSocket server. pub async fn connect( mut self, @@ -409,17 +432,51 @@ impl WebsocketsRequest { return Err(WsClientError::MissingWebSocketAcceptHeader); }; - // response and ws framed - Ok(( - ClientResponse::new(head, Payload::None), - framed.into_map_codec(|_| { + #[cfg(feature = "compress-ws-deflate")] + let framed = { + let selected_parameter = head + .headers + .get_all(header::SEC_WEBSOCKET_EXTENSIONS) + .filter_map(|header| { + if let Ok(header_str) = header.to_str() { + Some(DeflateSessionParameters::from_extension_header(header_str)) + } else { + None + } + }) + .flatten() + .filter_map(Result::ok) + .next(); + + framed.into_map_codec(move |_| { + let codec = if let Some(parameter) = selected_parameter.clone() { + let context = parameter.create_context(self.deflate_compression_level, false); + Codec::new_deflate(context) + } else { + Codec::new() + } + .max_size(max_size); + if server_mode { - ws::Codec::new().max_size(max_size) + codec } else { - ws::Codec::new().max_size(max_size).client_mode() + codec.client_mode() } - }), - )) + }) + }; + #[cfg(not(feature = "compress-ws-deflate"))] + let framed = framed.into_map_codec(move |_| { + let codec = Codec::new().max_size(max_size); + + if server_mode { + codec + } else { + codec.client_mode() + } + }); + + // response and ws framed + Ok((ClientResponse::new(head, Payload::None), framed)) } } From 5bab1567006c04a553fdc71a7ef242f4df1c44ec Mon Sep 17 00:00:00 2001 From: Park Joon-Kyu Date: Wed, 6 Nov 2024 21:57:48 +0900 Subject: [PATCH 3/8] add docs --- actix-http/src/ws/codec.rs | 39 ++++++++++++++++-- actix-http/src/ws/deflate.rs | 79 ++++++++++++++++++++++-------------- actix-http/src/ws/mod.rs | 10 ++--- 3 files changed, 89 insertions(+), 39 deletions(-) diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs index 6c096eb4b68..6add7ada4e7 100644 --- a/actix-http/src/ws/codec.rs +++ b/actix-http/src/ws/codec.rs @@ -80,7 +80,7 @@ bitflags! { } /// WebSocket message encoder. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct Encoder { flags: Flags, @@ -283,7 +283,7 @@ impl codec::Encoder for Encoder { } /// WebSocket message decoder. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct Decoder { flags: Flags, max_size: usize, @@ -461,12 +461,45 @@ impl codec::Decoder for Decoder { } /// WebSocket protocol codec. -#[derive(Debug, Default, Clone)] +/// +/// # Note +/// Cloning [`Codec`] creates a new codec with existing configurations +/// and will not preserve the current context. +#[derive(Debug, Default)] pub struct Codec { encoder: Encoder, decoder: Decoder, } +impl Clone for Codec { + fn clone(&self) -> Self { + Self { + encoder: Encoder { + flags: self.encoder.flags & Flags::SERVER, + #[cfg(feature = "compress-ws-deflate")] + deflate_compress: self.encoder.deflate_compress.as_ref().map(|c| { + DeflateCompressionContext::new( + Some(c.compression_level), + c.remote_no_context_takeover, + c.remote_max_window_bits, + ) + }), + }, + decoder: Decoder { + flags: self.decoder.flags & Flags::SERVER, + max_size: self.decoder.max_size, + #[cfg(feature = "compress-ws-deflate")] + deflate_decompress: self.decoder.deflate_decompress.as_ref().map(|d| { + DeflateDecompressionContext::new( + d.local_no_context_takeover, + d.local_max_window_bits, + ) + }), + }, + } + } +} + impl Codec { /// Create new WebSocket frames codec. pub fn new() -> Codec { diff --git a/actix-http/src/ws/deflate.rs b/actix-http/src/ws/deflate.rs index bb256c18f2c..ef3dbf48cf9 100644 --- a/actix-http/src/ws/deflate.rs +++ b/actix-http/src/ws/deflate.rs @@ -1,3 +1,7 @@ +//! WebSocket permessage-deflate compression implementation. +//! +//! + use std::convert::Infallible; use bytes::Bytes; @@ -6,17 +10,31 @@ pub use flate2::Compression as DeflateCompressionLevel; use super::{OpCode, ProtocolError, RsvBits}; use crate::header::{HeaderName, HeaderValue, TryIntoHeaderPair, SEC_WEBSOCKET_EXTENSIONS}; +// NOTE: according to [RFC 7692 §7.1.2.1] window bit size should be within 8..=15 +// but we have to limit the range to 9..=15 because [flate2] only supports window bit within 9..=15. +// +// [RFC 6792]: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1.2.1 +// [flate2]: https://docs.rs/flate2/latest/flate2/struct.Compress.html#method.new_with_window_bits const MAX_WINDOW_BITS_RANGE: std::ops::RangeInclusive = 9..=15; const DEFAULT_WINDOW_BITS: u8 = 15; + const BUF_SIZE: usize = 2048; pub(super) const RSV_BIT_DEFLATE_FLAG: RsvBits = RsvBits::RSV1; +/// DEFLATE compression related handshake errors. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum DeflateHandshakeError { + /// Unknown extension parameter given. UnknownWebSocketParameters, + + /// Duplicate parameter found in single extension statement. DuplicateParameter(&'static str), + + /// Max window bits size out of range. Should be in 9..=15 MaxWindowBitsOutOfRange, + + /// Multiple `permessage-deflate` statements found but failed to negotiate any. NoSuitableConfigurationFound, } @@ -45,17 +63,28 @@ impl std::fmt::Display for DeflateHandshakeError { impl std::error::Error for DeflateHandshakeError {} +/// Maximum size of client's DEFLATE sliding window. #[derive(Copy, Clone, Debug)] pub enum ClientMaxWindowBits { + /// Unspecified. Indicates server should decide its size. NotSpecified, + /// Specified size of client's DEFLATE sliding window size in bits, between 9 and 15. Specified(u8), } +/// DEFLATE negotiation parameter. It can be used both client and server side. +/// At client side, it can be used to pass desired configuration to server. +/// At server side, negotiated parameter will be sent to client with this. +/// This can be represented in HTTP header form as it implements [`TryIntoHeaderPair`] trait. #[derive(Debug, Clone, Default)] pub struct DeflateSessionParameters { + /// Disallow server from take over context. pub server_no_context_takeover: bool, + /// Disallow client from take over context. pub client_no_context_takeover: bool, + /// Maximum size of server's DEFLATE sliding window in bits, between 9 and 15. pub server_max_window_bits: Option, + /// Maximum size of client's DEFLATE sliding window. pub client_max_window_bits: Option, } @@ -219,8 +248,10 @@ impl DeflateSessionParameters { } } +/// Server-side DEFLATE configuration. #[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct DeflateServerConfig { + /// DEFLATE compression level. See [`flate2::`] pub compression_level: Option, pub server_no_context_takeover: bool, @@ -294,15 +325,8 @@ pub struct DeflateDecompressionContext { total_bytes_read: u64, } -impl Clone for DeflateDecompressionContext { - fn clone(&self) -> Self { - // Create with empty context because the context is not meant to be cloned. - Self::new(self.local_no_context_takeover, self.local_max_window_bits) - } -} - impl DeflateDecompressionContext { - fn new(local_no_context_takeover: bool, local_max_window_bits: u8) -> Self { + pub(super) fn new(local_no_context_takeover: bool, local_max_window_bits: u8) -> Self { Self { local_no_context_takeover, local_max_window_bits, @@ -315,7 +339,11 @@ impl DeflateDecompressionContext { } } - pub fn reset_with(&mut self, local_no_context_takeover: bool, local_max_window_bits: u8) { + pub(super) fn reset_with( + &mut self, + local_no_context_takeover: bool, + local_max_window_bits: u8, + ) { *self = Self::new(local_no_context_takeover, local_max_window_bits); } @@ -352,16 +380,16 @@ impl DeflateDecompressionContext { &mut buf, flate2::FlushDecompress::Finish, ) - .map_err(|e| { + .map_err(|err| { self.reset(); - ProtocolError::Io(e.into()) + ProtocolError::Io(err.into()) })? } else { self.decompress .decompress(&payload[offset..], &mut buf, flate2::FlushDecompress::None) - .map_err(|e| { + .map_err(|err| { self.reset(); - ProtocolError::Io(e.into()) + ProtocolError::Io(err.into()) })? }; @@ -399,7 +427,7 @@ impl DeflateDecompressionContext { #[derive(Debug)] pub struct DeflateCompressionContext { - compression_level: flate2::Compression, + pub(super) compression_level: flate2::Compression, pub(super) remote_no_context_takeover: bool, pub(super) remote_max_window_bits: u8, @@ -408,19 +436,8 @@ pub struct DeflateCompressionContext { total_bytes_read: u64, } -impl Clone for DeflateCompressionContext { - fn clone(&self) -> Self { - // Create with empty context because the context is not meant to be cloned. - Self::new( - Some(self.compression_level), - self.remote_no_context_takeover, - self.remote_max_window_bits, - ) - } -} - impl DeflateCompressionContext { - fn new( + pub(super) fn new( compression_level: Option, remote_no_context_takeover: bool, remote_max_window_bits: u8, @@ -443,7 +460,7 @@ impl DeflateCompressionContext { } } - pub fn reset_with( + pub(super) fn reset_with( mut self, remote_no_context_takeover: bool, remote_max_window_bits: u8, @@ -466,16 +483,16 @@ impl DeflateCompressionContext { let res = if total_in >= payload.len() as u64 { self.compress .compress(&[], &mut buf, flate2::FlushCompress::Sync) - .map_err(|e| { + .map_err(|err| { self.reset(); - ProtocolError::Io(e.into()) + ProtocolError::Io(err.into()) })? } else { self.compress .compress(&payload, &mut buf, flate2::FlushCompress::None) - .map_err(|e| { + .map_err(|err| { self.reset(); - ProtocolError::Io(e.into()) + ProtocolError::Io(err.into()) })? }; diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index 90ef4a932c7..a5521e500bf 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -100,7 +100,7 @@ pub enum HandshakeError { /// Invalid `permessage-deflate` request. #[cfg(feature = "compress-ws-deflate")] - #[display(fmt = "invalid WebSocket `permessage-deflate` extension request")] + #[display("invalid WebSocket `permessage-deflate` extension request")] BadDeflateRequest(deflate::DeflateHandshakeError), } @@ -190,13 +190,13 @@ pub fn handshake_with_deflate( let mut selected_error = None; for config in available_configurations { match config { - Ok(v) => { - selected_config = Some(v); + Ok(config) => { + selected_config = Some(config); break; } - Err(e) => { + Err(err) => { if selected_error.is_none() { - selected_error = Some(e); + selected_error = Some(err); } else { selected_error = Some(deflate::DeflateHandshakeError::NoSuitableConfigurationFound); From a2e3b62926b95a286165fa5d23473c64d6c86c2d Mon Sep 17 00:00:00 2001 From: Park Joon-Kyu Date: Wed, 6 Nov 2024 21:57:48 +0900 Subject: [PATCH 4/8] Add tests --- actix-http/src/lib.rs | 25 +-- actix-http/src/ws/codec.rs | 58 ++++-- actix-http/src/ws/deflate.rs | 360 ++++++++++++++++++++++++++++++++--- actix-http/src/ws/mod.rs | 21 +- actix-http/src/ws/proto.rs | 2 +- 5 files changed, 403 insertions(+), 63 deletions(-) diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs index 734e6e1e159..a1b218f26aa 100644 --- a/actix-http/src/lib.rs +++ b/actix-http/src/lib.rs @@ -2,18 +2,19 @@ //! //! ## Crate Features //! -//! | Feature | Functionality | -//! | ------------------- | ------------------------------------------- | -//! | `http2` | HTTP/2 support via [h2]. | -//! | `openssl` | TLS support via [OpenSSL]. | -//! | `rustls-0_20` | TLS support via rustls 0.20. | -//! | `rustls-0_21` | TLS support via rustls 0.21. | -//! | `rustls-0_22` | TLS support via rustls 0.22. | -//! | `rustls-0_23` | TLS support via [rustls] 0.23. | -//! | `compress-brotli` | Payload compression support: Brotli. | -//! | `compress-gzip` | Payload compression support: Deflate, Gzip. | -//! | `compress-zstd` | Payload compression support: Zstd. | -//! | `trust-dns` | Use [trust-dns] as the client DNS resolver. | +//! | Feature | Functionality | +//! | --------------------- | ------------------------------------------- | +//! | `http2` | HTTP/2 support via [h2]. | +//! | `openssl` | TLS support via [OpenSSL]. | +//! | `rustls-0_20` | TLS support via rustls 0.20. | +//! | `rustls-0_21` | TLS support via rustls 0.21. | +//! | `rustls-0_22` | TLS support via rustls 0.22. | +//! | `rustls-0_23` | TLS support via [rustls] 0.23. | +//! | `compress-brotli` | Payload compression support: Brotli. | +//! | `compress-gzip` | Payload compression support: Deflate, Gzip. | +//! | `compress-zstd` | Payload compression support: Zstd. | +//! | `compress-ws-deflate` | WebSocket DEFLATE compression support. | +//! | `trust-dns` | Use [trust-dns] as the client DNS resolver. | //! //! [h2]: https://crates.io/crates/h2 //! [OpenSSL]: https://crates.io/crates/openssl diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs index 6add7ada4e7..526ce23bdc6 100644 --- a/actix-http/src/ws/codec.rs +++ b/actix-http/src/ws/codec.rs @@ -6,7 +6,7 @@ use tracing::error; #[cfg(feature = "compress-ws-deflate")] use super::deflate::{ - DeflateCompressionContext, DeflateContext, DeflateDecompressionContext, RSV_BIT_DEFLATE_FLAG, + DeflateCompressionContext, DeflateDecompressionContext, RSV_BIT_DEFLATE_FLAG, }; use super::{ frame::Parser, @@ -100,6 +100,8 @@ impl Encoder { } /// Create new WebSocket frames encoder with `permessage-deflate` extension support. + /// Compression context can be made from + /// [`DeflateSessionParameters::create_context`](super::DeflateSessionParameters::create_context). #[cfg(feature = "compress-ws-deflate")] pub fn new_deflate(compress: DeflateCompressionContext) -> Encoder { Encoder { @@ -109,7 +111,11 @@ impl Encoder { } } - fn set_client_mode(mut self) -> Self { + /// Set encoder to client mode. + /// + /// By default encoder works in server mode. + #[must_use = "This returns the a new Encoder, without modifying the original."] + pub fn client_mode(mut self) -> Self { self.flags = Flags::empty(); self } @@ -305,6 +311,8 @@ impl Decoder { } /// Create new WebSocket frames decoder with `permessage-deflate` extension support. + /// Decompression context can be made from + /// [`DeflateSessionParameters::create_context`](super::DeflateSessionParameters::create_context). #[cfg(feature = "compress-ws-deflate")] pub fn new_deflate(decompress: DeflateDecompressionContext) -> Decoder { Decoder { @@ -315,7 +323,20 @@ impl Decoder { } } - fn set_client_mode(mut self) -> Self { + /// Set max frame size. + /// + /// By default max size is set to 64KiB. + #[must_use = "This returns the a new Decoder, without modifying the original."] + pub fn max_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + + /// Set decoder to client mode. + /// + /// By default decoder works in server mode. + #[must_use = "This returns the a new Decoder, without modifying the original."] + pub fn client_mode(mut self) -> Self { self.flags = Flags::empty(); self } @@ -333,11 +354,6 @@ impl Decoder { self } - fn set_max_size(mut self, size: usize) -> Self { - self.max_size = size; - self - } - #[cfg(feature = "compress-ws-deflate")] fn process_payload( &mut self, @@ -461,10 +477,12 @@ impl codec::Decoder for Decoder { } /// WebSocket protocol codec. +/// This is essentially a combination of [`Encoder`] and [`Decoder`] and +/// actual conversion behaviors are defined in both structs respectively. /// /// # Note /// Cloning [`Codec`] creates a new codec with existing configurations -/// and will not preserve the current context. +/// and will not preserve the context information. #[derive(Debug, Default)] pub struct Codec { encoder: Encoder, @@ -510,13 +528,13 @@ impl Codec { } /// Create new WebSocket frames codec with DEFLATE compression. + /// Both compression and decompression contexts can be made from + /// [`DeflateSessionParameters::create_context`](super::DeflateSessionParameters::create_context). #[cfg(feature = "compress-ws-deflate")] - pub fn new_deflate(context: DeflateContext) -> Codec { - let DeflateContext { - compress, - decompress, - } = context; - + pub fn new_deflate( + compress: DeflateCompressionContext, + decompress: DeflateDecompressionContext, + ) -> Codec { Codec { encoder: Encoder::new_deflate(compress), decoder: Decoder::new_deflate(decompress), @@ -532,13 +550,13 @@ impl Codec { Codec { encoder, - decoder: decoder.set_max_size(size), + decoder: decoder.max_size(size), } } - /// Set decoder to client mode. + /// Set codec to client mode. /// - /// By default decoder works in server mode. + /// By default codec works in server mode. #[must_use = "This returns the a new Codec, without modifying the original."] pub fn client_mode(self) -> Self { let Self { @@ -546,8 +564,8 @@ impl Codec { mut decoder, } = self; - encoder = encoder.set_client_mode(); - decoder = decoder.set_client_mode(); + encoder = encoder.client_mode(); + decoder = decoder.client_mode(); #[cfg(feature = "compress-ws-deflate")] { if let Some(decoder) = &decoder.deflate_decompress { diff --git a/actix-http/src/ws/deflate.rs b/actix-http/src/ws/deflate.rs index ef3dbf48cf9..dc02f46ef1f 100644 --- a/actix-http/src/ws/deflate.rs +++ b/actix-http/src/ws/deflate.rs @@ -1,6 +1,4 @@ //! WebSocket permessage-deflate compression implementation. -//! -//! use std::convert::Infallible; @@ -13,8 +11,8 @@ use crate::header::{HeaderName, HeaderValue, TryIntoHeaderPair, SEC_WEBSOCKET_EX // NOTE: according to [RFC 7692 §7.1.2.1] window bit size should be within 8..=15 // but we have to limit the range to 9..=15 because [flate2] only supports window bit within 9..=15. // -// [RFC 6792]: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1.2.1 -// [flate2]: https://docs.rs/flate2/latest/flate2/struct.Compress.html#method.new_with_window_bits +// [RFC 6792 §7.1.2.1]: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1.2.1 +// [flate2]: https://docs.rs/flate2/latest/flate2/struct.Compress.html#method.new_with_window_bits const MAX_WINDOW_BITS_RANGE: std::ops::RangeInclusive = 9..=15; const DEFAULT_WINDOW_BITS: u8 = 15; @@ -64,7 +62,7 @@ impl std::fmt::Display for DeflateHandshakeError { impl std::error::Error for DeflateHandshakeError {} /// Maximum size of client's DEFLATE sliding window. -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum ClientMaxWindowBits { /// Unspecified. Indicates server should decide its size. NotSpecified, @@ -76,7 +74,7 @@ pub enum ClientMaxWindowBits { /// At client side, it can be used to pass desired configuration to server. /// At server side, negotiated parameter will be sent to client with this. /// This can be represented in HTTP header form as it implements [`TryIntoHeaderPair`] trait. -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, Eq, PartialEq)] pub struct DeflateSessionParameters { /// Disallow server from take over context. pub server_no_context_takeover: bool, @@ -133,7 +131,9 @@ impl DeflateSessionParameters { let mut unknown_parameters = vec![]; for fragment in extension_frags { - if fragment == "client_max_window_bits" { + if fragment.is_empty() { + continue; + } else if fragment == "client_max_window_bits" { if client_max_window_bits.is_some() { return Err(DeflateHandshakeError::DuplicateParameter( "client_max_window_bits", @@ -197,6 +197,9 @@ impl DeflateSessionParameters { } } + /// Parse desired parameters from `Sec-WebSocket-Extensions` header. + /// The result may contain multiple values as it's possible to pass multiple parameters + /// separated with comma. pub fn from_extension_header(header_value: &str) -> Vec> { let mut results = vec![]; for extension in header_value.split(',').map(str::trim) { @@ -209,11 +212,12 @@ impl DeflateSessionParameters { results } + /// Create compression and decompression context based on the parameter. pub fn create_context( &self, compression_level: Option, is_client_mode: bool, - ) -> DeflateContext { + ) -> (DeflateCompressionContext, DeflateDecompressionContext) { let client_max_window_bits = if let Some(ClientMaxWindowBits::Specified(value)) = self.client_max_window_bits { value @@ -234,33 +238,76 @@ impl DeflateSessionParameters { (self.server_no_context_takeover, server_max_window_bits) }; - DeflateContext { - compress: DeflateCompressionContext::new( + ( + DeflateCompressionContext::new( compression_level, remote_no_context_takeover, remote_max_window_bits, ), - decompress: DeflateDecompressionContext::new( - local_no_context_takeover, - local_max_window_bits, - ), - } + DeflateDecompressionContext::new(local_no_context_takeover, local_max_window_bits), + ) } } /// Server-side DEFLATE configuration. #[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct DeflateServerConfig { - /// DEFLATE compression level. See [`flate2::`] + /// DEFLATE compression level. See [`flate2::Compression`] for details. pub compression_level: Option, - + /// Disallow server from take over context. Default is false. pub server_no_context_takeover: bool, + /// Disallow client from take over context. Default is false. pub client_no_context_takeover: bool, + /// Maximum size of server's DEFLATE sliding window in bits, between 9 and 15. Default is 15. pub server_max_window_bits: Option, + /// Maximum size of client's DEFLATE sliding window in bits, between 9 and 15. Default is 15. pub client_max_window_bits: Option, } impl DeflateServerConfig { + /// Negotiate context parameters. + /// Since parameters from the client may be incompatible with the server configuration, + /// actual parameters could be adjusted here. Conversion rules are as follows: + /// + /// ## server_no_context_takeover + /// + /// | Config | Request | Response | + /// | ------ | ------- | --------- | + /// | false | false | false | + /// | false | true | true | + /// | true | false | true | + /// | true | true | true | + /// + /// ## client_no_context_takeover + /// + /// | Config | Request | Response | + /// | ------ | ------- | --------- | + /// | false | false | false | + /// | false | true | true | + /// | true | false | true | + /// | true | true | true | + /// + /// ## server_max_window_bits + /// + /// | Config | Request | Response | + /// | ------------ | ------------ | -------- | + /// | None | None | None | + /// | None | 9 <= R <= 15 | R | + /// | 9 <= C <= 15 | None | C | + /// | 9 <= C <= 15 | 9 <= R <= C | R | + /// | 9 <= C <= 15 | C <= R <= 15 | C | + /// + /// ## client_max_window_bits + /// + /// | Config | Request | Response | + /// | ------------ | ------------ | -------- | + /// | None | None | None | + /// | None | Unspecified | None | + /// | None | 9 <= R <= 15 | R | + /// | 9 <= C <= 15 | None | None | + /// | 9 <= C <= 15 | Unspecified | C | + /// | 9 <= C <= 15 | 9 <= R <= C | R | + /// | 9 <= C <= 15 | C <= R <= 15 | C | pub fn negotiate(&self, params: DeflateSessionParameters) -> DeflateSessionParameters { let server_no_context_takeover = if self.server_no_context_takeover && !params.server_no_context_takeover { @@ -313,6 +360,7 @@ impl DeflateServerConfig { } } +/// DEFLATE decompression context. #[derive(Debug)] pub struct DeflateDecompressionContext { pub(super) local_no_context_takeover: bool, @@ -347,7 +395,7 @@ impl DeflateDecompressionContext { *self = Self::new(local_no_context_takeover, local_max_window_bits); } - pub fn decompress( + pub(super) fn decompress( &mut self, fin: bool, opcode: OpCode, @@ -418,13 +466,14 @@ impl DeflateDecompressionContext { Ok(output.into()) } - pub(super) fn reset(&mut self) { + fn reset(&mut self) { self.decompress.reset(false); self.total_bytes_read = 0; self.total_bytes_written = 0; } } +/// DEFLATE compression context. #[derive(Debug)] pub struct DeflateCompressionContext { pub(super) compression_level: flate2::Compression, @@ -474,7 +523,7 @@ impl DeflateCompressionContext { self } - pub fn compress(&mut self, fin: bool, payload: Bytes) -> Result { + pub(super) fn compress(&mut self, fin: bool, payload: Bytes) -> Result { let mut output = vec![]; let mut buf = [0u8; BUF_SIZE]; @@ -526,8 +575,271 @@ impl DeflateCompressionContext { } } -#[derive(Debug)] -pub struct DeflateContext { - pub compress: DeflateCompressionContext, - pub decompress: DeflateDecompressionContext, +#[cfg(test)] +mod tests { + use crate::body::MessageBody; + + use super::*; + + #[test] + fn test_session_parameters() { + let extension = "abc, def, permessage-deflate"; + assert_eq!( + DeflateSessionParameters::from_extension_header(&extension), + vec![Ok(DeflateSessionParameters::default())] + ); + + let extension = "permessage-deflate; unknown_parameter"; + assert_eq!( + DeflateSessionParameters::from_extension_header(&extension), + vec![Err(DeflateHandshakeError::UnknownWebSocketParameters)] + ); + + let extension = "permessage-deflate; client_max_window_bits=9; client_max_window_bits=10"; + assert_eq!( + DeflateSessionParameters::from_extension_header(&extension), + vec![Err(DeflateHandshakeError::DuplicateParameter( + "client_max_window_bits" + ))] + ); + + let extension = "permessage-deflate; server_max_window_bits=8"; + assert_eq!( + DeflateSessionParameters::from_extension_header(&extension), + vec![Err(DeflateHandshakeError::MaxWindowBitsOutOfRange)] + ); + + let extension = "permessage-deflate; server_max_window_bits=16"; + assert_eq!( + DeflateSessionParameters::from_extension_header(&extension), + vec![Err(DeflateHandshakeError::MaxWindowBitsOutOfRange)] + ); + + let extension = "permessage-deflate; client_max_window_bits; server_max_window_bits=15; \ + client_no_context_takeover; server_no_context_takeover, \ + permessage-deflate; client_max_window_bits=10"; + assert_eq!( + DeflateSessionParameters::from_extension_header(&extension), + vec![ + Ok(DeflateSessionParameters { + server_no_context_takeover: true, + client_no_context_takeover: true, + server_max_window_bits: Some(15), + client_max_window_bits: Some(ClientMaxWindowBits::NotSpecified) + }), + Ok(DeflateSessionParameters { + server_no_context_takeover: false, + client_no_context_takeover: false, + server_max_window_bits: None, + client_max_window_bits: Some(ClientMaxWindowBits::Specified(10)) + }) + ] + ); + } + + #[test] + fn test_compress() { + // With context takeover + + let mut compress = DeflateCompressionContext::new(None, false, 15); + assert_eq!( + compress + .compress(true, "Hello World".try_into_bytes().unwrap()) + .unwrap(), + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ); + assert_eq!( + compress + .compress(true, "Hello World".try_into_bytes().unwrap()) + .unwrap(), + Bytes::from_static(b"\xf2@0\x01\0") + ); + + // Without context takeover + + let mut compress = DeflateCompressionContext::new(None, true, 15); + assert_eq!( + compress + .compress(true, "Hello World".try_into_bytes().unwrap()) + .unwrap(), + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ); + assert_eq!( + compress + .compress(true, "Hello World".try_into_bytes().unwrap()) + .unwrap(), + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ); + + // With continuation + assert_eq!( + compress + .compress(false, "Hello World".try_into_bytes().unwrap()) + .unwrap(), + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ); + // Continuation keeps context. + assert_eq!( + compress + .compress(true, "Hello World".try_into_bytes().unwrap()) + .unwrap(), + Bytes::from_static(b"\xf2@0\x01\0") + ); + // after continuation, context resets + assert_eq!( + compress + .compress(true, "Hello World".try_into_bytes().unwrap()) + .unwrap(), + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ); + } + + #[test] + fn test_decompress() { + // With context takeover + + let mut decompress = DeflateDecompressionContext::new(false, 15); + + // Without RSV1 bit, decompression does not happen. + assert_eq!( + decompress + .decompress( + true, + OpCode::Text, + RsvBits::empty(), + Bytes::from_static(b"Hello World") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + + // Control frames (such as ping/pong) are not decompressed + assert_eq!( + decompress + .decompress( + true, + OpCode::Ping, + RsvBits::RSV1, + Bytes::from_static(b"Hello World") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + + // Successful decompression + assert_eq!( + decompress + .decompress( + true, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + + // Success subsequent decompression + assert_eq!( + decompress + .decompress( + true, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2@0\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + + // Invalid compression payload + assert!(decompress + .decompress( + true, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"Hello World") + ) + .is_err()); + + // When there was error, context is reset. + assert_eq!( + decompress + .decompress( + true, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + + // Without context takeover + + let mut decompress = DeflateDecompressionContext::new(true, 15); + + // Successful decompression + assert_eq!( + decompress + .decompress( + true, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + + // Context has been reset. + assert_eq!( + decompress + .decompress( + true, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + + // With continuation + assert_eq!( + decompress + .decompress( + false, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + // Continuation keeps context. + assert_eq!( + decompress + .decompress( + true, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2@0\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + // When continuation has finished, context is reset. + assert_eq!( + decompress + .decompress( + false, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + } } diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index a5521e500bf..f08012573b6 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -21,7 +21,7 @@ mod proto; #[cfg(feature = "compress-ws-deflate")] pub use self::deflate::{DeflateCompressionLevel, DeflateServerConfig, DeflateSessionParameters}; pub use self::{ - codec::{Codec, Frame, Item, Message}, + codec::{Codec, Decoder, Encoder, Frame, Item, Message}, dispatcher::Dispatcher, frame::Parser, proto::{hash_key, CloseCode, CloseReason, OpCode, RsvBits}, @@ -169,10 +169,19 @@ pub fn handshake(req: &RequestHead) -> Result { /// Verify WebSocket handshake request with DEFLATE compression configurations. #[cfg(feature = "compress-ws-deflate")] -pub fn handshake_with_deflate( - req: &RequestHead, +pub fn handshake_deflate( config: &deflate::DeflateServerConfig, -) -> Result<(ResponseBuilder, Option), HandshakeError> { + req: &RequestHead, +) -> Result< + ( + ResponseBuilder, + Option<( + deflate::DeflateCompressionContext, + deflate::DeflateDecompressionContext, + )>, + ), + HandshakeError, +> { verify_handshake(req)?; let mut available_configurations = vec![]; @@ -212,9 +221,9 @@ pub fn handshake_with_deflate( if let Some(selected_config) = selected_config { let param = config.negotiate(selected_config); - let context = param.create_context(config.compression_level, false); + let contexts = param.create_context(config.compression_level, false); response.insert_header(param); - Ok((response, Some(context))) + Ok((response, Some(contexts))) } else { Ok((response, None)) } diff --git a/actix-http/src/ws/proto.rs b/actix-http/src/ws/proto.rs index 6941f5828cd..1bdfcf8f7f6 100644 --- a/actix-http/src/ws/proto.rs +++ b/actix-http/src/ws/proto.rs @@ -226,7 +226,7 @@ bitflags::bitflags! { /// RSV bits defined in [RFC 6455 §5.2]. /// Reserved for extensions and should be set to zero if no extensions are applicable. /// - /// [RFC 6455]: https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 + /// [RFC 6455 §5.2]: https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 #[derive(Debug, Eq, PartialEq, Clone, Copy)] pub struct RsvBits: u8 { const RSV1 = 0b0000_0100; From 8edbcd03d3a6f39af30e3885d2410a98f696e759 Mon Sep 17 00:00:00 2001 From: Park Joon-Kyu Date: Wed, 6 Nov 2024 22:05:45 +0900 Subject: [PATCH 5/8] revert autoformatting --- actix-http/CHANGES.md | 1 + actix-http/Cargo.toml | 29 +++++++++++++++-------------- awc/Cargo.toml | 22 ++++++---------------- 3 files changed, 22 insertions(+), 30 deletions(-) diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index 982add26a23..3d924a241cb 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -5,6 +5,7 @@ ### Added - Add `header::CLEAR_SITE_DATA` constant. +- Add DEFLATE compression support for WebSocket. ### Changed diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index e94c6745ebd..704f06c1e06 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -1,7 +1,10 @@ [package] name = "actix-http" version = "3.9.0" -authors = ["Nikolay Kim ", "Rob Ede "] +authors = [ + "Nikolay Kim ", + "Rob Ede ", +] description = "HTTP types and services for the Actix ecosystem" keywords = ["actix", "http", "framework", "async", "futures"] homepage = "https://actix.rs" @@ -60,7 +63,12 @@ default = [] http2 = ["dep:h2"] # WebSocket protocol implementation -ws = ["dep:local-channel", "dep:base64", "dep:rand", "dep:sha1"] +ws = [ + "dep:local-channel", + "dep:base64", + "dep:rand", + "dep:sha1", +] # TLS via OpenSSL openssl = ["__tls", "actix-tls/accept", "actix-tls/openssl"] @@ -82,8 +90,8 @@ rustls-0_23 = ["__tls", "actix-tls/accept", "actix-tls/rustls-0_23"] # Compression codecs compress-brotli = ["__compress", "dep:brotli"] -compress-gzip = ["__compress", "dep:flate2"] -compress-zstd = ["__compress", "dep:zstd"] +compress-gzip = ["__compress", "dep:flate2"] +compress-zstd = ["__compress", "dep:zstd"] compress-ws-deflate = ["dep:flate2", "flate2/zlib-default"] # Internal (PRIVATE!) features used to aid testing and checking feature status. @@ -106,9 +114,7 @@ bytes = "1" bytestring = "1" derive_more = { version = "1", features = ["as_ref", "deref", "deref_mut", "display", "error", "from"] } encoding_rs = "0.8" -futures-core = { version = "0.3.17", default-features = false, features = [ - "alloc", -] } +futures-core = { version = "0.3.17", default-features = false, features = ["alloc"] } http = "0.2.7" httparse = "1.5.1" httpdate = "1.0.1" @@ -142,19 +148,14 @@ zstd = { version = "0.13", optional = true } [dev-dependencies] actix-http-test = { version = "3", features = ["openssl"] } actix-server = "2" -actix-tls = { version = "3.4", features = [ - "openssl", - "rustls-0_23-webpki-roots", -] } +actix-tls = { version = "3.4", features = ["openssl", "rustls-0_23-webpki-roots"] } actix-web = "4" async-stream = "0.3" criterion = { version = "0.5", features = ["html_reports"] } divan = "0.1.8" env_logger = "0.11" -futures-util = { version = "0.3.17", default-features = false, features = [ - "alloc", -] } +futures-util = { version = "0.3.17", default-features = false, features = ["alloc"] } memchr = "2.4" once_cell = "1.9" rcgen = "0.13" diff --git a/awc/Cargo.toml b/awc/Cargo.toml index 95db5e761b1..4ef3c52397d 100644 --- a/awc/Cargo.toml +++ b/awc/Cargo.toml @@ -63,15 +63,9 @@ rustls-0_20 = ["tls-rustls-0_20", "actix-tls/rustls-0_20"] # TLS via Rustls v0.21 rustls-0_21 = ["tls-rustls-0_21", "actix-tls/rustls-0_21"] # TLS via Rustls v0.22 (WebPKI roots) -rustls-0_22-webpki-roots = [ - "tls-rustls-0_22", - "actix-tls/rustls-0_22-webpki-roots", -] +rustls-0_22-webpki-roots = ["tls-rustls-0_22", "actix-tls/rustls-0_22-webpki-roots"] # TLS via Rustls v0.22 (Native roots) -rustls-0_22-native-roots = [ - "tls-rustls-0_22", - "actix-tls/rustls-0_22-native-roots", -] +rustls-0_22-native-roots = ["tls-rustls-0_22", "actix-tls/rustls-0_22-native-roots"] # TLS via Rustls v0.23 rustls-0_23 = ["tls-rustls-0_23", "actix-tls/rustls-0_23"] # TLS via Rustls v0.23 (WebPKI roots) @@ -120,7 +114,7 @@ futures-util = { version = "0.3.17", default-features = false, features = ["allo h2 = "0.3.26" http = "0.2.7" itoa = "1" -log = " 0.4" +log = "0.4" mime = "0.3" percent-encoding = "2.1" pin-project-lite = "0.2" @@ -133,12 +127,8 @@ tokio = { version = "1.24.2", features = ["sync"] } cookie = { version = "0.16", features = ["percent-encode"], optional = true } tls-openssl = { package = "openssl", version = "0.10.55", optional = true } -tls-rustls-0_20 = { package = "rustls", version = "0.20", optional = true, features = [ - "dangerous_configuration", -] } -tls-rustls-0_21 = { package = "rustls", version = "0.21", optional = true, features = [ - "dangerous_configuration", -] } +tls-rustls-0_20 = { package = "rustls", version = "0.20", optional = true, features = ["dangerous_configuration"] } +tls-rustls-0_21 = { package = "rustls", version = "0.21", optional = true, features = ["dangerous_configuration"] } tls-rustls-0_22 = { package = "rustls", version = "0.22", optional = true } tls-rustls-0_23 = { package = "rustls", version = "0.23", optional = true, default-features = false } @@ -163,7 +153,7 @@ rcgen = "0.13" rustls-pemfile = "2" tokio = { version = "1.24.2", features = ["rt-multi-thread", "macros"] } zstd = "0.13" -tls-rustls-0_23 = { package = "rustls", version = "0.23" } # add rustls 0.23 with default features to make aws_lc_rs work in tests +tls-rustls-0_23 = { package = "rustls", version = "0.23" } # add rustls 0.23 with default features to make aws_lc_rs work in tests [lints] workspace = true From cc54b2b89d5f707ceae108cd196b1127747d09d3 Mon Sep 17 00:00:00 2001 From: Park Joon-Kyu Date: Wed, 6 Nov 2024 22:39:10 +0900 Subject: [PATCH 6/8] fix fmt, compile failures --- actix-http/src/ws/deflate.rs | 3 +-- awc/src/ws.rs | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/actix-http/src/ws/deflate.rs b/actix-http/src/ws/deflate.rs index dc02f46ef1f..d5528d9880d 100644 --- a/actix-http/src/ws/deflate.rs +++ b/actix-http/src/ws/deflate.rs @@ -577,9 +577,8 @@ impl DeflateCompressionContext { #[cfg(test)] mod tests { - use crate::body::MessageBody; - use super::*; + use crate::body::MessageBody; #[test] fn test_session_parameters() { diff --git a/awc/src/ws.rs b/awc/src/ws.rs index 376dccaa9bb..507bcd3c484 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -450,8 +450,9 @@ impl WebsocketsRequest { framed.into_map_codec(move |_| { let codec = if let Some(parameter) = selected_parameter.clone() { - let context = parameter.create_context(self.deflate_compression_level, false); - Codec::new_deflate(context) + let (compress, decompress) = + parameter.create_context(self.deflate_compression_level, false); + Codec::new_deflate(compress, decompress) } else { Codec::new() } From e3f3c055dbc74d881dd1020d10847da63bd2c739 Mon Sep 17 00:00:00 2001 From: Park Joon-Kyu Date: Wed, 6 Nov 2024 22:43:45 +0900 Subject: [PATCH 7/8] please clippy --- actix-http/src/ws/deflate.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/actix-http/src/ws/deflate.rs b/actix-http/src/ws/deflate.rs index d5528d9880d..a28ae8fb4dd 100644 --- a/actix-http/src/ws/deflate.rs +++ b/actix-http/src/ws/deflate.rs @@ -584,19 +584,19 @@ mod tests { fn test_session_parameters() { let extension = "abc, def, permessage-deflate"; assert_eq!( - DeflateSessionParameters::from_extension_header(&extension), + DeflateSessionParameters::from_extension_header(extension), vec![Ok(DeflateSessionParameters::default())] ); let extension = "permessage-deflate; unknown_parameter"; assert_eq!( - DeflateSessionParameters::from_extension_header(&extension), + DeflateSessionParameters::from_extension_header(extension), vec![Err(DeflateHandshakeError::UnknownWebSocketParameters)] ); let extension = "permessage-deflate; client_max_window_bits=9; client_max_window_bits=10"; assert_eq!( - DeflateSessionParameters::from_extension_header(&extension), + DeflateSessionParameters::from_extension_header(extension), vec![Err(DeflateHandshakeError::DuplicateParameter( "client_max_window_bits" ))] @@ -604,13 +604,13 @@ mod tests { let extension = "permessage-deflate; server_max_window_bits=8"; assert_eq!( - DeflateSessionParameters::from_extension_header(&extension), + DeflateSessionParameters::from_extension_header(extension), vec![Err(DeflateHandshakeError::MaxWindowBitsOutOfRange)] ); let extension = "permessage-deflate; server_max_window_bits=16"; assert_eq!( - DeflateSessionParameters::from_extension_header(&extension), + DeflateSessionParameters::from_extension_header(extension), vec![Err(DeflateHandshakeError::MaxWindowBitsOutOfRange)] ); @@ -618,7 +618,7 @@ mod tests { client_no_context_takeover; server_no_context_takeover, \ permessage-deflate; client_max_window_bits=10"; assert_eq!( - DeflateSessionParameters::from_extension_header(&extension), + DeflateSessionParameters::from_extension_header(extension), vec![ Ok(DeflateSessionParameters { server_no_context_takeover: true, From 5cadd2cec7bf4aad434faa0f0d3c137b36ee4898 Mon Sep 17 00:00:00 2001 From: Park Joon-Kyu Date: Fri, 8 Nov 2024 13:23:07 +0900 Subject: [PATCH 8/8] Doc cleanup --- actix-http/src/ws/deflate.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/actix-http/src/ws/deflate.rs b/actix-http/src/ws/deflate.rs index a28ae8fb4dd..76bfebb350b 100644 --- a/actix-http/src/ws/deflate.rs +++ b/actix-http/src/ws/deflate.rs @@ -9,7 +9,7 @@ use super::{OpCode, ProtocolError, RsvBits}; use crate::header::{HeaderName, HeaderValue, TryIntoHeaderPair, SEC_WEBSOCKET_EXTENSIONS}; // NOTE: according to [RFC 7692 §7.1.2.1] window bit size should be within 8..=15 -// but we have to limit the range to 9..=15 because [flate2] only supports window bit within 9..=15. +// but we have to limit the range to 9..=15 because [flate2] only supports window bit within 9..=15. // // [RFC 6792 §7.1.2.1]: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1.2.1 // [flate2]: https://docs.rs/flate2/latest/flate2/struct.Compress.html#method.new_with_window_bits @@ -64,13 +64,15 @@ impl std::error::Error for DeflateHandshakeError {} /// Maximum size of client's DEFLATE sliding window. #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum ClientMaxWindowBits { - /// Unspecified. Indicates server should decide its size. + /// Unspecified. Indicates client will follow server configuration. NotSpecified, /// Specified size of client's DEFLATE sliding window size in bits, between 9 and 15. Specified(u8), } -/// DEFLATE negotiation parameter. It can be used both client and server side. +/// Per-session DEFLATE configuration parameter. +/// +/// It can be used both client and server side. /// At client side, it can be used to pass desired configuration to server. /// At server side, negotiated parameter will be sent to client with this. /// This can be represented in HTTP header form as it implements [`TryIntoHeaderPair`] trait.