Skip to content

Commit

Permalink
Remove custom Request/Response from server code
Browse files Browse the repository at this point in the history
  • Loading branch information
sdroege committed Nov 23, 2019
1 parent 0178f28 commit 0f2f0d9
Showing 1 changed file with 33 additions and 42 deletions.
75 changes: 33 additions & 42 deletions src/handshake/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::io::{Read, Write};
use std::marker::PhantomData;
use std::result::Result as StdResult;

use http::{HeaderMap, StatusCode};
use http::{HeaderMap, Response, StatusCode};
use httparse::Status;
use log::*;

Expand All @@ -15,6 +15,8 @@ use super::{convert_key, HandshakeRole, MidHandshake, ProcessingResult};
use crate::error::{Error, Result};
use crate::protocol::{Role, WebSocket, WebSocketConfig};

// TODO get rid of this too

/// Request from the client.
#[derive(Debug)]
pub struct Request {
Expand All @@ -39,17 +41,15 @@ impl Request {
Sec-WebSocket-Accept: {}\r\n",
convert_key(key.as_bytes())?
);
add_headers(&mut reply, extra_headers)?;
add_headers(&mut reply, extra_headers.as_ref())?;
Ok(reply.into())
}
}

fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<HeaderMap>) -> Result<()> {
fn add_headers(reply: &mut impl FmtWrite, extra_headers: Option<&HeaderMap>) -> Result<()> {
if let Some(eh) = extra_headers {
for (k, v) in eh {
if let Some(k) = k {
writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap();
}
writeln!(reply, "{}: {}\r", k, v.to_str()?).unwrap();
}
}
writeln!(reply, "\r").unwrap();
Expand Down Expand Up @@ -85,27 +85,6 @@ impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
}
}

/// An error response sent to the client.
#[derive(Debug)]
pub struct ErrorResponse {
/// HTTP error code.
pub error_code: StatusCode,
/// Extra response headers, if any.
pub headers: Option<HeaderMap>,
/// Response body, if any.
pub body: Option<String>,
}

impl From<StatusCode> for ErrorResponse {
fn from(error_code: StatusCode) -> Self {
ErrorResponse {
error_code,
headers: None,
body: None,
}
}
}

/// The callback trait.
///
/// The callback is called when the server receives an incoming WebSocket
Expand All @@ -116,14 +95,20 @@ pub trait Callback: Sized {
/// Called whenever the server read the request from the client and is ready to reply to it.
/// May return additional reply headers.
/// Returning an error resulting in rejecting the incoming connection.
fn on_request(self, request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse>;
fn on_request(
self,
request: &Request,
) -> StdResult<Option<HeaderMap>, Response<Option<String>>>;
}

impl<F> Callback for F
where
F: FnOnce(&Request) -> StdResult<Option<HeaderMap>, ErrorResponse>,
F: FnOnce(&Request) -> StdResult<Option<HeaderMap>, Response<Option<String>>>,
{
fn on_request(self, request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse> {
fn on_request(
self,
request: &Request,
) -> StdResult<Option<HeaderMap>, Response<Option<String>>> {
self(request)
}
}
Expand All @@ -133,7 +118,10 @@ where
pub struct NoCallback;

impl Callback for NoCallback {
fn on_request(self, _request: &Request) -> StdResult<Option<HeaderMap>, ErrorResponse> {
fn on_request(
self,
_request: &Request,
) -> StdResult<Option<HeaderMap>, Response<Option<String>>> {
Ok(None)
}
}
Expand Down Expand Up @@ -204,19 +192,22 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
}

Err(ErrorResponse {
error_code,
headers,
body,
}) => {
self.error_code = Some(error_code.as_u16());
Err(resp) => {
if resp.status().is_success() {
return Err(Error::Protocol(
"Custom response must not be successful".into(),
));
}

self.error_code = Some(resp.status().as_u16());
let mut response = format!(
"HTTP/1.1 {} {}\r\n",
error_code.as_str(),
error_code.canonical_reason().unwrap_or("")
"{version:?} {status} {reason}\r\n",
version = resp.version(),
status = resp.status().as_u16(),
reason = resp.status().canonical_reason().unwrap_or("")
);
add_headers(&mut response, headers)?;
if let Some(body) = body {
add_headers(&mut response, Some(resp.headers()))?;
if let Some(body) = resp.body() {
response += &body;
}
ProcessingResult::Continue(HandshakeMachine::start_write(stream, response))
Expand Down

0 comments on commit 0f2f0d9

Please sign in to comment.