Skip to content

Commit

Permalink
Remove custom Headers type and use http::HeaderMap instead
Browse files Browse the repository at this point in the history
  • Loading branch information
sdroege committed Nov 23, 2019
1 parent bb80143 commit 9ae7a16
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 136 deletions.
46 changes: 43 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub enum Error {
/// connection when it really shouldn't anymore, so this really indicates a programmer
/// error on your part.
AlreadyClosed,
/// Input-output error. Appart from WouldBlock, these are generally errors with the
/// Input-output error. Apart from WouldBlock, these are generally errors with the
/// underlying connection and you should probably consider them fatal.
Io(io::Error),
#[cfg(feature = "tls")]
Expand All @@ -61,10 +61,12 @@ pub enum Error {
SendQueueFull(Message),
/// UTF coding error
Utf8,
/// Invlid URL.
/// Invalid URL.
Url(Cow<'static, str>),
/// HTTP error.
Http(u16),
/// HTTP format error.
HttpFormat(http::Error),
}

impl fmt::Display for Error {
Expand All @@ -80,7 +82,8 @@ impl fmt::Display for Error {
Error::SendQueueFull(_) => write!(f, "Send queue is full"),
Error::Utf8 => write!(f, "UTF-8 encoding error"),
Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(code) => write!(f, "HTTP code: {}", code),
Error::Http(code) => write!(f, "HTTP error: {}", code),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
}
}
}
Expand All @@ -99,6 +102,7 @@ impl ErrorTrait for Error {
Error::Utf8 => "",
Error::Url(ref msg) => msg.borrow(),
Error::Http(_) => "",
Error::HttpFormat(ref err) => err.description(),
}
}
}
Expand All @@ -121,6 +125,42 @@ impl From<string::FromUtf8Error> for Error {
}
}

impl From<http::header::InvalidHeaderValue> for Error {
fn from(err: http::header::InvalidHeaderValue) -> Self {
Error::HttpFormat(err.into())
}
}

impl From<http::header::InvalidHeaderName> for Error {
fn from(err: http::header::InvalidHeaderName) -> Self {
Error::HttpFormat(err.into())
}
}

impl From<http::header::ToStrError> for Error {
fn from(_: http::header::ToStrError) -> Self {
Error::Utf8
}
}

impl From<http::uri::InvalidUri> for Error {
fn from(err: http::uri::InvalidUri) -> Self {
Error::HttpFormat(err.into())
}
}

impl From<http::status::InvalidStatusCode> for Error {
fn from(err: http::status::InvalidStatusCode) -> Self {
Error::HttpFormat(err.into())
}
}

impl From<http::Error> for Error {
fn from(err: http::Error) -> Self {
Error::HttpFormat(err)
}
}

#[cfg(feature = "tls")]
impl From<tls::Error> for Error {
fn from(err: tls::Error) -> Self {
Expand Down
26 changes: 16 additions & 10 deletions src/handshake/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ use std::borrow::Cow;
use std::io::{Read, Write};
use std::marker::PhantomData;

use http::HeaderMap;
use httparse::Status;
use log::*;
use url::Url;

use super::headers::{FromHttparse, Headers, MAX_HEADERS};
use super::headers::{FromHttparse, MAX_HEADERS};
use super::machine::{HandshakeMachine, StageResult, TryParse};
use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
Expand Down Expand Up @@ -171,7 +172,10 @@ impl VerifyData {
// _Fail the WebSocket Connection_. (RFC 6455)
if !response
.headers
.header_is_ignore_case("Upgrade", "websocket")
.get("Upgrade")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Upgrade: websocket\" in server reply".into(),
Expand All @@ -183,7 +187,10 @@ impl VerifyData {
// MUST _Fail the WebSocket Connection_. (RFC 6455)
if !response
.headers
.header_is_ignore_case("Connection", "Upgrade")
.get("Connection")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err(Error::Protocol(
"No \"Connection: upgrade\" in server reply".into(),
Expand All @@ -195,7 +202,9 @@ impl VerifyData {
// Connection_. (RFC 6455)
if !response
.headers
.header_is("Sec-WebSocket-Accept", &self.accept_key)
.get("Sec-WebSocket-Accept")
.map(|h| h == &self.accept_key)
.unwrap_or(false)
{
return Err(Error::Protocol(
"Key mismatch in Sec-WebSocket-Accept".into(),
Expand Down Expand Up @@ -225,7 +234,7 @@ pub struct Response {
/// HTTP response code of the response.
pub code: u16,
/// Received headers.
pub headers: Headers,
pub headers: HeaderMap,
}

impl TryParse for Response {
Expand All @@ -248,7 +257,7 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Response<'h, 'b>> for Response {
}
Ok(Response {
code: raw.code.expect("Bug: no HTTP response code"),
headers: Headers::from_httparse(raw.headers)?,
headers: HeaderMap::from_httparse(raw.headers)?,
})
}
}
Expand Down Expand Up @@ -287,9 +296,6 @@ mod tests {
const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n";
let (_, resp) = Response::try_parse(DATA).unwrap().unwrap();
assert_eq!(resp.code, 200);
assert_eq!(
resp.headers.find_first("Content-Type"),
Some(&b"text/html"[..])
);
assert_eq!(resp.headers.get("Content-Type").unwrap(), &b"text/html"[..],);
}
}
118 changes: 27 additions & 91 deletions src/handshake/headers.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
//! HTTP Request and response header handling.
use std::slice;
use std::str::from_utf8;

use http;
use http::header::{HeaderMap, HeaderName, HeaderValue};
use httparse;
use httparse::Status;

Expand All @@ -12,90 +11,31 @@ use crate::error::Result;
/// Limit for the number of header lines.
pub const MAX_HEADERS: usize = 124;

/// HTTP request or response headers.
#[derive(Debug)]
pub struct Headers {
data: Vec<(String, Box<[u8]>)>,
/// Trait to convert raw objects into HTTP parseables.
pub(crate) trait FromHttparse<T>: Sized {
/// Convert raw object into parsed HTTP headers.
fn from_httparse(raw: T) -> Result<Self>;
}

impl Headers {
/// Get first header with the given name, if any.
pub fn find_first(&self, name: &str) -> Option<&[u8]> {
self.find(name).next()
}

/// Iterate over all headers with the given name.
pub fn find<'headers, 'name>(&'headers self, name: &'name str) -> HeadersIter<'name, 'headers> {
HeadersIter {
name,
iter: self.data.iter(),
impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for HeaderMap {
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
let mut headers = HeaderMap::new();
for h in raw {
headers.append(
HeaderName::from_bytes(h.name.as_bytes())?,
HeaderValue::from_bytes(h.value)?,
);
}
}

/// Check if the given header has the given value.
pub fn header_is(&self, name: &str, value: &str) -> bool {
self.find_first(name)
.map(|v| v == value.as_bytes())
.unwrap_or(false)
}

/// Check if the given header has the given value (case-insensitive).
pub fn header_is_ignore_case(&self, name: &str, value: &str) -> bool {
self.find_first(name)
.ok_or(())
.and_then(|val_raw| from_utf8(val_raw).map_err(|_| ()))
.map(|val| val.eq_ignore_ascii_case(value))
.unwrap_or(false)
}

/// Allows to iterate over available headers.
pub fn iter(&self) -> slice::Iter<(String, Box<[u8]>)> {
self.data.iter()
}
}

/// The iterator over headers.
#[derive(Debug)]
pub struct HeadersIter<'name, 'headers> {
name: &'name str,
iter: slice::Iter<'headers, (String, Box<[u8]>)>,
}

impl<'name, 'headers> Iterator for HeadersIter<'name, 'headers> {
type Item = &'headers [u8];
fn next(&mut self) -> Option<Self::Item> {
while let Some(&(ref name, ref value)) = self.iter.next() {
if name.eq_ignore_ascii_case(self.name) {
return Some(value);
}
}
None
Ok(headers)
}
}

/// Trait to convert raw objects into HTTP parseables.
pub trait FromHttparse<T>: Sized {
/// Convert raw object into parsed HTTP headers.
fn from_httparse(raw: T) -> Result<Self>;
}

impl TryParse for Headers {
impl TryParse for HeaderMap {
fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
Ok(match httparse::parse_headers(buf, &mut hbuffer)? {
Status::Partial => None,
Status::Complete((size, hdr)) => Some((size, Headers::from_httparse(hdr)?)),
})
}
}

impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
fn from_httparse(raw: &'b [httparse::Header<'h>]) -> Result<Self> {
Ok(Headers {
data: raw
.iter()
.map(|h| (h.name.into(), Vec::from(h.value).into_boxed_slice()))
.collect(),
Status::Complete((size, hdr)) => Some((size, HeaderMap::from_httparse(hdr)?)),
})
}
}
Expand All @@ -104,22 +44,18 @@ impl<'b: 'h, 'h> FromHttparse<&'b [httparse::Header<'h>]> for Headers {
mod tests {

use super::super::machine::TryParse;
use super::Headers;
use super::HeaderMap;

#[test]
fn headers() {
const DATA: &'static [u8] = b"Host: foo.com\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
\r\n";
let (_, hdr) = Headers::try_parse(DATA).unwrap().unwrap();
assert_eq!(hdr.find_first("Host"), Some(&b"foo.com"[..]));
assert_eq!(hdr.find_first("Upgrade"), Some(&b"websocket"[..]));
assert_eq!(hdr.find_first("Connection"), Some(&b"Upgrade"[..]));

assert!(hdr.header_is("upgrade", "websocket"));
assert!(!hdr.header_is("upgrade", "Websocket"));
assert!(hdr.header_is_ignore_case("upgrade", "Websocket"));
let (_, hdr) = HeaderMap::try_parse(DATA).unwrap().unwrap();
assert_eq!(hdr.get("Host").unwrap(), &b"foo.com"[..]);
assert_eq!(hdr.get("Upgrade").unwrap(), &b"websocket"[..]);
assert_eq!(hdr.get("Connection").unwrap(), &b"Upgrade"[..]);
}

#[test]
Expand All @@ -130,10 +66,10 @@ mod tests {
Sec-WebSocket-ExtenSIONS: permessage-unknown\r\n\
Upgrade: websocket\r\n\
\r\n";
let (_, hdr) = Headers::try_parse(DATA).unwrap().unwrap();
let mut iter = hdr.find("Sec-WebSocket-Extensions");
assert_eq!(iter.next(), Some(&b"permessage-deflate"[..]));
assert_eq!(iter.next(), Some(&b"permessage-unknown"[..]));
let (_, hdr) = HeaderMap::try_parse(DATA).unwrap().unwrap();
let mut iter = hdr.get_all("Sec-WebSocket-Extensions").iter();
assert_eq!(iter.next().unwrap(), &b"permessage-deflate"[..]);
assert_eq!(iter.next().unwrap(), &b"permessage-unknown"[..]);
assert_eq!(iter.next(), None);
}

Expand All @@ -142,7 +78,7 @@ mod tests {
const DATA: &'static [u8] = b"Host: foo.com\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n";
let hdr = Headers::try_parse(DATA).unwrap();
let hdr = HeaderMap::try_parse(DATA).unwrap();
assert!(hdr.is_none());
}
}
Loading

0 comments on commit 9ae7a16

Please sign in to comment.