diff --git a/src/error.rs b/src/error.rs index 8629efcc..49233945 100644 --- a/src/error.rs +++ b/src/error.rs @@ -121,6 +121,18 @@ impl From for Error { } } +impl From for Error { + fn from(_: http::header::InvalidHeaderValue) -> Self { + Error::Utf8 + } +} + +impl From for Error { + fn from(_: http::header::InvalidHeaderName) -> Self { + Error::Utf8 + } +} + #[cfg(feature = "tls")] impl From for Error { fn from(err: tls::Error) -> Self { diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 7e23af43..8aec86b1 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -4,11 +4,12 @@ use std::borrow::Cow; use std::io::{Read, Write}; use std::marker::PhantomData; +use http::HeaderMap; use httparse::Status; use log::*; use url::Url; -use super::headers::{FromHttparse, Headers, MAX_HEADERS}; +use super::headers::{FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; @@ -171,7 +172,10 @@ impl VerifyData { // _Fail the WebSocket Connection_. (RFC 6455) if !response .headers - .header_is_ignore_case("Upgrade", "websocket") + .get("Upgrade") + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("websocket")) + .unwrap_or(false) { return Err(Error::Protocol( "No \"Upgrade: websocket\" in server reply".into(), @@ -183,7 +187,10 @@ impl VerifyData { // MUST _Fail the WebSocket Connection_. (RFC 6455) if !response .headers - .header_is_ignore_case("Connection", "Upgrade") + .get("Connection") + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("Upgrade")) + .unwrap_or(false) { return Err(Error::Protocol( "No \"Connection: upgrade\" in server reply".into(), @@ -195,7 +202,9 @@ impl VerifyData { // Connection_. (RFC 6455) if !response .headers - .header_is("Sec-WebSocket-Accept", &self.accept_key) + .get("Sec-WebSocket-Accept") + .map(|h| h == &self.accept_key) + .unwrap_or(false) { return Err(Error::Protocol( "Key mismatch in Sec-WebSocket-Accept".into(), @@ -225,7 +234,7 @@ pub struct Response { /// HTTP response code of the response. pub code: u16, /// Received headers. - pub headers: Headers, + pub headers: HeaderMap, } impl TryParse for Response { @@ -248,7 +257,7 @@ impl<'h, 'b: 'h> FromHttparse> for Response { } Ok(Response { code: raw.code.expect("Bug: no HTTP response code"), - headers: Headers::from_httparse(raw.headers)?, + headers: HeaderMap::from_httparse(raw.headers)?, }) } } @@ -287,9 +296,6 @@ mod tests { const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); assert_eq!(resp.code, 200); - assert_eq!( - resp.headers.find_first("Content-Type"), - Some(&b"text/html"[..]) - ); + assert_eq!(resp.headers.get("Content-Type").unwrap(), &b"text/html"[..],); } } diff --git a/src/handshake/headers.rs b/src/handshake/headers.rs index 097b22bd..ba129540 100644 --- a/src/handshake/headers.rs +++ b/src/handshake/headers.rs @@ -1,8 +1,7 @@ //! HTTP Request and response header handling. -use std::slice; -use std::str::from_utf8; - +use http; +use http::header::{HeaderMap, HeaderName, HeaderValue}; use httparse; use httparse::Status; @@ -12,90 +11,31 @@ use crate::error::Result; /// Limit for the number of header lines. pub const MAX_HEADERS: usize = 124; -/// HTTP request or response headers. -#[derive(Debug)] -pub struct Headers { - data: Vec<(String, Box<[u8]>)>, +/// Trait to convert raw objects into HTTP parseables. +pub(crate) trait FromHttparse: Sized { + /// Convert raw object into parsed HTTP headers. + fn from_httparse(raw: T) -> Result; } -impl Headers { - /// Get first header with the given name, if any. - pub fn find_first(&self, name: &str) -> Option<&[u8]> { - self.find(name).next() - } - - /// Iterate over all headers with the given name. - pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> { - HeadersIter { - name, - iter: self.data.iter(), +impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for HeaderMap { + fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result { + let mut headers = HeaderMap::new(); + for h in raw { + headers.append( + HeaderName::from_bytes(h.name.as_bytes())?, + HeaderValue::from_bytes(h.value)?, + ); } - } - - /// Check if the given header has the given value. - pub fn header_is(&self, name: &str, value: &str) -> bool { - self.find_first(name) - .map(|v| v == value.as_bytes()) - .unwrap_or(false) - } - - /// Check if the given header has the given value (case-insensitive). - pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool { - self.find_first(name) - .ok_or(()) - .and_then(|val_raw| from_utf8(val_raw).map_err(|_| ())) - .map(|val| val.eq_ignore_ascii_case(value)) - .unwrap_or(false) - } - - /// Allows to iterate over available headers. - pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> { - self.data.iter() - } -} -/// The iterator over headers. -#[derive(Debug)] -pub struct HeadersIter<'name, 'headers> { - name: &'name str, - iter: slice::Iter<'headers, (String, Box<[u8]>)>, -} - -impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> { - type Item = &'headers [u8]; - fn next(&mut self) -> Option { - while let Some(&(ref name, ref value)) = self.iter.next() { - if name.eq_ignore_ascii_case(self.name) { - return Some(value); - } - } - None + Ok(headers) } } - -/// Trait to convert raw objects into HTTP parseables. -pub trait FromHttparse: Sized { - /// Convert raw object into parsed HTTP headers. - fn from_httparse(raw: T) -> Result; -} - -impl TryParse for Headers { +impl TryParse for HeaderMap { fn try_parse(buf: &[u8]) -> Result> { let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; Ok(match httparse::parse_headers(buf, &mut hbuffer)? { Status::Partial => None, - Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)), - }) - } -} - -impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { - fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result { - Ok(Headers { - data: raw - .iter() - .map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice())) - .collect(), + Status::Complete((size, hdr)) => Some((size, HeaderMap::from_httparse(hdr)?)), }) } } @@ -104,7 +44,7 @@ impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers { mod tests { use super::super::machine::TryParse; - use super::Headers; + use super::HeaderMap; #[test] fn headers() { @@ -112,14 +52,10 @@ mod tests { Connection: Upgrade\r\n\ Upgrade: websocket\r\n\ \r\n"; - let (_, hdr) = Headers::try_parse(DATA).unwrap().unwrap(); - assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..])); - assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..])); - assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..])); - - assert!(hdr.header_is("upgrade", "websocket")); - assert!(!hdr.header_is("upgrade", "Websocket")); - assert!(hdr.header_is_ignore_case("upgrade", "Websocket")); + let (_, hdr) = HeaderMap::try_parse(DATA).unwrap().unwrap(); + assert_eq!(hdr.get("Host").unwrap(), &b"foo.com"[..]); + assert_eq!(hdr.get("Upgrade").unwrap(), &b"websocket"[..]); + assert_eq!(hdr.get("Connection").unwrap(), &b"Upgrade"[..]); } #[test] @@ -130,10 +66,10 @@ mod tests { Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\ Upgrade: websocket\r\n\ \r\n"; - let (_, hdr) = Headers::try_parse(DATA).unwrap().unwrap(); - let mut iter = hdr.find("Sec-WebSocket-Extensions"); - assert_eq!(iter.next(), Some(&b"permessage-deflate"[..])); - assert_eq!(iter.next(), Some(&b"permessage-unknown"[..])); + let (_, hdr) = HeaderMap::try_parse(DATA).unwrap().unwrap(); + let mut iter = hdr.get_all("Sec-WebSocket-Extensions").iter(); + assert_eq!(iter.next().unwrap(), &b"permessage-deflate"[..]); + assert_eq!(iter.next().unwrap(), &b"permessage-unknown"[..]); assert_eq!(iter.next(), None); } @@ -142,7 +78,7 @@ mod tests { const DATA: &'static [u8] = b"Host: foo.com\r\n\ Connection: Upgrade\r\n\ Upgrade: websocket\r\n"; - let hdr = Headers::try_parse(DATA).unwrap(); + let hdr = HeaderMap::try_parse(DATA).unwrap(); assert!(hdr.is_none()); } } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index d3624c3a..6510e860 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -5,11 +5,11 @@ use std::io::{Read, Write}; use std::marker::PhantomData; use std::result::Result as StdResult; -use http::StatusCode; +use http::{HeaderMap, StatusCode}; use httparse::Status; use log::*; -use super::headers::{FromHttparse, Headers, MAX_HEADERS}; +use super::headers::{FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; @@ -21,15 +21,15 @@ pub struct Request { /// Path part of the URL. pub path: String, /// HTTP headers. - pub headers: Headers, + pub headers: HeaderMap, } impl Request { /// Reply to the response. - pub fn reply(&self, extra_headers: Option>) -> Result> { + pub fn reply(&self, extra_headers: Option) -> Result> { let key = self .headers - .find_first("Sec-WebSocket-Key") + .get("Sec-WebSocket-Key") .ok_or_else(|| Error::Protocol("Missing Sec-WebSocket-Key".into()))?; let mut reply = format!( "\ @@ -37,17 +37,21 @@ impl Request { Connection: Upgrade\r\n\ Upgrade: websocket\r\n\ Sec-WebSocket-Accept: {}\r\n", - convert_key(key)? + convert_key(key.as_bytes())? ); add_headers(&mut reply, extra_headers); Ok(reply.into()) } } -fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option) { +fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option) { if let Some(eh) = extra_headers { for (k, v) in eh { - writeln!(reply, "{}: {}\r", k, v).unwrap(); + if let Some(k) = k { + // FIXME unwrap, should use http types for serialization instead of working on + // strings everywhere here + writeln!(reply, "{}: {}\r", k, v.to_str().unwrap()).unwrap(); + } } } writeln!(reply, "\r").unwrap(); @@ -76,21 +80,18 @@ impl<'h, 'b: 'h> FromHttparse> for Request { } Ok(Request { path: raw.path.expect("Bug: no path in header").into(), - headers: Headers::from_httparse(raw.headers)?, + headers: HeaderMap::from_httparse(raw.headers)?, }) } } -/// Extra headers for responses. -pub type ExtraHeaders = Vec<(String, String)>; - /// An error response sent to the client. #[derive(Debug)] pub struct ErrorResponse { /// HTTP error code. pub error_code: StatusCode, /// Extra response headers, if any. - pub headers: Option, + pub headers: Option, /// Response body, if any. pub body: Option, } @@ -115,14 +116,14 @@ pub trait Callback: Sized { /// Called whenever the server read the request from the client and is ready to reply to it. /// May return additional reply headers. /// Returning an error resulting in rejecting the incoming connection. - fn on_request(self, request: &Request) -> StdResult, ErrorResponse>; + fn on_request(self, request: &Request) -> StdResult, ErrorResponse>; } impl Callback for F where - F: FnOnce(&Request) -> StdResult, ErrorResponse>, + F: FnOnce(&Request) -> StdResult, ErrorResponse>, { - fn on_request(self, request: &Request) -> StdResult, ErrorResponse> { + fn on_request(self, request: &Request) -> StdResult, ErrorResponse> { self(request) } } @@ -132,7 +133,7 @@ where pub struct NoCallback; impl Callback for NoCallback { - fn on_request(self, _request: &Request) -> StdResult, ErrorResponse> { + fn on_request(self, _request: &Request) -> StdResult, ErrorResponse> { Ok(None) } } @@ -241,14 +242,15 @@ impl HandshakeRole for ServerHandshake { mod tests { use super::super::client::Response; use super::super::machine::TryParse; - use super::Request; + use super::{HeaderMap, Request}; + use http::header::HeaderName; #[test] fn request_parsing() { const DATA: &'static [u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n"; let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); assert_eq!(req.path, "/script.ws"); - assert_eq!(req.headers.find_first("Host"), Some(&b"foo.com"[..])); + assert_eq!(req.headers.get("Host").unwrap(), &b"foo.com"[..]); } #[test] @@ -264,19 +266,25 @@ mod tests { let (_, req) = Request::try_parse(DATA).unwrap().unwrap(); let _ = req.reply(None).unwrap(); - let extra_headers = Some(vec![ - ( - String::from("MyCustomHeader"), - String::from("MyCustomValue"), - ), - (String::from("MyVersion"), String::from("LOL")), - ]); - let reply = req.reply(extra_headers).unwrap(); + let extra_headers = { + let mut headers = HeaderMap::new(); + headers.insert( + HeaderName::from_bytes(&b"MyCustomHeader"[..]).unwrap(), + "MyCustomValue".parse().unwrap(), + ); + headers.insert( + HeaderName::from_bytes(&b"MyVersion"[..]).unwrap(), + "LOL".parse().unwrap(), + ); + + headers + }; + let reply = req.reply(Some(extra_headers)).unwrap(); let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); assert_eq!( - req.headers.find_first("MyCustomHeader"), - Some(b"MyCustomValue".as_ref()) + req.headers.get("MyCustomHeader").unwrap(), + b"MyCustomValue".as_ref() ); - assert_eq!(req.headers.find_first("MyVersion"), Some(b"LOL".as_ref())); + assert_eq!(req.headers.get("MyVersion").unwrap(), b"LOL".as_ref()); } }