Skip to content

Commit

Permalink
Remove custom Request/Response types from client code
Browse files Browse the repository at this point in the history
  • Loading branch information
sdroege committed Nov 24, 2019
1 parent 9ae7a16 commit 82ac126
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 144 deletions.
153 changes: 107 additions & 46 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult;

use http::{Request, Response, Uri};
use log::*;

use url::Url;

use crate::handshake::client::Response;
use crate::protocol::WebSocketConfig;

#[cfg(feature = "tls")]
Expand Down Expand Up @@ -64,7 +65,7 @@ use self::encryption::wrap_stream;
pub use self::encryption::AutoStream;

use crate::error::{Error, Result};
use crate::handshake::client::{ClientHandshake, Request};
use crate::handshake::client::ClientHandshake;
use crate::handshake::HandshakeError;
use crate::protocol::WebSocket;
use crate::stream::{Mode, NoDelay};
Expand All @@ -84,37 +85,23 @@ use crate::stream::{Mode, NoDelay};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
pub fn connect_with_config<Req: IntoClientRequest>(
request: Req,
config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> {
let request: Request = request.into();
let mode = url_mode(&request.url)?;
) -> Result<(WebSocket<AutoStream>, Response<()>)> {
let request: Request<()> = request.into_client_request()?;
let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request
.url
.uri()
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = request
.url
.port_or_known_default()
.ok_or_else(|| Error::Url("No port number in the URL".into()))?;
let addrs;
let addr;
let addrs = match host {
url::Host::Domain(domain) => {
addrs = (domain, port).to_socket_addrs()?;
addrs.as_slice()
}
url::Host::Ipv4(ip) => {
addr = (ip, port).into();
std::slice::from_ref(&addr)
}
url::Host::Ipv6(ip) => {
addr = (ip, port).into();
std::slice::from_ref(&addr)
}
};
let mut stream = connect_to_some(addrs, &request.url, mode)?;
let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80,
Mode::Tls => 443,
});
let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config).map_err(|e| match e {
HandshakeError::Failure(f) => f,
Expand All @@ -134,35 +121,35 @@ pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(
pub fn connect<Req: IntoClientRequest>(
request: Req,
) -> Result<(WebSocket<AutoStream>, Response)> {
) -> Result<(WebSocket<AutoStream>, Response<()>)> {
connect_with_config(request, None)
}

fn connect_to_some(addrs: &[SocketAddr], url: &Url, mode: Mode) -> Result<AutoStream> {
let domain = url
.host_str()
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {
let domain = uri
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs {
debug!("Trying to contact {} at {}...", url, addr);
debug!("Trying to contact {} at {}...", uri, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) {
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
return Ok(stream);
}
}
}
Err(Error::Url(format!("Unable to connect to {}", url).into()))
Err(Error::Url(format!("Unable to connect to {}", uri).into()))
}

/// Get the mode of the given URL.
///
/// This function may be used to ease the creation of custom TLS streams
/// in non-blocking algorithmss or for use with TLS libraries other than `native_tls`.
pub fn url_mode(url: &Url) -> Result<Mode> {
match url.scheme() {
"ws" => Ok(Mode::Plain),
"wss" => Ok(Mode::Tls),
pub fn uri_mode(uri: &Uri) -> Result<Mode> {
match uri.scheme_str() {
Some("ws") => Ok(Mode::Plain),
Some("wss") => Ok(Mode::Tls),
_ => Err(Error::Url("URL scheme not supported".into())),
}
}
Expand All @@ -173,30 +160,104 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client_with_config<'t, Stream, Req>(
pub fn client_with_config<Stream, Req>(
request: Req,
stream: Stream,
config: Option<WebSocketConfig>,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
) -> StdResult<(WebSocket<Stream>, Response<()>), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,
Req: IntoClientRequest,
{
ClientHandshake::start(stream, request.into(), config).handshake()
ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
}

/// Do the client handshake over the given stream.
///
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(
pub fn client<Stream, Req>(
request: Req,
stream: Stream,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
) -> StdResult<(WebSocket<Stream>, Response<()>), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,
Req: IntoClientRequest,
{
client_with_config(request, stream, None)
}

/// Trait for converting various types into HTTP requests used for a client connection.
///
/// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and
/// `http::Request<()>`.
pub trait IntoClientRequest {
/// Convert into a `Request<()>` that can be used for a client connection.
fn into_client_request(self) -> Result<Request<()>>;
}

impl<'a> IntoClientRequest for &'a str {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl<'a> IntoClientRequest for &'a String {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl IntoClientRequest for String {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl<'a> IntoClientRequest for &'a Uri {
fn into_client_request(self) -> Result<Request<()>> {
Ok(Request::get(self.clone()).body(())?)
}
}

impl IntoClientRequest for Uri {
fn into_client_request(self) -> Result<Request<()>> {
Ok(Request::get(self).body(())?)
}
}

impl<'a> IntoClientRequest for &'a Url {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.as_str().parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl IntoClientRequest for Url {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.as_str().parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl IntoClientRequest for Request<()> {
fn into_client_request(self) -> Result<Request<()>> {
Ok(self)
}
}

impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
fn into_client_request(self) -> Result<Request<()>> {
use crate::handshake::headers::FromHttparse;
Request::<()>::from_httparse(self)
}
}
3 changes: 2 additions & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::result;
use std::str;
use std::string;

use http;
use httparse;

use crate::protocol::Message;
Expand Down Expand Up @@ -64,7 +65,7 @@ pub enum Error {
/// Invalid URL.
Url(Cow<'static, str>),
/// HTTP error.
Http(u16),
Http(http::StatusCode),
/// HTTP format error.
HttpFormat(http::Error),
}
Expand Down
Loading

0 comments on commit 82ac126

Please sign in to comment.