Skip to content

Commit

Permalink
Merge pull request #93 from sdroege/http-types
Browse files Browse the repository at this point in the history
Base HTTP-types (request, headers, response, status code, etc) on the ones from the http crate
  • Loading branch information
daniel-abramov authored Nov 25, 2019
2 parents bb80143 + e1a5153 commit 345d262
Show file tree
Hide file tree
Showing 14 changed files with 438 additions and 379 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ authors = ["Alexey Galakhov"]
license = "MIT/Apache-2.0"
readme = "README.md"
homepage = "https://github.com/snapview/tungstenite-rs"
documentation = "https://docs.rs/tungstenite/0.9.3"
documentation = "https://docs.rs/tungstenite/0.10.0"
repository = "https://github.com/snapview/tungstenite-rs"
version = "0.9.3"
version = "0.10.0"
edition = "2018"

[features]
Expand Down
4 changes: 2 additions & 2 deletions examples/autobahn-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use url::Url;

use tungstenite::{connect, Error, Message, Result};

const AGENT: &'static str = "Tungstenite";
const AGENT: &str = "Tungstenite";

fn get_case_count() -> Result<u32> {
let (mut socket, _) = connect(Url::parse("ws://localhost:9001/getCaseCount").unwrap())?;
Expand Down Expand Up @@ -47,7 +47,7 @@ fn main() {

let total = get_case_count().unwrap();

for case in 1..(total + 1) {
for case in 1..=total {
if let Err(e) = run_test(case) {
match e {
Error::Protocol(_) => {}
Expand Down
14 changes: 7 additions & 7 deletions examples/callback-error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@ use std::net::TcpListener;
use std::thread::spawn;

use tungstenite::accept_hdr;
use tungstenite::handshake::server::{ErrorResponse, Request};
use tungstenite::handshake::server::{Request, Response};
use tungstenite::http::StatusCode;

fn main() {
let server = TcpListener::bind("127.0.0.1:3012").unwrap();
for stream in server.incoming() {
spawn(move || {
let callback = |_req: &Request| {
Err(ErrorResponse {
error_code: StatusCode::FORBIDDEN,
headers: None,
body: Some("Access denied".into()),
})
let callback = |_req: &Request, _resp| {
let resp = Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Some("Access denied".into()))
.unwrap();
Err(resp)
};
accept_hdr(stream.unwrap(), callback).unwrap_err();
});
Expand Down
4 changes: 2 additions & 2 deletions examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ fn main() {
connect(Url::parse("ws://localhost:3012/socket").unwrap()).expect("Can't connect");

println!("Connected to the server");
println!("Response HTTP code: {}", response.code);
println!("Response HTTP code: {}", response.status());
println!("Response contains the following headers:");
for &(ref header, _ /*value*/) in response.headers.iter() {
for (ref header, _value) in response.headers() {
println!("* {}", header);
}

Expand Down
21 changes: 9 additions & 12 deletions examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,27 @@ use std::net::TcpListener;
use std::thread::spawn;

use tungstenite::accept_hdr;
use tungstenite::handshake::server::Request;
use tungstenite::handshake::server::{Request, Response};

fn main() {
env_logger::init();
let server = TcpListener::bind("127.0.0.1:3012").unwrap();
for stream in server.incoming() {
spawn(move || {
let callback = |req: &Request| {
let callback = |req: &Request, mut response: Response| {
println!("Received a new ws handshake");
println!("The request's path is: {}", req.path);
println!("The request's path is: {}", req.uri().path());
println!("The request's headers are:");
for &(ref header, _ /* value */) in req.headers.iter() {
for (ref header, _value) in req.headers() {
println!("* {}", header);
}

// Let's add an additional header to our response to the client.
let extra_headers = vec![
(String::from("MyCustomHeader"), String::from(":)")),
(
String::from("SOME_TUNGSTENITE_HEADER"),
String::from("header_value"),
),
];
Ok(Some(extra_headers))
let headers = response.headers_mut();
headers.append("MyCustomHeader", ":)".parse().unwrap());
headers.append("SOME_TUNGSTENITE_HEADER", "header_value".parse().unwrap());

Ok(response)
};
let mut websocket = accept_hdr(stream.unwrap(), callback).unwrap();

Expand Down
148 changes: 104 additions & 44 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::result::Result as StdResult;

use http::Uri;
use log::*;

use url::Url;

use crate::handshake::client::Response;
use crate::handshake::client::{Request, Response};
use crate::protocol::WebSocketConfig;

#[cfg(feature = "tls")]
Expand Down Expand Up @@ -64,7 +66,7 @@ use self::encryption::wrap_stream;
pub use self::encryption::AutoStream;

use crate::error::{Error, Result};
use crate::handshake::client::{ClientHandshake, Request};
use crate::handshake::client::ClientHandshake;
use crate::handshake::HandshakeError;
use crate::protocol::WebSocket;
use crate::stream::{Mode, NoDelay};
Expand All @@ -84,37 +86,23 @@ use crate::stream::{Mode, NoDelay};
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
pub fn connect_with_config<Req: IntoClientRequest>(
request: Req,
config: Option<WebSocketConfig>,
) -> Result<(WebSocket<AutoStream>, Response)> {
let request: Request = request.into();
let mode = url_mode(&request.url)?;
let request: Request = request.into_client_request()?;
let uri = request.uri();
let mode = uri_mode(uri)?;
let host = request
.url
.uri()
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
let port = request
.url
.port_or_known_default()
.ok_or_else(|| Error::Url("No port number in the URL".into()))?;
let addrs;
let addr;
let addrs = match host {
url::Host::Domain(domain) => {
addrs = (domain, port).to_socket_addrs()?;
addrs.as_slice()
}
url::Host::Ipv4(ip) => {
addr = (ip, port).into();
std::slice::from_ref(&addr)
}
url::Host::Ipv6(ip) => {
addr = (ip, port).into();
std::slice::from_ref(&addr)
}
};
let mut stream = connect_to_some(addrs, &request.url, mode)?;
let port = uri.port_u16().unwrap_or(match mode {
Mode::Plain => 80,
Mode::Tls => 443,
});
let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri(), mode)?;
NoDelay::set_nodelay(&mut stream, true)?;
client_with_config(request, stream, config).map_err(|e| match e {
HandshakeError::Failure(f) => f,
Expand All @@ -134,35 +122,33 @@ pub fn connect_with_config<'t, Req: Into<Request<'t>>>(
/// This function uses `native_tls` to do TLS. If you want to use other TLS libraries,
/// use `client` instead. There is no need to enable the "tls" feature if you don't call
/// `connect` since it's the only function that uses native_tls.
pub fn connect<'t, Req: Into<Request<'t>>>(
request: Req,
) -> Result<(WebSocket<AutoStream>, Response)> {
pub fn connect<Req: IntoClientRequest>(request: Req) -> Result<(WebSocket<AutoStream>, Response)> {
connect_with_config(request, None)
}

fn connect_to_some(addrs: &[SocketAddr], url: &Url, mode: Mode) -> Result<AutoStream> {
let domain = url
.host_str()
fn connect_to_some(addrs: &[SocketAddr], uri: &Uri, mode: Mode) -> Result<AutoStream> {
let domain = uri
.host()
.ok_or_else(|| Error::Url("No host name in the URL".into()))?;
for addr in addrs {
debug!("Trying to contact {} at {}...", url, addr);
debug!("Trying to contact {} at {}...", uri, addr);
if let Ok(raw_stream) = TcpStream::connect(addr) {
if let Ok(stream) = wrap_stream(raw_stream, domain, mode) {
return Ok(stream);
}
}
}
Err(Error::Url(format!("Unable to connect to {}", url).into()))
Err(Error::Url(format!("Unable to connect to {}", uri).into()))
}

/// Get the mode of the given URL.
///
/// This function may be used to ease the creation of custom TLS streams
/// in non-blocking algorithmss or for use with TLS libraries other than `native_tls`.
pub fn url_mode(url: &Url) -> Result<Mode> {
match url.scheme() {
"ws" => Ok(Mode::Plain),
"wss" => Ok(Mode::Tls),
pub fn uri_mode(uri: &Uri) -> Result<Mode> {
match uri.scheme_str() {
Some("ws") => Ok(Mode::Plain),
Some("wss") => Ok(Mode::Tls),
_ => Err(Error::Url("URL scheme not supported".into())),
}
}
Expand All @@ -173,30 +159,104 @@ pub fn url_mode(url: &Url) -> Result<Mode> {
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client_with_config<'t, Stream, Req>(
pub fn client_with_config<Stream, Req>(
request: Req,
stream: Stream,
config: Option<WebSocketConfig>,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,
Req: IntoClientRequest,
{
ClientHandshake::start(stream, request.into(), config).handshake()
ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
}

/// Do the client handshake over the given stream.
///
/// Use this function if you need a nonblocking handshake support or if you
/// want to use a custom stream like `mio::tcp::TcpStream` or `openssl::ssl::SslStream`.
/// Any stream supporting `Read + Write` will do.
pub fn client<'t, Stream, Req>(
pub fn client<Stream, Req>(
request: Req,
stream: Stream,
) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
where
Stream: Read + Write,
Req: Into<Request<'t>>,
Req: IntoClientRequest,
{
client_with_config(request, stream, None)
}

/// Trait for converting various types into HTTP requests used for a client connection.
///
/// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and
/// `http::Request<()>`.
pub trait IntoClientRequest {
/// Convert into a `Request` that can be used for a client connection.
fn into_client_request(self) -> Result<Request>;
}

impl<'a> IntoClientRequest for &'a str {
fn into_client_request(self) -> Result<Request> {
let uri: Uri = self.parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl<'a> IntoClientRequest for &'a String {
fn into_client_request(self) -> Result<Request> {
let uri: Uri = self.parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl IntoClientRequest for String {
fn into_client_request(self) -> Result<Request> {
let uri: Uri = self.parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl<'a> IntoClientRequest for &'a Uri {
fn into_client_request(self) -> Result<Request> {
Ok(Request::get(self.clone()).body(())?)
}
}

impl IntoClientRequest for Uri {
fn into_client_request(self) -> Result<Request> {
Ok(Request::get(self).body(())?)
}
}

impl<'a> IntoClientRequest for &'a Url {
fn into_client_request(self) -> Result<Request> {
let uri: Uri = self.as_str().parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl IntoClientRequest for Url {
fn into_client_request(self) -> Result<Request> {
let uri: Uri = self.as_str().parse()?;

Ok(Request::get(uri).body(())?)
}
}

impl IntoClientRequest for Request {
fn into_client_request(self) -> Result<Request> {
Ok(self)
}
}

impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
fn into_client_request(self) -> Result<Request> {
use crate::handshake::headers::FromHttparse;
Request::from_httparse(self)
}
}
Loading

0 comments on commit 345d262

Please sign in to comment.