From f38a0523e1eb8d4b695ac8f2add0093f7d4dfe9e Mon Sep 17 00:00:00 2001 From: Lucas Leadbetter <5595530+lleadbet@users.noreply.github.com> Date: Tue, 17 Sep 2024 18:40:37 -0400 Subject: [PATCH] add support for body-based parameters (#6) Co-authored-by: Shane Myrick --- src/config/mod.rs | 2 + src/graphql_request/mod.rs | 6 +- src/handler.rs | 196 +++++++++++++++++++++++++++++++++++-- 3 files changed, 194 insertions(+), 10 deletions(-) diff --git a/src/config/mod.rs b/src/config/mod.rs index 7518c4e..cd96739 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -73,6 +73,8 @@ pub struct Endpoint { pub query_params: Option>, /// The path arguments that the endpoint should accept pub path_arguments: Option>, + /// The body parameters that the endpoint should accept + pub body_params: Option>, } #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, JsonSchema)] diff --git a/src/graphql_request/mod.rs b/src/graphql_request/mod.rs index 9f5de32..d19ad2a 100644 --- a/src/graphql_request/mod.rs +++ b/src/graphql_request/mod.rs @@ -60,6 +60,9 @@ impl Client { // Remove the host header to prevent issues with the proxy request_headers.remove("host"); + // Remove the content-length header to prevent issues with the proxy for POST requests + request_headers.remove("content-length"); + request = request.headers(request_headers.clone()); debug!("Request Headers: {:?}", request_headers); debug!("Making request to: {}", &self.url); @@ -77,7 +80,7 @@ impl Client { match serde_json::to_string(&body) { Ok(json) => { - debug!("JSON: {:?}", json); + debug!("Request JSON: {:?}", json); request = request.body(json); } Err(e) => return Err(Box::from(e.to_string().as_str())), @@ -123,6 +126,7 @@ mod tests { pq_id: "test".to_string(), path_arguments: None, query_params: None, + body_params: None, }; let response = client diff --git a/src/handler.rs b/src/handler.rs index 098c360..1e62fb7 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -2,7 +2,7 @@ use crate::config::Parameter; use crate::{config::Endpoint, graphql_request::Client}; use axum::http::StatusCode; use axum::{ - extract::{Path, Query, State}, + extract::{Json as ExtractJson, Path, Query, State}, http::HeaderMap, response::IntoResponse, Json, @@ -27,6 +27,7 @@ pub async fn handler( Path(path_parameters): Path>, State(state): State, Query(query_parameters): Query>, + body: Option>, ) -> impl IntoResponse { let mut request_variables = HashMap::::new(); @@ -48,6 +49,32 @@ pub async fn handler( request_variables.extend(path_variables); + // If the endpoint is configured to use the request body as variables, parse the body + let bp = match body { + Some(body) => { + let mut m = HashMap::::new(); + // convert the body to a hashmap of strings since we really only care about top level keys at the moment + for (key, val) in body.as_object().unwrap() { + let value = val.as_str(); + match value { + Some(v) => { + m.insert(key.clone(), v.to_string()); + } + None => { + m.insert(key.clone(), val.to_string()); + } + } + } + m + } + None => HashMap::::new(), + }; + let body_params = match parse_parameters(bp, state.endpoint.body_params.clone()) { + Ok(p) => p, + Err(e) => return build_error_response(StatusCode::BAD_REQUEST, e), + }; + request_variables.extend(body_params); + debug!("Request Parameters: {:?}", request_variables); let response = state .client @@ -97,18 +124,16 @@ fn build_error_response( } fn parse_parameters( - parameters: HashMap, + request_parameters: HashMap, config_parameters: Option>, ) -> Result, String> { - let mut request_parameters = HashMap::::new(); + let mut parameters = HashMap::::new(); if let Some(params) = config_parameters { for param in params { - if parameters.contains_key(param.from.clone().as_str()) { - if let Some(value) = parameters.get(param.from.as_str()) { + if request_parameters.contains_key(param.from.clone().as_str()) { + if let Some(value) = request_parameters.get(param.from.as_str()) { match param.kind.clone().from_str(value) { - Ok(p) => { - request_parameters.insert(param.to.unwrap_or(param.from.clone()), p) - } + Ok(p) => parameters.insert(param.to.unwrap_or(param.from.clone()), p), Err(e) => return Err(e.to_string()), }; } @@ -117,7 +142,7 @@ fn parse_parameters( } } } - Ok(request_parameters) + Ok(parameters) } #[cfg(test)] mod tests { @@ -129,6 +154,45 @@ mod tests { use crate::config::{Endpoint, ParamKind, Parameter}; use crate::Client; + #[tokio::test] + async fn test_parse_parameters() { + let request_parameters = vec![("param1".to_string(), "value1".to_string())] + .into_iter() + .collect(); + + let config_parameters = Some(vec![Parameter { + from: "param1".to_string(), + to: Some("param1".to_string()), + kind: ParamKind::STRING, + required: true, + }]); + + let result = parse_parameters(request_parameters, config_parameters); + assert!(result.is_ok()); + let result = result.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!( + result.get("param1").unwrap(), + &Value::String("value1".to_string()) + ); + } + + #[tokio::test] + async fn test_parse_parameters_with_missing_required_parameter() { + let request_parameters = HashMap::new(); + + let config_parameters = Some(vec![Parameter { + from: "param1".to_string(), + to: Some("param1".to_string()), + kind: ParamKind::STRING, + required: true, + }]); + + let result = parse_parameters(request_parameters, config_parameters); + assert!(result.is_err()); + assert_eq!(result.err().unwrap(), "Missing required parameter: param1"); + } + #[tokio::test] async fn test_handler_with_valid_parameters() { let mut server = mockito::Server::new_async().await; @@ -160,6 +224,7 @@ mod tests { kind: ParamKind::STRING, required: true, }]), + body_params: None, }; let client = Client::new(server.url().as_str()); @@ -174,6 +239,7 @@ mod tests { Path(vec![].into_iter().collect()), State(state), Query(query_parameters), + None, ) .await .into_response() @@ -206,6 +272,7 @@ mod tests { kind: ParamKind::STRING, required: true, }]), + body_params: None, }; let client = Client::new(""); @@ -222,6 +289,7 @@ mod tests { Path(path_parameters), State(state), Query(query_parameters), + None, ) .await .into_response() @@ -253,6 +321,7 @@ mod tests { required: true, }]), query_params: None, + body_params: None, }; let client = Client::new(""); @@ -269,6 +338,7 @@ mod tests { Path(path_parameters), State(state), Query(query_parameters), + None, ) .await .into_response() @@ -287,4 +357,112 @@ mod tests { .to_string() ); } + + #[tokio::test] + async fn test_handler_with_body_params() { + let mut server = mockito::Server::new_async().await; + let mock_endpoint = server + .mock("POST", "/") + .with_header("content-type", "application/json") + .match_body(mockito::Matcher::Json(json!({ + "variables": { + "param1": "value1" + }, + "extensions": { + "persistedQuery": { + "sha256Hash": "test", + "version": 1 + } + } + }))) + .with_body(json!({"data": "test"}).to_string()) + .create(); + + let endpoint = Endpoint { + method: crate::config::HttpMethod::POST, + path: "/test".to_string(), + pq_id: "test".to_string(), + path_arguments: None, + query_params: None, + body_params: Some(vec![Parameter { + from: "param1".to_string(), + to: Some("param1".to_string()), + kind: ParamKind::STRING, + required: true, + }]), + }; + + let client = Client::new(server.url().as_str()); + let state = EndpointHandler { endpoint, client }; + let query_parameters = HashMap::new(); + + let (response, body) = handler( + HeaderMap::new(), + Path(vec![].into_iter().collect()), + State(state), + Query(query_parameters), + Some(Json(json!({"param1": "value1"}))), + ) + .await + .into_response() + .into_parts(); + + mock_endpoint.assert(); + assert_eq!(response.status, StatusCode::OK); + + let body_bytes = to_bytes(body, usize::MAX).await.unwrap(); + let body_string = std::str::from_utf8(&body_bytes).unwrap(); + assert_eq!( + body_string, + json!({ + "data": "test" + }) + .to_string() + ); + } + + #[tokio::test] + async fn test_handler_with_missing_body_params() { + let endpoint = Endpoint { + method: crate::config::HttpMethod::POST, + path: "/test".to_string(), + pq_id: "test".to_string(), + path_arguments: None, + query_params: None, + body_params: Some(vec![Parameter { + from: "param1".to_string(), + to: Some("param1".to_string()), + kind: ParamKind::STRING, + required: true, + }]), + }; + + let client = Client::new(""); + let state = EndpointHandler { endpoint, client }; + let query_parameters = HashMap::new(); + + let (response, body) = handler( + HeaderMap::new(), + Path(vec![].into_iter().collect()), + State(state), + Query(query_parameters), + None, + ) + .await + .into_response() + .into_parts(); + + assert_eq!(response.status, StatusCode::BAD_REQUEST); + + let body_bytes = to_bytes(body, usize::MAX).await.unwrap(); + let body_string = std::str::from_utf8(&body_bytes).unwrap(); + assert_eq!( + body_string, + json!({ + "errors": [{"message": "Missing required parameter: param1"}], + "data": null + }) + .to_string() + ); + } }