Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Base HTTP-types (request, headers, response, status code, etc) on the ones from the http crate #93

Merged
merged 9 commits into from
Nov 25, 2019
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ authors = ["Alexey Galakhov"]
license = "MIT/Apache-2.0"
readme = "README.md"
homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.9.3"
documentation = "https://docs.rs/tungstenite/0.10.0"
repository = "https://github.com/snapview/tungstenite-rs"
version = "0.9.3"
version = "0.10.0"
edition = "2018"

[features]
Expand Down
4 changes: 2 additions & 2 deletions examples/autobahn-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use url::Url;

use tungstenite::{connect, Error, Message, Result};

const AGENT: &'static str = "Tungstenite";
const AGENT: &str = "Tungstenite";

fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?;
Expand Down Expand Up @@ -47,7 +47,7 @@ fn main() {

let total = get_case_count().unwrap();

for case in 1..(total + 1) {
for case in 1..=total {
if let Err(e) = run_test(case) {
match e {
Error::Protocol(_) => {}
Expand Down
14 changes: 7 additions & 7 deletions examples/callback-error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@ use std::net::TcpListener;
use std::thread::spawn;

use tungstenite::accept_hdr;
use tungstenite::handshake::server::{ErrorResponse, Request};
use tungstenite::handshake::server::{Request, Response};
use tungstenite::http::StatusCode;

fn main() {
let server = TcpListener::bind("127.0.0.1:3012").unwrap();
for stream in server.incoming() {
spawn(move || {
let callback = |_req: &Request| {
Err(ErrorResponse {
error_code: StatusCode::FORBIDDEN,
headers: None,
body: Some("Access denied".into()),
})
let callback = |_req: &Request, _resp| {
let resp = Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Some("Access denied".into()))
.unwrap();
Err(resp)
};
accept_hdr(stream.unwrap(), callback).unwrap_err();
});
Expand Down
4 changes: 2 additions & 2 deletions examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ fn main() {
connect(Url::parse("ws://localhost:3012/socket").unwrap()).expect("Can't connect");

println!("Connected to the server");
println!("Response HTTP code: {}", response.code);
println!("Response HTTP code: {}", response.status());
println!("Response contains the following headers:");
for &(ref header, _ /*value*/) in response.headers.iter() {
for (ref header, _value) in response.headers() {
println!("* {}", header);
}

Expand Down
21 changes: 9 additions & 12 deletions examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,27 @@ use std::net::TcpListener;
use std::thread::spawn;

use tungstenite::accept_hdr;
use tungstenite::handshake::server::Request;
use tungstenite::handshake::server::{Request, Response};

fn main() {
env_logger::init();
let server = TcpListener::bind("127.0.0.1:3012").unwrap();
for stream in server.incoming() {
spawn(move || {
let callback = |req: &Request| {
let callback = |req: &Request, mut response: Response| {
println!("Received a new ws handshake");
println!("The request's path is: {}", req.path);
println!("The request's path is: {}", req.uri().path());
println!("The request's headers are:");
for &(ref header, _ /* value */) in req.headers.iter() {
for (ref header, _value) in req.headers() {
println!("* {}", header);
}

// Let's add an additional header to our response to the client.
let extra_headers = vec![
(String::from("MyCustomHeader"), String::from(":)")),
(
String::from("SOME_TUNGSTENITE_HEADER"),
String::from("header_value"),
),
];
Ok(Some(extra_headers))
let headers = response.headers_mut();
headers.append("MyCustomHeader", ":)".parse().unwrap());
headers.append("SOME_TUNGSTENITE_HEADER", "header_value".parse().unwrap());

Ok(response)
};
let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap();

Expand Down
148 changes: 104 additions & 44 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult;

use http::Uri;
use log::*;

use url::Url;

use crate::handshake::client::Response;
use crate::handshake::client::{Request, Response};
use crate::protocol::WebSocketConfig;

#[cfg(feature = "tls")]
Expand Down Expand Up @@ -64,7 +66,7 @@ use self::encryption::wrap_stream;
pub use self::encryption::AutoStream;

use crate::error::{Error, Result};
use crate::handshake::client::{ClientHandshake, Request};
use crate::handshake::client::ClientHandshake;
use crate::handshake::HandshakeError;
use crate::protocol::WebSocket;
use crate::stream::{Mode, NoDelay};
Expand All @@ -84,37 +86,23 @@ use crate::stream::{Mode, NoDelay};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
pub fn connect_with_config<Req: IntoClientRequest>(
request: Req,
config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> {
let request: Request = request.into();
let mode = url_mode(&request.url)?;
let request: Request = request.into_client_request()?;
let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request
.url
.uri()
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = request
.url
.port_or_known_default()
.ok_or_else(|| Error::Url("No port number in the URL".into()))?;
let addrs;
let addr;
let addrs = match host {
url::Host::Domain(domain) => {
addrs = (domain, port).to_socket_addrs()?;
addrs.as_slice()
}
url::Host::Ipv4(ip) => {
addr = (ip, port).into();
std::slice::from_ref(&addr)
}
url::Host::Ipv6(ip) => {
addr = (ip, port).into();
std::slice::from_ref(&addr)
}
};
let mut stream = connect_to_some(addrs, &request.url, mode)?;
let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80,
Mode::Tls => 443,
});
let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config).map_err(|e| match e {
HandshakeError::Failure(f) => f,
Expand All @@ -134,35 +122,33 @@ pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(
request: Req,
) -> Result<(WebSocket<AutoStream>, Response)> {
pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> {
connect_with_config(request, None)
}

fn connect_to_some(addrs: &[SocketAddr], url: &Url, mode: Mode) -> Result<AutoStream> {
let domain = url
.host_str()
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {
let domain = uri
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs {
debug!("Trying to contact {} at {}...", url, addr);
debug!("Trying to contact {} at {}...", uri, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) {
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
return Ok(stream);
}
}
}
Err(Error::Url(format!("Unable to connect to {}", url).into()))
Err(Error::Url(format!("Unable to connect to {}", uri).into()))
}

/// Get the mode of the given URL.
///
/// This function may be used to ease the creation of custom TLS streams
/// in non-blocking algorithmss or for use with TLS libraries other than `native_tls`.
pub fn url_mode(url: &Url) -> Result<Mode> {
match url.scheme() {
"ws" => Ok(Mode::Plain),
"wss" => Ok(Mode::Tls),
pub fn uri_mode(uri: &Uri) -> Result<Mode> {
match uri.scheme_str() {
Some("ws") => Ok(Mode::Plain),
Some("wss") => Ok(Mode::Tls),
_ => Err(Error::Url("URL scheme not supported".into())),
}
}
Expand All @@ -173,30 +159,104 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client_with_config<'t, Stream, Req>(
pub fn client_with_config<Stream, Req>(
request: Req,
stream: Stream,
config: Option<WebSocketConfig>,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,
Req: IntoClientRequest,
{
ClientHandshake::start(stream, request.into(), config).handshake()
ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
}

/// Do the client handshake over the given stream.
///
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(
pub fn client<Stream, Req>(
request: Req,
stream: Stream,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,
Req: IntoClientRequest,
{
client_with_config(request, stream, None)
}

/// Trait for converting various types into HTTP requests used for a client connection.
///
/// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and
/// `http::Request<()>`.
pub trait IntoClientRequest {
/// Convert into a `Request` that can be used for a client connection.
fn into_client_request(self) -> Result<Request>;
}

impl<'a> IntoClientRequest for &'a str {
fn into_client_request(self) -> Result<Request> {
let uri: Uri = self.parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl<'a> IntoClientRequest for &'a String {
fn into_client_request(self) -> Result<Request> {
let uri: Uri = self.parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl IntoClientRequest for String {
fn into_client_request(self) -> Result<Request> {
let uri: Uri = self.parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl<'a> IntoClientRequest for &'a Uri {
fn into_client_request(self) -> Result<Request> {
Ok(Request::get(self.clone()).body(())?)
}
}

impl IntoClientRequest for Uri {
fn into_client_request(self) -> Result<Request> {
Ok(Request::get(self).body(())?)
}
}

impl<'a> IntoClientRequest for &'a Url {
fn into_client_request(self) -> Result<Request> {
let uri: Uri = self.as_str().parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl IntoClientRequest for Url {
fn into_client_request(self) -> Result<Request> {
let uri: Uri = self.as_str().parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl IntoClientRequest for Request {
fn into_client_request(self) -> Result<Request> {
Ok(self)
}
}

impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
fn into_client_request(self) -> Result<Request> {
use crate::handshake::headers::FromHttparse;
Request::from_httparse(self)
}
}
Loading