Skip to content

Commit

Permalink
Make into_websocket() send the request if not yet sent
Browse files Browse the repository at this point in the history
Rework the error type to also handle HTTP errors, while still allowing
recovery of the `Conn`.

Add a test case for recovering the `Conn` from an error and using that
to read an error message from the body.
  • Loading branch information
joshtriplett authored and jbr committed Jan 19, 2024
1 parent f4475fa commit 37fd185
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 40 deletions.
2 changes: 1 addition & 1 deletion client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ keywords = ["trillium", "framework", "async"]
categories = ["web-programming", "web-programming::http-client"]

[features]
websockets = ["dep:trillium-websockets"]
websockets = ["dep:trillium-websockets", "thiserror"]
json = ["serde_json", "serde", "thiserror"]

[dependencies]
Expand Down
93 changes: 61 additions & 32 deletions client/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,46 +14,51 @@ use crate::{Conn, WebSocketConfig, WebSocketConn};
pub use trillium_websockets::Message;

impl Conn {
/// Set the appropriate headers for upgrading to a WebSocket
pub fn with_websocket_upgrade_headers(self) -> Conn {
self.with_header(KnownHeaderName::Upgrade, "websocket")
.with_header(KnownHeaderName::Connection, "upgrade")
.with_header(KnownHeaderName::SecWebsocketVersion, "13")
.with_header(SecWebsocketKey, websocket_key())
fn set_websocket_upgrade_headers(&mut self) {
let h = self.request_headers_mut();
h.try_insert(KnownHeaderName::Upgrade, "websocket");
h.try_insert(KnownHeaderName::Connection, "upgrade");
h.try_insert(KnownHeaderName::SecWebsocketVersion, "13");
h.try_insert(SecWebsocketKey, websocket_key());
}

/// Turn this `Conn` into a [`WebSocketConn`]
///
/// If the request has not yet been sent, this will call `with_websocket_upgrade_headers()` and
/// then send the request.
pub async fn into_websocket(self) -> Result<WebSocketConn, WebSocketUpgradeError> {
self.into_websocket_with_config(WebSocketConfig::default())
.await
}

/// Turn this `Conn` into a [`WebSocketConn`], with a custom [`WebSocketConfig`]
///
/// If the request has not yet been sent, this will call `with_websocket_upgrade_headers()` and
/// then send the request.
pub async fn into_websocket_with_config(
self,
mut self,
config: WebSocketConfig,
) -> Result<WebSocketConn, WebSocketUpgradeError> {
let status = self
.status()
.expect("into_websocket() with request not yet sent; remember to call .await");
let status = match self.status() {
Some(status) => status,
None => {
self.set_websocket_upgrade_headers();
if let Err(e) = (&mut self).await {
return Err(WebSocketUpgradeError::new(self, e.into()));
}
self.status().expect("Response did not include status")
}
};
if status != Status::SwitchingProtocols {
return Err(WebSocketUpgradeError::new(
self,
"Expected status 101 (Switching Protocols)",
));
return Err(WebSocketUpgradeError::new(self, ErrorKind::Status(status)));
}
let Some(key) = self.request_headers().get_str(SecWebsocketKey) else {
return Err(WebSocketUpgradeError::new(
self,
"Request did not include Sec-WebSocket-Key",
));
};
let key = self
.request_headers()
.get_str(SecWebsocketKey)
.expect("Request did not include Sec-WebSocket-Key");
let accept_key = websocket_accept_hash(key);
if self.response_headers().get_str(SecWebsocketAccept) != Some(&accept_key) {
return Err(WebSocketUpgradeError::new(
self,
"Response did not contain valid Sec-WebSocket-Accept",
));
return Err(WebSocketUpgradeError::new(self, ErrorKind::InvalidAccept));
}
let peer_ip = self.peer_addr().map(|addr| addr.ip());
let mut conn = WebSocketConn::new(Upgrade::from(self), Some(config), Role::Client).await;
Expand All @@ -62,40 +67,64 @@ impl Conn {
}
}

/// The kind of error that occurred when attempting a websocket upgrade
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
/// An Error type that represents all exceptional conditions that can be encoutered in the operation
/// of this crate
pub enum ErrorKind {
/// an HTTP error attempting to make the request
#[error(transparent)]
Http(#[from] trillium_http::Error),

/// Response didn't have status 101 (Switching Protocols)
#[error("Expected status 101 (Switching Protocols), got {0}")]
Status(Status),

/// Response Sec-WebSocket-Accept was missing or invalid; generally a server bug
#[error("Response Sec-WebSocket-Accept was missing or invalid")]
InvalidAccept,
}

/// An attempted upgrade to a WebSocket failed. You can transform this back into the Conn with
/// [`From::from`]/[`Into::into`].
/// [`From::from`]/[`Into::into`], if you need to look at the server response.
#[derive(Debug)]
pub struct WebSocketUpgradeError(Box<Conn>, &'static str);
pub struct WebSocketUpgradeError {
/// The kind of error that occurred
pub kind: ErrorKind,
conn: Box<Conn>,
}

impl WebSocketUpgradeError {
fn new(conn: Conn, msg: &'static str) -> Self {
Self(Box::new(conn), msg)
fn new(conn: Conn, kind: ErrorKind) -> Self {
let conn = Box::new(conn);
Self { conn, kind }
}
}

impl From<WebSocketUpgradeError> for Conn {
fn from(value: WebSocketUpgradeError) -> Self {
*value.0
*value.conn
}
}

impl Deref for WebSocketUpgradeError {
type Target = Conn;

fn deref(&self) -> &Self::Target {
&self.0
&self.conn
}
}
impl DerefMut for WebSocketUpgradeError {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
&mut self.conn
}
}

impl std::error::Error for WebSocketUpgradeError {}

impl Display for WebSocketUpgradeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.1)
self.kind.fmt(f)
}
}
36 changes: 29 additions & 7 deletions client/tests/websocket.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use futures_lite::StreamExt;
use trillium_client::{websocket::Message, Client, WebSocketConn};
use trillium_client::{
websocket::{self, Message},
Client, WebSocketConn,
};
use trillium_http::Status;
use trillium_testing::ClientConfig;
use trillium_websockets::websocket;

Expand All @@ -16,12 +20,7 @@ fn test_websockets() {
let client = Client::new(ClientConfig::new());

trillium_testing::with_server(handler, move |url| async move {
let mut ws = client
.get(url)
.with_websocket_upgrade_headers()
.await?
.into_websocket()
.await?;
let mut ws = client.get(url).into_websocket().await?;

ws.send_string("Client test message".to_string()).await?;

Expand All @@ -35,3 +34,26 @@ fn test_websockets() {
Ok(())
})
}

#[test]
fn test_websockets_error() {
let handler =
|conn: trillium::Conn| async { conn.with_status(404).with_body("This does not exist") };
let client = Client::new(ClientConfig::new());
trillium_testing::with_server(handler, move |url| async move {
let err = client
.get(url)
.into_websocket()
.await
.expect_err("Expected a 404");
assert!(matches!(
err.kind,
websocket::ErrorKind::Status(Status::NotFound),
));
let mut conn = trillium_client::Conn::from(err);
let body = conn.response_body().read_string().await?;
assert_eq!(body, "This does not exist");

Ok(())
})
}

0 comments on commit 37fd185

Please sign in to comment.