Skip to content

Commit

Permalink
implement rfc status code details (#8)
Browse files Browse the repository at this point in the history
Co-authored-by: Shane Myrick <[email protected]>
  • Loading branch information
lleadbet and smyrick authored Sep 23, 2024
1 parent f38a052 commit 6dabbea
Showing 1 changed file with 154 additions and 5 deletions.
159 changes: 154 additions & 5 deletions src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use axum::{
Json,
};
use axum_macros::debug_handler;
use serde_json::json;
use serde_json::Value;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use tracing::debug;

Expand All @@ -18,6 +18,20 @@ pub struct EndpointHandler {
pub endpoint: Endpoint,
pub client: Client,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct ClientResponse {
#[serde(skip_serializing_if = "Option::is_none")]
data: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
errors: Option<Vec<ClientResponseError>>,
#[serde(skip_serializing_if = "Option::is_none")]
extensions: Option<Value>,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
struct ClientResponseError {
message: String,
}

#[debug_handler]
/// The handler function for the endpoints.
Expand Down Expand Up @@ -86,17 +100,27 @@ pub async fn handler(
debug!("Response: {:?}", resp);
debug!("Response headers: {:?}", resp.headers());

let status = resp.status();
let mut status = resp.status();
let mut headers = resp.headers().clone();

// Remove transfer-encoding header to prevent issues with gzip responses
headers.remove("transfer-encoding");

let json = resp.json::<Value>().await;
let json = resp.json::<ClientResponse>().await;
match json {
Ok(json) => {
debug!("JSON: {:?}", json);
(status, headers, Json(json))
if let Some(ref errors) = json.errors {
// If there are errors in the response, set the status to 500 if the response is 200 or 400; this prioritizes the status returned by the router in non-compliant situations
if status == StatusCode::OK && !errors.is_empty() {
status = StatusCode::INTERNAL_SERVER_ERROR;
// If there is data in the response, set the status to 206 to indicate partial content per RFC
if json.data.is_some() {
status = StatusCode::PARTIAL_CONTENT;
}
}
}
(status, headers, Json(json!(json)))
}
Err(e) => build_error_response(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
}
Expand Down Expand Up @@ -259,6 +283,131 @@ mod tests {
);
}

#[tokio::test]
async fn test_handler_returns_500() {
let server_body = json!({
"errors": [{"message": "test"}]
});
let mut server = mockito::Server::new_async().await;
let mock_endpoint = server
.mock("POST", "/")
.with_header("content-type", "application/json")
.match_body(mockito::Matcher::Json(json!({
"variables": {
"param1": "value1"
},
"extensions": {
"persistedQuery": {
"sha256Hash": "test",
"version": 1
}
}
})))
.with_body(server_body.to_string())
.create();

let endpoint = Endpoint {
path: "/test".to_string(),
pq_id: "test".to_string(),
method: crate::config::HttpMethod::GET,
path_arguments: None,
query_params: Some(vec![Parameter {
from: "param1".to_string(),
to: Some("param1".to_string()),
kind: ParamKind::STRING,
required: true,
}]),
body_params: None,
};
let client = Client::new(server.url().as_str());
let state = EndpointHandler { endpoint, client };

let query_parameters = vec![("param1".to_string(), "value1".to_string())]
.into_iter()
.collect();

let (response, body) = handler(
HeaderMap::new(),
Path(vec![].into_iter().collect()),
State(state),
Query(query_parameters),
None,
)
.await
.into_response()
.into_parts();

let body_bytes = to_bytes(body, usize::MAX).await.unwrap();
let body_string = std::str::from_utf8(&body_bytes).unwrap();

assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR);
mock_endpoint.assert();
assert_eq!(body_string, server_body.to_string());
}

#[tokio::test]
async fn test_handler_returns_206() {
let server_body = json!({
"data": "test",
"errors": [{"message": "test"}]
});
let mut server = mockito::Server::new_async().await;
let mock_endpoint = server
.mock("POST", "/")
.with_header("content-type", "application/json")
.match_body(mockito::Matcher::Json(json!({
"variables": {
"param1": "value1"
},
"extensions": {
"persistedQuery": {
"sha256Hash": "test",
"version": 1
}
}
})))
.with_body(server_body.to_string())
.create();

let endpoint = Endpoint {
path: "/test".to_string(),
pq_id: "test".to_string(),
method: crate::config::HttpMethod::GET,
path_arguments: None,
query_params: Some(vec![Parameter {
from: "param1".to_string(),
to: Some("param1".to_string()),
kind: ParamKind::STRING,
required: true,
}]),
body_params: None,
};
let client = Client::new(server.url().as_str());
let state = EndpointHandler { endpoint, client };

let query_parameters = vec![("param1".to_string(), "value1".to_string())]
.into_iter()
.collect();

let (response, body) = handler(
HeaderMap::new(),
Path(vec![].into_iter().collect()),
State(state),
Query(query_parameters),
None,
)
.await
.into_response()
.into_parts();

let body_bytes = to_bytes(body, usize::MAX).await.unwrap();
let body_string = std::str::from_utf8(&body_bytes).unwrap();

assert_eq!(response.status, StatusCode::PARTIAL_CONTENT);
mock_endpoint.assert();
assert_eq!(body_string, server_body.to_string());
}

#[tokio::test]
async fn test_handler_with_missing_required_parameter() {
let endpoint = Endpoint {
Expand Down

0 comments on commit 6dabbea

Please sign in to comment.