From 82ac126dcbe3ce62ee85d8e7b62ab4f171ac4666 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Dr=C3=B6ge?= Date: Fri, 22 Nov 2019 15:59:27 +0200 Subject: [PATCH] Remove custom Request/Response types from client code https://github.com/snapview/tungstenite-rs/issues/92 --- src/client.rs | 153 ++++++++++++++++++++++++++++------------ src/error.rs | 3 +- src/handshake/client.rs | 153 ++++++++++++++++------------------------ src/handshake/server.rs | 8 +-- 4 files changed, 173 insertions(+), 144 deletions(-) diff --git a/src/client.rs b/src/client.rs index e35a24d5..66b73fd8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,10 +4,11 @@ use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; use std::result::Result as StdResult; +use http::{Request, Response, Uri}; use log::*; + use url::Url; -use crate::handshake::client::Response; use crate::protocol::WebSocketConfig; #[cfg(feature = "tls")] @@ -64,7 +65,7 @@ use self::encryption::wrap_stream; pub use self::encryption::AutoStream; use crate::error::{Error, Result}; -use crate::handshake::client::{ClientHandshake, Request}; +use crate::handshake::client::ClientHandshake; use crate::handshake::HandshakeError; use crate::protocol::WebSocket; use crate::stream::{Mode, NoDelay}; @@ -84,37 +85,23 @@ use crate::stream::{Mode, NoDelay}; /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect_with_config<'t, Req: Into>>( +pub fn connect_with_config( request: Req, config: Option, -) -> Result<(WebSocket, Response)> { - let request: Request = request.into(); - let mode = url_mode(&request.url)?; +) -> Result<(WebSocket, Response<()>)> { + let request: Request<()> = request.into_client_request()?; + let uri = request.uri(); + let mode = uri_mode(uri)?; let host = request - .url + .uri() .host() .ok_or_else(|| Error::Url("No host name in the URL".into()))?; - let port = request - .url - .port_or_known_default() - .ok_or_else(|| Error::Url("No port number in the URL".into()))?; - let addrs; - let addr; - let addrs = match host { - url::Host::Domain(domain) => { - addrs = (domain, port).to_socket_addrs()?; - addrs.as_slice() - } - url::Host::Ipv4(ip) => { - addr = (ip, port).into(); - std::slice::from_ref(&addr) - } - url::Host::Ipv6(ip) => { - addr = (ip, port).into(); - std::slice::from_ref(&addr) - } - }; - let mut stream = connect_to_some(addrs, &request.url, mode)?; + let port = uri.port_u16().unwrap_or(match mode { + Mode::Plain => 80, + Mode::Tls => 443, + }); + let addrs = (host, port).to_socket_addrs()?; + let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?; NoDelay::set_nodelay(&mut stream, true)?; client_with_config(request, stream, config).map_err(|e| match e { HandshakeError::Failure(f) => f, @@ -134,35 +121,35 @@ pub fn connect_with_config<'t, Req: Into>>( /// This function uses `native_tls` to do TLS. If you want to use other TLS libraries, /// use `client` instead. There is no need to enable the "tls" feature if you don't call /// `connect` since it's the only function that uses native_tls. -pub fn connect<'t, Req: Into>>( +pub fn connect( request: Req, -) -> Result<(WebSocket, Response)> { +) -> Result<(WebSocket, Response<()>)> { connect_with_config(request, None) } -fn connect_to_some(addrs: &[SocketAddr], url: &Url, mode: Mode) -> Result { - let domain = url - .host_str() +fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result { + let domain = uri + .host() .ok_or_else(|| Error::Url("No host name in the URL".into()))?; for addr in addrs { - debug!("Trying to contact {} at {}...", url, addr); + debug!("Trying to contact {} at {}...", uri, addr); if let Ok(raw_stream) = TcpStream::connect(addr) { if let Ok(stream) = wrap_stream(raw_stream, domain, mode) { return Ok(stream); } } } - Err(Error::Url(format!("Unable to connect to {}", url).into())) + Err(Error::Url(format!("Unable to connect to {}", uri).into())) } /// Get the mode of the given URL. /// /// This function may be used to ease the creation of custom TLS streams /// in non-blocking algorithmss or for use with TLS libraries other than `native_tls`. -pub fn url_mode(url: &Url) -> Result { - match url.scheme() { - "ws" => Ok(Mode::Plain), - "wss" => Ok(Mode::Tls), +pub fn uri_mode(uri: &Uri) -> Result { + match uri.scheme_str() { + Some("ws") => Ok(Mode::Plain), + Some("wss") => Ok(Mode::Tls), _ => Err(Error::Url("URL scheme not supported".into())), } } @@ -173,16 +160,16 @@ pub fn url_mode(url: &Url) -> Result { /// Use this function if you need a nonblocking handshake support or if you /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// Any stream supporting `Read + Write` will do. -pub fn client_with_config<'t, Stream, Req>( +pub fn client_with_config( request: Req, stream: Stream, config: Option, -) -> StdResult<(WebSocket, Response), HandshakeError>> +) -> StdResult<(WebSocket, Response<()>), HandshakeError>> where Stream: Read + Write, - Req: Into>, + Req: IntoClientRequest, { - ClientHandshake::start(stream, request.into(), config).handshake() + ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake() } /// Do the client handshake over the given stream. @@ -190,13 +177,87 @@ where /// Use this function if you need a nonblocking handshake support or if you /// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`. /// Any stream supporting `Read + Write` will do. -pub fn client<'t, Stream, Req>( +pub fn client( request: Req, stream: Stream, -) -> StdResult<(WebSocket, Response), HandshakeError>> +) -> StdResult<(WebSocket, Response<()>), HandshakeError>> where Stream: Read + Write, - Req: Into>, + Req: IntoClientRequest, { client_with_config(request, stream, None) } + +/// Trait for converting various types into HTTP requests used for a client connection. +/// +/// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and +/// `http::Request<()>`. +pub trait IntoClientRequest { + /// Convert into a `Request<()>` that can be used for a client connection. + fn into_client_request(self) -> Result>; +} + +impl<'a> IntoClientRequest for &'a str { + fn into_client_request(self) -> Result> { + let uri: Uri = self.parse()?; + + Ok(Request::get(uri).body(())?) + } +} + +impl<'a> IntoClientRequest for &'a String { + fn into_client_request(self) -> Result> { + let uri: Uri = self.parse()?; + + Ok(Request::get(uri).body(())?) + } +} + +impl IntoClientRequest for String { + fn into_client_request(self) -> Result> { + let uri: Uri = self.parse()?; + + Ok(Request::get(uri).body(())?) + } +} + +impl<'a> IntoClientRequest for &'a Uri { + fn into_client_request(self) -> Result> { + Ok(Request::get(self.clone()).body(())?) + } +} + +impl IntoClientRequest for Uri { + fn into_client_request(self) -> Result> { + Ok(Request::get(self).body(())?) + } +} + +impl<'a> IntoClientRequest for &'a Url { + fn into_client_request(self) -> Result> { + let uri: Uri = self.as_str().parse()?; + + Ok(Request::get(uri).body(())?) + } +} + +impl IntoClientRequest for Url { + fn into_client_request(self) -> Result> { + let uri: Uri = self.as_str().parse()?; + + Ok(Request::get(uri).body(())?) + } +} + +impl IntoClientRequest for Request<()> { + fn into_client_request(self) -> Result> { + Ok(self) + } +} + +impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> { + fn into_client_request(self) -> Result> { + use crate::handshake::headers::FromHttparse; + Request::<()>::from_httparse(self) + } +} diff --git a/src/error.rs b/src/error.rs index e04c1998..ac2de26f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,6 +9,7 @@ use std::result; use std::str; use std::string; +use http; use httparse; use crate::protocol::Message; @@ -64,7 +65,7 @@ pub enum Error { /// Invalid URL. Url(Cow<'static, str>), /// HTTP error. - Http(u16), + Http(http::StatusCode), /// HTTP format error. HttpFormat(http::Error), } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 8aec86b1..a0bb951a 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -1,13 +1,11 @@ //! Client handshake machine. -use std::borrow::Cow; use std::io::{Read, Write}; use std::marker::PhantomData; -use http::HeaderMap; +use http::{HeaderMap, Request, Response, StatusCode}; use httparse::Status; use log::*; -use url::Url; use super::headers::{FromHttparse, MAX_HEADERS}; use super::machine::{HandshakeMachine, StageResult, TryParse}; @@ -15,57 +13,6 @@ use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; use crate::protocol::{Role, WebSocket, WebSocketConfig}; -/// Client request. -#[derive(Debug)] -pub struct Request<'t> { - /// `ws://` or `wss://` URL to connect to. - pub url: Url, - /// Extra HTTP headers to append to the request. - pub extra_headers: Option, Cow<'t, str>)>>, -} - -impl<'t> Request<'t> { - /// Returns the GET part of the request. - fn get_path(&self) -> String { - if let Some(query) = self.url.query() { - format!("{path}?{query}", path = self.url.path(), query = query) - } else { - self.url.path().into() - } - } - - /// Returns the host part of the request. - fn get_host(&self) -> String { - let host = self.url.host_str().expect("Bug: URL without host"); - if let Some(port) = self.url.port() { - format!("{host}:{port}", host = host, port = port) - } else { - host.into() - } - } - - /// Adds a WebSocket protocol to the request. - pub fn add_protocol(&mut self, protocol: Cow<'t, str>) { - self.add_header(Cow::from("Sec-WebSocket-Protocol"), protocol); - } - - /// Adds a custom header to the request. - pub fn add_header(&mut self, name: Cow<'t, str>, value: Cow<'t, str>) { - let mut headers = self.extra_headers.take().unwrap_or_else(Vec::new); - headers.push((name, value)); - self.extra_headers = Some(headers); - } -} - -impl From for Request<'static> { - fn from(value: Url) -> Self { - Request { - url: value, - extra_headers: None, - } - } -} - /// Client handshake role. #[derive(Debug)] pub struct ClientHandshake { @@ -78,31 +25,51 @@ impl ClientHandshake { /// Initiate a client handshake. pub fn start( stream: S, - request: Request, + request: Request<()>, config: Option, - ) -> MidHandshake { + ) -> Result> { + if request.method() != http::Method::GET { + return Err(Error::Protocol( + "Invalid HTTP method, only GET supported".into(), + )); + } + + if request.version() < http::Version::HTTP_11 { + return Err(Error::Protocol( + "HTTP version should be 1.1 or higher".into(), + )); + } + + // Check the URI scheme: only ws or wss are supported + let _ = crate::client::uri_mode(request.uri())?; + let key = generate_key(); let machine = { let mut req = Vec::new(); + let uri = request.uri(); write!( req, "\ - GET {path} HTTP/1.1\r\n\ + GET {path} {version:?}\r\n\ Host: {host}\r\n\ Connection: Upgrade\r\n\ Upgrade: websocket\r\n\ Sec-WebSocket-Version: 13\r\n\ Sec-WebSocket-Key: {key}\r\n", - host = request.get_host(), - path = request.get_path(), + version = request.version(), + host = uri + .host() + .ok_or_else(|| Error::Url("No host name in the URL".into()))?, + path = uri + .path_and_query() + .ok_or_else(|| Error::Url("No path/query in URL".into()))? + .as_str(), key = key ) .unwrap(); - if let Some(eh) = request.extra_headers { - for (k, v) in eh { - writeln!(req, "{}: {}\r", k, v).unwrap(); - } + for (k, v) in request.headers() { + writeln!(req, "{}: {}\r", k, v.to_str()?).unwrap(); } writeln!(req, "\r").unwrap(); HandshakeMachine::start_write(stream, req) @@ -118,17 +85,17 @@ impl ClientHandshake { }; trace!("Client handshake initiated."); - MidHandshake { + Ok(MidHandshake { role: client, machine, - } + }) } } impl HandshakeRole for ClientHandshake { - type IncomingData = Response; + type IncomingData = Response<()>; type InternalStream = S; - type FinalResult = (WebSocket, Response); + type FinalResult = (WebSocket, Response<()>); fn stage_finished( &mut self, finish: StageResult, @@ -160,18 +127,19 @@ struct VerifyData { } impl VerifyData { - pub fn verify_response(&self, response: &Response) -> Result<()> { + pub fn verify_response(&self, response: &Response<()>) -> Result<()> { // 1. If the status code received from the server is not 101, the // client handles the response per HTTP [RFC2616] procedures. (RFC 6455) - if response.code != 101 { - return Err(Error::Http(response.code)); + if response.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(Error::Http(response.status())); } + let headers = response.headers(); + // 2. If the response lacks an |Upgrade| header field or the |Upgrade| // header field contains a value that is not an ASCII case- // insensitive match for the value "websocket", the client MUST // _Fail the WebSocket Connection_. (RFC 6455) - if !response - .headers + if !headers .get("Upgrade") .and_then(|h| h.to_str().ok()) .map(|h| h.eq_ignore_ascii_case("websocket")) @@ -185,8 +153,7 @@ impl VerifyData { // |Connection| header field doesn't contain a token that is an // ASCII case-insensitive match for the value "Upgrade", the client // MUST _Fail the WebSocket Connection_. (RFC 6455) - if !response - .headers + if !headers .get("Connection") .and_then(|h| h.to_str().ok()) .map(|h| h.eq_ignore_ascii_case("Upgrade")) @@ -200,8 +167,7 @@ impl VerifyData { // the |Sec-WebSocket-Accept| contains a value other than the // base64-encoded SHA-1 of ... the client MUST _Fail the WebSocket // Connection_. (RFC 6455) - if !response - .headers + if !headers .get("Sec-WebSocket-Accept") .map(|h| h == &self.accept_key) .unwrap_or(false) @@ -228,16 +194,7 @@ impl VerifyData { } } -/// Server response. -#[derive(Debug)] -pub struct Response { - /// HTTP response code of the response. - pub code: u16, - /// Received headers. - pub headers: HeaderMap, -} - -impl TryParse for Response { +impl TryParse for Response<()> { fn try_parse(buf: &[u8]) -> Result> { let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut req = httparse::Response::new(&mut hbuffer); @@ -248,17 +205,24 @@ impl TryParse for Response { } } -impl<'h, 'b: 'h> FromHttparse> for Response { +impl<'h, 'b: 'h> FromHttparse> for Response<()> { fn from_httparse(raw: httparse::Response<'h, 'b>) -> Result { if raw.version.expect("Bug: no HTTP version") < /*1.*/1 { return Err(Error::Protocol( "HTTP version should be 1.1 or higher".into(), )); } - Ok(Response { - code: raw.code.expect("Bug: no HTTP response code"), - headers: HeaderMap::from_httparse(raw.headers)?, - }) + + let headers = HeaderMap::from_httparse(raw.headers)?; + + let mut response = Response::new(()); + *response.status_mut() = StatusCode::from_u16(raw.code.expect("Bug: no HTTP status code"))?; + *response.headers_mut() = headers; + // TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0 + // so the only valid value we could get in the response would be 1.1. + *response.version_mut() = http::Version::HTTP_11; + + Ok(response) } } @@ -295,7 +259,10 @@ mod tests { fn response_parsing() { const DATA: &'static [u8] = b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; let (_, resp) = Response::try_parse(DATA).unwrap().unwrap(); - assert_eq!(resp.code, 200); - assert_eq!(resp.headers.get("Content-Type").unwrap(), &b"text/html"[..],); + assert_eq!(resp.status(), http::StatusCode::OK); + assert_eq!( + resp.headers().get("Content-Type").unwrap(), + &b"text/html"[..], + ); } } diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 3f551ea0..5a5890cd 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -227,7 +227,7 @@ impl HandshakeRole for ServerHandshake { StageResult::DoneWriting(stream) => { if let Some(err) = self.error_code.take() { debug!("Server handshake failed."); - return Err(Error::Http(err)); + return Err(Error::Http(StatusCode::from_u16(err)?)); } else { debug!("Server handshake done."); let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config); @@ -240,10 +240,10 @@ impl HandshakeRole for ServerHandshake { #[cfg(test)] mod tests { - use super::super::client::Response; use super::super::machine::TryParse; use super::{HeaderMap, Request}; use http::header::HeaderName; + use http::Response; #[test] fn request_parsing() { @@ -282,9 +282,9 @@ mod tests { let reply = req.reply(Some(extra_headers)).unwrap(); let (_, req) = Response::try_parse(&reply).unwrap().unwrap(); assert_eq!( - req.headers.get("MyCustomHeader").unwrap(), + req.headers().get("MyCustomHeader").unwrap(), b"MyCustomValue".as_ref() ); - assert_eq!(req.headers.get("MyVersion").unwrap(), b"LOL".as_ref()); + assert_eq!(req.headers().get("MyVersion").unwrap(), b"LOL".as_ref()); } }