Skip to content

Commit

Permalink
Remove custom Request/Response types from client code
Browse files Browse the repository at this point in the history
  • Loading branch information
sdroege committed Nov 23, 2019
1 parent 54bacd7 commit 5f11ff7
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 147 deletions.
146 changes: 100 additions & 46 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult;

use http::{Request, Response, Uri};
use log::*;

use url::Url;

use crate::handshake::client::Response;
use crate::protocol::WebSocketConfig;

#[cfg(feature = "tls")]
Expand Down Expand Up @@ -64,7 +65,7 @@ use self::encryption::wrap_stream;
pub use self::encryption::AutoStream;

use crate::error::{Error, Result};
use crate::handshake::client::{ClientHandshake, Request};
use crate::handshake::client::ClientHandshake;
use crate::handshake::HandshakeError;
use crate::protocol::WebSocket;
use crate::stream::{Mode, NoDelay};
Expand All @@ -84,37 +85,23 @@ use crate::stream::{Mode, NoDelay};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
pub fn connect_with_config<Req: IntoClientRequest>(
request: Req,
config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> {
let request: Request = request.into();
let mode = url_mode(&request.url)?;
) -> Result<(WebSocket<AutoStream>, Response<()>)> {
let request: Request<()> = request.into_client_request()?;
let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request
.url
.uri()
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = request
.url
.port_or_known_default()
.ok_or_else(|| Error::Url("No port number in the URL".into()))?;
let addrs;
let addr;
let addrs = match host {
url::Host::Domain(domain) => {
addrs = (domain, port).to_socket_addrs()?;
addrs.as_slice()
}
url::Host::Ipv4(ip) => {
addr = (ip, port).into();
std::slice::from_ref(&addr)
}
url::Host::Ipv6(ip) => {
addr = (ip, port).into();
std::slice::from_ref(&addr)
}
};
let mut stream = connect_to_some(addrs, &request.url, mode)?;
let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80,
Mode::Tls => 443,
});
let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config).map_err(|e| match e {
HandshakeError::Failure(f) => f,
Expand All @@ -134,35 +121,35 @@ pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(
pub fn connect<Req: IntoClientRequest>(
request: Req,
) -> Result<(WebSocket<AutoStream>, Response)> {
) -> Result<(WebSocket<AutoStream>, Response<()>)> {
connect_with_config(request, None)
}

fn connect_to_some(addrs: &[SocketAddr], url: &Url, mode: Mode) -> Result<AutoStream> {
let domain = url
.host_str()
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {
let domain = uri
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs {
debug!("Trying to contact {} at {}...", url, addr);
debug!("Trying to contact {} at {}...", uri, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) {
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
return Ok(stream);
}
}
}
Err(Error::Url(format!("Unable to connect to {}", url).into()))
Err(Error::Url(format!("Unable to connect to {}", uri).into()))
}

/// Get the mode of the given URL.
///
/// This function may be used to ease the creation of custom TLS streams
/// in non-blocking algorithmss or for use with TLS libraries other than `native_tls`.
pub fn url_mode(url: &Url) -> Result<Mode> {
match url.scheme() {
"ws" => Ok(Mode::Plain),
"wss" => Ok(Mode::Tls),
pub fn uri_mode(uri: &Uri) -> Result<Mode> {
match uri.scheme_str() {
Some("ws") => Ok(Mode::Plain),
Some("wss") => Ok(Mode::Tls),
_ => Err(Error::Url("URL scheme not supported".into())),
}
}
Expand All @@ -173,30 +160,97 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client_with_config<'t, Stream, Req>(
pub fn client_with_config<Stream, Req>(
request: Req,
stream: Stream,
config: Option<WebSocketConfig>,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
) -> StdResult<(WebSocket<Stream>, Response<()>), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,
Req: IntoClientRequest,
{
ClientHandshake::start(stream, request.into(), config).handshake()
ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
}

/// Do the client handshake over the given stream.
///
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(
pub fn client<Stream, Req>(
request: Req,
stream: Stream,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
) -> StdResult<(WebSocket<Stream>, Response<()>), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,
Req: IntoClientRequest,
{
client_with_config(request, stream, None)
}

/// Trait for converting various types into HTTP requests used for a client connection.
///
/// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and
/// `http::Request<()>`.
pub trait IntoClientRequest {
/// Convert into a `Request<()>` that can be used for a client connection.
fn into_client_request(self) -> Result<Request<()>>;
}

impl<'a> IntoClientRequest for &'a str {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl<'a> IntoClientRequest for &'a String {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl IntoClientRequest for String {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl<'a> IntoClientRequest for &'a Uri {
fn into_client_request(self) -> Result<Request<()>> {
Ok(Request::get(self.clone()).body(())?)
}
}

impl IntoClientRequest for Uri {
fn into_client_request(self) -> Result<Request<()>> {
Ok(Request::get(self).body(())?)
}
}

impl<'a> IntoClientRequest for &'a Url {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.as_str().parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl IntoClientRequest for Url {
fn into_client_request(self) -> Result<Request<()>> {
let uri: Uri = self.as_str().parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl IntoClientRequest for Request<()> {
fn into_client_request(self) -> Result<Request<()>> {
Ok(self)
}
}
39 changes: 34 additions & 5 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::result;
use std::str;
use std::string;

use http;
use httparse;

use crate::protocol::Message;
Expand Down Expand Up @@ -64,7 +65,9 @@ pub enum Error {
/// Invlid URL.
Url(Cow<'static, str>),
/// HTTP error.
Http(u16),
Http(http::StatusCode),
/// HTTP format error.
HttpFormat(http::Error),
}

impl fmt::Display for Error {
Expand All @@ -80,7 +83,8 @@ impl fmt::Display for Error {
Error::SendQueueFull(_) => write!(f, "Send queue is full"),
Error::Utf8 => write!(f, "UTF-8 encoding error"),
Error::Url(ref msg) => write!(f, "URL error: {}", msg),
Error::Http(code) => write!(f, "HTTP code: {}", code),
Error::Http(code) => write!(f, "HTTP error: {}", code),
Error::HttpFormat(ref err) => write!(f, "HTTP format error: {}", err),
}
}
}
Expand All @@ -99,6 +103,7 @@ impl ErrorTrait for Error {
Error::Utf8 => "",
Error::Url(ref msg) => msg.borrow(),
Error::Http(_) => "",
Error::HttpFormat(ref err) => err.description(),
}
}
}
Expand All @@ -122,17 +127,41 @@ impl From<string::FromUtf8Error> for Error {
}

impl From<http::header::InvalidHeaderValue> for Error {
fn from(_: http::header::InvalidHeaderValue) -> Self {
Error::Utf8
fn from(err: http::header::InvalidHeaderValue) -> Self {
Error::HttpFormat(err.into())
}
}

impl From<http::header::InvalidHeaderName> for Error {
fn from(_: http::header::InvalidHeaderName) -> Self {
fn from(err: http::header::InvalidHeaderName) -> Self {
Error::HttpFormat(err.into())
}
}

impl From<http::header::ToStrError> for Error {
fn from(_: http::header::ToStrError) -> Self {
Error::Utf8
}
}

impl From<http::uri::InvalidUri> for Error {
fn from(err: http::uri::InvalidUri) -> Self {
Error::HttpFormat(err.into())
}
}

impl From<http::status::InvalidStatusCode> for Error {
fn from(err: http::status::InvalidStatusCode) -> Self {
Error::HttpFormat(err.into())
}
}

impl From<http::Error> for Error {
fn from(err: http::Error) -> Self {
Error::HttpFormat(err)
}
}

#[cfg(feature = "tls")]
impl From<tls::Error> for Error {
fn from(err: tls::Error) -> Self {
Expand Down
Loading

0 comments on commit 5f11ff7

Please sign in to comment.