diff --git a/Cargo.lock b/Cargo.lock index d139ea00..cd7ae25b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2762,6 +2762,8 @@ dependencies = [ "env_logger", "flate2", "form_urlencoded", + "futures-core", + "http", "http-body-util", "hyper", "hyper-util", @@ -2777,6 +2779,7 @@ dependencies = [ "once_cell", "os_display", "pem", + "pin-project-lite", "predicates", "rand", "regex-lite", diff --git a/Cargo.toml b/Cargo.toml index 1ba36c7a..6935e9b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,8 +27,11 @@ dirs = "5.0" encoding_rs = "0.8.28" encoding_rs_io = "0.1.7" flate2 = "1.0.22" +futures-core = { version = "0.3.28", default-features = false } +http = "1" # Add "tracing" feature to hyper once it stabilizes hyper = { version = "1.2", default-features = false } +hyper-util = { version = "0.1", features = ["tokio"] } indicatif = "0.17" jsonxf = "1.1.0" memchr = "2.4.1" @@ -38,6 +41,7 @@ mime_guess = "2.0" once_cell = "1.8.0" os_display = "0.1.3" pem = "3.0" +pin-project-lite = "0.2" regex-lite = "0.1.5" roff = "0.2.1" rpassword = "7.2.0" @@ -48,6 +52,7 @@ serde_urlencoded = "0.7.0" supports-hyperlinks = "3.0.0" termcolor = "1.1.2" time = "0.3.16" +tokio = { version = "1", features = ["rt-multi-thread"] } unicode-width = "0.1.9" url = "2.2.2" ruzstd = { version = "0.7", default-features = false, features = ["std"]} diff --git a/src/cli.rs b/src/cli.rs index 8d227a18..840ee9fb 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -347,6 +347,16 @@ Example: --print=Hb" #[clap(short = '6', long)] pub ipv6: bool, + /// Connect using a Unix domain socket. + /// + /// Example: xh :/index.html --unix-socket=/var/run/temp.sock + #[clap( + long, + value_name = "FILE", + conflicts_with_all=["proxy", "verify", "cert", "cert_key", "ssl", "resolve", "interface", "ipv4", "ipv6", "https", "http_version"] + )] + pub unix_socket: Option, + /// Do not attempt to read stdin. /// /// This disables the default behaviour of reading the request body from stdin @@ -1011,7 +1021,7 @@ impl FromStr for Print { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct Timeout(Duration); impl Timeout { diff --git a/src/cookie.rs b/src/cookie.rs new file mode 100644 index 00000000..1ffa2dd2 --- /dev/null +++ b/src/cookie.rs @@ -0,0 +1,44 @@ +use std::sync::Arc; + +use anyhow::Result; +use reqwest::{ + blocking::{Request, Response}, + cookie::CookieStore, + header, +}; + +use crate::middleware::{Context, Middleware}; + +pub struct CookieMiddleware(Arc); + +impl CookieMiddleware { + pub fn new(cookie_jar: Arc) -> Self { + CookieMiddleware(cookie_jar) + } +} + +impl Middleware for CookieMiddleware { + fn handle(&mut self, mut ctx: Context, mut request: Request) -> Result { + let url = request.url().clone(); + + if let Some(header) = self.0.cookies(&url) { + request + .headers_mut() + .entry(header::COOKIE) + .or_insert(header); + } + + let response = self.next(&mut ctx, request)?; + + let mut cookies = response + .headers() + .get_all(header::SET_COOKIE) + .iter() + .peekable(); + if cookies.peek().is_some() { + self.0.set_cookies(&mut cookies, &url); + } + + Ok(response) + } +} diff --git a/src/main.rs b/src/main.rs index 2caac847..7dad0de6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ mod auth; mod buffer; mod cli; +mod cookie; mod decoder; mod download; mod formatting; @@ -15,6 +16,8 @@ mod redirect; mod request_items; mod session; mod to_curl; +#[cfg(unix)] +mod unix_socket; mod utils; use std::env; @@ -27,6 +30,7 @@ use std::str::FromStr; use std::sync::Arc; use anyhow::{anyhow, Context, Result}; +use cookie::CookieMiddleware; use cookie_store::{CookieStore, RawCookie}; use redirect::RedirectFollower; use reqwest::blocking::Client; @@ -35,7 +39,6 @@ use reqwest::header::{ }; use reqwest::tls; use url::Host; -use utils::reason_phrase; use crate::auth::{Auth, DigestAuthMiddleware}; use crate::buffer::Buffer; @@ -45,7 +48,7 @@ use crate::middleware::ClientWithMiddleware; use crate::printer::Printer; use crate::request_items::{Body, FORM_CONTENT_TYPE, JSON_ACCEPT, JSON_CONTENT_TYPE}; use crate::session::Session; -use crate::utils::{test_mode, test_pretend_term, url_with_query}; +use crate::utils::{reason_phrase, test_mode, test_pretend_term, url_with_query}; #[cfg(not(any(feature = "native-tls", feature = "rustls")))] compile_error!("Either native-tls or rustls feature must be enabled!"); @@ -86,6 +89,15 @@ fn main() { eprintln!(); eprintln!("Try running without the --native-tls flag."); } + if msg.starts_with("deadline has elapsed") { + process::exit(2); + } + #[cfg(unix)] + { + if err.downcast_ref::().is_some() { + process::exit(2); + } + } if let Some(err) = err.downcast_ref::() { if err.is_timeout() { process::exit(2); @@ -286,9 +298,6 @@ fn run(args: Cli) -> Result { None => client, }; - let cookie_jar = Arc::new(reqwest_cookie_store::CookieStoreMutex::default()); - client = client.cookie_provider(cookie_jar.clone()); - client = match (args.ipv4, args.ipv6) { (true, false) => client.local_address(IpAddr::from(Ipv4Addr::UNSPECIFIED)), (false, true) => client.local_address(IpAddr::from(Ipv6Addr::UNSPECIFIED)), @@ -340,6 +349,8 @@ fn run(args: Cli) -> Result { log::trace!("{client:#?}"); let client = client.build()?; + let cookie_jar = Arc::new(reqwest_cookie_store::CookieStoreMutex::default()); + let mut session = match &args.session { Some(name_or_path) => Some( Session::load_session(url.clone(), name_or_path.clone(), args.is_session_read_only) @@ -561,42 +572,55 @@ fn run(args: Cli) -> Result { printer.print_request_body(&mut request)?; } + let mut client = ClientWithMiddleware::new(client); + if !args.offline { let mut response = { let history_print = args.history_print.unwrap_or(print); - let mut client = ClientWithMiddleware::new(&client); - if args.all { - client = client.with_printer(|prev_response, next_request| { - if history_print.response_headers { - printer.print_response_headers(prev_response)?; - } - if history_print.response_body { - printer.print_response_body( - prev_response, - response_charset, - response_mime, - )?; - printer.print_separator()?; - } - if history_print.response_meta { - printer.print_response_meta(prev_response)?; - } - if history_print.request_headers { - printer.print_request_headers(next_request, &*cookie_jar)?; - } - if history_print.request_body { - printer.print_request_body(next_request)?; - } - Ok(()) - }); - } if args.follow { client = client.with(RedirectFollower::new(args.max_redirects.unwrap_or(10))); } if let Some(Auth::Digest(username, password)) = &auth { client = client.with(DigestAuthMiddleware::new(username, password)); } - client.execute(request)? + client = client.with(CookieMiddleware::new(cookie_jar.clone())); + if let Some(socket_path) = args.unix_socket { + #[cfg(not(unix))] + { + return Err(anyhow::anyhow!( + "HTTP over Unix domain sockets is not supported on this platform" + )); + } + #[cfg(unix)] + { + client = client.with_unix_socket( + socket_path, + args.timeout.and_then(|t| t.as_duration()), + )?; + } + } + client.execute(request, |prev_response, next_request| { + if !args.all { + return Ok(()); + } + if history_print.response_headers { + printer.print_response_headers(prev_response)?; + } + if history_print.response_body { + printer.print_response_body(prev_response, response_charset, response_mime)?; + printer.print_separator()?; + } + if history_print.response_meta { + printer.print_response_meta(prev_response)?; + } + if history_print.request_headers { + printer.print_request_headers(next_request, &*cookie_jar)?; + } + if history_print.request_body { + printer.print_request_body(next_request)?; + } + Ok(()) + })? }; let status = response.status(); diff --git a/src/middleware.rs b/src/middleware.rs index d60c24a2..d3bac384 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -1,7 +1,10 @@ -use std::time::{Duration, Instant}; +use std::{ + path::PathBuf, + time::{Duration, Instant}, +}; use anyhow::Result; -use reqwest::blocking::{Client, Request, Response}; +use reqwest::blocking::{Request, Response}; #[derive(Clone)] pub struct ResponseMeta { @@ -24,18 +27,18 @@ impl ResponseExt for Response { } } -type Printer<'a, 'b> = &'a mut (dyn FnMut(&mut Response, &mut Request) -> Result<()> + 'b); +type Printer<'a> = &'a mut (dyn FnMut(&mut Response, &mut Request) -> Result<()> + 'a); -pub struct Context<'a, 'b> { - client: &'a Client, - printer: Option>, +pub struct Context<'a, 'b, 'c, 'd> { + client: &'d Client, + printer: Printer<'c>, middlewares: &'a mut [Box], } -impl<'a, 'b> Context<'a, 'b> { +impl<'a, 'b, 'c, 'd> Context<'a, 'b, 'c, 'd> { fn new( - client: &'a Client, - printer: Option>, + client: &'d Client, + printer: Printer<'c>, middlewares: &'a mut [Box], ) -> Self { Context { @@ -49,18 +52,20 @@ impl<'a, 'b> Context<'a, 'b> { match self.middlewares { [] => { let starting_time = Instant::now(); - let mut response = self.client.execute(request)?; + let mut response = match self.client { + Client::Http(client) => client.execute(request)?, + #[cfg(unix)] + Client::Unix(client) => client.execute(request)?, + }; response.extensions_mut().insert(ResponseMeta { request_duration: starting_time.elapsed(), content_download_duration: None, }); Ok(response) } - [ref mut head, tail @ ..] => head.handle( - #[allow(clippy::needless_option_as_deref)] - Context::new(self.client, self.printer.as_deref_mut(), tail), - request, - ), + [ref mut head, tail @ ..] => { + head.handle(Context::new(self.client, self.printer, tail), request) + } } } } @@ -78,38 +83,38 @@ pub trait Middleware { response: &mut Response, request: &mut Request, ) -> Result<()> { - if let Some(ref mut printer) = ctx.printer { - printer(response, request)?; - } - + (ctx.printer)(response, request)?; Ok(()) } } -pub struct ClientWithMiddleware<'a, T> -where - T: FnMut(&mut Response, &mut Request) -> Result<()>, -{ - client: &'a Client, - printer: Option, +enum Client { + Http(reqwest::blocking::Client), + #[cfg(unix)] + Unix(crate::unix_socket::UnixClient), +} + +pub struct ClientWithMiddleware<'a> { + client: Client, middlewares: Vec>, } -impl<'a, T> ClientWithMiddleware<'a, T> -where - T: FnMut(&mut Response, &mut Request) -> Result<()> + 'a, -{ - pub fn new(client: &'a Client) -> Self { +impl<'a> ClientWithMiddleware<'a> { + pub fn new(client: reqwest::blocking::Client) -> Self { ClientWithMiddleware { - client, - printer: None, + client: Client::Http(client), middlewares: vec![], } } - pub fn with_printer(mut self, printer: T) -> Self { - self.printer = Some(printer); - self + #[cfg(unix)] + pub fn with_unix_socket( + mut self, + socket_path: PathBuf, + timeout: Option, + ) -> Result { + self.client = Client::Unix(crate::unix_socket::UnixClient::new(socket_path, timeout)); + Ok(self) } pub fn with(mut self, middleware: impl Middleware + 'a) -> Self { @@ -117,12 +122,11 @@ where self } - pub fn execute(&mut self, request: Request) -> Result { - let mut ctx = Context::new( - self.client, - self.printer.as_mut().map(|p| p as _), - &mut self.middlewares[..], - ); + pub fn execute<'b, T>(&mut self, request: Request, mut printer: T) -> Result + where + T: FnMut(&mut Response, &mut Request) -> Result<()> + 'b, + { + let mut ctx = Context::new(&self.client, &mut printer, &mut self.middlewares[..]); ctx.execute(request) } } diff --git a/src/printer.rs b/src/printer.rs index 729bba26..cd4eb822 100644 --- a/src/printer.rs +++ b/src/printer.rs @@ -20,7 +20,7 @@ use crate::{ formatting::serde_json_format, formatting::{get_json_formatter, Highlighter}, middleware::ResponseExt, - utils::{copy_largebuf, test_mode, BUFFER_SIZE}, + utils::{copy_largebuf, BUFFER_SIZE}, }; const BINARY_SUPPRESSOR: &str = concat!( @@ -345,9 +345,7 @@ impl Printer { // even know if we're going to use HTTP/2 yet. headers.entry(HOST).or_insert_with(|| { // Added at https://github.com/hyperium/hyper-util/blob/53aadac50d/src/client/legacy/client.rs#L278 - if test_mode() { - HeaderValue::from_str("http.mock") - } else if let Some(port) = request.url().port() { + if let Some(port) = request.url().port() { HeaderValue::from_str(&format!("{}:{}", host, port)) } else { HeaderValue::from_str(host) diff --git a/src/to_curl.rs b/src/to_curl.rs index 744970e4..a6f0f3d7 100644 --- a/src/to_curl.rs +++ b/src/to_curl.rs @@ -299,6 +299,11 @@ pub fn translate(args: Cli) -> Result { cmd.arg(interface); }; + if let Some(unix_socket) = args.unix_socket { + cmd.arg("--unix-socket"); + cmd.arg(unix_socket); + } + if !args.resolve.is_empty() { let port = url .port_or_known_default() diff --git a/src/unix_socket.rs b/src/unix_socket.rs new file mode 100644 index 00000000..baae32f5 --- /dev/null +++ b/src/unix_socket.rs @@ -0,0 +1,168 @@ +use std::future::Future; +use std::path::PathBuf; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +use anyhow::{anyhow, Result}; +use pin_project_lite::pin_project; +use reqwest::blocking::{Request, Response}; +use reqwest::header::{HeaderValue, HOST}; + +pub struct UnixClient { + rt: tokio::runtime::Runtime, + socket_path: PathBuf, + timeout: Option, +} + +impl UnixClient { + pub fn new(socket_path: PathBuf, timeout: Option) -> Self { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + Self { + rt, + socket_path, + timeout, + } + } + + async fn connect(&self) -> Result> { + // TODO: Add support for Windows named pipes by replacing UnixStream with namedPipeClient. + // See https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.ClientOptions.html#method.open + let stream = tokio::net::UnixStream::connect(&self.socket_path).await?; + let (sender, conn) = hyper::client::conn::http1::Builder::new() + .title_case_headers(true) + .handshake(hyper_util::rt::TokioIo::new(stream)) + .await?; + + tokio::task::spawn(async move { + if let Err(err) = conn.await { + log::error!("Connection failed: {:?}", err); + } + }); + + Ok(sender) + } + + pub fn execute(&self, request: Request) -> Result { + self.rt.block_on(async { + let http_request = into_async_request(request)?; + + let mut sender = with_timeout(self.connect(), self.timeout).await??; + let response = with_timeout(sender.send_request(http_request), self.timeout).await??; + + Ok(Response::from(response.map(|body| { + if let Some(timeout) = self.timeout { + reqwest::Body::wrap(TotalTimeoutBody::new(body, timeout)) + } else { + reqwest::Body::wrap(body) + } + }))) + }) + } +} + +fn into_async_request(mut request: Request) -> Result> { + let mut http_request = http::Request::builder() + .version(request.version()) + .method(request.method()) + .uri(request.url().as_str()) + .body(reqwest::Body::default())?; + + *http_request.headers_mut() = request.headers_mut().clone(); + + if let Some(host) = request.url().host_str() { + http_request.headers_mut().entry(HOST).or_insert_with(|| { + if let Some(port) = request.url().port() { + HeaderValue::from_str(&format!("{}:{}", host, port)) + } else { + HeaderValue::from_str(host) + } + .expect("hostname should already be validated/parsed") + }); + } + + if let Some(body) = request.body_mut().as_mut() { + *http_request.body_mut() = reqwest::Body::from(body.buffer()?.to_owned()); + } + + Ok(http_request) +} + +async fn with_timeout(fut: F, timeout: Option) -> Result +where + F: std::future::IntoFuture, +{ + if let Some(timeout) = timeout { + tokio::time::timeout(timeout, fut) + .await + .map_err(|_| anyhow!(TimeoutError)) + } else { + Ok(fut.await) + } +} + +#[derive(Debug, Clone)] +pub struct TimeoutError; + +impl std::fmt::Display for TimeoutError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "operation timed out") + } +} + +// Copied from https://github.com/seanmonstar/reqwest/blob/8b8fdd2552ad645c7e9dd494930b3e95e2aedef2/src/async_impl/body.rs#L314 +// with some slight tweaks +pin_project! { + pub(crate) struct TotalTimeoutBody { + #[pin] + inner: B, + timeout: Pin>, + } +} + +impl TotalTimeoutBody { + fn new(body: B, timeout: Duration) -> TotalTimeoutBody { + let total_timeout = Box::pin(tokio::time::sleep(timeout)); + TotalTimeoutBody { + inner: body, + timeout: total_timeout, + } + } +} + +impl hyper::body::Body for TotalTimeoutBody +where + B: hyper::body::Body, + B::Error: Into>, +{ + type Data = B::Data; + type Error = anyhow::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll, Self::Error>>> { + let this = self.project(); + if let Poll::Ready(()) = this.timeout.as_mut().poll(cx) { + return Poll::Ready(Some(Err(anyhow!(TimeoutError)))); + } + Poll::Ready( + futures_core::ready!(this.inner.poll_frame(cx)) + .map(|opt_chunk| opt_chunk.map_err(|e| anyhow!(e.into()))), + ) + } + + #[inline] + fn size_hint(&self) -> hyper::body::SizeHint { + self.inner.size_hint() + } + + #[inline] + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } +} diff --git a/tests/cases/http_unix.rs b/tests/cases/http_unix.rs new file mode 100644 index 00000000..8c2012eb --- /dev/null +++ b/tests/cases/http_unix.rs @@ -0,0 +1,217 @@ +#[cfg(unix)] +use indoc::indoc; +use predicates::str::contains; + +use crate::prelude::*; + +#[cfg(not(unix))] +#[test] +fn error_on_unsupported_platform() { + get_command() + .arg(format!("--unix-socket=/tmp/missing.sock",)) + .arg(":/index.html") + .assert() + .failure() + .stderr(contains( + "HTTP over Unix domain sockets is not supported on this platform", + )); +} + +#[cfg(unix)] +#[test] +fn json_post() { + let server = server::http_unix(|req| async move { + assert_eq!(req.method(), "POST"); + assert_eq!(req.headers()["Content-Type"], "application/json"); + assert_eq!(req.headers()["Host"], "example.com"); + assert_eq!(req.body_as_string().await, "{\"foo\":\"bar\"}"); + + hyper::Response::builder() + .header(hyper::header::CONTENT_TYPE, "application/json") + .body(r#"{"status":"ok"}"#.into()) + .unwrap() + }); + + get_command() + .arg("--print=b") + .arg("--pretty=format") + .arg("post") + .arg("http://example.com") + .arg(format!( + "--unix-socket={}", + server.socket_path().to_string_lossy() + )) + .arg("foo=bar") + .assert() + .stdout(indoc! {r#" + { + "status": "ok" + } + + + "#}); +} + +#[cfg(unix)] +#[test] +fn redirects_stay_on_same_server() { + let server = server::http_unix(|req| async move { + match dbg!(req.uri().to_string().as_str()) { + "http://example.com/first_page" => hyper::Response::builder() + .status(302) + .header("Date", "N/A") + .header("Location", "http://localhost:8000/second_page") + .body("redirecting...".into()) + .unwrap(), + "http://localhost:8000/second_page" => hyper::Response::builder() + .status(302) + .header("Date", "N/A") + .header("Location", "/third_page") + .body("redirecting...".into()) + .unwrap(), + "http://localhost:8000/third_page" => hyper::Response::builder() + .header("Date", "N/A") + .body("final destination".into()) + .unwrap(), + _ => panic!("unknown path"), + } + }); + + get_command() + .arg("http://example.com/first_page") + .arg(format!( + "--unix-socket={}", + server.socket_path().to_string_lossy() + )) + .arg("--follow") + .arg("--verbose") + .arg("--all") + .assert() + .stdout(indoc! {r#" + GET /first_page HTTP/1.1 + Accept: */* + Accept-Encoding: gzip, deflate, br, zstd + Connection: keep-alive + Host: example.com + User-Agent: xh/0.0.0 (test mode) + + HTTP/1.1 302 Found + Content-Length: 14 + Date: N/A + Location: http://localhost:8000/second_page + + redirecting... + + GET /second_page HTTP/1.1 + Accept: */* + Accept-Encoding: gzip, deflate, br, zstd + Connection: keep-alive + Host: localhost:8000 + User-Agent: xh/0.0.0 (test mode) + + HTTP/1.1 302 Found + Content-Length: 14 + Date: N/A + Location: /third_page + + redirecting... + + GET /third_page HTTP/1.1 + Accept: */* + Accept-Encoding: gzip, deflate, br, zstd + Connection: keep-alive + Host: localhost:8000 + User-Agent: xh/0.0.0 (test mode) + + HTTP/1.1 200 OK + Content-Length: 17 + Date: N/A + + final destination + "#}); + + server.assert_hits(3); +} + +#[cfg(unix)] +#[test] +fn cookies_persist_across_redirects() { + let server = server::http_unix(|req| async move { + match req.uri().path() { + "/first_page" => hyper::Response::builder() + .status(302) + .header("Date", "N/A") + .header("Location", "/second_page") + .header("set-cookie", "hello=world") + .body("redirecting...".into()) + .unwrap(), + "/second_page" => hyper::Response::builder() + .header("Date", "N/A") + .body("final destination".into()) + .unwrap(), + _ => panic!("unknown path"), + } + }); + + get_command() + .arg("localhost:3000/first_page") + .arg(format!( + "--unix-socket={}", + server.socket_path().to_string_lossy() + )) + .arg("--follow") + .arg("--verbose") + .arg("--all") + .assert() + .stdout(indoc! {r#" + GET /first_page HTTP/1.1 + Accept: */* + Accept-Encoding: gzip, deflate, br, zstd + Connection: keep-alive + Host: localhost:3000 + User-Agent: xh/0.0.0 (test mode) + + HTTP/1.1 302 Found + Content-Length: 14 + Date: N/A + Location: /second_page + Set-Cookie: hello=world + + redirecting... + + GET /second_page HTTP/1.1 + Accept: */* + Accept-Encoding: gzip, deflate, br, zstd + Connection: keep-alive + Cookie: hello=world + Host: localhost:3000 + User-Agent: xh/0.0.0 (test mode) + + HTTP/1.1 200 OK + Content-Length: 17 + Date: N/A + + final destination + "#}); +} + +#[cfg(unix)] +#[test] +fn timeout() { + let mut server = server::http_unix(|_req| async move { + tokio::time::sleep(std::time::Duration::from_secs_f32(0.5)).await; + hyper::Response::default() + }); + server.disable_hit_checks(); + + get_command() + .arg(":") + .arg(format!( + "--unix-socket={}", + server.socket_path().to_string_lossy() + )) + .arg("--timeout=0.1") + .assert() + .code(2) + .stderr(contains("operation timed out")); +} diff --git a/tests/cases/mod.rs b/tests/cases/mod.rs index bd9fa8b8..b8433031 100644 --- a/tests/cases/mod.rs +++ b/tests/cases/mod.rs @@ -1,2 +1,3 @@ mod download; +mod http_unix; mod logging; diff --git a/tests/cli.rs b/tests/cli.rs index c2008a64..adc861da 100644 --- a/tests/cli.rs +++ b/tests/cli.rs @@ -15,7 +15,7 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; use assert_cmd::cmd::Command; use http_body_util::BodyExt; -use indoc::indoc; +use indoc::{formatdoc, indoc}; use predicates::function::function; use predicates::str::contains; use reqwest::header::HeaderValue; @@ -429,19 +429,19 @@ fn verbose() { get_command() .args(["--verbose", &server.base_url(), "x=y"]) .assert() - .stdout(indoc! {r#" + .stdout(formatdoc! {r#" POST / HTTP/1.1 Accept: application/json, */*;q=0.5 Accept-Encoding: gzip, deflate, br, zstd Connection: keep-alive Content-Length: 9 Content-Type: application/json - Host: http.mock + Host: 127.0.0.1:{port} User-Agent: xh/0.0.0 (test mode) - { + {{ "x": "y" - } + }} @@ -451,7 +451,7 @@ fn verbose() { X-Foo: Bar a body - "#}); + "#, port = server.port() }); } #[test] @@ -836,12 +836,12 @@ fn digest_auth_with_redirection() { .arg("--verbose") .arg(server.url("/login_page")) .assert() - .stdout(indoc! {r#" + .stdout(formatdoc! {r#" GET /login_page HTTP/1.1 Accept: */* Accept-Encoding: gzip, deflate, br, zstd Connection: keep-alive - Host: http.mock + Host: 127.0.0.1:{port} User-Agent: xh/0.0.0 (test mode) HTTP/1.1 401 Unauthorized @@ -856,7 +856,7 @@ fn digest_auth_with_redirection() { Accept-Encoding: gzip, deflate, br, zstd Authorization: Digest username="ahmed", realm="me@xh.com", nonce="e5051361f053723a807674177fc7022f", uri="/login_page", qop=auth, nc=00000001, cnonce="f2/wE4q74E6zIJEtWaHKaf5wv/H5QzzpXusqGemxURZJ", response="894fd5ee1dcc702df7e4a6abed37fd56", opaque="9dcf562038f1ec1c8d02f218ef0e7a4b", algorithm=MD5 Connection: keep-alive - Host: http.mock + Host: 127.0.0.1:{port} User-Agent: xh/0.0.0 (test mode) HTTP/1.1 302 Found @@ -870,7 +870,7 @@ fn digest_auth_with_redirection() { Accept: */* Accept-Encoding: gzip, deflate, br, zstd Connection: keep-alive - Host: http.mock + Host: 127.0.0.1:{port} User-Agent: xh/0.0.0 (test mode) HTTP/1.1 200 OK @@ -878,7 +878,7 @@ fn digest_auth_with_redirection() { Date: N/A admin page - "#}); + "#, port = server.port() }); server.assert_hits(3); } @@ -1626,12 +1626,12 @@ fn redirect_support_utf8_location() { get_command() .args([&server.url("/first_page"), "--follow", "--verbose", "--all"]) .assert() - .stdout(indoc! {r#" + .stdout(formatdoc! {r#" GET /first_page HTTP/1.1 Accept: */* Accept-Encoding: gzip, deflate, br, zstd Connection: keep-alive - Host: http.mock + Host: 127.0.0.1:{port} User-Agent: xh/0.0.0 (test mode) HTTP/1.1 302 Found @@ -1645,7 +1645,7 @@ fn redirect_support_utf8_location() { Accept: */* Accept-Encoding: gzip, deflate, br, zstd Connection: keep-alive - Host: http.mock + Host: 127.0.0.1:{port} User-Agent: xh/0.0.0 (test mode) HTTP/1.1 200 OK @@ -1653,7 +1653,7 @@ fn redirect_support_utf8_location() { Date: N/A final destination - "#}); + "#, port = server.port() }); } #[test] @@ -2048,7 +2048,7 @@ fn can_unset_default_headers() { Accept: */* Accept-Encoding: gzip, deflate, br, zstd Connection: keep-alive - Host: http.mock + Host: localhost "#}); } @@ -2064,7 +2064,7 @@ fn can_unset_headers() { Accept-Encoding: gzip, deflate, br, zstd Connection: keep-alive Hello: world - Host: http.mock + Host: localhost User-Agent: xh/0.0.0 (test mode) "#}); @@ -2081,7 +2081,7 @@ fn can_set_unset_header() { Accept-Encoding: gzip, deflate, br, zstd Connection: keep-alive Hello: world - Host: http.mock + Host: localhost User-Agent: xh/0.0.0 (test mode) "#}); @@ -2857,12 +2857,12 @@ fn print_intermediate_requests_and_responses() { get_command() .args([&server.url("/first_page"), "--follow", "--verbose", "--all"]) .assert() - .stdout(indoc! {r#" + .stdout(formatdoc! {r#" GET /first_page HTTP/1.1 Accept: */* Accept-Encoding: gzip, deflate, br, zstd Connection: keep-alive - Host: http.mock + Host: 127.0.0.1:{port} User-Agent: xh/0.0.0 (test mode) HTTP/1.1 302 Found @@ -2876,7 +2876,7 @@ fn print_intermediate_requests_and_responses() { Accept: */* Accept-Encoding: gzip, deflate, br, zstd Connection: keep-alive - Host: http.mock + Host: 127.0.0.1:{port} User-Agent: xh/0.0.0 (test mode) HTTP/1.1 200 OK @@ -2884,7 +2884,7 @@ fn print_intermediate_requests_and_responses() { Date: N/A final destination - "#}); + "#, port = server.port() }); } #[test] @@ -2912,12 +2912,12 @@ fn history_print() { .arg("--history-print=Hh") .arg("--all") .assert() - .stdout(indoc! {r#" + .stdout(formatdoc! {r#" GET /first_page HTTP/1.1 Accept: */* Accept-Encoding: gzip, deflate, br, zstd Connection: keep-alive - Host: http.mock + Host: 127.0.0.1:{port} User-Agent: xh/0.0.0 (test mode) HTTP/1.1 302 Found @@ -2929,7 +2929,7 @@ fn history_print() { Accept: */* Accept-Encoding: gzip, deflate, br, zstd Connection: keep-alive - Host: http.mock + Host: 127.0.0.1:{port} User-Agent: xh/0.0.0 (test mode) HTTP/1.1 200 OK @@ -2937,7 +2937,7 @@ fn history_print() { Date: N/A final destination - "#}); + "#, port = server.port() }); } #[test] diff --git a/tests/server/mod.rs b/tests/server/mod.rs index 9ca6c604..97f81579 100644 --- a/tests/server/mod.rs +++ b/tests/server/mod.rs @@ -2,7 +2,7 @@ // with some slight tweaks use std::convert::Infallible; use std::future::Future; -use std::net; +use std::path::PathBuf; use std::sync::mpsc as std_mpsc; use std::sync::{Arc, Mutex}; use std::thread; @@ -18,8 +18,20 @@ use tokio::sync::oneshot; type Body = Full; type Builder = hyper_util::server::conn::auto::Builder; +enum Addr { + TcpAddr(std::net::SocketAddr), + #[cfg(unix)] + UnixAddr(tokio::net::unix::SocketAddr), +} + +enum Listener { + TcpListener(tokio::net::TcpListener), + #[cfg(unix)] + UnixListener(tokio::net::UnixListener), +} + pub struct Server { - addr: net::SocketAddr, + addr: Addr, panic_rx: std_mpsc::Receiver<()>, successful_hits: Arc>, total_hits: Arc>, @@ -29,19 +41,43 @@ pub struct Server { impl Server { pub fn base_url(&self) -> String { - format!("http://{}", self.addr) + match self.addr { + Addr::TcpAddr(addr) => format!("http://{}", addr), + #[cfg(unix)] + _ => panic!("no base_url for unix server"), + } } pub fn url(&self, path: &str) -> String { - format!("http://{}{}", self.addr, path) + match self.addr { + Addr::TcpAddr(addr) => format!("http://{}{}", addr, path), + #[cfg(unix)] + _ => panic!("no url for unix server"), + } } pub fn host(&self) -> String { - String::from("127.0.0.1") + match self.addr { + Addr::TcpAddr(_) => String::from("127.0.0.1"), + #[cfg(unix)] + _ => panic!("no host for unix server"), + } + } + + #[cfg(unix)] + pub fn socket_path(&self) -> PathBuf { + match &self.addr { + Addr::UnixAddr(addr) => addr.as_pathname().unwrap().to_path_buf(), + _ => panic!("no socket_path for tcp server"), + } } pub fn port(&self) -> u16 { - self.addr.port() + match self.addr { + Addr::TcpAddr(addr) => addr.port(), + #[cfg(unix)] + _ => panic!("no port for unix server"), + } } pub fn assert_hits(&self, hits: u8) { @@ -89,13 +125,36 @@ where F: Fn(Request) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { - http_inner(Arc::new(move |req| Box::new(Box::pin(func(req))))) + http_inner(Arc::new(move |req| Box::new(Box::pin(func(req)))), None) +} + +#[cfg(unix)] +pub fn http_unix(func: F) -> Server +where + F: Fn(Request) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, +{ + use rand::Rng; + let file_name: String = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(10) + .map(char::from) + .collect(); + let path = PathBuf::from(format!("/tmp/{file_name}.sock")); + if path.exists() { + std::fs::remove_file(&path).expect("could not remove old socket"); + } + + http_inner( + Arc::new(move |req| Box::new(Box::pin(func(req)))), + Some(path), + ) } type Serv = dyn Fn(Request) -> Box + Send + Sync; type ServFut = dyn Future> + Send + Unpin; -fn http_inner(func: Arc) -> Server { +fn http_inner(func: Arc, socket_path: Option) -> Server { // Spawn new runtime in thread to prevent reactor execution context conflict thread::spawn(move || { let rt = runtime::Builder::new_current_thread() @@ -104,12 +163,30 @@ fn http_inner(func: Arc) -> Server { .expect("new rt"); let successful_hits = Arc::new(Mutex::new(0)); let total_hits = Arc::new(Mutex::new(0)); - let listener = rt.block_on(async move { - tokio::net::TcpListener::bind(&std::net::SocketAddr::from(([127, 0, 0, 1], 0))) - .await - .unwrap() + + let (listener, addr) = rt.block_on(async move { + #[allow(unused_variables)] + if let Some(path) = &socket_path { + #[cfg(unix)] + { + let listener = tokio::net::UnixListener::bind(path).unwrap(); + let addr = listener.local_addr().unwrap(); + (Listener::UnixListener(listener), Addr::UnixAddr(addr)) + } + + #[cfg(not(unix))] + { + unreachable!("cannot create http_unix server outside of unix target_family") + } + } else { + let listener = + tokio::net::TcpListener::bind(&std::net::SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap(); + let addr = listener.local_addr().unwrap(); + (Listener::TcpListener(listener), Addr::TcpAddr(addr)) + } }); - let addr = listener.local_addr().unwrap(); let (shutdown_tx, shutdown_rx) = oneshot::channel(); let (panic_tx, panic_rx) = std_mpsc::channel(); @@ -145,14 +222,26 @@ fn http_inner(func: Arc) -> Server { }) }; - let (io, _) = listener.accept().await.unwrap(); - let builder = builder.clone(); - tokio::spawn(async move { - let _ = builder - .serve_connection(hyper_util::rt::TokioIo::new(io), svc) - .await; - }); + match &listener { + Listener::TcpListener(listener) => { + let (io, _) = listener.accept().await.unwrap(); + tokio::spawn(async move { + let _ = builder + .serve_connection(hyper_util::rt::TokioIo::new(io), svc) + .await; + }); + } + #[cfg(unix)] + Listener::UnixListener(listener) => { + let (io, _) = listener.accept().await.unwrap(); + tokio::spawn(async move { + let _ = builder + .serve_connection(hyper_util::rt::TokioIo::new(io), svc) + .await; + }); + } + } } }); let _ = rt.block_on(shutdown_rx);