diff --git a/Cargo.toml b/Cargo.toml index 42cc51ab..3152ba22 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,7 +50,8 @@ serde_json = { version = "1.0.51", optional = true } serde_crate = { version = "1.0.106", features = ["derive"], optional = true, package = "serde" } serde_urlencoded = { version = "0.7.0", optional = true} serde_qs = { version = "0.9.1", optional = true } - +miette = "3" +thiserror = "1.0.29" [dev-dependencies] http = "0.2.0" diff --git a/src/auth/authentication_scheme.rs b/src/auth/authentication_scheme.rs index 52b85b7f..9fbeb42d 100644 --- a/src/auth/authentication_scheme.rs +++ b/src/auth/authentication_scheme.rs @@ -1,7 +1,7 @@ use std::fmt::{self, Display}; use std::str::FromStr; -use crate::bail_status as bail; +use crate::errors::AuthError; /// HTTP Mutual Authentication Algorithms #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -47,7 +47,7 @@ impl Display for AuthenticationScheme { } impl FromStr for AuthenticationScheme { - type Err = crate::Error; + type Err = AuthError; fn from_str(s: &str) -> Result { // NOTE(yosh): matching here is lowercase as specified by RFC2617#section-1.2 @@ -65,7 +65,7 @@ impl FromStr for AuthenticationScheme { "scram-sha-1" => Ok(Self::ScramSha1), "scram-sha-256" => Ok(Self::ScramSha256), "vapid" => Ok(Self::Vapid), - s => bail!(400, "`{}` is not a recognized authentication scheme", s), + s => Err(AuthError::SchemeUnrecognized(s.to_string())), } } } diff --git a/src/auth/authorization.rs b/src/auth/authorization.rs index cf2f2ddb..ee59fd31 100644 --- a/src/auth/authorization.rs +++ b/src/auth/authorization.rs @@ -1,5 +1,5 @@ use crate::auth::AuthenticationScheme; -use crate::bail_status as bail; +use crate::errors::AuthError; use crate::headers::{Header, HeaderName, HeaderValue, Headers, AUTHORIZATION}; /// Credentials to authenticate a user agent with a server. @@ -60,8 +60,8 @@ impl Authorization { let scheme = iter.next(); let credential = iter.next(); let (scheme, credentials) = match (scheme, credential) { - (None, _) => bail!(400, "Could not find scheme"), - (Some(_), None) => bail!(400, "Could not find credentials"), + (None, _) => return Err(AuthError::SchemeMissing.into()), + (Some(_), None) => return Err(AuthError::CredentialsMissing.into()), (Some(scheme), Some(credentials)) => (scheme.parse()?, credentials.to_owned()), }; @@ -107,8 +107,10 @@ impl Header for Authorization { #[cfg(test)] mod test { - use super::*; use crate::headers::Headers; + use crate::StatusCode; + + use super::*; #[test] fn smoke() -> crate::Result<()> { @@ -133,6 +135,6 @@ mod test { .insert(AUTHORIZATION, "") .unwrap(); let err = Authorization::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } } diff --git a/src/auth/basic_auth.rs b/src/auth/basic_auth.rs index 664cfcb5..b6781432 100644 --- a/src/auth/basic_auth.rs +++ b/src/auth/basic_auth.rs @@ -1,10 +1,9 @@ +use crate::errors::AuthError; use crate::headers::{HeaderName, HeaderValue, Headers, AUTHORIZATION}; -use crate::Status; use crate::{ auth::{AuthenticationScheme, Authorization}, headers::Header, }; -use crate::{bail_status as bail, ensure_status as ensure}; /// HTTP Basic authorization. /// @@ -60,19 +59,21 @@ impl BasicAuth { }; let scheme = auth.scheme(); - ensure!( + internal_ensure!( matches!(scheme, AuthenticationScheme::Basic), - 400, - "Expected basic auth scheme found `{}`", - scheme + AuthError::SchemeUnexpected(AuthenticationScheme::Basic, scheme.to_string()) ); Self::from_credentials(auth.credentials()).map(Some) } /// Create a new instance from the base64 encoded credentials. pub fn from_credentials(credentials: impl AsRef<[u8]>) -> crate::Result { - let bytes = base64::decode(credentials).status(400)?; - let credentials = String::from_utf8(bytes).status(400)?; + let bytes = base64::decode(credentials).map_err(|_| { + AuthError::CredentialsInvalid(AuthenticationScheme::Basic, "invalid base64") + })?; + let credentials = String::from_utf8(bytes).map_err(|_| { + AuthError::CredentialsInvalid(AuthenticationScheme::Basic, "invalid utf8 from base64") + })?; let mut iter = credentials.splitn(2, ':'); let username = iter.next(); @@ -80,8 +81,20 @@ impl BasicAuth { let (username, password) = match (username, password) { (Some(username), Some(password)) => (username.to_string(), password.to_string()), - (Some(_), None) => bail!(400, "Expected basic auth to contain a password"), - (None, _) => bail!(400, "Expected basic auth to contain a username"), + (Some(_), None) => { + return Err(AuthError::CredentialsInvalid( + AuthenticationScheme::Basic, + "missing password", + ) + .into()) + } + (None, _) => { + return Err(AuthError::CredentialsInvalid( + AuthenticationScheme::Basic, + "missing username", + ) + .into()) + } }; Ok(Self { username, password }) @@ -113,8 +126,10 @@ impl Header for BasicAuth { #[cfg(test)] mod test { - use super::*; use crate::headers::Headers; + use crate::StatusCode; + + use super::*; #[test] fn smoke() -> crate::Result<()> { @@ -139,6 +154,6 @@ mod test { .insert(AUTHORIZATION, "") .unwrap(); let err = BasicAuth::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } } diff --git a/src/auth/www_authenticate.rs b/src/auth/www_authenticate.rs index bf9d7a05..baad13cf 100644 --- a/src/auth/www_authenticate.rs +++ b/src/auth/www_authenticate.rs @@ -1,4 +1,4 @@ -use crate::bail_status as bail; +use crate::errors::{AuthError, HeaderError}; use crate::headers::{HeaderName, HeaderValue, Headers, WWW_AUTHENTICATE}; use crate::{auth::AuthenticationScheme, headers::Header}; @@ -63,15 +63,15 @@ impl WwwAuthenticate { let scheme = iter.next(); let credential = iter.next(); let (scheme, realm) = match (scheme, credential) { - (None, _) => bail!(400, "Could not find scheme"), - (Some(_), None) => bail!(400, "Could not find realm"), + (None, _) => return Err(AuthError::SchemeMissing.into()), + (Some(_), None) => return Err(AuthError::RealmMissing.into()), (Some(scheme), Some(realm)) => (scheme.parse()?, realm.to_owned()), }; let realm = realm.trim_start(); let realm = match realm.strip_prefix(r#"realm=""#) { Some(realm) => realm, - None => bail!(400, "realm not found"), + None => return Err(AuthError::RealmMissing.into()), }; let mut chars = realm.chars(); @@ -87,7 +87,7 @@ impl WwwAuthenticate { }) .collect(); if !closing_quote { - bail!(400, r"Expected a closing quote"); + return Err(HeaderError::WWWAuthenticateInvalid("Expected a closing quote").into()); } Ok(Some(Self { scheme, realm })) @@ -129,8 +129,10 @@ impl Header for WwwAuthenticate { #[cfg(test)] mod test { - use super::*; use crate::headers::Headers; + use crate::StatusCode; + + use super::*; #[test] fn smoke() -> crate::Result<()> { @@ -160,6 +162,6 @@ mod test { .insert(WWW_AUTHENTICATE, "") .unwrap(); let err = WwwAuthenticate::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } } diff --git a/src/body.rs b/src/body.rs index ac1acfc7..0238b6d0 100644 --- a/src/body.rs +++ b/src/body.rs @@ -7,8 +7,8 @@ use std::fmt::{self, Debug}; use std::pin::Pin; use std::task::{Context, Poll}; +use crate::errors::BodyError; use crate::mime::{self, Mime}; -use crate::{Status, StatusCode}; pin_project_lite::pin_project! { /// A streaming HTTP body. @@ -177,9 +177,7 @@ impl Body { /// ``` pub async fn into_bytes(mut self) -> crate::Result> { let mut buf = Vec::with_capacity(1024); - self.read_to_end(&mut buf) - .await - .status(StatusCode::UnprocessableEntity)?; + self.read_to_end(&mut buf).await?; Ok(buf) } @@ -225,11 +223,10 @@ impl Body { /// # Ok(()) }) } /// ``` pub async fn into_string(mut self) -> crate::Result { - let len = usize::try_from(self.len().unwrap_or(0)).status(StatusCode::PayloadTooLarge)?; + let len = usize::try_from(self.len().unwrap_or(0)) + .map_err(|_| BodyError::PayloadTooLarge(self.len()))?; let mut result = String::with_capacity(len); - self.read_to_string(&mut result) - .await - .status(StatusCode::UnprocessableEntity)?; + self.read_to_string(&mut result).await?; Ok(result) } @@ -249,7 +246,7 @@ impl Body { /// ``` #[cfg(feature = "serde")] pub fn from_json(json: &impl Serialize) -> crate::Result { - let bytes = serde_json::to_vec(&json)?; + let bytes = serde_json::to_vec(&json).map_err(BodyError::SerializeJSON)?; let body = Self { length: Some(bytes.len() as u64), reader: Box::new(io::Cursor::new(bytes)), @@ -283,7 +280,7 @@ impl Body { pub async fn into_json(mut self) -> crate::Result { let mut buf = Vec::with_capacity(1024); self.read_to_end(&mut buf).await?; - serde_json::from_slice(&buf).status(StatusCode::UnprocessableEntity) + serde_json::from_slice(&buf).map_err(|e| BodyError::DeserializeJSON(e).into()) } /// Creates a `Body` from a type, serializing it using form encoding. @@ -316,7 +313,7 @@ impl Body { /// ``` #[cfg(feature = "serde")] pub fn from_form(form: &impl Serialize) -> crate::Result { - let query = serde_urlencoded::to_string(form)?; + let query = serde_urlencoded::to_string(form).map_err(BodyError::SerializeForm)?; let bytes = query.into_bytes(); let body = Self { @@ -356,7 +353,7 @@ impl Body { #[cfg(feature = "serde")] pub async fn into_form(self) -> crate::Result { let s = self.into_string().await?; - serde_urlencoded::from_str(&s).status(StatusCode::UnprocessableEntity) + serde_urlencoded::from_str(&s).map_err(|e| BodyError::DeserializeForm(e).into()) } /// Create a `Body` from a file named by a path. @@ -636,10 +633,13 @@ fn guess_ext(path: &std::path::Path) -> Option { #[cfg(test)] mod test { - use super::*; use async_std::io::Cursor; use serde_crate::Deserialize; + use crate::StatusCode; + + use super::*; + #[async_std::test] async fn json_status() { #[derive(Debug, Deserialize)] @@ -650,7 +650,10 @@ mod test { } let body = Body::empty(); let res = body.into_json::().await; - assert_eq!(res.unwrap_err().status(), 422); + assert_eq!( + res.unwrap_err().associated_status_code(), + Some(StatusCode::UnprocessableEntity) + ); } #[async_std::test] @@ -663,10 +666,16 @@ mod test { } let body = Body::empty(); let res = body.into_form::().await; - assert_eq!(res.unwrap_err().status(), 422); + assert_eq!( + res.unwrap_err().associated_status_code(), + Some(StatusCode::UnprocessableEntity) + ); } - async fn read_with_buffers_of_size(reader: &mut R, size: usize) -> crate::Result + async fn read_with_buffers_of_size( + reader: &mut R, + size: usize, + ) -> crate::ResponseResult where R: AsyncRead + Unpin, { @@ -681,7 +690,7 @@ mod test { } #[async_std::test] - async fn attempting_to_read_past_length() -> crate::Result<()> { + async fn attempting_to_read_past_length() -> crate::ResponseResult<()> { for buf_len in 1..13 { let mut body = Body::from_reader(Cursor::new("hello world"), Some(5)); assert_eq!( @@ -695,7 +704,7 @@ mod test { } #[async_std::test] - async fn attempting_to_read_when_length_is_greater_than_content() -> crate::Result<()> { + async fn attempting_to_read_when_length_is_greater_than_content() -> crate::ResponseResult<()> { for buf_len in 1..13 { let mut body = Body::from_reader(Cursor::new("hello world"), Some(15)); assert_eq!( @@ -709,7 +718,7 @@ mod test { } #[async_std::test] - async fn attempting_to_read_when_length_is_exactly_right() -> crate::Result<()> { + async fn attempting_to_read_when_length_is_exactly_right() -> crate::ResponseResult<()> { for buf_len in 1..13 { let mut body = Body::from_reader(Cursor::new("hello world"), Some(11)); assert_eq!( @@ -723,7 +732,8 @@ mod test { } #[async_std::test] - async fn reading_in_various_buffer_lengths_when_there_is_no_length() -> crate::Result<()> { + async fn reading_in_various_buffer_lengths_when_there_is_no_length() -> crate::ResponseResult<()> + { for buf_len in 1..13 { let mut body = Body::from_reader(Cursor::new("hello world"), None); assert_eq!( @@ -737,7 +747,7 @@ mod test { } #[async_std::test] - async fn chain_strings() -> crate::Result<()> { + async fn chain_strings() -> crate::ResponseResult<()> { for buf_len in 1..13 { let mut body = Body::from("hello ").chain(Body::from("world")); assert_eq!(body.len(), Some(11)); @@ -753,7 +763,7 @@ mod test { } #[async_std::test] - async fn chain_mixed_bytes_string() -> crate::Result<()> { + async fn chain_mixed_bytes_string() -> crate::ResponseResult<()> { for buf_len in 1..13 { let mut body = Body::from(&b"hello "[..]).chain(Body::from("world")); assert_eq!(body.len(), Some(11)); @@ -769,7 +779,7 @@ mod test { } #[async_std::test] - async fn chain_mixed_reader_string() -> crate::Result<()> { + async fn chain_mixed_reader_string() -> crate::ResponseResult<()> { for buf_len in 1..13 { let mut body = Body::from_reader(Cursor::new("hello "), Some(6)).chain(Body::from("world")); @@ -786,7 +796,7 @@ mod test { } #[async_std::test] - async fn chain_mixed_nolen_len() -> crate::Result<()> { + async fn chain_mixed_nolen_len() -> crate::ResponseResult<()> { for buf_len in 1..13 { let mut body = Body::from_reader(Cursor::new("hello "), None).chain(Body::from("world")); @@ -803,7 +813,7 @@ mod test { } #[async_std::test] - async fn chain_mixed_len_nolen() -> crate::Result<()> { + async fn chain_mixed_len_nolen() -> crate::ResponseResult<()> { for buf_len in 1..13 { let mut body = Body::from("hello ").chain(Body::from_reader(Cursor::new("world"), None)); @@ -820,7 +830,7 @@ mod test { } #[async_std::test] - async fn chain_short() -> crate::Result<()> { + async fn chain_short() -> crate::ResponseResult<()> { for buf_len in 1..26 { let mut body = Body::from_reader(Cursor::new("hello xyz"), Some(6)) .chain(Body::from_reader(Cursor::new("world abc"), Some(5))); @@ -837,7 +847,7 @@ mod test { } #[async_std::test] - async fn chain_many() -> crate::Result<()> { + async fn chain_many() -> crate::ResponseResult<()> { for buf_len in 1..13 { let mut body = Body::from("hello") .chain(Body::from(&b" "[..])) @@ -855,7 +865,7 @@ mod test { } #[async_std::test] - async fn chain_skip_start() -> crate::Result<()> { + async fn chain_skip_start() -> crate::ResponseResult<()> { for buf_len in 1..26 { let mut body1 = Body::from_reader(Cursor::new("1234 hello xyz"), Some(11)); let mut buf = vec![0; 5]; diff --git a/src/cache/age.rs b/src/cache/age.rs index 0ee78a00..bcf089a2 100644 --- a/src/cache/age.rs +++ b/src/cache/age.rs @@ -1,5 +1,5 @@ +use crate::errors::HeaderError; use crate::headers::{Header, HeaderName, HeaderValue, Headers, AGE}; -use crate::Status; use std::fmt::Debug; @@ -62,7 +62,10 @@ impl Age { // entry. We want the last entry. let header = headers.iter().last().unwrap(); - let num: u64 = header.as_str().parse::().status(400)?; + let num: u64 = header + .as_str() + .parse::() + .map_err(|_| HeaderError::AgeInvalid)?; let dur = Duration::from_secs_f64(num as f64); Ok(Some(Self { dur })) @@ -84,8 +87,10 @@ impl Header for Age { #[cfg(test)] mod test { - use super::*; use crate::headers::Headers; + use crate::StatusCode; + + use super::*; #[test] fn smoke() -> crate::Result<()> { @@ -104,6 +109,6 @@ mod test { let mut headers = Headers::new(); headers.insert(AGE, "").unwrap(); let err = Age::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } } diff --git a/src/cache/cache_control/cache_directive.rs b/src/cache/cache_control/cache_directive.rs index 5e8b0b98..3cae858c 100644 --- a/src/cache/cache_control/cache_directive.rs +++ b/src/cache/cache_control/cache_directive.rs @@ -1,5 +1,5 @@ +use crate::errors::HeaderError; use crate::headers::HeaderValue; -use crate::Status; use std::time::Duration; @@ -93,8 +93,10 @@ impl CacheDirective { let next = parts.next().unwrap(); let mut get_dur = || -> crate::Result { - let dur = parts.next().status(400)?; - let dur: u64 = dur.parse::().status(400)?; + let dur = parts.next().ok_or(HeaderError::CacheControlInvalid)?; + let dur: u64 = dur + .parse::() + .map_err(|_| HeaderError::CacheControlInvalid)?; Ok(Duration::new(dur, 0)) }; @@ -112,7 +114,9 @@ impl CacheDirective { "max-age" => Some(MaxAge(get_dur()?)), "max-stale" => match parts.next() { Some(secs) => { - let dur: u64 = secs.parse::().status(400)?; + let dur: u64 = secs + .parse::() + .map_err(|_| HeaderError::CacheControlInvalid)?; Some(MaxStale(Some(Duration::new(dur, 0)))) } None => Some(MaxStale(None)), diff --git a/src/cache/cache_control/mod.rs b/src/cache/cache_control/mod.rs index f9150363..7453b12c 100644 --- a/src/cache/cache_control/mod.rs +++ b/src/cache/cache_control/mod.rs @@ -15,8 +15,10 @@ pub use cache_directive::CacheDirective; #[cfg(test)] mod test { - use super::*; use crate::headers::{Header, Headers, CACHE_CONTROL}; + use crate::StatusCode; + + use super::*; #[test] fn smoke() -> crate::Result<()> { @@ -49,6 +51,6 @@ mod test { let mut headers = Headers::new(); headers.insert(CACHE_CONTROL, "min-fresh=0.9").unwrap(); // floats are not supported let err = CacheControl::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } } diff --git a/src/cache/clear_site_data/mod.rs b/src/cache/clear_site_data/mod.rs index bb38ab80..b6579d17 100644 --- a/src/cache/clear_site_data/mod.rs +++ b/src/cache/clear_site_data/mod.rs @@ -1,6 +1,7 @@ //! Clear browsing data (cookies, storage, cache) associated with the //! requesting website +use crate::errors::HeaderError; use crate::headers::{self, HeaderName, HeaderValue, Headers, CLEAR_SITE_DATA}; use std::fmt::{self, Debug, Write}; @@ -75,7 +76,9 @@ impl ClearSiteData { wildcard = true; continue; } - entries.push(ClearDirective::from_str(part)?); + entries.push( + ClearDirective::from_str(part).map_err(HeaderError::ClearSiteDataInvalid)?, + ); } } diff --git a/src/cache/expires.rs b/src/cache/expires.rs index 89c4e6f4..a0cce659 100644 --- a/src/cache/expires.rs +++ b/src/cache/expires.rs @@ -13,7 +13,7 @@ use std::time::{Duration, SystemTime}; /// # Examples /// /// ``` -/// # fn main() -> http_types::Result<()> { +/// # fn main() -> anyhow::Result<()> { /// # /// use http_types::Response; /// use http_types::cache::Expires; @@ -85,11 +85,13 @@ impl Header for Expires { #[cfg(test)] mod test { - use super::*; use crate::headers::Headers; + use crate::StatusCode; + + use super::*; #[test] - fn smoke() -> crate::Result<()> { + fn smoke() -> anyhow::Result<()> { let time = SystemTime::now() + Duration::from_secs(5 * 60); let expires = Expires::new_at(time); @@ -109,6 +111,6 @@ mod test { let mut headers = Headers::new(); headers.insert(EXPIRES, "").unwrap(); let err = Expires::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } } diff --git a/src/conditional/etag.rs b/src/conditional/etag.rs index c7245f86..73d7fa60 100644 --- a/src/conditional/etag.rs +++ b/src/conditional/etag.rs @@ -1,5 +1,5 @@ +use crate::errors::HeaderError; use crate::headers::{Header, HeaderName, HeaderValue, Headers, ETAG}; -use crate::{Error, StatusCode}; use std::fmt::{self, Debug, Display}; @@ -89,22 +89,14 @@ impl ETag { let s = match s.strip_prefix('"').and_then(|s| s.strip_suffix('"')) { Some(s) => s.to_owned(), - None => { - return Err(Error::from_str( - StatusCode::BadRequest, - "Invalid ETag header", - )) - } + None => return Err(HeaderError::ETagInvalid.into()), }; if !s .bytes() .all(|c| c == 0x21 || (0x23..=0x7E).contains(&c) || c >= 0x80) { - return Err(Error::from_str( - StatusCode::BadRequest, - "Invalid ETag header", - )); + return Err(HeaderError::ETagInvalid.into()); } let etag = if weak { Self::Weak(s) } else { Self::Strong(s) }; @@ -134,8 +126,10 @@ impl Display for ETag { #[cfg(test)] mod test { - use super::*; use crate::headers::Headers; + use crate::StatusCode; + + use super::*; #[test] fn smoke() -> crate::Result<()> { @@ -166,15 +160,21 @@ mod test { let mut headers = Headers::new(); headers.insert(ETAG, "").unwrap(); let err = ETag::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } #[test] fn validate_quotes() { - assert_entry_err(r#""hello"#, "Invalid ETag header"); - assert_entry_err(r#"hello""#, "Invalid ETag header"); - assert_entry_err(r#"/O"valid content""#, "Invalid ETag header"); - assert_entry_err(r#"/Wvalid content""#, "Invalid ETag header"); + assert_entry_err(r#""hello"#, "Header error: ETag header was invalid"); + assert_entry_err(r#"hello""#, "Header error: ETag header was invalid"); + assert_entry_err( + r#"/O"valid content""#, + "Header error: ETag header was invalid", + ); + assert_entry_err( + r#"/Wvalid content""#, + "Header error: ETag header was invalid", + ); } fn assert_entry_err(s: &str, msg: &str) { @@ -186,7 +186,7 @@ mod test { #[test] fn validate_characters() { - assert_entry_err(r#"""hello""#, "Invalid ETag header"); - assert_entry_err("\"hello\x7F\"", "Invalid ETag header"); + assert_entry_err(r#"""hello""#, "Header error: ETag header was invalid"); + assert_entry_err("\"hello\x7F\"", "Header error: ETag header was invalid"); } } diff --git a/src/conditional/if_modified_since.rs b/src/conditional/if_modified_since.rs index e5566a12..2721852b 100644 --- a/src/conditional/if_modified_since.rs +++ b/src/conditional/if_modified_since.rs @@ -15,7 +15,7 @@ use std::time::SystemTime; /// # Examples /// /// ``` -/// # fn main() -> http_types::Result<()> { +/// # fn main() -> anyhow::Result<()> { /// # /// use http_types::Response; /// use http_types::conditional::IfModifiedSince; @@ -81,12 +81,15 @@ impl Header for IfModifiedSince { #[cfg(test)] mod test { - use super::*; - use crate::headers::Headers; use std::time::Duration; + use crate::headers::Headers; + use crate::StatusCode; + + use super::*; + #[test] - fn smoke() -> crate::Result<()> { + fn smoke() -> anyhow::Result<()> { let time = SystemTime::now() + Duration::from_secs(5 * 60); let expires = IfModifiedSince::new(time); @@ -108,6 +111,6 @@ mod test { .insert(IF_MODIFIED_SINCE, "") .unwrap(); let err = IfModifiedSince::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } } diff --git a/src/conditional/if_unmodified_since.rs b/src/conditional/if_unmodified_since.rs index 2d964271..338a8d81 100644 --- a/src/conditional/if_unmodified_since.rs +++ b/src/conditional/if_unmodified_since.rs @@ -15,7 +15,7 @@ use std::time::SystemTime; /// # Examples /// /// ``` -/// # fn main() -> http_types::Result<()> { +/// # fn main() -> anyhow::Result<()> { /// # /// use http_types::Response; /// use http_types::conditional::IfUnmodifiedSince; @@ -81,12 +81,15 @@ impl Header for IfUnmodifiedSince { #[cfg(test)] mod test { - use super::*; - use crate::headers::Headers; use std::time::Duration; + use crate::headers::Headers; + use crate::StatusCode; + + use super::*; + #[test] - fn smoke() -> crate::Result<()> { + fn smoke() -> anyhow::Result<()> { let time = SystemTime::now() + Duration::from_secs(5 * 60); let expires = IfUnmodifiedSince::new(time); @@ -108,6 +111,6 @@ mod test { .insert(IF_UNMODIFIED_SINCE, "") .unwrap(); let err = IfUnmodifiedSince::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } } diff --git a/src/conditional/last_modified.rs b/src/conditional/last_modified.rs index dc71fc1b..0010a8a6 100644 --- a/src/conditional/last_modified.rs +++ b/src/conditional/last_modified.rs @@ -14,7 +14,7 @@ use std::time::SystemTime; /// # Examples /// /// ``` -/// # fn main() -> http_types::Result<()> { +/// # fn main() -> anyhow::Result<()> { /// # /// use http_types::Response; /// use http_types::conditional::LastModified; @@ -80,12 +80,15 @@ impl Header for LastModified { #[cfg(test)] mod test { - use super::*; - use crate::headers::Headers; use std::time::Duration; + use crate::headers::Headers; + use crate::StatusCode; + + use super::*; + #[test] - fn smoke() -> crate::Result<()> { + fn smoke() -> anyhow::Result<()> { let time = SystemTime::now() + Duration::from_secs(5 * 60); let last_modified = LastModified::new(time); @@ -107,6 +110,6 @@ mod test { .insert(LAST_MODIFIED, "") .unwrap(); let err = LastModified::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } } diff --git a/src/content/accept.rs b/src/content/accept.rs index 87dad909..68b4ac35 100644 --- a/src/content/accept.rs +++ b/src/content/accept.rs @@ -1,5 +1,6 @@ //! Client header advertising which media types the client is able to understand. +use crate::errors::HeaderError; use crate::headers::{HeaderName, HeaderValue, Headers, ACCEPT}; use crate::mime::Mime; use crate::utils::sort_by_weight; @@ -7,7 +8,6 @@ use crate::{ content::{ContentType, MediaTypeProposal}, headers::Header, }; -use crate::{Error, StatusCode}; use std::fmt::{self, Debug, Write}; @@ -30,7 +30,7 @@ use std::slice; /// # Examples /// /// ``` -/// # fn main() -> http_types::Result<()> { +/// # fn main() -> anyhow::Result<()> { /// # /// use http_types::content::{Accept, MediaTypeProposal}; /// use http_types::{mime, Response}; @@ -86,7 +86,8 @@ impl Accept { // Try and parse a directive from a str. If the directive is // unkown we skip it. - let entry = MediaTypeProposal::from_str(part)?; + let entry = MediaTypeProposal::from_str(part) + .map_err(HeaderError::AcceptInvalidMediaType)?; entries.push(entry); } } @@ -141,9 +142,7 @@ impl Accept { } } - let mut err = Error::new_adhoc("No suitable Content-Type found"); - err.set_status(StatusCode::NotAcceptable); - Err(err) + Err(HeaderError::AcceptUnnegotiable.into()) } /// An iterator visiting all entries. @@ -288,9 +287,9 @@ impl Debug for Accept { #[cfg(test)] mod test { + use crate::{mime, Response, StatusCode}; + use super::*; - use crate::mime; - use crate::Response; #[test] fn smoke() -> crate::Result<()> { @@ -350,7 +349,7 @@ mod test { } #[test] - fn reorder_based_on_weight() -> crate::Result<()> { + fn reorder_based_on_weight() -> anyhow::Result<()> { let mut accept = Accept::new(); accept.push(MediaTypeProposal::new(mime::HTML, Some(0.4))?); accept.push(MediaTypeProposal::new(mime::XML, None)?); @@ -369,7 +368,7 @@ mod test { } #[test] - fn reorder_based_on_weight_and_location() -> crate::Result<()> { + fn reorder_based_on_weight_and_location() -> anyhow::Result<()> { let mut accept = Accept::new(); accept.push(MediaTypeProposal::new(mime::HTML, None)?); accept.push(MediaTypeProposal::new(mime::XML, None)?); @@ -388,7 +387,7 @@ mod test { } #[test] - fn negotiate() -> crate::Result<()> { + fn negotiate() -> anyhow::Result<()> { let mut accept = Accept::new(); accept.push(MediaTypeProposal::new(mime::HTML, Some(0.4))?); accept.push(MediaTypeProposal::new(mime::PLAIN, Some(0.8))?); @@ -399,20 +398,26 @@ mod test { } #[test] - fn negotiate_not_acceptable() -> crate::Result<()> { + fn negotiate_not_acceptable() -> anyhow::Result<()> { let mut accept = Accept::new(); let err = accept.negotiate(&[mime::JSON]).unwrap_err(); - assert_eq!(err.status(), 406); + assert_eq!( + err.associated_status_code(), + Some(StatusCode::NotAcceptable) + ); let mut accept = Accept::new(); accept.push(MediaTypeProposal::new(mime::JSON, Some(0.8))?); let err = accept.negotiate(&[mime::XML]).unwrap_err(); - assert_eq!(err.status(), 406); + assert_eq!( + err.associated_status_code(), + Some(StatusCode::NotAcceptable) + ); Ok(()) } #[test] - fn negotiate_wildcard() -> crate::Result<()> { + fn negotiate_wildcard() -> anyhow::Result<()> { let mut accept = Accept::new(); accept.push(MediaTypeProposal::new(mime::JSON, Some(0.8))?); accept.set_wildcard(true); @@ -422,7 +427,7 @@ mod test { } #[test] - fn negotiate_missing_encoding() -> crate::Result<()> { + fn negotiate_missing_encoding() -> anyhow::Result<()> { let mime_html = "text/html".parse::()?; let mut browser_accept = Accept::new(); diff --git a/src/content/accept_encoding.rs b/src/content/accept_encoding.rs index 5c469a40..43c283c9 100644 --- a/src/content/accept_encoding.rs +++ b/src/content/accept_encoding.rs @@ -1,12 +1,12 @@ //! Client header advertising available compression algorithms. +use crate::errors::HeaderError; use crate::headers::{HeaderName, HeaderValue, Headers, ACCEPT_ENCODING}; use crate::utils::sort_by_weight; use crate::{ content::{ContentEncoding, Encoding, EncodingProposal}, headers::Header, }; -use crate::{Error, StatusCode}; use std::fmt::{self, Debug, Write}; @@ -21,7 +21,7 @@ use std::slice; /// # Examples /// /// ``` -/// # fn main() -> http_types::Result<()> { +/// # fn main() -> anyhow::Result<()> { /// # /// use http_types::content::{AcceptEncoding, ContentEncoding, Encoding, EncodingProposal}; /// use http_types::Response; @@ -77,7 +77,9 @@ impl AcceptEncoding { // Try and parse a directive from a str. If the directive is // unkown we skip it. - if let Some(entry) = EncodingProposal::from_str(part)? { + if let Some(entry) = EncodingProposal::from_str(part) + .map_err(HeaderError::AcceptEncodingInvalidEncoding)? + { entries.push(entry); } } @@ -133,9 +135,7 @@ impl AcceptEncoding { } } - let mut err = Error::new_adhoc("No suitable Content-Encoding found"); - err.set_status(StatusCode::NotAcceptable); - Err(err) + Err(HeaderError::AcceptEncodingUnnegotiable.into()) } /// An iterator visiting all entries. @@ -281,9 +281,10 @@ impl Debug for AcceptEncoding { #[cfg(test)] mod test { - use super::*; use crate::content::Encoding; - use crate::Response; + use crate::{Response, StatusCode}; + + use super::*; #[test] fn smoke() -> crate::Result<()> { @@ -343,7 +344,7 @@ mod test { } #[test] - fn reorder_based_on_weight() -> crate::Result<()> { + fn reorder_based_on_weight() -> anyhow::Result<()> { let mut accept = AcceptEncoding::new(); accept.push(EncodingProposal::new(Encoding::Gzip, Some(0.4))?); accept.push(EncodingProposal::new(Encoding::Identity, None)?); @@ -362,7 +363,7 @@ mod test { } #[test] - fn reorder_based_on_weight_and_location() -> crate::Result<()> { + fn reorder_based_on_weight_and_location() -> anyhow::Result<()> { let mut accept = AcceptEncoding::new(); accept.push(EncodingProposal::new(Encoding::Identity, None)?); accept.push(EncodingProposal::new(Encoding::Gzip, None)?); @@ -381,7 +382,7 @@ mod test { } #[test] - fn negotiate() -> crate::Result<()> { + fn negotiate() -> anyhow::Result<()> { let mut accept = AcceptEncoding::new(); accept.push(EncodingProposal::new(Encoding::Brotli, Some(0.8))?); accept.push(EncodingProposal::new(Encoding::Gzip, Some(0.4))?); @@ -395,20 +396,26 @@ mod test { } #[test] - fn negotiate_not_acceptable() -> crate::Result<()> { + fn negotiate_not_acceptable() -> anyhow::Result<()> { let mut accept = AcceptEncoding::new(); let err = accept.negotiate(&[Encoding::Gzip]).unwrap_err(); - assert_eq!(err.status(), 406); + assert_eq!( + err.associated_status_code(), + Some(StatusCode::NotAcceptable) + ); let mut accept = AcceptEncoding::new(); accept.push(EncodingProposal::new(Encoding::Brotli, Some(0.8))?); let err = accept.negotiate(&[Encoding::Gzip]).unwrap_err(); - assert_eq!(err.status(), 406); + assert_eq!( + err.associated_status_code(), + Some(StatusCode::NotAcceptable) + ); Ok(()) } #[test] - fn negotiate_wildcard() -> crate::Result<()> { + fn negotiate_wildcard() -> anyhow::Result<()> { let mut accept = AcceptEncoding::new(); accept.push(EncodingProposal::new(Encoding::Brotli, Some(0.8))?); accept.set_wildcard(true); diff --git a/src/content/content_length.rs b/src/content/content_length.rs index 4bae42f2..9f2afb10 100644 --- a/src/content/content_length.rs +++ b/src/content/content_length.rs @@ -1,5 +1,5 @@ +use crate::errors::HeaderError; use crate::headers::{Header, HeaderName, HeaderValue, Headers, CONTENT_LENGTH}; -use crate::Status; /// The size of the entity-body, in bytes, sent to the recipient. /// @@ -47,7 +47,11 @@ impl ContentLength { // If we successfully parsed the header then there's always at least one // entry. We want the last entry. let value = headers.iter().last().unwrap(); - let length = value.as_str().trim().parse::().status(400)?; + let length = value + .as_str() + .trim() + .parse::() + .map_err(|_| HeaderError::ContentLengthInvalid)?; Ok(Some(Self { length })) } @@ -76,8 +80,10 @@ impl Header for ContentLength { #[cfg(test)] mod test { - use super::*; use crate::headers::Headers; + use crate::StatusCode; + + use super::*; #[test] fn smoke() -> crate::Result<()> { @@ -98,6 +104,6 @@ mod test { .insert(CONTENT_LENGTH, "") .unwrap(); let err = ContentLength::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } } diff --git a/src/content/content_location.rs b/src/content/content_location.rs index 61c09135..9826f574 100644 --- a/src/content/content_location.rs +++ b/src/content/content_location.rs @@ -1,5 +1,6 @@ +use crate::errors::HeaderError; use crate::headers::{Header, HeaderName, HeaderValue, Headers, CONTENT_LOCATION}; -use crate::{bail_status as bail, Status, Url}; +use crate::Url; use std::convert::TryInto; @@ -14,7 +15,7 @@ use std::convert::TryInto; /// # Examples /// /// ``` -/// # fn main() -> http_types::Result<()> { +/// # fn main() -> anyhow::Result<()> { /// # /// use http_types::{Response, Url}; /// use http_types::content::ContentLocation; @@ -45,7 +46,7 @@ impl ContentLocation { pub fn from_headers(base_url: U, headers: impl AsRef) -> crate::Result> where U: TryInto, - U::Error: std::fmt::Debug, + U::Error: std::fmt::Debug + Send + Sync + 'static, { let headers = match headers.as_ref().get(CONTENT_LOCATION) { Some(headers) => headers, @@ -54,13 +55,13 @@ impl ContentLocation { // If we successfully parsed the header then there's always at least one // entry. We want the last entry. - let value = headers.iter().last().unwrap(); - let base = match base_url.try_into() { - Ok(b) => b, - Err(_) => bail!(400, "Invalid base url provided"), + let header_value = headers.iter().last().unwrap(); + let url = match base_url.try_into() { + Ok(base_url) => base_url + .join(header_value.as_str().trim()) + .map_err(HeaderError::ContentLocationInvalidUrl)?, + Err(e) => return Err(HeaderError::ContentLocationInvalidBaseUrl(Box::new(e)).into()), }; - - let url = base.join(value.as_str().trim()).status(400)?; Ok(Some(Self { url })) } @@ -95,11 +96,12 @@ impl Header for ContentLocation { #[cfg(test)] mod test { - use super::*; use crate::headers::Headers; + use super::*; + #[test] - fn smoke() -> crate::Result<()> { + fn smoke() -> anyhow::Result<()> { let content_location = ContentLocation::new(Url::parse("https://example.net/test.json")?); let mut headers = Headers::new(); @@ -124,6 +126,6 @@ mod test { let err = ContentLocation::from_headers(Url::parse("https://example.net").unwrap(), headers) .unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), None); } } diff --git a/src/content/content_type.rs b/src/content/content_type.rs index 9397165c..d1b2fb28 100644 --- a/src/content/content_type.rs +++ b/src/content/content_type.rs @@ -1,5 +1,6 @@ use std::{convert::TryInto, str::FromStr}; +use crate::errors::HeaderError; use crate::headers::{Header, HeaderName, HeaderValue, Headers, CONTENT_TYPE}; use crate::mime::Mime; @@ -15,7 +16,7 @@ use crate::mime::Mime; /// # Examples /// /// ``` -/// # fn main() -> http_types::Result<()> { +/// # fn main() -> anyhow::Result<()> { /// # /// use http_types::content::ContentType; /// use http_types::{headers::Header, Response}; @@ -67,10 +68,8 @@ impl ContentType { // entry. We want the last entry. let ctation = headers.iter().last().unwrap(); - let media_type = Mime::from_str(ctation.as_str()).map_err(|mut e| { - e.set_status(400); - e - })?; + let media_type = + Mime::from_str(ctation.as_str()).map_err(HeaderError::ContentTypeInvalidMediaType)?; Ok(Some(Self { media_type })) } } @@ -106,11 +105,13 @@ impl From for ContentType { #[cfg(test)] mod test { - use super::*; use crate::headers::Headers; + use crate::StatusCode; + + use super::*; #[test] - fn smoke() -> crate::Result<()> { + fn smoke() -> anyhow::Result<()> { let ct = ContentType::new(Mime::from_str("text/*")?); let mut headers = Headers::new(); @@ -131,6 +132,6 @@ mod test { .insert(CONTENT_TYPE, "") .unwrap(); let err = ContentType::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } } diff --git a/src/content/encoding_proposal.rs b/src/content/encoding_proposal.rs index 7d329447..1611fbf5 100644 --- a/src/content/encoding_proposal.rs +++ b/src/content/encoding_proposal.rs @@ -1,5 +1,5 @@ use crate::content::Encoding; -use crate::ensure; +use crate::errors::EncodingError; use crate::headers::HeaderValue; use crate::utils::parse_weight; @@ -20,11 +20,11 @@ pub struct EncodingProposal { impl EncodingProposal { /// Create a new instance of `EncodingProposal`. - pub fn new(encoding: impl Into, weight: Option) -> crate::Result { + pub fn new(encoding: impl Into, weight: Option) -> Result { if let Some(weight) = weight { - ensure!( + internal_ensure!( weight.is_sign_positive() && weight <= 1.0, - "EncodingProposal should have a weight between 0.0 and 1.0" + EncodingError::Proposal("should have a weight between 0.0 and 1.0") ) } @@ -44,13 +44,17 @@ impl EncodingProposal { self.weight } - pub(crate) fn from_str(s: &str) -> crate::Result> { + pub(crate) fn from_str(s: &str) -> Result, EncodingError> { let mut parts = s.split(';'); let encoding = match Encoding::from_str(parts.next().unwrap()) { Some(encoding) => encoding, None => return Ok(None), }; - let weight = parts.next().map(parse_weight).transpose()?; + let weight = parts + .next() + .map(parse_weight) + .transpose() + .map_err(|_| EncodingError::Proposal("weight not a valid float32"))?; Ok(Some(Self::new(encoding, weight)?)) } @@ -128,16 +132,4 @@ mod test { let _ = EncodingProposal::new(Encoding::Gzip, Some(0.5)).unwrap(); let _ = EncodingProposal::new(Encoding::Gzip, Some(1.0)).unwrap(); } - - #[test] - fn error_code_500() { - let err = EncodingProposal::new(Encoding::Gzip, Some(1.1)).unwrap_err(); - assert_eq!(err.status(), 500); - - let err = EncodingProposal::new(Encoding::Gzip, Some(-0.1)).unwrap_err(); - assert_eq!(err.status(), 500); - - let err = EncodingProposal::new(Encoding::Gzip, Some(-0.0)).unwrap_err(); - assert_eq!(err.status(), 500); - } } diff --git a/src/content/media_type_proposal.rs b/src/content/media_type_proposal.rs index 622558de..fdb1cc01 100644 --- a/src/content/media_type_proposal.rs +++ b/src/content/media_type_proposal.rs @@ -1,4 +1,4 @@ -use crate::ensure; +use crate::errors::MediaTypeError; use crate::headers::HeaderValue; use crate::mime::Mime; @@ -22,11 +22,11 @@ pub struct MediaTypeProposal { impl MediaTypeProposal { /// Create a new instance of `MediaTypeProposal`. - pub fn new(media_type: impl Into, weight: Option) -> crate::Result { + pub fn new(media_type: impl Into, weight: Option) -> Result { if let Some(weight) = weight { - ensure!( + internal_ensure!( weight.is_sign_positive() && weight <= 1.0, - "MediaTypeProposal should have a weight between 0.0 and 1.0" + MediaTypeError::Proposal("should have a weight between 0.0 and 1.0") ) } if weight.is_none() { @@ -57,11 +57,16 @@ impl MediaTypeProposal { /// Because `;` and `q=0.0` are all valid values for in use in a media type, /// we have to parse the full string to the media type first, and then see if /// a `q` value has been set. - pub(crate) fn from_str(s: &str) -> crate::Result { + pub(crate) fn from_str(s: &str) -> Result { let mut media_type = Mime::from_str(s)?; let weight = media_type .remove_param("q") - .map(|param| param.as_str().parse()) + .map(|param| { + param + .as_str() + .parse() + .map_err(|_| MediaTypeError::Proposal("weight not a valid float32")) + }) .transpose()?; Self::new(media_type, weight) } @@ -146,16 +151,4 @@ mod test { let _ = MediaTypeProposal::new(mime::XML, Some(0.5)).unwrap(); let _ = MediaTypeProposal::new(mime::HTML, Some(1.0)).unwrap(); } - - #[test] - fn error_code_500() { - let err = MediaTypeProposal::new(mime::JSON, Some(1.1)).unwrap_err(); - assert_eq!(err.status(), 500); - - let err = MediaTypeProposal::new(mime::XML, Some(-0.1)).unwrap_err(); - assert_eq!(err.status(), 500); - - let err = MediaTypeProposal::new(mime::HTML, Some(-0.0)).unwrap_err(); - assert_eq!(err.status(), 500); - } } diff --git a/src/content/mod.rs b/src/content/mod.rs index aff9fb4b..c6fe07a5 100644 --- a/src/content/mod.rs +++ b/src/content/mod.rs @@ -12,7 +12,7 @@ //! # Examples //! //! ``` -//! # fn main() -> http_types::Result<()> { +//! # fn main() -> anyhow::Result<()> { //! # //! use http_types::content::{Accept, MediaTypeProposal}; //! use http_types::{mime, Response}; diff --git a/src/errors/error_kind.rs b/src/errors/error_kind.rs new file mode 100644 index 00000000..c6fd2643 --- /dev/null +++ b/src/errors/error_kind.rs @@ -0,0 +1,312 @@ +use miette::Diagnostic; +use thiserror::Error as ThisError; + +use crate::auth::AuthenticationScheme; +use crate::StatusCode; + +/// Error kind for http-types +#[derive(Debug, Diagnostic, ThisError)] +#[allow(missing_docs)] +pub enum Error { + /// This only happens for APIs which support `TryInto` for an argument. + #[error("An argument failed to convert during TryInto: {}", .0)] + ArgTryIntoError(Box), + + #[error("HTTP Method was unrecognized: {}", .0)] + HttpMethodUnrecognized(String), + + #[error("I/O error: {}", .0)] + IO(#[from] std::io::Error), + + #[error("Invalid Status Code: {}", .0)] + StatusCodeInvalid(u16), + + #[error("Querystring deserialization error: {}", .0)] + QueryDeserialize(serde_qs::Error), + #[error("Querystring serialization error: {}", .0)] + QuerySerialize(serde_qs::Error), + + #[error("Body error: {}", .0)] + Body(BodyError), + + #[error("Header error: {}", .0)] + Header(HeaderError), + + #[cfg(feature = "hyperium_http")] + #[error("URL Parse error: {}", .0)] + URLParse(#[from] url::ParseError), + + #[cfg(feature = "hyperium_http")] + #[error("Hyperium HTTP error: {}", .0)] + HyperiumHttp(#[from] http::Error), +} + +#[derive(Debug, Diagnostic, ThisError)] +pub enum BodyError { + #[error("Body size too large: {:?} (PayloadTooLarge)", .0)] + PayloadTooLarge(Option), + + // Deserialization + #[error("Failed to Deserialize JSON: {:?}", .0)] + DeserializeJSON(serde_json::Error), + #[error("Failed to Deserialize x-form-urlencoded: {:?}", .0)] + DeserializeForm(#[from] serde_urlencoded::de::Error), + // #[error("Failed to Deserialize utf8: {:?}", .0)] + // DeserializeUTF8(#[from] std::str::Utf8Error), + + // Serialization + #[error("JSON: {:?}", .0)] + SerializeJSON(serde_json::Error), + #[error("x-form-urlencoded: {:?}", .0)] + SerializeForm(#[from] serde_urlencoded::ser::Error), +} + +impl From for Error { + fn from(other: BodyError) -> Self { + Error::Body(other) + } +} + +#[derive(Debug, Diagnostic, ThisError)] +pub enum HeaderError { + // Parse + #[error("Header value specificity was invalid")] + SpecificityInvalid, + /// Header name was invalid + #[error("Header name was invalid: {}", .0)] + NameInvalid(&'static str), + /// Header value was invalid + #[error("Header value was invalid: {}", .0)] + ValueInvalid(&'static str), + + // Parse specific headers + #[error("Date header was invalid: {}", .0)] + DateInvalid(DateError), + #[error("No suitable Transfer-Encoding found during negotiation")] + TransferEncodingUnnegotiable, + #[error("Transfer-Encoding header encoding was invalid: {}", .0)] + TransferEncodingInvalidEncoding(EncodingError), + #[error("Trace-Context header was invalid: {}", .0)] + TraceContextInvalid(&'static str), + #[error("Server-Timing header was invalid: {}", .0)] + ServerTimingInvalid(&'static str), + #[error("Server-Timing header metric was invalid: {}", .0)] + ServerTimingInvalidMetric(&'static str), + #[error("Timing-Allow-Origin header was invalid: {:?}", .0)] + TimingAllowOriginInvalidUrl(url::ParseError), + #[error("Forwarded header was invalid: {}", .0)] + ForwardedInvalid(&'static str), + #[error("Sourcemap header url was invalid: {:?}", .0)] + SourceMapInvalidUrl(url::ParseError), + #[error("Sourcemap header base url was invalid: {:?}", .0)] + SourceMapInvalidBaseUrl(Box), + #[error("Referer header url was invalid: {:?}", .0)] + RefererInvalidUrl(url::ParseError), + #[error("Referer header base url was invalid: {:?}", .0)] + RefererInvalidBaseUrl(Box), + #[error("Content-Type header MediaType (MIME) was invalid: {}", .0)] + ContentTypeInvalidMediaType(MediaTypeError), + #[error("Content-Length header was invalid: length out of bounds (unsized 64 bit integer)")] + ContentLengthInvalid, + #[error("Accept header was invalid: {}", .0)] + AcceptInvalidMediaType(MediaTypeError), + #[error("No suitable Content-Type header MediaType found during Accept negotiation")] + AcceptUnnegotiable, + #[error("Accept-Encoding header was invalid: {}", .0)] + AcceptEncodingInvalidEncoding(EncodingError), + #[error( + "No suitable Content-Encoding header Encoding found during Accept-Encoding negotiation" + )] + AcceptEncodingUnnegotiable, + #[error("ETag header was invalid")] + ETagInvalid, + #[error("Age header was invalid: length out of bounds (unsized 64 bit integer)")] + AgeInvalid, + #[error("Clear-Site-Data header was invalid: {}", .0)] + ClearSiteDataInvalid(std::string::ParseError), + #[error("Cache-Control header was invalid")] + CacheControlInvalid, + #[error("Authorization header was invalid: {}", .0)] + AuthorizationInvalid(AuthError), + #[error("WWW-Authenticate header was invalid: {}", .0)] + WWWAuthenticateInvalid(&'static str), + #[error("Content-Location header url was invalid: {:?}", .0)] + ContentLocationInvalidUrl(url::ParseError), + #[error("Content-Location header base url was invalid: {:?}", .0)] + ContentLocationInvalidBaseUrl(Box), + #[error("Expect header was malformed.")] + ExpectInvalid, + #[error("Timing-Allow-Origin header was invalid: {:?}", .0)] + StrictTransportSecurityInvalid(&'static str), +} + +impl From for Error { + fn from(other: HeaderError) -> Self { + Error::Header(other) + } +} + +#[derive(Debug, Diagnostic, ThisError)] +pub enum AuthError { + #[error("`{}` Auth had invalid credentials: {}", .0, .1)] + CredentialsInvalid(AuthenticationScheme, &'static str), + #[error("`{}` is not a recognized auth scheme.", .0)] + SchemeUnrecognized(String), + #[error("Could not find auth scheme")] + SchemeMissing, + #[error("Could not find auth credentials")] + CredentialsMissing, + #[error("Expected `{}` auth scheme but found `{}`", .0, .1)] + SchemeUnexpected(AuthenticationScheme, String), + #[error("Could not find www-auth realm")] + RealmMissing, +} + +impl From for Error { + fn from(other: AuthError) -> Self { + Error::Header(other.into()) + } +} + +impl From for HeaderError { + fn from(other: AuthError) -> Self { + HeaderError::AuthorizationInvalid(other) + } +} + +#[derive(Debug, Diagnostic, ThisError)] +pub enum DateError { + #[error("HTTP Date-Time not in {} format", .0)] + FormatInvalid(&'static str), + + #[error("HTTP Date-Time failed all parsings: imf_fixdate, rfc850, asctime")] + Unparseable, + + #[error("HTTP Date-Time invalid: parts out of logical bounds")] + OutOfBounds, + + #[error("HTTP Date-Time string was not ASCII")] + NotASCII, + + // Individual parts + #[error("HTTP Date-Time invalid seconds")] + SecondsInvalid, + #[error("HTTP Date-Time invalid minutes")] + MinutesInvalid, + #[error("HTTP Date-Time invalid hours")] + HourInvalid, + #[error("HTTP Date-Time invalid day")] + DayInvalid, + #[error("HTTP Date-Time invalid month")] + MonthInvalid, + #[error("HTTP Date-Time invalid year")] + YearInvalid, + #[error("HTTP Date-Time invalid week-day")] + WeekdayInvalid, +} + +impl From for Error { + fn from(other: DateError) -> Self { + Error::Header(other.into()) + } +} + +impl From for HeaderError { + fn from(other: DateError) -> Self { + HeaderError::DateInvalid(other) + } +} + +#[derive(Debug, Diagnostic, ThisError)] +pub enum MediaTypeError { + #[error("MediaType (MIME) parse error: {}", .0)] + Parse(&'static str), + #[error("MediaType (MIME) invalid Param name: {}", .0)] + ParamName(&'static str), + #[error("MediaType (MIME) invalid proposal: {}", .0)] + Proposal(&'static str), + #[error("Media Type (MIME) could not be sniffed / inferred")] + Sniff, +} + +#[derive(Debug, Diagnostic, ThisError)] +pub enum EncodingError { + #[error("Encoding parse error: {}", .0)] + Parse(&'static str), + #[error("Encoding invalid proposal: {}", .0)] + Proposal(&'static str), +} + +impl Error { + /// Maps this error to its associated http status code, if one logically exists. + /// + /// `None` is returned in cases where a default should be used. + /// It is suggested that frameworks using this code map these to 500 by default when there is no other developer intervention. + /// (This is what Tide does.) + pub fn associated_status_code(&self) -> Option { + use Error::*; + match self { + QueryDeserialize(_) => Some(StatusCode::BadRequest), // XXX(Jeremiah): should this also be 422? + Body(inner) => inner.associated_status_code(), + Header(inner) => inner.associated_status_code(), + _ => None, + } + } +} + +impl BodyError { + /// Maps this error to its associated http status code, if one logically exists. + /// + /// `None` is returned in cases where a default should be used. + /// It is suggested that frameworks using this code map these to 500 by default when there is no other developer intervention. + /// (This is what Tide does.) + pub fn associated_status_code(&self) -> Option { + use BodyError::*; + match self { + PayloadTooLarge(_) => Some(StatusCode::PayloadTooLarge), + DeserializeJSON(_) => Some(StatusCode::UnprocessableEntity), + DeserializeForm(_) => Some(StatusCode::UnprocessableEntity), + // XXX(Jeremiah): This is currently a std::io::Error but should probably be mapped to this + // BodyError::DeserializeUTF8(_) => Some(StatusCode::UnprocessableEntity), + _ => None, + } + } +} + +impl HeaderError { + /// Maps this error to its associated http status code, if one logically exists. + /// + /// `None` is returned in cases where a default should be used. + /// It is suggested that frameworks using this code map these to 500 by default when there is no other developer intervention. + /// (This is what Tide does.) + pub fn associated_status_code(&self) -> Option { + use HeaderError::*; + use StatusCode::*; + match self { + SpecificityInvalid => Some(BadRequest), + + DateInvalid(_) => Some(BadRequest), + TransferEncodingUnnegotiable => Some(NotAcceptable), + TransferEncodingInvalidEncoding(_) => Some(BadRequest), + TraceContextInvalid(_) => Some(BadRequest), + ServerTimingInvalid(_) => Some(BadRequest), + TimingAllowOriginInvalidUrl(_) => Some(BadRequest), + ForwardedInvalid(_) => Some(BadRequest), + ContentTypeInvalidMediaType(_) => Some(BadRequest), + ContentLengthInvalid => Some(BadRequest), + AcceptInvalidMediaType(_) => Some(BadRequest), + AcceptUnnegotiable => Some(NotAcceptable), + AcceptEncodingInvalidEncoding(_) => Some(BadRequest), + AcceptEncodingUnnegotiable => Some(NotAcceptable), + ETagInvalid => Some(BadRequest), + AgeInvalid => Some(BadRequest), + CacheControlInvalid => Some(BadRequest), + AuthorizationInvalid(_) => Some(BadRequest), + WWWAuthenticateInvalid(_) => Some(BadRequest), + ExpectInvalid => Some(BadRequest), + StrictTransportSecurityInvalid(_) => Some(BadRequest), + + _ => None, + } + } +} diff --git a/src/errors/mod.rs b/src/errors/mod.rs new file mode 100644 index 00000000..821595da --- /dev/null +++ b/src/errors/mod.rs @@ -0,0 +1,24 @@ +//! HTTP error types +//! +//! This includes two error types for different purposes: +//! One, to either be used as a Response and consumed by a server's middleware, or produced by +//! a client with middleware capabilities; with the ability to dynamically encapsulate +//! any error in handlers (or middleware). +//! Another, to be made by common http operation errors. + +mod error_kind; +mod request_error; +mod response_error; + +pub use error_kind::*; +pub use request_error::RequestError; +pub use response_error::ResponseError; + +/// Result type for errors from http-types. +pub type Result = std::result::Result; + +/// Result type for errors from a client making requests using http-types. +pub type RequestResult = std::result::Result; + +/// Result type for errors provided to a response handler using http-types. +pub type ResponseResult = std::result::Result; diff --git a/src/errors/request_error.rs b/src/errors/request_error.rs new file mode 100644 index 00000000..66db93c9 --- /dev/null +++ b/src/errors/request_error.rs @@ -0,0 +1,65 @@ +use std::error::Error as StdError; +use std::fmt::{self, Debug, Display}; + +use miette::Diagnostic; +use thiserror::Error as ThisError; + +use crate::StatusCode; + +use super::{Error, ResponseError}; + +#[derive(Debug, Diagnostic, ThisError)] +/// An error type to be used for clients which handle http requests. +pub enum RequestError { + #[error(transparent)] + /// An internal, concrete http-types error without indirection. + Internal(Error), + #[error(transparent)] + /// A dynamic error, usually generated in a response handler. + /// + /// This has a layer of indirection to get around trait conflicts regarding StdErr and anyhow. + Dynamic(ResponseErrorIndirection), +} + +pub struct ResponseErrorIndirection(ResponseError); + +impl StdError for ResponseErrorIndirection { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.0.error.source() + } + + #[cfg(backtrace)] + fn backtrace(&self) -> Option<&std::backtrace::Backtrace> { + Some(self.0.error.backtrace()) + } + + fn description(&self) -> &str { + "description() is deprecated; use Display" + } + + fn cause(&self) -> Option<&dyn StdError> { + self.source() + } +} + +impl RequestError { + /// Get the status code associated with this error. + pub fn status(&self) -> Option { + match self { + RequestError::Internal(inner) => inner.associated_status_code(), + RequestError::Dynamic(ResponseErrorIndirection(inner)) => inner.status(), + } + } +} + +impl Debug for ResponseErrorIndirection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_fmt(format_args!("{:?}", self.0)) + } +} + +impl Display for ResponseErrorIndirection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_fmt(format_args!("{}", self.0)) + } +} diff --git a/src/error.rs b/src/errors/response_error.rs similarity index 69% rename from src/error.rs rename to src/errors/response_error.rs index 04fbcd8d..02b28976 100644 --- a/src/error.rs +++ b/src/errors/response_error.rs @@ -1,21 +1,16 @@ -//! HTTP error types - +use std::convert::TryInto; use std::error::Error as StdError; use std::fmt::{self, Debug, Display}; -use crate::StatusCode; -use std::convert::TryInto; +use anyhow::Context; -/// A specialized `Result` type for HTTP operations. -/// -/// This type is broadly used across `http_types` for any operation which may -/// produce an error. -pub type Result = std::result::Result; +use crate::StatusCode; -/// The error type for HTTP operations. -pub struct Error { - error: anyhow::Error, - status: crate::StatusCode, +/// An error type to be used for where handlers and middleware can error when handling an http response. +#[derive(Debug)] +pub struct ResponseError { + pub(super) error: anyhow::Error, + status: Option, type_name: Option<&'static str>, } @@ -30,64 +25,85 @@ impl Display for BacktracePlaceholder { } } -impl Error { +impl ResponseError { /// Create a new error object from any error type. /// - /// The error type must be threadsafe and 'static, so that the Error will be + /// The error type must be thread-safe and 'static, so that the Error will be /// as well. If the error type does not provide a backtrace, a backtrace will /// be created here to ensure that a backtrace exists. - pub fn new(status: S, error: E) -> Self + pub fn new(error: E) -> Self where - S: TryInto, - S::Error: Debug, E: Into, { Self { - status: status - .try_into() - .expect("Could not convert into a valid `StatusCode`"), + status: None, error: error.into(), type_name: Some(std::any::type_name::()), } } - /// Create a new error object from static string. - pub fn from_str(status: S, msg: M) -> Self + /// Create a new error object from any error type. + /// + /// The error type must be thread-safe and 'static, so that the Error will be + /// as well. If the error type does not provide a backtrace, a backtrace will + /// be created here to ensure that a backtrace exists. + pub fn new_status(status: S, error: E) -> Self where S: TryInto, - S::Error: Debug, + S::Error: StdError + Send + Sync + 'static, + E: Into, + { + let mut err = Self::new(error); + if let Err(new_err) = err.set_status(status) { + return new_err; + } + err + } + + /// Create a new error object from static string. + #[allow(clippy::should_implement_trait)] + pub fn from_str(msg: M) -> Self + where M: Display + Debug + Send + Sync + 'static, { Self { - status: status - .try_into() - .expect("Could not convert into a valid `StatusCode`"), + status: None, error: anyhow::Error::msg(msg), type_name: None, } } - /// Create a new error from a message. - pub(crate) fn new_adhoc(message: M) -> Error + + /// Create a new error object from static string. + pub fn from_str_status(status: S, msg: M) -> Self where + S: TryInto, + S::Error: StdError + Send + Sync + 'static, M: Display + Debug + Send + Sync + 'static, { - Self::from_str(StatusCode::InternalServerError, message) + let mut err = Self::from_str(msg); + if let Err(new_err) = err.set_status(status) { + return new_err; + } + err } /// Get the status code associated with this error. - pub fn status(&self) -> StatusCode { + pub fn status(&self) -> Option { self.status } /// Set the status code associated with this error. - pub fn set_status(&mut self, status: S) + pub fn set_status(&mut self, status: S) -> Result<(), ResponseError> where S: TryInto, - S::Error: Debug, + S::Error: StdError + Send + Sync + 'static, { - self.status = status - .try_into() - .expect("Could not convert into a valid `StatusCode`"); + self.status = Some( + status + .try_into() + .context("Could not convert into a valid `StatusCode`")?, + ); + Ok(()) } /// Get the backtrace for this Error. @@ -165,7 +181,7 @@ impl Error { /// Converts anything which implements `Display` into an `http_types::Error`. /// /// This is handy for errors which are not `Send + Sync + 'static` because `std::error::Error` requires `Display`. - /// Note that any assiciated context not included in the `Display` output will be lost, + /// Note that any associated context not included in the `Display` output will be lost, /// and so this may be lossy for some types which implement `std::error::Error`. /// /// **Note: Prefer `error.into()` via `From>` when possible!** @@ -176,7 +192,7 @@ impl Error { /// Converts anything which implements `Debug` into an `http_types::Error`. /// /// This is handy for errors which are not `Send + Sync + 'static` because `std::error::Error` requires `Debug`. - /// Note that any assiciated context not included in the `Debug` output will be lost, + /// Note that any associated context not included in the `Debug` output will be lost, /// and so this may be lossy for some types which implement `std::error::Error`. /// /// **Note: Prefer `error.into()` via `From>` when possible!** @@ -185,61 +201,43 @@ impl Error { } } -impl Display for Error { +impl Display for ResponseError { fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { Display::fmt(&self.error, formatter) } } -impl Debug for Error { - fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - Debug::fmt(&self.error, formatter) - } -} - -impl> From for Error { +impl> From for ResponseError { fn from(error: E) -> Self { - Self::new(StatusCode::InternalServerError, error) + Self::new(error) } } -impl AsRef for Error { +impl AsRef for ResponseError { fn as_ref(&self) -> &(dyn StdError + Send + Sync + 'static) { self.error.as_ref() } } -impl AsRef for Error { - fn as_ref(&self) -> &StatusCode { - &self.status - } -} - -impl AsMut for Error { - fn as_mut(&mut self) -> &mut StatusCode { - &mut self.status - } -} - -impl AsRef for Error { +impl AsRef for ResponseError { fn as_ref(&self) -> &(dyn StdError + 'static) { self.error.as_ref() } } -impl From for Box { - fn from(error: Error) -> Self { +impl From for Box { + fn from(error: ResponseError) -> Self { error.error.into() } } -impl From for Box { - fn from(error: Error) -> Self { +impl From for Box { + fn from(error: ResponseError) -> Self { Box::::from(error.error) } } -impl AsRef for Error { +impl AsRef for ResponseError { fn as_ref(&self) -> &anyhow::Error { &self.error } diff --git a/src/headers/header_name.rs b/src/headers/header_name.rs index 430687f8..0b929b08 100644 --- a/src/headers/header_name.rs +++ b/src/headers/header_name.rs @@ -2,6 +2,7 @@ use std::borrow::Cow; use std::fmt::{self, Debug, Display}; use std::str::FromStr; +use crate::errors::HeaderError; use crate::Error; use super::Header; @@ -17,7 +18,10 @@ impl HeaderName { /// /// This function will error if the bytes is not valid ASCII. pub fn from_bytes(mut bytes: Vec) -> Result { - crate::ensure!(bytes.is_ascii(), "Bytes should be valid ASCII"); + internal_ensure!( + bytes.is_ascii(), + HeaderError::NameInvalid("Bytes should be valid ASCII") + ); bytes.make_ascii_lowercase(); // This is permitted because ASCII is valid UTF-8, and we just checked that. @@ -79,7 +83,10 @@ impl FromStr for HeaderName { /// /// This checks it's valid ASCII, and lowercases it. fn from_str(s: &str) -> Result { - crate::ensure!(s.is_ascii(), "String slice should be valid ASCII"); + internal_ensure!( + s.is_ascii(), + HeaderError::NameInvalid("String slice should be valid ASCII") + ); Ok(HeaderName(Cow::Owned(s.to_ascii_lowercase()))) } } diff --git a/src/headers/header_value.rs b/src/headers/header_value.rs index c8287a1c..f712e7f5 100644 --- a/src/headers/header_value.rs +++ b/src/headers/header_value.rs @@ -4,6 +4,7 @@ use std::str::FromStr; #[cfg(feature = "cookies")] use crate::cookies::Cookie; +use crate::errors::HeaderError; use crate::headers::HeaderValues; use crate::mime::Mime; use crate::Error; @@ -21,7 +22,10 @@ impl HeaderValue { /// /// This function will error if the bytes is not valid ASCII. pub fn from_bytes(bytes: Vec) -> Result { - crate::ensure!(bytes.is_ascii(), "Bytes should be valid ASCII"); + internal_ensure!( + bytes.is_ascii(), + HeaderError::ValueInvalid("Bytes should be valid ASCII") + ); // This is permitted because ASCII is valid UTF-8, and we just checked that. let string = unsafe { String::from_utf8_unchecked(bytes) }; @@ -80,7 +84,10 @@ impl FromStr for HeaderValue { /// /// This checks it's valid ASCII. fn from_str(s: &str) -> Result { - crate::ensure!(s.is_ascii(), "String slice should be valid ASCII"); + internal_ensure!( + s.is_ascii(), + HeaderError::ValueInvalid("String slice should be valid ASCII") + ); Ok(Self { inner: String::from(s), }) diff --git a/src/headers/headers.rs b/src/headers/headers.rs index a7a9d90e..9b02159d 100644 --- a/src/headers/headers.rs +++ b/src/headers/headers.rs @@ -208,7 +208,7 @@ mod tests { const STATIC_HEADER: HeaderName = HeaderName::from_lowercase_str("hello"); #[test] - fn test_header_name_static_non_static() -> crate::Result<()> { + fn test_header_name_static_non_static() -> crate::ResponseResult<()> { let static_header = HeaderName::from_lowercase_str("hello"); let non_static_header = HeaderName::from_str("hello")?; diff --git a/src/headers/to_header_values.rs b/src/headers/to_header_values.rs index 37ef54d6..22bba471 100644 --- a/src/headers/to_header_values.rs +++ b/src/headers/to_header_values.rs @@ -1,5 +1,4 @@ use std::borrow::Cow; -use std::io; use std::iter; use std::option; use std::slice; @@ -51,9 +50,7 @@ impl<'a> ToHeaderValues for &'a str { type Iter = option::IntoIter; fn to_header_values(&self) -> crate::Result { - let value = self - .parse() - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + let value = self.parse()?; Ok(Some(value).into_iter()) } } diff --git a/src/hyperium_http.rs b/src/hyperium_http.rs index 81d33425..fa480e32 100644 --- a/src/hyperium_http.rs +++ b/src/hyperium_http.rs @@ -1,5 +1,6 @@ // This is the compat file for the "hyperium/http" crate. +use crate::errors::HeaderError; use crate::headers::{HeaderName, HeaderValue, Headers}; use crate::{Body, Error, Method, Request, Response, StatusCode, Url, Version}; use std::convert::{TryFrom, TryInto}; @@ -68,7 +69,8 @@ impl TryFrom for http::header::HeaderName { fn try_from(name: HeaderName) -> Result { let name = name.as_str().as_bytes(); - http::header::HeaderName::from_bytes(name).map_err(Error::new_adhoc) + http::header::HeaderName::from_bytes(name) + .map_err(|_| HeaderError::NameInvalid("(hyper http)").into()) } } @@ -86,7 +88,8 @@ impl TryFrom for http::header::HeaderValue { fn try_from(value: HeaderValue) -> Result { let value = value.as_str().as_bytes(); - http::header::HeaderValue::from_bytes(value).map_err(Error::new_adhoc) + http::header::HeaderValue::from_bytes(value) + .map_err(|_| HeaderError::ValueInvalid("(hyper http)").into()) } } diff --git a/src/lib.rs b/src/lib.rs index c8821a41..e221797e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -115,6 +115,9 @@ pub mod url { }; } +#[macro_use] +mod macros_internal; + #[macro_use] mod utils; @@ -132,7 +135,7 @@ pub mod transfer; pub mod upgrade; mod body; -mod error; +mod errors; mod extensions; mod macros; mod method; @@ -144,7 +147,7 @@ mod status_code; mod version; pub use body::Body; -pub use error::{Error, Result}; +pub use errors::{Error, RequestError, RequestResult, ResponseError, ResponseResult, Result}; pub use method::Method; pub use request::Request; pub use response::Response; @@ -175,15 +178,15 @@ pub mod convert { // Not public API. Referenced by macro-generated code. #[doc(hidden)] pub mod private { - use crate::Error; + use crate::ResponseError; pub use crate::StatusCode; use core::fmt::{Debug, Display}; pub use core::result::Result::Err; - pub fn new_adhoc(message: M) -> Error + pub fn new_adhoc(message: M) -> ResponseError where M: Display + Debug + Send + Sync + 'static, { - Error::new_adhoc(message) + ResponseError::from_str(message) } } diff --git a/src/macros.rs b/src/macros.rs index e29ae6c3..130f3e8e 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -78,7 +78,7 @@ macro_rules! format_err { }; ($err:expr $(,)?) => ({ let error = $err; - Error::new_adhoc(error) + ResponseError::new_adhoc(error) }); ($fmt:expr, $($arg:tt)*) => { $crate::private::new_adhoc(format!($fmt, $($arg)*)) diff --git a/src/macros_internal.rs b/src/macros_internal.rs new file mode 100644 index 00000000..22118d99 --- /dev/null +++ b/src/macros_internal.rs @@ -0,0 +1,30 @@ +/// Return early with an error if a condition is not satisfied. +/// +/// This macro is equivalent to `if !$cond { return Err(From::from($err)); }`. +/// +/// Analogously to `assert!`, `ensure!` takes a condition and exits the function +/// if the condition fails. Unlike `assert!`, `ensure!` returns an `Error` +/// rather than panicking. +macro_rules! internal_ensure { + ($cond:expr, $err:expr $(,)?) => { + if !$cond { + return $crate::private::Err($err.into()); + } + }; +} + +/// Return early with an error if two expressions are not equal to each other. +/// +/// This macro is equivalent to `if $left != $right { return Err(From::from($err)); }`. +/// +/// Analogously to `assert_eq!`, `ensure_eq!` takes two expressions and exits the function +/// if the expressions are not equal. Unlike `assert_eq!`, `ensure_eq!` returns an `Error` +/// rather than panicking. +#[macro_export] +macro_rules! internal_ensure_eq { + ($left:expr, $right:expr, $err:expr $(,)?) => { + if $left != $right { + return $crate::private::Err($err.into()); + } + }; +} diff --git a/src/method.rs b/src/method.rs index 7ad2b46b..8d343a41 100644 --- a/src/method.rs +++ b/src/method.rs @@ -1,6 +1,8 @@ use std::fmt::{self, Display}; use std::str::FromStr; +use crate::Error; + /// HTTP request methods. /// /// See also [Mozilla's documentation][Mozilla docs], the [RFC7231, Section 4][] and @@ -454,7 +456,7 @@ impl FromStr for Method { "UPDATE" => Ok(Self::Update), "UPDATEREDIRECTREF" => Ok(Self::UpdateRedirectRef), "VERSION-CONTROL" => Ok(Self::VersionControl), - _ => crate::bail!("Invalid HTTP method"), + s => Err(Error::HttpMethodUnrecognized(s.to_string())), } } } diff --git a/src/mime/mod.rs b/src/mime/mod.rs index f416de3d..d7de7b8b 100644 --- a/src/mime/mod.rs +++ b/src/mime/mod.rs @@ -12,6 +12,7 @@ use std::fmt::{self, Debug, Display}; use std::option; use std::str::FromStr; +use crate::errors::MediaTypeError; use crate::headers::{HeaderValue, ToHeaderValues}; use infer::Infer; @@ -41,11 +42,11 @@ pub struct Mime { impl Mime { /// Sniff the mime type from a byte slice. - pub fn sniff(bytes: &[u8]) -> crate::Result { + pub fn sniff(bytes: &[u8]) -> Result { let info = Infer::new(); let mime = match info.get(bytes) { Some(info) => info.mime_type(), - None => crate::bail!("Could not sniff the mime type"), + None => return Err(MediaTypeError::Sniff), }; Mime::from_str(mime) } @@ -179,7 +180,7 @@ impl Display for Mime { // } impl FromStr for Mime { - type Err = crate::Error; + type Err = MediaTypeError; /// Create a new `Mime`. /// @@ -225,13 +226,16 @@ impl Display for ParamName { } impl FromStr for ParamName { - type Err = crate::Error; + type Err = MediaTypeError; /// Create a new `HeaderName`. /// /// This checks it's valid ASCII, and lowercases it. fn from_str(s: &str) -> Result { - crate::ensure!(s.is_ascii(), "String slice should be valid ASCII"); + internal_ensure!( + s.is_ascii(), + MediaTypeError::ParamName("Param Name: String slice should be valid ASCII") + ); Ok(ParamName(Cow::Owned(s.to_ascii_lowercase()))) } } diff --git a/src/mime/parse.rs b/src/mime/parse.rs index f2ec5698..86c98eca 100644 --- a/src/mime/parse.rs +++ b/src/mime/parse.rs @@ -1,11 +1,13 @@ use std::borrow::Cow; use std::fmt; +use crate::errors::MediaTypeError; + use super::{Mime, ParamName, ParamValue}; /// Parse a string into a mime type. /// Follows the [WHATWG MIME parsing algorithm](https://mimesniff.spec.whatwg.org/#parsing-a-mime-type) -pub(crate) fn parse(input: &str) -> crate::Result { +pub(crate) fn parse(input: &str) -> Result { // 1 let input = input.trim_matches(is_http_whitespace_char); @@ -13,14 +15,20 @@ pub(crate) fn parse(input: &str) -> crate::Result { let (basetype, input) = collect_code_point_sequence_char(input, '/'); // 4. - crate::ensure!(!basetype.is_empty(), "MIME type should not be empty"); - crate::ensure!( + internal_ensure!( + !basetype.is_empty(), + MediaTypeError::Parse("MIME type should not be empty") + ); + internal_ensure!( basetype.chars().all(is_http_token_code_point), - "MIME type should ony contain valid HTTP token code points" + MediaTypeError::Parse("MIME type should ony contain valid HTTP token code points") ); // 5. - crate::ensure!(!input.is_empty(), "MIME must contain a sub type"); + internal_ensure!( + !input.is_empty(), + MediaTypeError::Parse("MIME must contain a sub type") + ); // 6. let input = &input[1..]; @@ -32,10 +40,13 @@ pub(crate) fn parse(input: &str) -> crate::Result { let subtype = subtype.trim_end_matches(is_http_whitespace_char); // 9. - crate::ensure!(!subtype.is_empty(), "MIME sub type should not be empty"); - crate::ensure!( + internal_ensure!( + !subtype.is_empty(), + MediaTypeError::Parse("MIME sub type should not be empty") + ); + internal_ensure!( subtype.chars().all(is_http_token_code_point), - "MIME sub type should ony contain valid HTTP token code points" + MediaTypeError::Parse("MIME sub type should ony contain valid HTTP token code points") ); // 10. diff --git a/src/other/date.rs b/src/other/date.rs index 1e381ffe..98ead812 100644 --- a/src/other/date.rs +++ b/src/other/date.rs @@ -12,7 +12,7 @@ use std::time::SystemTime; /// # Examples /// /// ``` -/// # fn main() -> http_types::Result<()> { +/// # fn main() -> anyhow::Result<()> { /// # /// use http_types::Response; /// use http_types::other::Date; @@ -60,14 +60,7 @@ impl Date { // If we successfully parsed the header then there's always at least one // entry. We want the last entry. let value = headers.iter().last().unwrap(); - let date: HttpDate = value - .as_str() - .trim() - .parse() - .map_err(|mut e: crate::Error| { - e.set_status(400); - e - })?; + let date: HttpDate = value.as_str().trim().parse()?; let at = date.into(); Ok(Some(Self { at })) } @@ -107,12 +100,15 @@ impl PartialEq for Date { #[cfg(test)] mod test { - use super::*; - use crate::headers::Headers; use std::time::Duration; + use crate::headers::Headers; + use crate::StatusCode; + + use super::*; + #[test] - fn smoke() -> crate::Result<()> { + fn smoke() -> anyhow::Result<()> { let now = SystemTime::now(); let date = Date::new(now); @@ -131,6 +127,6 @@ mod test { let mut headers = Headers::new(); headers.insert(DATE, "").unwrap(); let err = Date::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } } diff --git a/src/other/expect.rs b/src/other/expect.rs index 086c0d09..437c6185 100644 --- a/src/other/expect.rs +++ b/src/other/expect.rs @@ -1,5 +1,5 @@ -use crate::headers::{HeaderName, HeaderValue, Headers, EXPECT}; -use crate::{ensure_eq_status, headers::Header}; +use crate::errors::HeaderError; +use crate::headers::{Header, HeaderName, HeaderValue, Headers, EXPECT}; use std::fmt::Debug; @@ -50,7 +50,7 @@ impl Expect { // If we successfully parsed the header then there's always at least one // entry. We want the last entry. let header = headers.iter().last().unwrap(); - ensure_eq_status!(header, "100-continue", 400, "malformed `Expect` header"); + internal_ensure_eq!(header, "100-continue", HeaderError::ExpectInvalid); Ok(Some(Self { _priv: () })) } @@ -69,8 +69,10 @@ impl Header for Expect { #[cfg(test)] mod test { - use super::*; use crate::headers::Headers; + use crate::StatusCode; + + use super::*; #[test] fn smoke() -> crate::Result<()> { @@ -89,6 +91,6 @@ mod test { let mut headers = Headers::new(); headers.insert(EXPECT, "").unwrap(); let err = Expect::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } } diff --git a/src/other/referer.rs b/src/other/referer.rs index 5a860ab9..f8c0edc0 100644 --- a/src/other/referer.rs +++ b/src/other/referer.rs @@ -1,5 +1,6 @@ +use crate::errors::HeaderError; use crate::headers::{Header, HeaderName, HeaderValue, Headers, REFERER}; -use crate::{bail_status as bail, Status, Url}; +use crate::Url; use std::convert::TryInto; @@ -17,7 +18,7 @@ use std::convert::TryInto; /// # Examples /// /// ``` -/// # fn main() -> http_types::Result<()> { +/// # fn main() -> anyhow::Result<()> { /// # /// use http_types::{Response, Url}; /// use http_types::other::Referer; @@ -48,7 +49,7 @@ impl Referer { pub fn from_headers(base_url: U, headers: impl AsRef) -> crate::Result> where U: TryInto, - U::Error: std::fmt::Debug, + U::Error: std::fmt::Debug + Send + Sync + 'static, { let headers = match headers.as_ref().get(REFERER) { Some(headers) => headers, @@ -62,8 +63,10 @@ impl Referer { let url = match Url::parse(header_value.as_str()) { Ok(url) => url, Err(_) => match base_url.try_into() { - Ok(base_url) => base_url.join(header_value.as_str().trim()).status(500)?, - Err(_) => bail!(500, "Invalid base url provided"), + Ok(base_url) => base_url + .join(header_value.as_str().trim()) + .map_err(HeaderError::RefererInvalidUrl)?, + Err(e) => return Err(HeaderError::RefererInvalidBaseUrl(Box::new(e)).into()), }, }; @@ -101,11 +104,12 @@ impl Header for Referer { #[cfg(test)] mod test { - use super::*; use crate::headers::Headers; + use super::*; + #[test] - fn smoke() -> crate::Result<()> { + fn smoke() -> anyhow::Result<()> { let referer = Referer::new(Url::parse("https://example.net/test.json")?); let mut headers = Headers::new(); @@ -128,11 +132,11 @@ mod test { .unwrap(); let err = Referer::from_headers(Url::parse("https://example.net").unwrap(), headers).unwrap_err(); - assert_eq!(err.status(), 500); + assert_eq!(err.associated_status_code(), None); } #[test] - fn fallback_works() -> crate::Result<()> { + fn fallback_works() -> anyhow::Result<()> { let mut headers = Headers::new(); headers.insert(REFERER, "/test.json").unwrap(); diff --git a/src/other/retry_after.rs b/src/other/retry_after.rs index 21404598..b3a3e4ce 100644 --- a/src/other/retry_after.rs +++ b/src/other/retry_after.rs @@ -14,7 +14,7 @@ use crate::utils::{fmt_http_date, parse_http_date}; /// # Examples /// /// ```no_run -/// # fn main() -> http_types::Result<()> { +/// # fn main() -> anyhow::Result<()> { /// # /// use http_types::other::RetryAfter; /// use http_types::Response; @@ -129,7 +129,7 @@ mod test { use crate::headers::Headers; #[test] - fn smoke() -> crate::Result<()> { + fn smoke() -> anyhow::Result<()> { let retry = RetryAfter::new(Duration::from_secs(10)); let mut headers = Headers::new(); @@ -147,7 +147,7 @@ mod test { } #[test] - fn new_at() -> crate::Result<()> { + fn new_at() -> anyhow::Result<()> { let now = SystemTime::now(); let retry = RetryAfter::new_at(now + Duration::from_secs(10)); diff --git a/src/other/source_map.rs b/src/other/source_map.rs index 5412c6e1..1e42ee84 100644 --- a/src/other/source_map.rs +++ b/src/other/source_map.rs @@ -1,5 +1,6 @@ +use crate::errors::HeaderError; use crate::headers::{Header, HeaderName, HeaderValue, Headers, SOURCE_MAP}; -use crate::{bail_status as bail, Status, Url}; +use crate::Url; use std::convert::TryInto; @@ -14,7 +15,7 @@ use std::convert::TryInto; /// # Examples /// /// ``` -/// # fn main() -> http_types::Result<()> { +/// # fn main() -> anyhow::Result<()> { /// # /// use http_types::{Response, Url}; /// use http_types::other::SourceMap; @@ -45,7 +46,7 @@ impl SourceMap { pub fn from_headers(base_url: U, headers: impl AsRef) -> crate::Result> where U: TryInto, - U::Error: std::fmt::Debug, + U::Error: std::fmt::Debug + Send + Sync + 'static, { let headers = match headers.as_ref().get(SOURCE_MAP) { Some(headers) => headers, @@ -59,8 +60,10 @@ impl SourceMap { let url = match Url::parse(header_value.as_str()) { Ok(url) => url, Err(_) => match base_url.try_into() { - Ok(base_url) => base_url.join(header_value.as_str().trim()).status(500)?, - Err(_) => bail!(500, "Invalid base url provided"), + Ok(base_url) => base_url + .join(header_value.as_str().trim()) + .map_err(HeaderError::SourceMapInvalidUrl)?, + Err(e) => return Err(HeaderError::SourceMapInvalidBaseUrl(Box::new(e)).into()), }, }; @@ -98,11 +101,12 @@ impl Header for SourceMap { #[cfg(test)] mod test { - use super::*; use crate::headers::Headers; + use super::*; + #[test] - fn smoke() -> crate::Result<()> { + fn smoke() -> anyhow::Result<()> { let source_map = SourceMap::new(Url::parse("https://example.net/test.json")?); let mut headers = Headers::new(); @@ -125,11 +129,11 @@ mod test { .unwrap(); let err = SourceMap::from_headers(Url::parse("https://example.net").unwrap(), headers) .unwrap_err(); - assert_eq!(err.status(), 500); + assert_eq!(err.associated_status_code(), None); } #[test] - fn fallback_works() -> crate::Result<()> { + fn fallback_works() -> anyhow::Result<()> { let mut headers = Headers::new(); headers.insert(SOURCE_MAP, "/test.json").unwrap(); diff --git a/src/proxies/forwarded.rs b/src/proxies/forwarded.rs index 30e491b0..3c5bdc14 100644 --- a/src/proxies/forwarded.rs +++ b/src/proxies/forwarded.rs @@ -1,6 +1,8 @@ use crate::{ + errors::HeaderError, headers::{Header, HeaderName, HeaderValue, Headers, FORWARDED}, parse_utils::{parse_quoted_string, parse_token}, + Error, Result, }; use std::{borrow::Cow, convert::TryFrom, fmt::Write, net::IpAddr}; @@ -76,7 +78,7 @@ impl<'a> Forwarded<'a> { /// # Ok(()) } /// ``` - pub fn from_headers(headers: &'a impl AsRef) -> Result, ParseError> { + pub fn from_headers(headers: &'a impl AsRef) -> Result> { if let Some(forwarded) = Self::from_forwarded_header(headers)? { Ok(Some(forwarded)) } else { @@ -88,8 +90,8 @@ impl<'a> Forwarded<'a> { /// /// # Examples /// ```rust - /// # use http_types::{proxies::Forwarded, Method::Get, Request, Url, Result}; - /// # fn main() -> Result<()> { + /// # use http_types::{proxies::Forwarded, Method::Get, Request, Url,}; + /// # fn main() -> anyhow::Result<()> { /// let mut request = Request::new(Get, Url::parse("http://_/")?); /// request.insert_header( /// "Forwarded", @@ -101,16 +103,14 @@ impl<'a> Forwarded<'a> { /// # Ok(()) } /// ``` /// ```rust - /// # use http_types::{proxies::Forwarded, Method::Get, Request, Url, Result}; - /// # fn main() -> Result<()> { + /// # use http_types::{proxies::Forwarded, Method::Get, Request, Url}; + /// # fn main() -> anyhow::Result<()> { /// let mut request = Request::new(Get, Url::parse("http://_/")?); /// request.insert_header("X-Forwarded-For", "192.0.2.43, 2001:db8:cafe::17"); /// assert!(Forwarded::from_forwarded_header(&request)?.is_none()); /// # Ok(()) } /// ``` - pub fn from_forwarded_header( - headers: &'a impl AsRef, - ) -> Result, ParseError> { + pub fn from_forwarded_header(headers: &'a impl AsRef) -> Result> { if let Some(headers) = headers.as_ref().get(FORWARDED) { Ok(Some(Self::parse(headers.as_ref().as_str())?)) } else { @@ -124,8 +124,8 @@ impl<'a> Forwarded<'a> { /// /// # Examples /// ```rust - /// # use http_types::{proxies::Forwarded, Method::Get, Request, Url, Result}; - /// # fn main() -> Result<()> { + /// # use http_types::{proxies::Forwarded, Method::Get, Request, Url}; + /// # fn main() -> anyhow::Result<()> { /// let mut request = Request::new(Get, Url::parse("http://_/")?); /// request.insert_header("X-Forwarded-For", "192.0.2.43, 2001:db8:cafe::17"); /// let forwarded = Forwarded::from_headers(&request)?.unwrap(); @@ -133,8 +133,8 @@ impl<'a> Forwarded<'a> { /// # Ok(()) } /// ``` /// ```rust - /// # use http_types::{proxies::Forwarded, Method::Get, Request, Url, Result}; - /// # fn main() -> Result<()> { + /// # use http_types::{proxies::Forwarded, Method::Get, Request, Url}; + /// # fn main() -> anyhow::Result<()> { /// let mut request = Request::new(Get, Url::parse("http://_/")?); /// request.insert_header( /// "Forwarded", @@ -143,7 +143,7 @@ impl<'a> Forwarded<'a> { /// assert!(Forwarded::from_x_headers(&request)?.is_none()); /// # Ok(()) } /// ``` - pub fn from_x_headers(headers: &'a impl AsRef) -> Result, ParseError> { + pub fn from_x_headers(headers: &'a impl AsRef) -> Result> { let headers = headers.as_ref(); let forwarded_for: Vec> = headers @@ -204,7 +204,7 @@ impl<'a> Forwarded<'a> { /// ); /// # Ok(()) } /// ``` - pub fn parse(input: &'a str) -> Result { + pub fn parse(input: &'a str) -> Result { let mut input = input; let mut forwarded = Forwarded::new(); @@ -219,7 +219,7 @@ impl<'a> Forwarded<'a> { Ok(forwarded) } - fn parse_forwarded_pair(&mut self, input: &'a str) -> Result<&'a str, ParseError> { + fn parse_forwarded_pair(&mut self, input: &'a str) -> Result<&'a str> { let (key, value, rest) = match parse_token(input) { (Some(key), rest) if rest.starts_with('=') => match parse_value(&rest[1..]) { (Some(value), rest) => Some((key, value, rest)), @@ -227,26 +227,35 @@ impl<'a> Forwarded<'a> { }, _ => None, } - .ok_or_else(|| ParseError::new("parse error in forwarded-pair"))?; + .ok_or(HeaderError::ForwardedInvalid( + "parse error in forwarded-pair", + ))?; match key { "by" => { if self.by.is_some() { - return Err(ParseError::new("parse error, duplicate `by` key")); + return Err( + HeaderError::ForwardedInvalid("parse error, duplicate `by` key").into(), + ); } self.by = Some(value); } "host" => { if self.host.is_some() { - return Err(ParseError::new("parse error, duplicate `host` key")); + return Err( + HeaderError::ForwardedInvalid("parse error, duplicate `host` key").into(), + ); } self.host = Some(value); } "proto" => { if self.proto.is_some() { - return Err(ParseError::new("parse error, duplicate `proto` key")); + return Err(HeaderError::ForwardedInvalid( + "parse error, duplicate `proto` key", + ) + .into()); } self.proto = Some(value); } @@ -260,13 +269,17 @@ impl<'a> Forwarded<'a> { } } - fn parse_for(&mut self, input: &'a str) -> Result<&'a str, ParseError> { + fn parse_for(&mut self, input: &'a str) -> Result<&'a str> { let mut rest = input; loop { rest = match match_ignore_case("for=", rest) { (true, rest) => rest, - (false, _) => return Err(ParseError::new("http list must start with for=")), + (false, _) => { + return Err( + HeaderError::ForwardedInvalid("http list must start with for=").into(), + ) + } }; let (value, rest_) = parse_value(rest); @@ -276,7 +289,7 @@ impl<'a> Forwarded<'a> { // add a successful for= value self.forwarded_for.push(value); } else { - return Err(ParseError::new("for= without valid value")); + return Err(HeaderError::ForwardedInvalid("for= without valid value").into()); } match rest.chars().next() { @@ -292,7 +305,12 @@ impl<'a> Forwarded<'a> { None => return Ok(rest), // bail - _ => return Err(ParseError::new("unexpected character after for= section")), + _ => { + return Err(HeaderError::ForwardedInvalid( + "unexpected character after for= section", + ) + .into()) + } } } } @@ -443,24 +461,9 @@ impl std::fmt::Display for Forwarded<'_> { } } -#[derive(Debug, Clone)] -pub struct ParseError(&'static str); -impl ParseError { - pub fn new(msg: &'static str) -> Self { - Self(msg) - } -} - -impl std::error::Error for ParseError {} -impl std::fmt::Display for ParseError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "unable to parse forwarded header: {}", self.0) - } -} - impl<'a> TryFrom<&'a str> for Forwarded<'a> { - type Error = ParseError; - fn try_from(value: &'a str) -> Result { + type Error = Error; + fn try_from(value: &'a str) -> std::result::Result { Self::parse(value) } } @@ -468,7 +471,7 @@ impl<'a> TryFrom<&'a str> for Forwarded<'a> { #[cfg(test)] mod tests { use super::*; - use crate::{Method::Get, Request, Response, Result}; + use crate::{Method::Get, Request, Response}; use url::Url; #[test] @@ -536,31 +539,31 @@ mod tests { let err = Forwarded::parse("by=proxy.com;for=client;host=example.com;host").unwrap_err(); assert_eq!( err.to_string(), - "unable to parse forwarded header: parse error in forwarded-pair" + "Header error: Forwarded header was invalid: parse error in forwarded-pair" ); let err = Forwarded::parse("by;for;host;proto").unwrap_err(); assert_eq!( err.to_string(), - "unable to parse forwarded header: parse error in forwarded-pair" + "Header error: Forwarded header was invalid: parse error in forwarded-pair" ); let err = Forwarded::parse("for=for, key=value").unwrap_err(); assert_eq!( err.to_string(), - "unable to parse forwarded header: http list must start with for=" + "Header error: Forwarded header was invalid: http list must start with for=" ); let err = Forwarded::parse(r#"for="unterminated string"#).unwrap_err(); assert_eq!( err.to_string(), - "unable to parse forwarded header: for= without valid value" + "Header error: Forwarded header was invalid: for= without valid value" ); let err = Forwarded::parse(r#"for=, for=;"#).unwrap_err(); assert_eq!( err.to_string(), - "unable to parse forwarded header: for= without valid value" + "Header error: Forwarded header was invalid: for= without valid value" ); } @@ -570,7 +573,7 @@ mod tests { response.append_header("forwarded", "uh oh").unwrap(); assert_eq!( Forwarded::from_headers(&response).unwrap_err().to_string(), - "unable to parse forwarded header: parse error in forwarded-pair" + "Header error: Forwarded header was invalid: parse error in forwarded-pair" ); let response = Response::new(200); @@ -579,7 +582,7 @@ mod tests { } #[test] - fn from_x_headers() -> Result<()> { + fn from_x_headers() -> anyhow::Result<()> { let mut request = Request::new(Get, Url::parse("http://_/")?); request .append_header(X_FORWARDED_FOR, "192.0.2.43, 2001:db8:cafe::17") @@ -634,7 +637,7 @@ mod tests { } #[test] - fn from_request() -> Result<()> { + fn from_request() -> anyhow::Result<()> { let mut request = Request::new(Get, Url::parse("http://_/")?); request.append_header("Forwarded", "for=for").unwrap(); @@ -645,7 +648,7 @@ mod tests { } #[test] - fn owned_can_outlive_request() -> Result<()> { + fn owned_can_outlive_request() -> anyhow::Result<()> { let forwarded = { let mut request = Request::new(Get, Url::parse("http://_/")?); request @@ -664,7 +667,7 @@ mod tests { "by=proxy.com;proto=https;host=example.com;for=a,for=b", ]; for input in inputs { - let forwarded = Forwarded::parse(input).map_err(|_| crate::Error::new_adhoc(input))?; + let forwarded = Forwarded::parse(input)?; let header = forwarded.header_value(); let parsed = Forwarded::parse(header.as_str())?; assert_eq!(forwarded, parsed); diff --git a/src/request.rs b/src/request.rs index 6b3bb355..fcd68916 100644 --- a/src/request.rs +++ b/src/request.rs @@ -8,6 +8,7 @@ use std::task::{Context, Poll}; #[cfg(feature = "serde")] use crate::convert::{DeserializeOwned, Serialize}; +use crate::errors::Error; use crate::headers::{ self, HeaderName, HeaderValue, HeaderValues, Headers, Names, ToHeaderValues, Values, CONTENT_TYPE, @@ -659,11 +660,7 @@ impl Request { // This allows successful deserialisation of structs where all fields are optional // when none of those fields has actually been passed by the caller. let query = self.url().query().unwrap_or(""); - serde_qs::from_str(query).map_err(|e| { - // Return the displayable version of the deserialisation error to the caller - // for easier debugging. - crate::Error::from_str(crate::StatusCode::BadRequest, format!("{}", e)) - }) + serde_qs::from_str(query).map_err(Error::QueryDeserialize) } /// Set the URL querystring. @@ -689,8 +686,7 @@ impl Request { /// ``` #[cfg(feature = "serde")] pub fn set_query(&mut self, query: &impl Serialize) -> crate::Result<()> { - let query = serde_qs::to_string(query) - .map_err(|e| crate::Error::from_str(crate::StatusCode::BadRequest, format!("{}", e)))?; + let query = serde_qs::to_string(query).map_err(Error::QuerySerialize)?; self.url.set_query(Some(&query)); Ok(()) } diff --git a/src/security/strict_transport_security.rs b/src/security/strict_transport_security.rs index f6c4fe4c..0f25d763 100644 --- a/src/security/strict_transport_security.rs +++ b/src/security/strict_transport_security.rs @@ -1,5 +1,5 @@ +use crate::errors::HeaderError; use crate::headers::{Header, HeaderName, HeaderValue, Headers}; -use crate::Status; use crate::headers::STRICT_TRANSPORT_SECURITY; use std::time::Duration; @@ -122,7 +122,11 @@ impl StrictTransportSecurity { }; if key == "max-age" { - let secs = value.parse::().status(400)?; + let secs = value.parse::().map_err(|_| { + HeaderError::StrictTransportSecurityInvalid( + "`max-age` directive not a valid u64", + ) + })?; max_age = Some(Duration::from_secs(secs)); } } @@ -131,10 +135,10 @@ impl StrictTransportSecurity { let max_age = match max_age { Some(max_age) => max_age, None => { - return Err(crate::format_err_status!( - 400, - "`Strict-Transport-Security` header did not contain a `max-age` directive", - )); + return Err(HeaderError::StrictTransportSecurityInvalid( + "did not contain a `max-age` directive", + ) + .into()); } }; @@ -160,10 +164,12 @@ impl From for StrictTransportSecurity { #[cfg(test)] mod test { - use super::*; - use crate::Response; use std::time::Duration; + use crate::{Response, StatusCode}; + + use super::*; + #[test] fn smoke() -> crate::Result<()> { let duration = Duration::from_secs(30); @@ -187,7 +193,7 @@ mod test { .insert_header(STRICT_TRANSPORT_SECURITY, "") .unwrap(); let err = StrictTransportSecurity::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } #[test] @@ -197,7 +203,7 @@ mod test { .insert_header(STRICT_TRANSPORT_SECURITY, "max-age=birds") .unwrap(); let err = StrictTransportSecurity::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } #[test] diff --git a/src/security/timing_allow_origin.rs b/src/security/timing_allow_origin.rs index 443427fb..922e9526 100644 --- a/src/security/timing_allow_origin.rs +++ b/src/security/timing_allow_origin.rs @@ -8,7 +8,7 @@ //! # Examples //! //! ``` -//! # fn main() -> http_types::Result<()> { +//! # fn main() -> anyhow::Result<()> { //! # //! use http_types::{Response, Url, headers::Header}; //! use http_types::security::TimingAllowOrigin; @@ -26,8 +26,9 @@ //! # Ok(()) } //! ``` +use crate::errors::HeaderError; use crate::headers::{Header, HeaderName, HeaderValue, Headers, TIMING_ALLOW_ORIGIN}; -use crate::{Status, Url}; +use crate::Url; use std::fmt::Write; use std::fmt::{self, Debug}; @@ -40,7 +41,7 @@ use std::slice; /// # Examples /// /// ``` -/// # fn main() -> http_types::Result<()> { +/// # fn main() -> anyhow::Result<()> { /// # /// use http_types::{Response, Url}; /// use http_types::security::TimingAllowOrigin; @@ -91,7 +92,8 @@ impl TimingAllowOrigin { "*" => wildcard = true, r#""null""# => continue, origin => { - let url = Url::parse(origin).status(400)?; + let url = + Url::parse(origin).map_err(HeaderError::TimingAllowOriginInvalidUrl)?; origins.push(url); } } @@ -257,11 +259,12 @@ impl Debug for TimingAllowOrigin { #[cfg(test)] mod test { + use crate::{headers::Headers, StatusCode}; + use super::*; - use crate::headers::Headers; #[test] - fn smoke() -> crate::Result<()> { + fn smoke() -> anyhow::Result<()> { let mut origins = TimingAllowOrigin::new(); origins.push(Url::parse("https://example.com")?); @@ -275,7 +278,7 @@ mod test { } #[test] - fn multi() -> crate::Result<()> { + fn multi() -> anyhow::Result<()> { let mut origins = TimingAllowOrigin::new(); origins.push(Url::parse("https://example.com")?); origins.push(Url::parse("https://mozilla.org/")?); @@ -300,11 +303,11 @@ mod test { .insert(TIMING_ALLOW_ORIGIN, "server; ") .unwrap(); let err = TimingAllowOrigin::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } #[test] - fn wildcard() -> crate::Result<()> { + fn wildcard() -> anyhow::Result<()> { let mut origins = TimingAllowOrigin::new(); origins.push(Url::parse("https://example.com")?); origins.set_wildcard(true); diff --git a/src/status.rs b/src/status.rs index cf7baa1e..58725165 100644 --- a/src/status.rs +++ b/src/status.rs @@ -1,24 +1,24 @@ -use crate::{Error, StatusCode}; use core::convert::{Infallible, TryInto}; use std::error::Error as StdError; -use std::fmt::Debug; + +use crate::{ResponseError, StatusCode}; /// Provides the `status` method for `Result` and `Option`. /// /// This trait is sealed and cannot be implemented outside of `http-types`. pub trait Status: private::Sealed { /// Wrap the error value with an additional status code. - fn status(self, status: S) -> Result + fn status(self, status: S) -> Result where S: TryInto, - S::Error: Debug; + S::Error: StdError + Send + Sync + 'static; /// Wrap the error value with an additional status code that is evaluated /// lazily only once an error does occur. - fn with_status(self, f: F) -> Result + fn with_status(self, f: F) -> Result where S: TryInto, - S::Error: Debug, + S::Error: StdError + Send + Sync + 'static, F: FnOnce() -> S; } @@ -34,61 +34,12 @@ where /// /// [status]: crate::Status /// [statuscode]: crate::StatusCode - fn status(self, status: S) -> Result - where - S: TryInto, - S::Error: Debug, - { - self.map_err(|error| { - let status = status - .try_into() - .expect("Could not convert into a valid `StatusCode`"); - Error::new(status, error) - }) - } - - /// Wrap the error value with an additional status code that is evaluated - /// lazily only once an error does occur. - /// - /// # Panics - /// - /// Panics if [`Status`][status] is not a valid [`StatusCode`][statuscode]. - /// - /// [status]: crate::Status - /// [statuscode]: crate::StatusCode - fn with_status(self, f: F) -> Result - where - S: TryInto, - S::Error: Debug, - F: FnOnce() -> S, - { - self.map_err(|error| { - let status = f() - .try_into() - .expect("Could not convert into a valid `StatusCode`"); - Error::new(status, error) - }) - } -} - -impl Status for Result { - /// Wrap the error value with an additional status code. - /// - /// # Panics - /// - /// Panics if [`Status`][status] is not a valid [`StatusCode`][statuscode]. - /// - /// [status]: crate::Status - /// [statuscode]: crate::StatusCode - fn status(self, status: S) -> Result + fn status(self, status: S) -> Result where S: TryInto, - S::Error: Debug, + S::Error: StdError + Send + Sync + 'static, { - self.map_err(|mut error| { - error.set_status(status); - error - }) + self.map_err(|error| ResponseError::new_status(status, error)) } /// Wrap the error value with an additional status code that is evaluated @@ -100,16 +51,13 @@ impl Status for Result { /// /// [status]: crate::Status /// [statuscode]: crate::StatusCode - fn with_status(self, f: F) -> Result + fn with_status(self, f: F) -> Result where S: TryInto, - S::Error: Debug, + S::Error: StdError + Send + Sync + 'static, F: FnOnce() -> S, { - self.map_err(|mut error| { - error.set_status(f()); - error - }) + self.map_err(|error| ResponseError::new_status(f(), error)) } } @@ -122,17 +70,12 @@ impl Status for Option { /// /// [status]: crate::Status /// [statuscode]: crate::StatusCode - fn status(self, status: S) -> Result + fn status(self, status: S) -> Result where S: TryInto, - S::Error: Debug, + S::Error: StdError + Send + Sync + 'static, { - self.ok_or_else(|| { - let status = status - .try_into() - .expect("Could not convert into a valid `StatusCode`"); - Error::from_str(status, "NoneError") - }) + self.ok_or_else(|| ResponseError::from_str_status(status, "NoneError")) } /// Wrap the error value with an additional status code that is evaluated @@ -144,18 +87,13 @@ impl Status for Option { /// /// [status]: crate::Status /// [statuscode]: crate::StatusCode - fn with_status(self, f: F) -> Result + fn with_status(self, f: F) -> Result where S: TryInto, - S::Error: Debug, + S::Error: StdError + Send + Sync + 'static, F: FnOnce() -> S, { - self.ok_or_else(|| { - let status = f() - .try_into() - .expect("Could not convert into a valid `StatusCode`"); - Error::from_str(status, "NoneError") - }) + self.ok_or_else(|| ResponseError::from_str_status(f(), "NoneError")) } } diff --git a/src/status_code.rs b/src/status_code.rs index a86d8ac6..c7f5bb51 100644 --- a/src/status_code.rs +++ b/src/status_code.rs @@ -1,5 +1,7 @@ use std::fmt::{self, Debug, Display}; +use crate::Error; + /// HTTP response status codes. /// /// As defined by [rfc7231 section 6](https://tools.ietf.org/html/rfc7231#section-6). @@ -531,6 +533,12 @@ impl StatusCode { } } +impl Default for StatusCode { + fn default() -> Self { + Self::InternalServerError + } +} + #[cfg(feature = "serde")] mod serde { use super::StatusCode; @@ -624,7 +632,7 @@ impl From for u16 { } impl std::convert::TryFrom for StatusCode { - type Error = crate::Error; + type Error = Error; fn try_from(num: u16) -> Result { match num { @@ -687,7 +695,7 @@ impl std::convert::TryFrom for StatusCode { 508 => Ok(StatusCode::LoopDetected), 510 => Ok(StatusCode::NotExtended), 511 => Ok(StatusCode::NetworkAuthenticationRequired), - _ => crate::bail!("Invalid status code"), + _ => Err(Error::StatusCodeInvalid(num)), } } } diff --git a/src/trace/server_timing/metric.rs b/src/trace/server_timing/metric.rs index ed55c189..9e23df42 100644 --- a/src/trace/server_timing/metric.rs +++ b/src/trace/server_timing/metric.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use crate::headers::HeaderValue; +use crate::{errors::HeaderError, headers::HeaderValue}; /// An individual entry into `ServerTiming`. // @@ -28,9 +28,15 @@ impl Metric { /// /// An error will be returned if the string values are invalid ASCII. pub fn new(name: String, dur: Option, desc: Option) -> crate::Result { - crate::ensure!(name.is_ascii(), "Name should be valid ASCII"); + internal_ensure!( + name.is_ascii(), + HeaderError::ServerTimingInvalidMetric("Name should be valid ASCII") + ); if let Some(desc) = desc.as_ref() { - crate::ensure!(desc.is_ascii(), "Description should be valid ASCII"); + internal_ensure!( + desc.is_ascii(), + HeaderError::ServerTimingInvalidMetric("Description should be valid ASCII") + ); }; Ok(Self { name, dur, desc }) diff --git a/src/trace/server_timing/mod.rs b/src/trace/server_timing/mod.rs index ddd6c79a..d83315f9 100644 --- a/src/trace/server_timing/mod.rs +++ b/src/trace/server_timing/mod.rs @@ -216,8 +216,10 @@ impl<'a> Iterator for IterMut<'a> { #[cfg(test)] mod test { - use super::*; use crate::headers::Headers; + use crate::StatusCode; + + use super::*; #[test] fn smoke() -> crate::Result<()> { @@ -254,6 +256,6 @@ mod test { .insert(SERVER_TIMING, "server; ") .unwrap(); let err = ServerTiming::from_headers(headers).unwrap_err(); - assert_eq!(err.status(), 400); + assert_eq!(err.associated_status_code(), Some(StatusCode::BadRequest)); } } diff --git a/src/trace/server_timing/parse.rs b/src/trace/server_timing/parse.rs index 4469336a..35e6ecc9 100644 --- a/src/trace/server_timing/parse.rs +++ b/src/trace/server_timing/parse.rs @@ -1,17 +1,14 @@ use std::time::Duration; use super::Metric; -use crate::{ensure, format_err, StatusCode}; +use crate::errors::HeaderError; /// Parse multiple entries from a single header. /// /// Each entry is comma-delimited. pub(super) fn parse_header(s: &str, entries: &mut Vec) -> crate::Result<()> { for part in s.trim().split(',') { - let entry = parse_entry(part).map_err(|mut e| { - e.set_status(StatusCode::BadRequest); - e - })?; + let entry = parse_entry(part)?; entries.push(entry); } Ok(()) @@ -35,7 +32,7 @@ fn parse_entry(s: &str) -> crate::Result { // Get the name. This is non-optional. let name = parts .next() - .ok_or_else(|| format_err!("Server timing headers must include a name"))? + .ok_or(HeaderError::ServerTimingInvalid("must include a name"))? .trim_end(); // We must extract these values from the k-v pairs that follow. @@ -43,9 +40,9 @@ fn parse_entry(s: &str) -> crate::Result { let mut desc = None; for mut part in parts { - ensure!( + internal_ensure!( !part.is_empty(), - "Server timing params cannot end with a trailing `;`" + HeaderError::ServerTimingInvalid("params cannot end with a trailing `;`") ); part = part.trim_start(); @@ -53,33 +50,39 @@ fn parse_entry(s: &str) -> crate::Result { let mut params = part.split('='); let name = params .next() - .ok_or_else(|| format_err!("Server timing params must have a name"))? + .ok_or(HeaderError::ServerTimingInvalid("params must have a name"))? .trim_end(); let mut value = params .next() - .ok_or_else(|| format_err!("Server timing params must have a value"))? + .ok_or(HeaderError::ServerTimingInvalid("params must have a value"))? .trim_start(); match name { "dur" => { let millis: f64 = value.parse().map_err(|_| { - format_err!("Server timing duration params must be a valid double-precision floating-point number.") - })?; + HeaderError::ServerTimingInvalid( + "duration params must be a valid double-precision floating-point number.", + ) + })?; dur = Some(Duration::from_secs_f64(millis) / 1000); } "desc" => { // Ensure quotes line up, and strip them from the resulting output if value.starts_with('"') { value = &value[1..value.len()]; - ensure!( + internal_ensure!( value.ends_with('"'), - "Server timing description params must use matching quotes" + HeaderError::ServerTimingInvalid( + "description params must use matching quotes" + ) ); value = &value[0..value.len() - 1]; } else { - ensure!( + internal_ensure!( !value.ends_with('"'), - "Server timing description params must use matching quotes" + HeaderError::ServerTimingInvalid( + "description params must use matching quotes" + ) ); } desc = Some(value.to_string()); @@ -106,11 +109,11 @@ mod test { assert_entry("Server ", "Server", None, None)?; assert_entry_err( "Server ;", - "Server timing params cannot end with a trailing `;`", + "Header error: Server-Timing header was invalid: params cannot end with a trailing `;`", ); assert_entry_err( "Server; ", - "Server timing params cannot end with a trailing `;`", + "Header error: Server-Timing header was invalid: params cannot end with a trailing `;`", ); // Metric name + param @@ -120,7 +123,7 @@ mod test { assert_entry("Server; dur = 1000", "Server", Some(1000), None)?; assert_entry_err( "Server; dur=1000;", - "Server timing params cannot end with a trailing `;`", + "Header error: Server-Timing header was invalid: params cannot end with a trailing `;`", ); // Metric name + desc @@ -131,11 +134,11 @@ mod test { assert_entry(r#"DB; desc=a_db"#, "DB", None, Some("a_db"))?; assert_entry_err( r#"DB; desc="db"#, - "Server timing description params must use matching quotes", + "Header error: Server-Timing header was invalid: description params must use matching quotes", ); assert_entry_err( "Server; desc=a_db;", - "Server timing params cannot end with a trailing `;`", + "Header error: Server-Timing header was invalid: params cannot end with a trailing `;`", ); // Metric name + dur + desc @@ -147,7 +150,7 @@ mod test { )?; assert_entry_err( r#"Server; dur=1000; desc="a server";"#, - "Server timing params cannot end with a trailing `;`", + "Header error: Server-Timing header was invalid: params cannot end with a trailing `;`", ); Ok(()) } diff --git a/src/trace/trace_context.rs b/src/trace/trace_context.rs index 3bc062f7..5c746e61 100644 --- a/src/trace/trace_context.rs +++ b/src/trace/trace_context.rs @@ -1,7 +1,7 @@ use std::fmt; +use crate::errors::HeaderError; use crate::headers::{Header, HeaderName, HeaderValue, Headers, TRACEPARENT}; -use crate::Status; /// Extract and apply [Trace-Context](https://w3c.github.io/trace-context/) headers. /// @@ -110,10 +110,16 @@ impl TraceContext { Ok(Some(Self { id: fastrand::u64(..), - version: u8::from_str_radix(parts[0], 16)?, - trace_id: u128::from_str_radix(parts[1], 16).status(400)?, - parent_id: Some(u64::from_str_radix(parts[2], 16).status(400)?), - flags: u8::from_str_radix(parts[3], 16).status(400)?, + version: u8::from_str_radix(parts[0], 16) + .map_err(|_| HeaderError::TraceContextInvalid("version"))?, + trace_id: u128::from_str_radix(parts[1], 16) + .map_err(|_| HeaderError::TraceContextInvalid("trace_id"))?, + parent_id: Some( + u64::from_str_radix(parts[2], 16) + .map_err(|_| HeaderError::TraceContextInvalid("parent_id"))?, + ), + flags: u8::from_str_radix(parts[3], 16) + .map_err(|_| HeaderError::TraceContextInvalid("flags"))?, })) } diff --git a/src/transfer/encoding_proposal.rs b/src/transfer/encoding_proposal.rs index 72f083d5..14288d38 100644 --- a/src/transfer/encoding_proposal.rs +++ b/src/transfer/encoding_proposal.rs @@ -1,4 +1,4 @@ -use crate::ensure; +use crate::errors::EncodingError; use crate::headers::HeaderValue; use crate::transfer::Encoding; use crate::utils::parse_weight; @@ -22,11 +22,11 @@ pub struct EncodingProposal { impl EncodingProposal { /// Create a new instance of `EncodingProposal`. - pub fn new(encoding: impl Into, weight: Option) -> crate::Result { + pub fn new(encoding: impl Into, weight: Option) -> Result { if let Some(weight) = weight { - ensure!( + internal_ensure!( weight.is_sign_positive() && weight <= 1.0, - "EncodingProposal should have a weight between 0.0 and 1.0" + EncodingError::Proposal("should have a weight between 0.0 and 1.0") ) } @@ -46,13 +46,17 @@ impl EncodingProposal { self.weight } - pub(crate) fn from_str(s: &str) -> crate::Result> { + pub(crate) fn from_str(s: &str) -> Result, EncodingError> { let mut parts = s.split(';'); let encoding = match Encoding::from_str(parts.next().unwrap()) { Some(encoding) => encoding, None => return Ok(None), }; - let weight = parts.next().map(parse_weight).transpose()?; + let weight = parts + .next() + .map(parse_weight) + .transpose() + .map_err(|_| EncodingError::Proposal("weight not a valid float32"))?; Ok(Some(Self::new(encoding, weight)?)) } @@ -130,16 +134,4 @@ mod test { let _ = EncodingProposal::new(Encoding::Gzip, Some(0.5)).unwrap(); let _ = EncodingProposal::new(Encoding::Gzip, Some(1.0)).unwrap(); } - - #[test] - fn error_code_500() { - let err = EncodingProposal::new(Encoding::Gzip, Some(1.1)).unwrap_err(); - assert_eq!(err.status(), 500); - - let err = EncodingProposal::new(Encoding::Gzip, Some(-0.1)).unwrap_err(); - assert_eq!(err.status(), 500); - - let err = EncodingProposal::new(Encoding::Gzip, Some(-0.0)).unwrap_err(); - assert_eq!(err.status(), 500); - } } diff --git a/src/transfer/te.rs b/src/transfer/te.rs index acd1af80..c3412570 100644 --- a/src/transfer/te.rs +++ b/src/transfer/te.rs @@ -1,7 +1,7 @@ +use crate::errors::HeaderError; use crate::headers::{self, Header, HeaderName, HeaderValue, Headers}; use crate::transfer::{Encoding, EncodingProposal, TransferEncoding}; use crate::utils::sort_by_weight; -use crate::{Error, StatusCode}; use std::fmt::{self, Debug, Write}; @@ -19,7 +19,7 @@ use std::slice; /// # Examples /// /// ``` -/// # fn main() -> http_types::Result<()> { +/// # fn main() -> anyhow::Result<()> { /// # /// use http_types::transfer::{TE, TransferEncoding, Encoding, EncodingProposal}; /// use http_types::Response; @@ -75,7 +75,9 @@ impl TE { // Try and parse a directive from a str. If the directive is // unkown we skip it. - if let Some(entry) = EncodingProposal::from_str(part)? { + if let Some(entry) = EncodingProposal::from_str(part) + .map_err(HeaderError::TransferEncodingInvalidEncoding)? + { entries.push(entry); } } @@ -131,9 +133,7 @@ impl TE { } } - let mut err = Error::new_adhoc("No suitable Transfer-Encoding found"); - err.set_status(StatusCode::NotAcceptable); - Err(err) + Err(HeaderError::TransferEncodingUnnegotiable.into()) } /// An iterator visiting all entries. @@ -281,7 +281,7 @@ impl Debug for TE { mod test { use super::*; use crate::transfer::Encoding; - use crate::Response; + use crate::{Response, StatusCode}; #[test] fn smoke() -> crate::Result<()> { @@ -341,7 +341,7 @@ mod test { } #[test] - fn reorder_based_on_weight() -> crate::Result<()> { + fn reorder_based_on_weight() -> anyhow::Result<()> { let mut accept = TE::new(); accept.push(EncodingProposal::new(Encoding::Gzip, Some(0.4))?); accept.push(EncodingProposal::new(Encoding::Identity, None)?); @@ -360,7 +360,7 @@ mod test { } #[test] - fn reorder_based_on_weight_and_location() -> crate::Result<()> { + fn reorder_based_on_weight_and_location() -> anyhow::Result<()> { let mut accept = TE::new(); accept.push(EncodingProposal::new(Encoding::Identity, None)?); accept.push(EncodingProposal::new(Encoding::Gzip, None)?); @@ -379,7 +379,7 @@ mod test { } #[test] - fn negotiate() -> crate::Result<()> { + fn negotiate() -> anyhow::Result<()> { let mut accept = TE::new(); accept.push(EncodingProposal::new(Encoding::Brotli, Some(0.8))?); accept.push(EncodingProposal::new(Encoding::Gzip, Some(0.4))?); @@ -393,20 +393,26 @@ mod test { } #[test] - fn negotiate_not_acceptable() -> crate::Result<()> { + fn negotiate_not_acceptable() -> anyhow::Result<()> { let mut accept = TE::new(); let err = accept.negotiate(&[Encoding::Gzip]).unwrap_err(); - assert_eq!(err.status(), 406); + assert_eq!( + err.associated_status_code(), + Some(StatusCode::NotAcceptable) + ); let mut accept = TE::new(); accept.push(EncodingProposal::new(Encoding::Brotli, Some(0.8))?); let err = accept.negotiate(&[Encoding::Gzip]).unwrap_err(); - assert_eq!(err.status(), 406); + assert_eq!( + err.associated_status_code(), + Some(StatusCode::NotAcceptable) + ); Ok(()) } #[test] - fn negotiate_wildcard() -> crate::Result<()> { + fn negotiate_wildcard() -> anyhow::Result<()> { let mut accept = TE::new(); accept.push(EncodingProposal::new(Encoding::Brotli, Some(0.8))?); accept.set_wildcard(true); diff --git a/src/utils/date.rs b/src/utils/date.rs index 7f9f382b..00d9b194 100644 --- a/src/utils/date.rs +++ b/src/utils/date.rs @@ -2,8 +2,7 @@ use std::fmt::{self, Display, Formatter}; use std::str::{from_utf8, FromStr}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use crate::StatusCode; -use crate::{bail, ensure, format_err}; +use crate::errors::DateError; const IMF_FIXDATE_LENGTH: usize = 29; const RFC850_MAX_LENGTH: usize = 23; @@ -14,7 +13,7 @@ const SECONDS_IN_DAY: u64 = 86400; const SECONDS_IN_HOUR: u64 = 3600; /// Format using the `Display` trait. -/// Convert timestamp into/from `SytemTime` to use. +/// Convert timestamp into/from `SystemTime` to use. /// Supports comparison and sorting. #[derive(Copy, Clone, Debug, Eq)] pub(crate) struct HttpDate { @@ -39,11 +38,8 @@ pub(crate) struct HttpDate { /// Supports the preferred IMF-fixdate and the legacy RFC 805 and /// ascdate formats. Two digit years are mapped to dates between /// 1970 and 2069. -pub(crate) fn parse_http_date(s: &str) -> crate::Result { - s.parse::().map(|d| d.into()).map_err(|mut e| { - e.set_status(StatusCode::BadRequest); - e - }) +pub(crate) fn parse_http_date(s: &str) -> Result { + s.parse::().map(|d| d.into()) } /// Format a date to be used in a HTTP header field. @@ -69,7 +65,16 @@ impl HttpDate { } } -fn parse_imf_fixdate(s: &[u8]) -> crate::Result { +macro_rules! maybe_parse_number { + ($from:expr, $mapped_err:expr) => { + from_utf8($from) + .map_err(|_| $mapped_err)? + .parse() + .map_err(|_| $mapped_err)? + }; +} + +fn parse_imf_fixdate(s: &[u8]) -> Result { // Example: `Sun, 06 Nov 1994 08:49:37 GMT` if s.len() != IMF_FIXDATE_LENGTH || &s[25..] != b" GMT" @@ -77,13 +82,13 @@ fn parse_imf_fixdate(s: &[u8]) -> crate::Result { || s[19] != b':' || s[22] != b':' { - bail!("Date time not in imf fixdate format"); + return Err(DateError::FormatInvalid("imf fixdate")); } Ok(HttpDate { - second: from_utf8(&s[23..25])?.parse()?, - minute: from_utf8(&s[20..22])?.parse()?, - hour: from_utf8(&s[17..19])?.parse()?, - day: from_utf8(&s[5..7])?.parse()?, + second: maybe_parse_number!(&s[23..25], DateError::SecondsInvalid), + minute: maybe_parse_number!(&s[20..22], DateError::MinutesInvalid), + hour: maybe_parse_number!(&s[17..19], DateError::HourInvalid), + day: maybe_parse_number!(&s[5..7], DateError::DayInvalid), month: match &s[7..12] { b" Jan " => 1, b" Feb " => 2, @@ -97,9 +102,9 @@ fn parse_imf_fixdate(s: &[u8]) -> crate::Result { b" Oct " => 10, b" Nov " => 11, b" Dec " => 12, - _ => bail!("Invalid Month"), + _ => return Err(DateError::MonthInvalid), }, - year: from_utf8(&s[12..16])?.parse()?, + year: maybe_parse_number!(&s[12..16], DateError::YearInvalid), week_day: match &s[..5] { b"Mon, " => 1, b"Tue, " => 2, @@ -108,17 +113,13 @@ fn parse_imf_fixdate(s: &[u8]) -> crate::Result { b"Fri, " => 5, b"Sat, " => 6, b"Sun, " => 7, - _ => bail!("Invalid Day"), + _ => return Err(DateError::WeekdayInvalid), }, }) } -fn parse_rfc850_date(s: &[u8]) -> crate::Result { +fn parse_rfc850_date(s: &[u8]) -> Result { // Example: `Sunday, 06-Nov-94 08:49:37 GMT` - ensure!( - s.len() >= RFC850_MAX_LENGTH, - "Date time not in rfc850 format" - ); fn week_day<'a>(s: &'a [u8], week_day: u8, name: &'static [u8]) -> Option<(u8, &'a [u8])> { if &s[0..name.len()] == name { @@ -133,21 +134,26 @@ fn parse_rfc850_date(s: &[u8]) -> crate::Result { .or_else(|| week_day(s, 5, b"Friday, ")) .or_else(|| week_day(s, 6, b"Saturday, ")) .or_else(|| week_day(s, 7, b"Sunday, ")) - .ok_or_else(|| format_err!("Invalid day"))?; + .ok_or(DateError::FormatInvalid("rfc850"))?; + + if s.len() >= RFC850_MAX_LENGTH { + return Err(DateError::FormatInvalid("rfc850")); + } + if s.len() != 22 || s[12] != b':' || s[15] != b':' || &s[18..22] != b" GMT" { - bail!("Date time not in rfc950 fmt"); + return Err(DateError::FormatInvalid("rfc850")); // XXX(Jeremiah): More detailed error here? } - let mut year = from_utf8(&s[7..9])?.parse::()?; + let mut year = maybe_parse_number!(&s[7..9], DateError::YearInvalid); if year < 70 { year += 2000; } else { year += 1900; } Ok(HttpDate { - second: from_utf8(&s[16..18])?.parse()?, - minute: from_utf8(&s[13..15])?.parse()?, - hour: from_utf8(&s[10..12])?.parse()?, - day: from_utf8(&s[0..2])?.parse()?, + second: maybe_parse_number!(&s[16..18], DateError::SecondsInvalid), + minute: maybe_parse_number!(&s[13..15], DateError::MinutesInvalid), + hour: maybe_parse_number!(&s[10..12], DateError::HourInvalid), + day: maybe_parse_number!(&s[0..2], DateError::DayInvalid), month: match &s[2..7] { b"-Jan-" => 1, b"-Feb-" => 2, @@ -161,26 +167,29 @@ fn parse_rfc850_date(s: &[u8]) -> crate::Result { b"-Oct-" => 10, b"-Nov-" => 11, b"-Dec-" => 12, - _ => bail!("Invalid month"), + _ => return Err(DateError::MonthInvalid), }, year, week_day, }) } -fn parse_asctime(s: &[u8]) -> crate::Result { +fn parse_asctime(s: &[u8]) -> Result { // Example: `Sun Nov 6 08:49:37 1994` if s.len() != ASCTIME_LENGTH || s[10] != b' ' || s[13] != b':' || s[16] != b':' || s[19] != b' ' { - bail!("Date time not in asctime format"); + return Err(DateError::FormatInvalid("asctime")); } Ok(HttpDate { - second: from_utf8(&s[17..19])?.parse()?, - minute: from_utf8(&s[14..16])?.parse()?, - hour: from_utf8(&s[11..13])?.parse()?, + second: maybe_parse_number!(&s[17..19], DateError::SecondsInvalid), + minute: maybe_parse_number!(&s[14..16], DateError::MinutesInvalid), + hour: maybe_parse_number!(&s[11..13], DateError::HourInvalid), day: { let x = &s[8..10]; - from_utf8(if x[0] == b' ' { &x[1..2] } else { x })?.parse()? + maybe_parse_number!( + if x[0] == b' ' { &x[1..2] } else { x }, + DateError::DayInvalid + ) }, month: match &s[4..8] { b"Jan " => 1, @@ -195,9 +204,9 @@ fn parse_asctime(s: &[u8]) -> crate::Result { b"Oct " => 10, b"Nov " => 11, b"Dec " => 12, - _ => bail!("Invalid month"), + _ => return Err(DateError::MonthInvalid), }, - year: from_utf8(&s[20..24])?.parse()?, + year: maybe_parse_number!(&s[20..24], DateError::YearInvalid), week_day: match &s[0..4] { b"Mon " => 1, b"Tue " => 2, @@ -206,7 +215,7 @@ fn parse_asctime(s: &[u8]) -> crate::Result { b"Fri " => 5, b"Sat " => 6, b"Sun " => 7, - _ => bail!("Invalid day"), + _ => return Err(DateError::WeekdayInvalid), }, }) } @@ -329,16 +338,34 @@ impl From for SystemTime { } impl FromStr for HttpDate { - type Err = crate::Error; + type Err = DateError; fn from_str(s: &str) -> Result { - ensure!(s.is_ascii(), "String slice is not valid ASCII"); + internal_ensure!(s.is_ascii(), DateError::NotASCII); let x = s.trim().as_bytes(); - let date = parse_imf_fixdate(x) - .or_else(|_| parse_rfc850_date(x)) - .or_else(|_| parse_asctime(x))?; - ensure!(date.is_valid(), "Invalid date time"); - Ok(date) + + let mut date = None; + for parse_fn in [parse_imf_fixdate, parse_rfc850_date, parse_asctime] { + match parse_fn(x) { + Ok(d) => { + date = Some(d); + break; + } + Err(DateError::FormatInvalid(_)) => continue, + // XXX(Jeremiah): Is this correct or should we just continue to the next parser for _any_ error? + Err(date_err) => return Err(date_err), + } + } + + if let Some(date) = date { + if date.is_valid() { + Ok(date) + } else { + Err(DateError::OutOfBounds) + } + } else { + Err(DateError::Unparseable) + } } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 9c9ea923..4f7877a2 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -4,7 +4,7 @@ pub(crate) use date::fmt_http_date; pub(crate) use date::parse_http_date; pub(crate) use date::HttpDate; -use crate::{Error, Status, StatusCode}; +use crate::errors::HeaderError; use std::cmp::Ordering; use std::str::FromStr; @@ -13,20 +13,14 @@ use std::str::FromStr; pub(crate) fn parse_weight(s: &str) -> crate::Result { let mut parts = s.split('='); if !matches!(parts.next(), Some("q")) { - let mut err = Error::new_adhoc("invalid weight"); - err.set_status(StatusCode::BadRequest); - return Err(err); + return Err(HeaderError::SpecificityInvalid.into()); } match parts.next() { Some(s) => { - let weight = f32::from_str(s).status(400)?; + let weight = f32::from_str(s).map_err(|_| HeaderError::SpecificityInvalid)?; Ok(weight) } - None => { - let mut err = Error::new_adhoc("invalid weight"); - err.set_status(StatusCode::BadRequest); - Err(err) - } + None => Err(HeaderError::SpecificityInvalid.into()), } } diff --git a/tests/error.rs b/tests/error.rs index 4cf97923..22b1dd3a 100644 --- a/tests/error.rs +++ b/tests/error.rs @@ -1,11 +1,11 @@ -use http_types::{bail, ensure, ensure_eq, Error, StatusCode}; +use http_types::{bail, ensure, ensure_eq, Error, ResponseError, StatusCode}; use std::io; #[test] fn can_be_boxed() { fn can_be_boxed() -> Result<(), Box> { let err = io::Error::new(io::ErrorKind::Other, "Oh no"); - Err(Error::new(StatusCode::NotFound, err).into()) + Err(Error::IO(err).into()) } assert!(can_be_boxed().is_err()); } @@ -16,37 +16,37 @@ fn internal_server_error_by_default() { Err(io::Error::new(io::ErrorKind::Other, "Oh no").into()) } let err = run().unwrap_err(); - assert_eq!(err.status(), 500); + assert_eq!(err.associated_status_code(), None); } #[test] fn ensure() { - fn inner() -> http_types::Result<()> { + fn inner() -> http_types::ResponseResult<()> { ensure!(true, "Oh yes"); bail!("Oh no!"); } let res = inner(); assert!(res.is_err()); let err = res.unwrap_err(); - assert_eq!(err.status(), StatusCode::InternalServerError); + assert_eq!(err.status(), None); } #[test] fn ensure_eq() { - fn inner() -> http_types::Result<()> { + fn inner() -> http_types::ResponseResult<()> { ensure_eq!(1, 1, "Oh yes"); bail!("Oh no!"); } let res = inner(); assert!(res.is_err()); let err = res.unwrap_err(); - assert_eq!(err.status(), StatusCode::InternalServerError); + assert_eq!(err.status(), None); } #[test] fn result_ext() { use http_types::Status; - fn run() -> http_types::Result<()> { + fn run() -> http_types::ResponseResult<()> { let err = io::Error::new(io::ErrorKind::Other, "Oh no"); Err(err).status(StatusCode::NotFound)?; Ok(()) @@ -55,72 +55,72 @@ fn result_ext() { assert!(res.is_err()); let err = res.unwrap_err(); - assert_eq!(err.status(), StatusCode::NotFound); + assert_eq!(err.status(), Some(StatusCode::NotFound)); } #[test] fn option_ext() { use http_types::Status; - fn run() -> http_types::Result<()> { + fn run() -> http_types::ResponseResult<()> { None.status(StatusCode::NotFound) } let res = run(); assert!(res.is_err()); let err = res.unwrap_err(); - assert_eq!(err.status(), StatusCode::NotFound); + assert_eq!(err.status(), Some(StatusCode::NotFound)); } #[test] -fn anyhow_error_into_http_types_error() { +fn anyhow_error_into_http_types_response_error() { let anyhow_error = anyhow::Error::new(std::io::Error::new(std::io::ErrorKind::Other, "irrelevant")); - let http_types_error: Error = anyhow_error.into(); - assert_eq!(http_types_error.status(), StatusCode::InternalServerError); + let http_types_error: ResponseError = anyhow_error.into(); + assert_eq!(http_types_error.status(), None); let anyhow_error = anyhow::Error::new(std::io::Error::new(std::io::ErrorKind::Other, "irrelevant")); - let http_types_error: Error = Error::new(StatusCode::ImATeapot, anyhow_error); - assert_eq!(http_types_error.status(), StatusCode::ImATeapot); + let http_types_error: ResponseError = + ResponseError::new_status(StatusCode::ImATeapot, anyhow_error); + assert_eq!(http_types_error.status(), Some(StatusCode::ImATeapot)); } #[test] -fn normal_error_into_http_types_error() { - let http_types_error: Error = +fn normal_error_into_http_types_response_error() { + let http_types_error: ResponseError = std::io::Error::new(std::io::ErrorKind::Other, "irrelevant").into(); - assert_eq!(http_types_error.status(), StatusCode::InternalServerError); + assert_eq!(http_types_error.status(), None); - let http_types_error = Error::new( + let http_types_error = ResponseError::new_status( StatusCode::ImATeapot, std::io::Error::new(std::io::ErrorKind::Other, "irrelevant"), ); - assert_eq!(http_types_error.status(), StatusCode::ImATeapot); + assert_eq!(http_types_error.status(), Some(StatusCode::ImATeapot)); } #[test] fn u16_into_status_code_in_http_types_error() { - let http_types_error = Error::new(404, io::Error::new(io::ErrorKind::Other, "Not Found")); - let http_types_error2 = Error::new( - StatusCode::NotFound, - io::Error::new(io::ErrorKind::Other, "Not Found"), - ); - assert_eq!(http_types_error.status(), http_types_error2.status()); - - let http_types_error = Error::from_str(404, "Not Found"); - assert_eq!(http_types_error.status(), StatusCode::NotFound); + let http_types_error = ResponseError::from_str_status(404, "Not Found"); + assert_eq!(http_types_error.status(), Some(StatusCode::NotFound)); } #[test] -#[should_panic] +#[should_panic = "Could not convert into a valid `StatusCode`"] fn fail_test_u16_into_status_code_in_http_types_error_new() { - let _http_types_error = Error::new( - 1000, - io::Error::new(io::ErrorKind::Other, "Incorrect status code"), - ); + panic!( + "{}", + ResponseError::from_str_status( + 1000, + io::Error::new(io::ErrorKind::Other, "Incorrect status code"), + ) + ) } #[test] -#[should_panic] +#[should_panic = "Could not convert into a valid `StatusCode`"] fn fail_test_u16_into_status_code_in_http_types_error_from_str() { - let _http_types_error = Error::from_str(1000, "Incorrect status code"); + panic!( + "{}", + ResponseError::from_str_status(1000, "Incorrect status code",) + ) } diff --git a/tests/querystring.rs b/tests/querystring.rs index d0e3c8cf..61a9e0b0 100644 --- a/tests/querystring.rs +++ b/tests/querystring.rs @@ -32,7 +32,10 @@ fn unsuccessfully_deserialize_query() { let params = req.query::(); assert!(params.is_err()); - assert_eq!(params.err().unwrap().to_string(), "missing field `msg`"); + assert_eq!( + params.err().unwrap().to_string(), + "Querystring deserialization error: missing field `msg`" + ); } #[test] @@ -44,7 +47,10 @@ fn malformatted_query() { let params = req.query::(); assert!(params.is_err()); - assert_eq!(params.err().unwrap().to_string(), "missing field `msg`"); + assert_eq!( + params.err().unwrap().to_string(), + "Querystring deserialization error: missing field `msg`" + ); } #[test]