Skip to content

Commit

Permalink
add support for body-based parameters (#6)
Browse files Browse the repository at this point in the history
Co-authored-by: Shane Myrick <[email protected]>
  • Loading branch information
lleadbet and smyrick authored Sep 17, 2024
1 parent 0c0e558 commit f38a052
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ pub struct Endpoint {
pub query_params: Option<Vec<Parameter>>,
/// The path arguments that the endpoint should accept
pub path_arguments: Option<Vec<Parameter>>,
/// The body parameters that the endpoint should accept
pub body_params: Option<Vec<Parameter>>,
}

#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, JsonSchema)]
Expand Down
6 changes: 5 additions & 1 deletion src/graphql_request/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ impl Client {
// Remove the host header to prevent issues with the proxy
request_headers.remove("host");

// Remove the content-length header to prevent issues with the proxy for POST requests
request_headers.remove("content-length");

request = request.headers(request_headers.clone());
debug!("Request Headers: {:?}", request_headers);
debug!("Making request to: {}", &self.url);
Expand All @@ -77,7 +80,7 @@ impl Client {

match serde_json::to_string(&body) {
Ok(json) => {
debug!("JSON: {:?}", json);
debug!("Request JSON: {:?}", json);
request = request.body(json);
}
Err(e) => return Err(Box::from(e.to_string().as_str())),
Expand Down Expand Up @@ -123,6 +126,7 @@ mod tests {
pq_id: "test".to_string(),
path_arguments: None,
query_params: None,
body_params: None,
};

let response = client
Expand Down
196 changes: 187 additions & 9 deletions src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::config::Parameter;
use crate::{config::Endpoint, graphql_request::Client};
use axum::http::StatusCode;
use axum::{
extract::{Path, Query, State},
extract::{Json as ExtractJson, Path, Query, State},
http::HeaderMap,
response::IntoResponse,
Json,
Expand All @@ -27,6 +27,7 @@ pub async fn handler(
Path(path_parameters): Path<HashMap<String, String>>,
State(state): State<EndpointHandler>,
Query(query_parameters): Query<HashMap<String, String>>,
body: Option<ExtractJson<Value>>,
) -> impl IntoResponse {
let mut request_variables = HashMap::<String, Value>::new();

Expand All @@ -48,6 +49,32 @@ pub async fn handler(

request_variables.extend(path_variables);

// If the endpoint is configured to use the request body as variables, parse the body
let bp = match body {
Some(body) => {
let mut m = HashMap::<String, String>::new();
// convert the body to a hashmap of strings since we really only care about top level keys at the moment
for (key, val) in body.as_object().unwrap() {
let value = val.as_str();
match value {
Some(v) => {
m.insert(key.clone(), v.to_string());
}
None => {
m.insert(key.clone(), val.to_string());
}
}
}
m
}
None => HashMap::<String, String>::new(),
};
let body_params = match parse_parameters(bp, state.endpoint.body_params.clone()) {
Ok(p) => p,
Err(e) => return build_error_response(StatusCode::BAD_REQUEST, e),
};
request_variables.extend(body_params);

debug!("Request Parameters: {:?}", request_variables);
let response = state
.client
Expand Down Expand Up @@ -97,18 +124,16 @@ fn build_error_response(
}

fn parse_parameters(
parameters: HashMap<String, String>,
request_parameters: HashMap<String, String>,
config_parameters: Option<Vec<Parameter>>,
) -> Result<HashMap<String, Value>, String> {
let mut request_parameters = HashMap::<String, Value>::new();
let mut parameters = HashMap::<String, Value>::new();
if let Some(params) = config_parameters {
for param in params {
if parameters.contains_key(param.from.clone().as_str()) {
if let Some(value) = parameters.get(param.from.as_str()) {
if request_parameters.contains_key(param.from.clone().as_str()) {
if let Some(value) = request_parameters.get(param.from.as_str()) {
match param.kind.clone().from_str(value) {
Ok(p) => {
request_parameters.insert(param.to.unwrap_or(param.from.clone()), p)
}
Ok(p) => parameters.insert(param.to.unwrap_or(param.from.clone()), p),
Err(e) => return Err(e.to_string()),
};
}
Expand All @@ -117,7 +142,7 @@ fn parse_parameters(
}
}
}
Ok(request_parameters)
Ok(parameters)
}
#[cfg(test)]
mod tests {
Expand All @@ -129,6 +154,45 @@ mod tests {
use crate::config::{Endpoint, ParamKind, Parameter};
use crate::Client;

#[tokio::test]
async fn test_parse_parameters() {
let request_parameters = vec![("param1".to_string(), "value1".to_string())]
.into_iter()
.collect();

let config_parameters = Some(vec![Parameter {
from: "param1".to_string(),
to: Some("param1".to_string()),
kind: ParamKind::STRING,
required: true,
}]);

let result = parse_parameters(request_parameters, config_parameters);
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(
result.get("param1").unwrap(),
&Value::String("value1".to_string())
);
}

#[tokio::test]
async fn test_parse_parameters_with_missing_required_parameter() {
let request_parameters = HashMap::new();

let config_parameters = Some(vec![Parameter {
from: "param1".to_string(),
to: Some("param1".to_string()),
kind: ParamKind::STRING,
required: true,
}]);

let result = parse_parameters(request_parameters, config_parameters);
assert!(result.is_err());
assert_eq!(result.err().unwrap(), "Missing required parameter: param1");
}

#[tokio::test]
async fn test_handler_with_valid_parameters() {
let mut server = mockito::Server::new_async().await;
Expand Down Expand Up @@ -160,6 +224,7 @@ mod tests {
kind: ParamKind::STRING,
required: true,
}]),
body_params: None,
};

let client = Client::new(server.url().as_str());
Expand All @@ -174,6 +239,7 @@ mod tests {
Path(vec![].into_iter().collect()),
State(state),
Query(query_parameters),
None,
)
.await
.into_response()
Expand Down Expand Up @@ -206,6 +272,7 @@ mod tests {
kind: ParamKind::STRING,
required: true,
}]),
body_params: None,
};

let client = Client::new("");
Expand All @@ -222,6 +289,7 @@ mod tests {
Path(path_parameters),
State(state),
Query(query_parameters),
None,
)
.await
.into_response()
Expand Down Expand Up @@ -253,6 +321,7 @@ mod tests {
required: true,
}]),
query_params: None,
body_params: None,
};

let client = Client::new("");
Expand All @@ -269,6 +338,7 @@ mod tests {
Path(path_parameters),
State(state),
Query(query_parameters),
None,
)
.await
.into_response()
Expand All @@ -287,4 +357,112 @@ mod tests {
.to_string()
);
}

#[tokio::test]
async fn test_handler_with_body_params() {
let mut server = mockito::Server::new_async().await;
let mock_endpoint = server
.mock("POST", "/")
.with_header("content-type", "application/json")
.match_body(mockito::Matcher::Json(json!({
"variables": {
"param1": "value1"
},
"extensions": {
"persistedQuery": {
"sha256Hash": "test",
"version": 1
}
}
})))
.with_body(json!({"data": "test"}).to_string())
.create();

let endpoint = Endpoint {
method: crate::config::HttpMethod::POST,
path: "/test".to_string(),
pq_id: "test".to_string(),
path_arguments: None,
query_params: None,
body_params: Some(vec![Parameter {
from: "param1".to_string(),
to: Some("param1".to_string()),
kind: ParamKind::STRING,
required: true,
}]),
};

let client = Client::new(server.url().as_str());
let state = EndpointHandler { endpoint, client };
let query_parameters = HashMap::new();

let (response, body) = handler(
HeaderMap::new(),
Path(vec![].into_iter().collect()),
State(state),
Query(query_parameters),
Some(Json(json!({"param1": "value1"}))),
)
.await
.into_response()
.into_parts();

mock_endpoint.assert();
assert_eq!(response.status, StatusCode::OK);

let body_bytes = to_bytes(body, usize::MAX).await.unwrap();
let body_string = std::str::from_utf8(&body_bytes).unwrap();
assert_eq!(
body_string,
json!({
"data": "test"
})
.to_string()
);
}

#[tokio::test]
async fn test_handler_with_missing_body_params() {
let endpoint = Endpoint {
method: crate::config::HttpMethod::POST,
path: "/test".to_string(),
pq_id: "test".to_string(),
path_arguments: None,
query_params: None,
body_params: Some(vec![Parameter {
from: "param1".to_string(),
to: Some("param1".to_string()),
kind: ParamKind::STRING,
required: true,
}]),
};

let client = Client::new("");
let state = EndpointHandler { endpoint, client };
let query_parameters = HashMap::new();

let (response, body) = handler(
HeaderMap::new(),
Path(vec![].into_iter().collect()),
State(state),
Query(query_parameters),
None,
)
.await
.into_response()
.into_parts();

assert_eq!(response.status, StatusCode::BAD_REQUEST);

let body_bytes = to_bytes(body, usize::MAX).await.unwrap();
let body_string = std::str::from_utf8(&body_bytes).unwrap();
assert_eq!(
body_string,
json!({
"errors": [{"message": "Missing required parameter: param1"}],
"data": null
})
.to_string()
);
}
}

0 comments on commit f38a052

Please sign in to comment.