diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 5a5890cd..ea131d10 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -5,7 +5,7 @@ use std::io::{Read, Write}; use std::marker::PhantomData; use std::result::Result as StdResult; -use http::{HeaderMap, StatusCode}; +use http::{HeaderMap, Response, StatusCode}; use httparse::Status; use log::*; @@ -15,6 +15,8 @@ use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult}; use crate::error::{Error, Result}; use crate::protocol::{Role, WebSocket, WebSocketConfig}; +// TODO get rid of this too + /// Request from the client. #[derive(Debug)] pub struct Request { @@ -39,17 +41,15 @@ impl Request { Sec-WebSocket-Accept: {}\r\n", convert_key(key.as_bytes())? ); - add_headers(&mut reply, extra_headers)?; + add_headers(&mut reply, extra_headers.as_ref())?; Ok(reply.into()) } } -fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option) -> Result<()> { +fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<&HeaderMap>) -> Result<()> { if let Some(eh) = extra_headers { for (k, v) in eh { - if let Some(k) = k { - writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap(); - } + writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap(); } } writeln!(reply, "\r").unwrap(); @@ -85,27 +85,6 @@ impl<'h, 'b: 'h> FromHttparse> for Request { } } -/// An error response sent to the client. -#[derive(Debug)] -pub struct ErrorResponse { - /// HTTP error code. - pub error_code: StatusCode, - /// Extra response headers, if any. - pub headers: Option, - /// Response body, if any. - pub body: Option, -} - -impl From for ErrorResponse { - fn from(error_code: StatusCode) -> Self { - ErrorResponse { - error_code, - headers: None, - body: None, - } - } -} - /// The callback trait. /// /// The callback is called when the server receives an incoming WebSocket @@ -116,14 +95,20 @@ pub trait Callback: Sized { /// Called whenever the server read the request from the client and is ready to reply to it. /// May return additional reply headers. /// Returning an error resulting in rejecting the incoming connection. - fn on_request(self, request: &Request) -> StdResult, ErrorResponse>; + fn on_request( + self, + request: &Request, + ) -> StdResult, Response>>; } impl Callback for F where - F: FnOnce(&Request) -> StdResult, ErrorResponse>, + F: FnOnce(&Request) -> StdResult, Response>>, { - fn on_request(self, request: &Request) -> StdResult, ErrorResponse> { + fn on_request( + self, + request: &Request, + ) -> StdResult, Response>> { self(request) } } @@ -133,7 +118,10 @@ where pub struct NoCallback; impl Callback for NoCallback { - fn on_request(self, _request: &Request) -> StdResult, ErrorResponse> { + fn on_request( + self, + _request: &Request, + ) -> StdResult, Response>> { Ok(None) } } @@ -204,19 +192,22 @@ impl HandshakeRole for ServerHandshake { ProcessingResult::Continue(HandshakeMachine::start_write(stream, response)) } - Err(ErrorResponse { - error_code, - headers, - body, - }) => { - self.error_code = Some(error_code.as_u16()); + Err(resp) => { + if resp.status().is_success() { + return Err(Error::Protocol( + "Custom response must not be successful".into(), + )); + } + + self.error_code = Some(resp.status().as_u16()); let mut response = format!( - "HTTP/1.1 {} {}\r\n", - error_code.as_str(), - error_code.canonical_reason().unwrap_or("") + "{version:?} {status} {reason}\r\n", + version = resp.version(), + status = resp.status().as_u16(), + reason = resp.status().canonical_reason().unwrap_or("") ); - add_headers(&mut response, headers)?; - if let Some(body) = body { + add_headers(&mut response, Some(resp.headers()))?; + if let Some(body) = resp.body() { response += &body; } ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))