Skip to content

Commit

Permalink
feat: hyper wip
Browse files Browse the repository at this point in the history
  • Loading branch information
junkurihara committed Jan 26, 2024
1 parent 359a246 commit d8ce19b
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 3 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ http = { version = "1.0.0" }
http-body = { version = "1.0.0" }
http-body-util = { version = "0.1.0" }
bytes = { version = "1.5.0" }

#
async-trait = "0.1.77"
url = "2.5.0"


[dev-dependencies]
Expand Down
160 changes: 160 additions & 0 deletions src/ext/hyper_http.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
use super::{ContentDigestType, CONTENT_DIGEST_HEADER};
use crate::{
message_component::{
DerivedComponentName, HttpMessageComponent, HttpMessageComponentIdentifier, HttpMessageComponentParam,
HttpMessageComponentValue,
},
signature_base::SignatureBase,
signature_params::HttpSignatureParams,
};
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine as _};
use bytes::{Buf, Bytes};
Expand Down Expand Up @@ -139,11 +147,163 @@ where
}
}

/* --------------------------------------- */
#[async_trait]
/// A trait to set the http message signature from given http signature params
pub trait HyperRequestMessageSignature {
type Error;
async fn set_message_signature(
&mut self,
signature_params: &HttpSignatureParams,
) -> std::result::Result<(), Self::Error>
where
Self: Sized;
}

#[async_trait]
impl<D> HyperRequestMessageSignature for Request<D>
where
D: Send + Body,
{
type Error = anyhow::Error;

async fn set_message_signature(
&mut self,
signature_params: &HttpSignatureParams,
) -> std::result::Result<(), Self::Error>
where
Self: Sized,
{
let component_lines = signature_params
.covered_components
.iter()
.map(|component_id_str| {
let component_id = HttpMessageComponentIdentifier::from(component_id_str.as_str());

extract_component_from_request(self, &component_id)
})
.collect::<Vec<_>>();

anyhow::ensure!(
component_lines.iter().all(|c| c.is_ok()),
"Failed to extract component lines"
);
let component_lines = component_lines.into_iter().map(|c| c.unwrap()).collect::<Vec<_>>();

let signature_base = SignatureBase::try_new(&component_lines, signature_params);

Ok(())
}
}
/* --------------------------------------- */
/// Extract http message component from hyper http request
fn extract_component_from_request<B>(
req: &Request<B>,
target_component_id: &HttpMessageComponentIdentifier,
) -> Result<HttpMessageComponent, anyhow::Error> {
let params = match &target_component_id {
HttpMessageComponentIdentifier::HttpField(field_id) => &field_id.params,
HttpMessageComponentIdentifier::Derived(derived_id) => &derived_id.params,
};
anyhow::ensure!(
!params.0.contains(&HttpMessageComponentParam::Req),
"`req` is not allowed in request"
);

let field_values = match &target_component_id {
HttpMessageComponentIdentifier::HttpField(field_id) => {
let field_values = req
.headers()
.get_all(&field_id.filed_name)
.iter()
.map(|v| v.to_str().unwrap().to_owned())
.collect::<Vec<_>>();
field_values
}
HttpMessageComponentIdentifier::Derived(derived_id) => {
let url = url::Url::parse(&req.uri().to_string())?;
let field_value = match derived_id.component_name {
DerivedComponentName::Method => req.method().to_string(),
DerivedComponentName::TargetUri => url.to_string(),
DerivedComponentName::Authority => url.authority().to_string(),
DerivedComponentName::Scheme => url.scheme().to_string(),
DerivedComponentName::RequestTarget => match *req.method() {
http::Method::CONNECT => url.authority().to_string(),
http::Method::OPTIONS => "*".to_string(),
_ => {
let mut base = url.path().to_string();
if let Some(query) = url.query() {
base.push_str(&format!("?{query}"));
}
base
}
},
DerivedComponentName::Path => url.path().to_string(),
DerivedComponentName::Query => format!("?{}", url.query().unwrap_or("")),
DerivedComponentName::QueryParam => {
let query_pairs = url.query_pairs().collect::<Vec<_>>();
println!("query_param: {:?}", query_pairs);

todo!("not implemented yet") // TODO: dict
}
_ => panic!("invalid derived component name for request"),
};
vec![field_value]
}
};

let component = HttpMessageComponent {
id: target_component_id.clone(),
value: HttpMessageComponentValue::from(""),
};
Ok(component)
}

/* --------------------------------------- */
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_extract_component_from_request() {
let req = Request::builder()
.method("GET")
.uri("https://example.com/parameters?var=this%20is%20a%20big%0Amultiline%20value&bar=with+plus+whitespace&fa%C3%A7ade%22%3A%20=something")
.header("date", "Sun, 09 May 2021 18:30:00 GMT")
.header("content-type", "application/json")
.body(())
.unwrap();

let component_id_method = HttpMessageComponentIdentifier::from("\"@method\"");
let component = extract_component_from_request(&req, &component_id_method).unwrap();
println!("{:?}", component);

let component_id_query_param = HttpMessageComponentIdentifier::from("\"@query-param\"");
let component = extract_component_from_request(&req, &component_id_query_param).unwrap();
println!("{:?}", component);
// let component = extract_component_from_request(&req, &component_id).unwrap();
// assert_eq!(component.id, component_id);
// assert_eq!(component.field_values, vec!["GET".to_string()]);

// let component_id = HttpMessageComponentIdentifier::from("\"date\"");
// let component = extract_component_from_request(&req, &component_id).unwrap();
// assert_eq!(component.id, component_id);
// assert_eq!(
// component.field_values,
// vec!["Sun, 09 May 2021 18:30:00 GMT".to_string()]
// );

// let component_id = HttpMessageComponentIdentifier::from("\"content-type\"");
// let component = extract_component_from_request(&req, &component_id).unwrap();
// assert_eq!(component.id, component_id);
// assert_eq!(component.field_values, vec!["application/json".to_string()]);

// let component_id = HttpMessageComponentIdentifier::from("\"@signature-params\"");
// let component = extract_component_from_request(&req, &component_id).unwrap();
// assert_eq!(component.id, component_id);
// assert_eq!(component.field_values, vec!["".to_string()]);
}

#[tokio::test]
async fn content_digest() {
let body = Full::new(&b"{\"hello\": \"world\"}"[..]);
Expand Down
25 changes: 23 additions & 2 deletions src/message_component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ pub(crate) enum HttpMessageComponentParam {
Tr,
// req: https://www.ietf.org/archive/id/draft-ietf-httpbis-message-signatures-19.html#section-2.4
Req,
// name: https://www.ietf.org/archive/id/draft-ietf-httpbis-message-signatures-19.html#name-query-parameters
Name(String),
}

impl From<HttpMessageComponentParam> for String {
Expand All @@ -131,6 +133,7 @@ impl From<HttpMessageComponentParam> for String {
HttpMessageComponentParam::Bs => "bs".to_string(),
HttpMessageComponentParam::Tr => "tr".to_string(),
HttpMessageComponentParam::Req => "req".to_string(),
HttpMessageComponentParam::Name(v) => format!("name=\"{v}\""),
}
}
}
Expand All @@ -144,6 +147,8 @@ impl From<&str> for HttpMessageComponentParam {
_ => {
if val.starts_with("key=\"") && val.ends_with('"') {
Self::Key(val[5..val.len() - 1].to_string())
} else if val.starts_with("name=\"") && val.ends_with('"') {
Self::Name(val[6..val.len() - 1].to_string())
} else {
panic!("Invalid http field param: {}", val)
}
Expand Down Expand Up @@ -228,10 +233,13 @@ impl From<&str> for DerivedComponentId {
fn from(val: &str) -> Self {
let mut iter = val.split(';');
let component = iter.next().unwrap();
// only `req` field param is allowed for derived components unlike general http fields
// only `req` field param for any or `name="xx"` for @query-params are allowed for derived components unlike general http fields
let params = iter
.map(HttpMessageComponentParam::from)
.filter(|v| matches!(v, HttpMessageComponentParam::Req))
.filter(|v| {
matches!(v, HttpMessageComponentParam::Req)
|| (matches!(v, HttpMessageComponentParam::Name(_)) && matches!(component, "\"@query-param\""))
})
.collect::<HashSet<_>>();
Self {
component_name: DerivedComponentName::from(component),
Expand Down Expand Up @@ -387,6 +395,19 @@ mod tests {
);
}

#[test]
fn test_query_params() {
let params = DerivedComponentId::from("\"@query-param\";name=\"test\"");
assert_eq!(params.component_name, DerivedComponentName::QueryParam);
assert_eq!(
params.params.0,
vec![HttpMessageComponentParam::Name("test".to_string())]
.into_iter()
.collect::<HashSet<_>>()
);
assert_eq!(params.to_string(), "\"@query-param\";name=\"test\"");
}

#[test]
fn test_http_general_field() {
let params = HttpFieldComponentId::from("\"example-header\";bs");
Expand Down
2 changes: 1 addition & 1 deletion src/signature_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{message_component::HttpMessageComponent, signature_params::HttpSigna

/// Signature Base
/// https://www.ietf.org/archive/id/draft-ietf-httpbis-message-signatures-19.html#section-2.5
struct SignatureBase {
pub(crate) struct SignatureBase {
/// HTTP message field and derived components ordered as in the vector in signature params
component_lines: Vec<HttpMessageComponent>,
/// signature params
Expand Down

0 comments on commit d8ce19b

Please sign in to comment.