From 37fd1857fe3fb83700202ccafe9f6cd8fabcb635 Mon Sep 17 00:00:00 2001 From: Josh Triplett Date: Thu, 18 Jan 2024 18:07:02 -0800 Subject: [PATCH] Make `into_websocket()` send the request if not yet sent Rework the error type to also handle HTTP errors, while still allowing recovery of the `Conn`. Add a test case for recovering the `Conn` from an error and using that to read an error message from the body. --- client/Cargo.toml | 2 +- client/src/websocket.rs | 93 +++++++++++++++++++++++++-------------- client/tests/websocket.rs | 36 ++++++++++++--- 3 files changed, 91 insertions(+), 40 deletions(-) diff --git a/client/Cargo.toml b/client/Cargo.toml index 963eb5c59e..d571874072 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -11,7 +11,7 @@ keywords = ["trillium", "framework", "async"] categories = ["web-programming", "web-programming::http-client"] [features] -websockets = ["dep:trillium-websockets"] +websockets = ["dep:trillium-websockets", "thiserror"] json = ["serde_json", "serde", "thiserror"] [dependencies] diff --git a/client/src/websocket.rs b/client/src/websocket.rs index 7c05001620..46ef031e84 100644 --- a/client/src/websocket.rs +++ b/client/src/websocket.rs @@ -14,46 +14,51 @@ use crate::{Conn, WebSocketConfig, WebSocketConn}; pub use trillium_websockets::Message; impl Conn { - /// Set the appropriate headers for upgrading to a WebSocket - pub fn with_websocket_upgrade_headers(self) -> Conn { - self.with_header(KnownHeaderName::Upgrade, "websocket") - .with_header(KnownHeaderName::Connection, "upgrade") - .with_header(KnownHeaderName::SecWebsocketVersion, "13") - .with_header(SecWebsocketKey, websocket_key()) + fn set_websocket_upgrade_headers(&mut self) { + let h = self.request_headers_mut(); + h.try_insert(KnownHeaderName::Upgrade, "websocket"); + h.try_insert(KnownHeaderName::Connection, "upgrade"); + h.try_insert(KnownHeaderName::SecWebsocketVersion, "13"); + h.try_insert(SecWebsocketKey, websocket_key()); } /// Turn this `Conn` into a [`WebSocketConn`] + /// + /// If the request has not yet been sent, this will call `with_websocket_upgrade_headers()` and + /// then send the request. pub async fn into_websocket(self) -> Result { self.into_websocket_with_config(WebSocketConfig::default()) .await } /// Turn this `Conn` into a [`WebSocketConn`], with a custom [`WebSocketConfig`] + /// + /// If the request has not yet been sent, this will call `with_websocket_upgrade_headers()` and + /// then send the request. pub async fn into_websocket_with_config( - self, + mut self, config: WebSocketConfig, ) -> Result { - let status = self - .status() - .expect("into_websocket() with request not yet sent; remember to call .await"); + let status = match self.status() { + Some(status) => status, + None => { + self.set_websocket_upgrade_headers(); + if let Err(e) = (&mut self).await { + return Err(WebSocketUpgradeError::new(self, e.into())); + } + self.status().expect("Response did not include status") + } + }; if status != Status::SwitchingProtocols { - return Err(WebSocketUpgradeError::new( - self, - "Expected status 101 (Switching Protocols)", - )); + return Err(WebSocketUpgradeError::new(self, ErrorKind::Status(status))); } - let Some(key) = self.request_headers().get_str(SecWebsocketKey) else { - return Err(WebSocketUpgradeError::new( - self, - "Request did not include Sec-WebSocket-Key", - )); - }; + let key = self + .request_headers() + .get_str(SecWebsocketKey) + .expect("Request did not include Sec-WebSocket-Key"); let accept_key = websocket_accept_hash(key); if self.response_headers().get_str(SecWebsocketAccept) != Some(&accept_key) { - return Err(WebSocketUpgradeError::new( - self, - "Response did not contain valid Sec-WebSocket-Accept", - )); + return Err(WebSocketUpgradeError::new(self, ErrorKind::InvalidAccept)); } let peer_ip = self.peer_addr().map(|addr| addr.ip()); let mut conn = WebSocketConn::new(Upgrade::from(self), Some(config), Role::Client).await; @@ -62,20 +67,44 @@ impl Conn { } } +/// The kind of error that occurred when attempting a websocket upgrade +#[derive(thiserror::Error, Debug)] +#[non_exhaustive] +/// An Error type that represents all exceptional conditions that can be encoutered in the operation +/// of this crate +pub enum ErrorKind { + /// an HTTP error attempting to make the request + #[error(transparent)] + Http(#[from] trillium_http::Error), + + /// Response didn't have status 101 (Switching Protocols) + #[error("Expected status 101 (Switching Protocols), got {0}")] + Status(Status), + + /// Response Sec-WebSocket-Accept was missing or invalid; generally a server bug + #[error("Response Sec-WebSocket-Accept was missing or invalid")] + InvalidAccept, +} + /// An attempted upgrade to a WebSocket failed. You can transform this back into the Conn with -/// [`From::from`]/[`Into::into`]. +/// [`From::from`]/[`Into::into`], if you need to look at the server response. #[derive(Debug)] -pub struct WebSocketUpgradeError(Box, &'static str); +pub struct WebSocketUpgradeError { + /// The kind of error that occurred + pub kind: ErrorKind, + conn: Box, +} impl WebSocketUpgradeError { - fn new(conn: Conn, msg: &'static str) -> Self { - Self(Box::new(conn), msg) + fn new(conn: Conn, kind: ErrorKind) -> Self { + let conn = Box::new(conn); + Self { conn, kind } } } impl From for Conn { fn from(value: WebSocketUpgradeError) -> Self { - *value.0 + *value.conn } } @@ -83,12 +112,12 @@ impl Deref for WebSocketUpgradeError { type Target = Conn; fn deref(&self) -> &Self::Target { - &self.0 + &self.conn } } impl DerefMut for WebSocketUpgradeError { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 + &mut self.conn } } @@ -96,6 +125,6 @@ impl std::error::Error for WebSocketUpgradeError {} impl Display for WebSocketUpgradeError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(self.1) + self.kind.fmt(f) } } diff --git a/client/tests/websocket.rs b/client/tests/websocket.rs index 00988d9ff4..0052563ada 100644 --- a/client/tests/websocket.rs +++ b/client/tests/websocket.rs @@ -1,5 +1,9 @@ use futures_lite::StreamExt; -use trillium_client::{websocket::Message, Client, WebSocketConn}; +use trillium_client::{ + websocket::{self, Message}, + Client, WebSocketConn, +}; +use trillium_http::Status; use trillium_testing::ClientConfig; use trillium_websockets::websocket; @@ -16,12 +20,7 @@ fn test_websockets() { let client = Client::new(ClientConfig::new()); trillium_testing::with_server(handler, move |url| async move { - let mut ws = client - .get(url) - .with_websocket_upgrade_headers() - .await? - .into_websocket() - .await?; + let mut ws = client.get(url).into_websocket().await?; ws.send_string("Client test message".to_string()).await?; @@ -35,3 +34,26 @@ fn test_websockets() { Ok(()) }) } + +#[test] +fn test_websockets_error() { + let handler = + |conn: trillium::Conn| async { conn.with_status(404).with_body("This does not exist") }; + let client = Client::new(ClientConfig::new()); + trillium_testing::with_server(handler, move |url| async move { + let err = client + .get(url) + .into_websocket() + .await + .expect_err("Expected a 404"); + assert!(matches!( + err.kind, + websocket::ErrorKind::Status(Status::NotFound), + )); + let mut conn = trillium_client::Conn::from(err); + let body = conn.response_body().read_string().await?; + assert_eq!(body, "This does not exist"); + + Ok(()) + }) +}