diff --git a/src/handler.rs b/src/handler.rs index 1e62fb7..d80dcd4 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -8,8 +8,8 @@ use axum::{ Json, }; use axum_macros::debug_handler; -use serde_json::json; -use serde_json::Value; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; use std::collections::HashMap; use tracing::debug; @@ -18,6 +18,20 @@ pub struct EndpointHandler { pub endpoint: Endpoint, pub client: Client, } +#[derive(Serialize, Deserialize, Debug, Clone)] +struct ClientResponse { + #[serde(skip_serializing_if = "Option::is_none")] + data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + errors: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + extensions: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +struct ClientResponseError { + message: String, +} #[debug_handler] /// The handler function for the endpoints. @@ -86,17 +100,27 @@ pub async fn handler( debug!("Response: {:?}", resp); debug!("Response headers: {:?}", resp.headers()); - let status = resp.status(); + let mut status = resp.status(); let mut headers = resp.headers().clone(); // Remove transfer-encoding header to prevent issues with gzip responses headers.remove("transfer-encoding"); - let json = resp.json::().await; + let json = resp.json::().await; match json { Ok(json) => { debug!("JSON: {:?}", json); - (status, headers, Json(json)) + if let Some(ref errors) = json.errors { + // If there are errors in the response, set the status to 500 if the response is 200 or 400; this prioritizes the status returned by the router in non-compliant situations + if status == StatusCode::OK && !errors.is_empty() { + status = StatusCode::INTERNAL_SERVER_ERROR; + // If there is data in the response, set the status to 206 to indicate partial content per RFC + if json.data.is_some() { + status = StatusCode::PARTIAL_CONTENT; + } + } + } + (status, headers, Json(json!(json))) } Err(e) => build_error_response(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), } @@ -259,6 +283,131 @@ mod tests { ); } + #[tokio::test] + async fn test_handler_returns_500() { + let server_body = json!({ + "errors": [{"message": "test"}] + }); + let mut server = mockito::Server::new_async().await; + let mock_endpoint = server + .mock("POST", "/") + .with_header("content-type", "application/json") + .match_body(mockito::Matcher::Json(json!({ + "variables": { + "param1": "value1" + }, + "extensions": { + "persistedQuery": { + "sha256Hash": "test", + "version": 1 + } + } + }))) + .with_body(server_body.to_string()) + .create(); + + let endpoint = Endpoint { + path: "/test".to_string(), + pq_id: "test".to_string(), + method: crate::config::HttpMethod::GET, + path_arguments: None, + query_params: Some(vec![Parameter { + from: "param1".to_string(), + to: Some("param1".to_string()), + kind: ParamKind::STRING, + required: true, + }]), + body_params: None, + }; + let client = Client::new(server.url().as_str()); + let state = EndpointHandler { endpoint, client }; + + let query_parameters = vec![("param1".to_string(), "value1".to_string())] + .into_iter() + .collect(); + + let (response, body) = handler( + HeaderMap::new(), + Path(vec![].into_iter().collect()), + State(state), + Query(query_parameters), + None, + ) + .await + .into_response() + .into_parts(); + + let body_bytes = to_bytes(body, usize::MAX).await.unwrap(); + let body_string = std::str::from_utf8(&body_bytes).unwrap(); + + assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR); + mock_endpoint.assert(); + assert_eq!(body_string, server_body.to_string()); + } + + #[tokio::test] + async fn test_handler_returns_206() { + let server_body = json!({ + "data": "test", + "errors": [{"message": "test"}] + }); + let mut server = mockito::Server::new_async().await; + let mock_endpoint = server + .mock("POST", "/") + .with_header("content-type", "application/json") + .match_body(mockito::Matcher::Json(json!({ + "variables": { + "param1": "value1" + }, + "extensions": { + "persistedQuery": { + "sha256Hash": "test", + "version": 1 + } + } + }))) + .with_body(server_body.to_string()) + .create(); + + let endpoint = Endpoint { + path: "/test".to_string(), + pq_id: "test".to_string(), + method: crate::config::HttpMethod::GET, + path_arguments: None, + query_params: Some(vec![Parameter { + from: "param1".to_string(), + to: Some("param1".to_string()), + kind: ParamKind::STRING, + required: true, + }]), + body_params: None, + }; + let client = Client::new(server.url().as_str()); + let state = EndpointHandler { endpoint, client }; + + let query_parameters = vec![("param1".to_string(), "value1".to_string())] + .into_iter() + .collect(); + + let (response, body) = handler( + HeaderMap::new(), + Path(vec![].into_iter().collect()), + State(state), + Query(query_parameters), + None, + ) + .await + .into_response() + .into_parts(); + + let body_bytes = to_bytes(body, usize::MAX).await.unwrap(); + let body_string = std::str::from_utf8(&body_bytes).unwrap(); + + assert_eq!(response.status, StatusCode::PARTIAL_CONTENT); + mock_endpoint.assert(); + assert_eq!(body_string, server_body.to_string()); + } + #[tokio::test] async fn test_handler_with_missing_required_parameter() { let endpoint = Endpoint {