Skip to content

Commit

Permalink
Refactor response build in RSGI protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro committed Jan 16, 2023
1 parent e9925c9 commit e5139e2
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 100 deletions.
13 changes: 7 additions & 6 deletions src/rsgi/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
use super::{
errors::error_proto,
io::{RSGIHTTPProtocol as HTTPProtocol, RSGIWebsocketProtocol as WebsocketProtocol},
types::{RSGIScope as Scope, Response}
types::{RSGIScope as Scope, PyResponse, PyResponseBytes}
};


Expand Down Expand Up @@ -41,9 +41,10 @@ impl CallbackWatcherHTTP {
impl CallbackWatcherHTTP {
fn done(&mut self, py: Python) {
if let Ok(mut proto) = self.proto.as_ref(py).try_borrow_mut() {
if let (Some(tx), Some(mut res)) = proto.tx() {
res.error();
let _ = tx.send(res);
if let Some(tx) = proto.tx() {
let _ = tx.send(
PyResponse::Bytes(PyResponseBytes::empty(500, Vec::new()))
);
}
}
}
Expand Down Expand Up @@ -99,7 +100,7 @@ pub(crate) async fn call_rtb_http(
rt: RuntimeRef,
req: hyper::Request<hyper::Body>,
scope: Scope
) -> PyResult<Response> {
) -> PyResult<PyResponse> {
let callback = cb.callback.clone();
let (tx, rx) = oneshot::channel();
let protocol = HTTPProtocol::new(rt, tx, req);
Expand All @@ -124,7 +125,7 @@ pub(crate) async fn call_rtt_http(
rt: RuntimeRef,
req: hyper::Request<hyper::Body>,
scope: Scope
) -> PyResult<Response> {
) -> PyResult<PyResponse> {
let callback = cb.callback.clone();
let (tx, rx) = oneshot::channel();
let protocol = HTTPProtocol::new(rt, tx, req);
Expand Down
29 changes: 7 additions & 22 deletions src/rsgi/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ use hyper::{
http::response::Builder as ResponseBuilder
};
use std::net::SocketAddr;
use tokio::{fs::File, sync::mpsc};
use tokio_util::codec::{BytesCodec, FramedRead};
use tokio::sync::mpsc;

use crate::{
callbacks::CallbackWrapper,
Expand All @@ -18,16 +17,10 @@ use crate::{
};
use super::{
callbacks::{call_rtb_http, call_rtb_ws, call_rtt_http, call_rtt_ws},
types::{ResponseType, RSGIScope as Scope}
types::{RSGIScope as Scope, PyResponse}
};


async fn file_body(file_path: String) -> Body {
let file = File::open(file_path).await.unwrap();
let stream = FramedRead::new(file, BytesCodec::new());
Body::wrap_stream(stream)
}

macro_rules! default_scope {
($server_addr:expr, $client_addr:expr, $req:expr, $scheme:expr) => {
Scope::new(
Expand All @@ -46,19 +39,11 @@ macro_rules! default_scope {
macro_rules! handle_http_response {
($handler:expr, $rt:expr, $callback:expr, $req:expr, $scope:expr) => {
match $handler($callback, $rt, $req, $scope).await {
Ok(pyres) => {
let res = match pyres.mode {
ResponseType::Body => {
pyres.inner.body(pyres.body)
},
ResponseType::File => {
pyres.inner.body(file_body(pyres.file.unwrap()).await)
}
};
match res {
Ok(res) => res,
_ => response_500()
}
Ok(PyResponse::Bytes(pyres)) => {
pyres.to_response()
},
Ok(PyResponse::File(pyres)) => {
pyres.to_response().await
},
_ => response_500()
}
Expand Down
67 changes: 30 additions & 37 deletions src/rsgi/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,34 @@ use crate::{
runtime::{RuntimeRef, future_into_py},
ws::{HyperWebsocket, UpgradeData}
};
use super::{errors::{error_proto, error_stream}, types::{Response, ResponseType}};
use super::{
errors::{error_proto, error_stream},
types::{PyResponse, PyResponseBytes, PyResponseFile}
};


#[pyclass(module="granian._granian")]
pub(crate) struct RSGIHTTPProtocol {
rt: RuntimeRef,
tx: Option<oneshot::Sender<Response>>,
request: Arc<Mutex<Request<Body>>>,
response: Option<Response>
tx: Option<oneshot::Sender<super::types::PyResponse>>,
request: Arc<Mutex<Request<Body>>>
}

impl RSGIHTTPProtocol {
pub fn new(
rt: RuntimeRef,
tx: oneshot::Sender<Response>,
tx: oneshot::Sender<super::types::PyResponse>,
request: Request<Body>
) -> Self {
Self {
rt: rt,
tx: Some(tx),
request: Arc::new(Mutex::new(request)),
response: Some(Response::new())
request: Arc::new(Mutex::new(request))
}
}

pub fn tx(&mut self) -> (Option<oneshot::Sender<Response>>, Option<Response>) {
return (self.tx.take(), self.response.take())
pub fn tx(&mut self) -> Option<oneshot::Sender<super::types::PyResponse>> {
self.tx.take()
}
}

Expand All @@ -55,46 +56,38 @@ impl RSGIHTTPProtocol {
}

#[args(status="200", headers="vec![]")]
fn response_empty(&mut self, status: u16, headers: Vec<(&str, &str)>) {
if let Some(mut response) = self.response.take() {
response.head(status, &headers);
if let Some(tx) = self.tx.take() {
let _ = tx.send(response);
}
fn response_empty(&mut self, status: u16, headers: Vec<(String, String)>) {
if let Some(tx) = self.tx.take() {
let _ = tx.send(
PyResponse::Bytes(PyResponseBytes::empty(status, headers))
);
}
}

#[args(status="200", headers="vec![]")]
fn response_bytes(&mut self, status: u16, headers: Vec<(&str, &str)>, body: Vec<u8>) {
if let Some(mut response) = self.response.take() {
response.head(status, &headers);
response.body = Body::from(body);
if let Some(tx) = self.tx.take() {
let _ = tx.send(response);
}
fn response_bytes(&mut self, status: u16, headers: Vec<(String, String)>, body: Vec<u8>) {
if let Some(tx) = self.tx.take() {
let _ = tx.send(
PyResponse::Bytes(PyResponseBytes::from_bytes(status, headers, body))
);
}
}

#[args(status="200", headers="vec![]")]
fn response_str(&mut self, status: u16, headers: Vec<(&str, &str)>, body: String) {
if let Some(mut response) = self.response.take() {
response.head(status, &headers);
response.body = Body::from(body);
if let Some(tx) = self.tx.take() {
let _ = tx.send(response);
}
fn response_str(&mut self, status: u16, headers: Vec<(String, String)>, body: String) {
if let Some(tx) = self.tx.take() {
let _ = tx.send(
PyResponse::Bytes(PyResponseBytes::from_string(status, headers, body))
);
}
}

#[args(status="200", headers="vec![]")]
fn response_file(&mut self, status: u16, headers: Vec<(&str, &str)>, file: String) {
if let Some(mut response) = self.response.take() {
response.mode = ResponseType::File;
response.head(status, &headers);
response.file = Some(file);
if let Some(tx) = self.tx.take() {
let _ = tx.send(response);
}
fn response_file(&mut self, status: u16, headers: Vec<(String, String)>, file: String) {
if let Some(tx) = self.tx.take() {
let _ = tx.send(
PyResponse::File(PyResponseFile::new(status, headers, file))
);
}
}
}
Expand Down
105 changes: 70 additions & 35 deletions src/rsgi/types.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use hyper::{
header::{HeaderMap, HeaderName, HeaderValue, SERVER as HK_SERVER},
http::response::Builder as ResponseBuilder, Body, Uri, Version
Body, Uri, Version
};
use pyo3::prelude::*;
use pyo3::types::{PyString};
use pyo3::types::PyString;
use std::net::SocketAddr;
use tokio::fs::File;
use tokio_util::codec::{BytesCodec, FramedRead};

use crate::http::HV_SERVER;

Expand Down Expand Up @@ -138,51 +140,84 @@ impl RSGIScope {
}
}

#[derive(Debug)]
pub(crate) enum ResponseType {
Body = 1,
File = 10
pub(crate) enum PyResponse {
Bytes(PyResponseBytes),
File(PyResponseFile)
}

#[pyclass(frozen)]
#[derive(Debug)]
pub(crate) struct Response {
pub inner: ResponseBuilder,
pub mode: ResponseType,
pub body: Body,
pub file: Option<String>
pub(crate) struct PyResponseBytes {
status: u16,
headers: Vec<(String, String)>,
body: hyper::body::Bytes
}

impl Response {
pub fn new() -> Self {
pub(crate) struct PyResponseFile {
status: u16,
headers: Vec<(String, String)>,
file_path: String
}

macro_rules! response_head_from_py {
($status:expr, $headers:expr, $res:expr) => {
{
let mut rh = hyper::http::HeaderMap::new();

rh.insert(HK_SERVER, HV_SERVER);
for (key, value) in $headers {
rh.append(
HeaderName::from_bytes(key.as_bytes()).unwrap(),
HeaderValue::from_str(&value).unwrap()
);
}

*$res.status_mut() = $status.try_into().unwrap();
*$res.headers_mut() = rh;
}
}
}

impl PyResponseBytes {
pub fn empty(status: u16, headers: Vec<(String, String)>) -> Self {
Self {
inner: ResponseBuilder::new().status(200),
mode: ResponseType::Body,
body: Body::empty(),
file: None
status,
headers,
body: hyper::body::Bytes::new()
}
}

pub fn head(&mut self, status: u16, headers: &Vec<(&str, &str)>) {
match status {
200 => {},
_ => {
self.inner = ResponseBuilder::new().status(status);
}
pub fn from_bytes(status: u16, headers: Vec<(String, String)>, body: Vec<u8>) -> Self {
Self {
status,
headers,
body: hyper::body::Bytes::from(body)
}
}

let rh = self.inner.headers_mut().unwrap();
rh.insert(HK_SERVER, HV_SERVER);
for (key, value) in headers {
rh.append(
HeaderName::from_bytes(key.as_bytes()).unwrap(),
HeaderValue::from_str(value).unwrap()
);
pub fn from_string(status: u16, headers: Vec<(String, String)>, body: String) -> Self {
Self {
status,
headers,
body: hyper::body::Bytes::from(body)
}
}

pub fn error(&mut self) {
self.inner = ResponseBuilder::new().status(500);
self.body = Body::from("Internal server error");
pub fn to_response(&self) -> hyper::Response::<Body> {
let mut res = hyper::Response::<Body>::new(self.body.to_owned().into());
response_head_from_py!(self.status, &self.headers, res);
res
}
}

impl PyResponseFile {
pub fn new(status: u16, headers: Vec<(String, String)>, file_path: String) -> Self {
Self { status, headers, file_path }
}

pub async fn to_response(&self) -> hyper::Response::<Body> {
let file = File::open(&self.file_path).await.unwrap();
let stream = FramedRead::new(file, BytesCodec::new());
let mut res = hyper::Response::<Body>::new(Body::wrap_stream(stream));
response_head_from_py!(self.status, &self.headers, res);
res
}
}

0 comments on commit e5139e2

Please sign in to comment.