diff --git a/daphne/src/error/aborts.rs b/daphne/src/error/aborts.rs index a7ef1b5f9..83b3f6dee 100644 --- a/daphne/src/error/aborts.rs +++ b/daphne/src/error/aborts.rs @@ -35,10 +35,6 @@ pub enum DapAbort { #[error("batchOverlap")] BatchOverlap { detail: String, task_id: TaskId }, - /// Internal error. - #[error("internal error")] - Internal(#[source] Box), - /// Invalid batch size (either too small or too large). Sent in response to a CollectReq or /// AggregateShareReq. #[error("invalidBatchSize")] @@ -143,7 +139,7 @@ impl DapAbort { Some(agg_job_id_base64url), ), Self::InvalidMessage { detail, task_id } => (task_id, Some(detail), None), - Self::ReportTooLate | Self::UnrecognizedTask | Self::Internal(_) => (None, None, None), + Self::ReportTooLate | Self::UnrecognizedTask => (None, None, None), }; ProblemDetails { @@ -207,7 +203,7 @@ impl DapAbort { } #[inline] - pub fn report_rejected(failure_reason: TransitionFailure) -> Self { + pub fn report_rejected(failure_reason: TransitionFailure) -> Result { let detail = match failure_reason { TransitionFailure::BatchCollected => { "The report pertains to a batch that has already been collected." @@ -215,15 +211,17 @@ impl DapAbort { TransitionFailure::ReportReplayed => { "A report with the same ID was uploaded previously." } - _ => return fatal_error!( - err = "Attempted to construct a \"reportRejected\" abort with unexpected transition failure", - unexpected_transition_failure = ?failure_reason, - ).into(), + _ => { + return Err(fatal_error!( + err = "Attempted to construct a \"reportRejected\" abort with unexpected transition failure", + unexpected_transition_failure = ?failure_reason, + )) + } }; - Self::ReportRejected { + Ok(Self::ReportRejected { detail: detail.into(), - } + }) } fn title_and_type(&self) -> (String, Option) { @@ -267,7 +265,6 @@ impl DapAbort { Some(self.to_string()), ), Self::BadRequest(..) => ("Bad request", None), - Self::Internal(..) => ("Internal server error", None), }; ( @@ -277,16 +274,6 @@ impl DapAbort { } } -impl From for DapAbort { - fn from(e: DapError) -> Self { - match e { - e @ DapError::Fatal(..) => Self::Internal(Box::new(e)), - DapError::Abort(abort) => abort, - DapError::Transition(failure_reason) => Self::report_rejected(failure_reason), - } - } -} - impl DapAbort { pub fn from_codec_error>>(e: CodecError, task_id: Id) -> Self { Self::InvalidMessage { diff --git a/daphne/src/error/mod.rs b/daphne/src/error/mod.rs index be9a644ff..16bbf91e1 100644 --- a/daphne/src/error/mod.rs +++ b/daphne/src/error/mod.rs @@ -8,6 +8,8 @@ use std::fmt::{Debug, Display}; use crate::{messages::TransitionFailure, vdaf::VdafError}; pub use aborts::DapAbort; +use self::aborts::ProblemDetails; + /// DAP errors. #[derive(Debug, thiserror::Error)] pub enum DapError { @@ -28,6 +30,23 @@ pub enum DapError { Transition(#[from] TransitionFailure), } +impl DapError { + pub fn into_problem_details(self) -> ProblemDetails { + if let Self::Abort(a) = self { + return a.into_problem_details(); + } + + ProblemDetails { + typ: None, + title: "Internal server error".into(), + agg_job_id: None, + task_id: None, + instance: None, + detail: None, + } + } +} + impl FatalDapError { #[doc(hidden)] pub fn __use_the_macro(s: String) -> Self { diff --git a/daphne/src/roles/aggregator.rs b/daphne/src/roles/aggregator.rs index d99cac452..4c6f1a0ae 100644 --- a/daphne/src/roles/aggregator.rs +++ b/daphne/src/roles/aggregator.rs @@ -142,14 +142,14 @@ pub trait DapAggregator: HpkeDecrypter + DapReportInitializer + Sized { ) -> Result<(), DapError>; /// Handle request for the Aggregator's HPKE configuration. - async fn handle_hpke_config_req(&self, req: &DapRequest) -> Result { + async fn handle_hpke_config_req(&self, req: &DapRequest) -> Result { let metrics = self.metrics(); // Parse the task ID from the query string, ensuring that it is the only query parameter. let mut id = None; for (k, v) in req.url.query_pairs() { if k != "task_id" { - return Err(DapAbort::BadRequest("unexpected query parameter".into())); + return Err(DapAbort::BadRequest("unexpected query parameter".into()).into()); } let bytes = decode_base64url(v.as_bytes()).ok_or(DapAbort::BadRequest( @@ -169,10 +169,9 @@ pub trait DapAggregator: HpkeDecrypter + DapReportInitializer + Sized { // Check whether the DAP version in the request matches the task config. if task_config.as_ref().version != req.version { - return Err(DapAbort::version_mismatch( - req.version, - task_config.as_ref().version, - )); + return Err( + DapAbort::version_mismatch(req.version, task_config.as_ref().version).into(), + ); } } diff --git a/daphne/src/roles/helper.rs b/daphne/src/roles/helper.rs index 797e6740a..6d39a26aa 100644 --- a/daphne/src/roles/helper.rs +++ b/daphne/src/roles/helper.rs @@ -50,7 +50,7 @@ pub trait DapHelper: DapAggregator { req: &'req DapRequest, metrics: &DaphneMetrics, task_id: &TaskId, - ) -> Result { + ) -> Result { let agg_job_init_req = AggregationJobInitReq::get_decoded_with_param(&req.version, &req.payload) .map_err(|e| DapAbort::from_codec_error(e, *task_id))?; @@ -88,7 +88,8 @@ pub trait DapHelper: DapAggregator { detail: "some reports include the taskprov extensions and some do not" .to_string(), task_id: Some(*task_id), - }); + } + .into()); } }; resolve_taskprov(self, task_id, req, first_metadata).await?; @@ -105,14 +106,15 @@ pub trait DapHelper: DapAggregator { return Err(DapAbort::UnauthorizedRequest { detail: reason, task_id: *task_id, - }); + } + .into()); } let agg_job_id = resolve_agg_job_id(req, agg_job_init_req.draft02_agg_job_id.as_ref())?; // Check whether the DAP version in the request matches the task config. if task_config.version != req.version { - return Err(DapAbort::version_mismatch(req.version, task_config.version)); + return Err(DapAbort::version_mismatch(req.version, task_config.version).into()); } // Ensure we know which batch the request pertains to. @@ -140,7 +142,7 @@ pub trait DapHelper: DapAggregator { metrics, )? else { - return Err(DapAbort::from(fatal_error!(err = "unexpected transition"))); + return Err(fatal_error!(err = "unexpected transition")); }; if !self @@ -150,7 +152,8 @@ pub trait DapHelper: DapAggregator { // TODO spec: Consider an explicit abort for this case. return Err(DapAbort::BadRequest( "unexpected message for aggregation job (already exists)".into(), - )); + ) + .into()); } metrics.agg_job_started_inc(); agg_job_resp @@ -173,9 +176,7 @@ pub trait DapHelper: DapAggregator { metrics, )? else { - return Err(DapAbort::from(fatal_error!( - err = "unexpected transition" - ))); + return Err(fatal_error!(err = "unexpected transition")); }; Ok((agg_span, agg_job_resp)) }, @@ -209,7 +210,7 @@ pub trait DapHelper: DapAggregator { req: &'req DapRequest, metrics: &DaphneMetrics, task_id: &TaskId, - ) -> Result { + ) -> Result { if self.get_global_config().allow_taskprov { resolve_taskprov(self, task_id, req, None).await?; } @@ -224,12 +225,13 @@ pub trait DapHelper: DapAggregator { return Err(DapAbort::UnauthorizedRequest { detail: reason, task_id: *task_id, - }); + } + .into()); } // Check whether the DAP version in the request matches the task config. if task_config.version != req.version { - return Err(DapAbort::version_mismatch(req.version, task_config.version)); + return Err(DapAbort::version_mismatch(req.version, task_config.version).into()); } let agg_job_cont_req = @@ -290,7 +292,7 @@ pub trait DapHelper: DapAggregator { } /// Handle a request pertaining to an aggregation job. - async fn handle_agg_job_req(&self, req: &DapRequest) -> Result { + async fn handle_agg_job_req(&self, req: &DapRequest) -> Result { let metrics = self.metrics(); let task_id = req.task_id()?; @@ -302,13 +304,13 @@ pub trait DapHelper: DapAggregator { self.handle_agg_job_cont_req(req, metrics, task_id).await } //TODO spec: Specify this behavior. - _ => Err(DapAbort::BadRequest("unexpected media type".into())), + _ => Err(DapAbort::BadRequest("unexpected media type".into()).into()), } } /// Handle a request for an aggregate share. This is called by the Leader to complete a /// collection job. - async fn handle_agg_share_req(&self, req: &DapRequest) -> Result { + async fn handle_agg_share_req(&self, req: &DapRequest) -> Result { let now = self.get_current_time(); let metrics = self.metrics(); let task_id = req.task_id()?; @@ -330,12 +332,13 @@ pub trait DapHelper: DapAggregator { return Err(DapAbort::UnauthorizedRequest { detail: reason, task_id: *task_id, - }); + } + .into()); } // Check whether the DAP version in the request matches the task config. if task_config.version != req.version { - return Err(DapAbort::version_mismatch(req.version, task_config.version)); + return Err(DapAbort::version_mismatch(req.version, task_config.version).into()); } let agg_share_req = AggregateShareReq::get_decoded_with_param(&req.version, &req.payload) @@ -368,7 +371,7 @@ pub trait DapHelper: DapAggregator { agg_share.report_count, hex::encode(agg_share.checksum)), task_id: *task_id, - }); + }.into()); } // Check the batch size. @@ -382,7 +385,8 @@ pub trait DapHelper: DapAggregator { agg_share.report_count, task_config.min_batch_size ), task_id: *task_id, - }); + } + .into()); } // Mark each aggregated report as collected. @@ -468,9 +472,9 @@ async fn finish_agg_job_and_aggregate( &HashMap, ) -> Result< (DapAggregateSpan, AggregationJobResp), - DapAbort, + DapError, >, -) -> Result { +) -> Result { // This loop is intended to run at most once on the "happy path". The intent is as follows: // // - try to aggregate the output shares into an `DapAggregateShareSpan` @@ -522,7 +526,7 @@ async fn finish_agg_job_and_aggregate( // and if this error doesn't manifest itself all reports will be successfully // aggregated. Which means that no reports will be lost in a such a state that // they can never be aggregated. - (Err(e), _) => return Err(e.into()), + (Err(e), _) => return Err(e), } } if !inc_restart_metric.is_completed() { @@ -541,9 +545,7 @@ async fn finish_agg_job_and_aggregate( // We need to prevent an attacker from keeping this loop running for too long, potentially // enabling an DOS attack. - Err(DapAbort::BadRequest( - "AggregationJobContinueReq contained too many replays".into(), - )) + Err(DapAbort::BadRequest("AggregationJobContinueReq contained too many replays".into()).into()) } #[cfg(test)] diff --git a/daphne/src/roles/leader.rs b/daphne/src/roles/leader.rs index a2e478c3e..2625eed49 100644 --- a/daphne/src/roles/leader.rs +++ b/daphne/src/roles/leader.rs @@ -146,7 +146,7 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { async fn send_http_put(&self, req: DapRequest) -> Result; /// Handle a report from a Client. - async fn handle_upload_req(&self, req: &DapRequest) -> Result<(), DapAbort> { + async fn handle_upload_req(&self, req: &DapRequest) -> Result<(), DapError> { let metrics = self.metrics(); let task_id = req.task_id()?; debug!("upload for task {task_id}"); @@ -167,10 +167,9 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { // Check whether the DAP version in the request matches the task config. if task_config.as_ref().version != req.version { - return Err(DapAbort::version_mismatch( - req.version, - task_config.as_ref().version, - )); + return Err( + DapAbort::version_mismatch(req.version, task_config.as_ref().version).into(), + ); } if report.encrypted_input_shares.len() != 2 { @@ -180,7 +179,8 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { report.encrypted_input_shares.len() ), task_id: Some(*task_id), - }); + } + .into()); } // Check that the indicated HpkeConfig is present. @@ -192,12 +192,13 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { { return Err(DapAbort::ReportRejected { detail: "No current HPKE configuration matches the indicated ID.".into(), - }); + } + .into()); } // Check that the task has not expired. if report.report_metadata.time >= task_config.as_ref().expiration { - return Err(DapAbort::ReportTooLate); + return Err(DapAbort::ReportTooLate.into()); } // Store the report for future processing. At this point, the report may be rejected if @@ -211,7 +212,7 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { /// Handle a collect job from the Collector. The response is the URI that the Collector will /// poll later on to get the collection. - async fn handle_collect_job_req(&self, req: &DapRequest) -> Result { + async fn handle_collect_job_req(&self, req: &DapRequest) -> Result { let now = self.get_current_time(); let metrics = self.metrics(); let task_id = req.task_id()?; @@ -234,7 +235,8 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { return Err(DapAbort::UnauthorizedRequest { detail: reason, task_id: *task_id, - }); + } + .into()); } let mut collect_req = @@ -243,7 +245,7 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { // Check whether the DAP version in the request matches the task config. if task_config.version != req.version { - return Err(DapAbort::version_mismatch(req.version, task_config.version)); + return Err(DapAbort::version_mismatch(req.version, task_config.version).into()); } if collect_req.query == Query::FixedSizeCurrentBatch { @@ -283,7 +285,7 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { Some(*collect_job_id) } (DapVersion::Draft07, DapResource::Undefined) => { - return Err(DapAbort::BadRequest("undefined resource".into())); + return Err(DapAbort::BadRequest("undefined resource".into()).into()); } _ => unreachable!("unhandled resource {:?}", req.resource), }; @@ -309,7 +311,7 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { task_config: &DapTaskConfig, part_batch_sel: &PartialBatchSelector, reports: Vec, - ) -> Result { + ) -> Result { let metrics = self.metrics(); // Prepare AggregationJobInitReq. @@ -338,7 +340,9 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { } DapLeaderAggregationJobTransition::Finished(..) | DapLeaderAggregationJobTransition::Uncommitted(..) => { - return Err(fatal_error!(err = "unexpected state transition (uncommitted)").into()) + return Err(fatal_error!( + err = "unexpected state transition (uncommitted)" + )) } }; let method = if task_config.version != DapVersion::Draft02 { @@ -421,7 +425,7 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { } } DapLeaderAggregationJobTransition::Continued(..) => { - return Err(fatal_error!(err = "unexpected state transition (continue)").into()) + return Err(fatal_error!(err = "unexpected state transition (continue)")) } }; @@ -460,7 +464,7 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { collect_id: &CollectionJobId, task_config: &DapTaskConfig, collect_req: &CollectionReq, - ) -> Result { + ) -> Result { let metrics = self.metrics(); debug!("collecting id {collect_id}"); @@ -567,7 +571,7 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { &self, selector: &Self::ReportSelector, host: &str, - ) -> Result { + ) -> Result { let mut telem = DapLeaderProcessTelemetry::default(); tracing::debug!("RUNNING get_reports"); diff --git a/daphne/src/roles/mod.rs b/daphne/src/roles/mod.rs index 95109e1c2..84cfa030d 100644 --- a/daphne/src/roles/mod.rs +++ b/daphne/src/roles/mod.rs @@ -25,7 +25,7 @@ async fn check_batch( batch_sel: &BatchSelector, agg_param: &[u8], now: Time, -) -> Result<(), DapAbort> { +) -> Result<(), DapError> { let global_config = agg.get_global_config(); let batch_overlapping = agg.is_batch_overlapping(task_id, batch_sel); @@ -35,7 +35,8 @@ async fn check_batch( return Err(DapAbort::InvalidMessage { detail: "invalid aggregation parameter".into(), task_id: Some(*task_id), - }); + } + .into()); } // Check that the batch boundaries are valid. @@ -48,23 +49,23 @@ async fn check_batch( return Err(DapAbort::BatchInvalid { detail: format!("The queried batch interval ({batch_interval:?}) is too small or its boundaries are misaligned. The time precision for this task is {}s.", task_config.time_precision), task_id: *task_id, - }); + }.into()); } if batch_interval.duration > global_config.max_batch_duration { - return Err(DapAbort::BadRequest("batch interval too large".to_string())); + return Err(DapAbort::BadRequest("batch interval too large".to_string()).into()); } if now.abs_diff(batch_interval.start) > global_config.min_batch_interval_start { - return Err(DapAbort::BadRequest( - "batch interval too far into past".to_string(), - )); + return Err( + DapAbort::BadRequest("batch interval too far into past".to_string()).into(), + ); } if now.abs_diff(batch_interval.end()) > global_config.max_batch_interval_end { - return Err(DapAbort::BadRequest( - "batch interval too far into future".to_string(), - )); + return Err( + DapAbort::BadRequest("batch interval too far into future".to_string()).into(), + ); } } (DapQueryConfig::FixedSize { .. }, BatchSelector::FixedSizeByBatchId { batch_id }) => { @@ -81,21 +82,16 @@ async fn check_batch( batch_id.to_base64url() ), task_id: *task_id, - }); + } + .into()); } } - _ => { - return Err(DapAbort::query_mismatch( - task_id, - &task_config.query, - batch_sel, - )) - } + _ => return Err(DapAbort::query_mismatch(task_id, &task_config.query, batch_sel).into()), }; // Check that the batch does not overlap with any previously collected batch. if batch_overlapping.await? { - return Err(DapAbort::batch_overlap(task_id, batch_sel)); + return Err(DapAbort::batch_overlap(task_id, batch_sel).into()); } Ok(()) @@ -174,7 +170,7 @@ mod test { test_versions, testing::{AggStore, MockAggregator, MockAggregatorReportSelector}, vdaf::VdafVerifyKey, - DapAbort, DapAggregateShare, DapBatchBucket, DapCollectJob, DapGlobalConfig, + DapAbort, DapAggregateShare, DapBatchBucket, DapCollectJob, DapError, DapGlobalConfig, DapLeaderAggregationJobTransition, DapMeasurement, DapQueryConfig, DapRequest, DapResource, DapTaskConfig, DapVersion, MetaAggregationJobId, Prio3Config, VdafConfig, }; @@ -646,7 +642,7 @@ mod test { .unwrap() } - pub async fn run_agg_job(&self, task_id: &TaskId) -> Result<(), DapAbort> { + pub async fn run_agg_job(&self, task_id: &TaskId) -> Result<(), DapError> { let wrapped = self.leader.get_task_config_for(task_id).await.unwrap(); let task_config = wrapped.as_ref().unwrap(); @@ -661,7 +657,7 @@ mod test { Ok(()) } - pub async fn run_col_job(&self, task_id: &TaskId, query: &Query) -> Result<(), DapAbort> { + pub async fn run_col_job(&self, task_id: &TaskId, query: &Query) -> Result<(), DapError> { let wrapped = self.leader.get_task_config_for(task_id).await.unwrap(); let task_config = wrapped.as_ref().unwrap(); @@ -781,7 +777,7 @@ mod test { .await; assert_matches!( t.helper.handle_agg_job_req(&req).await.unwrap_err(), - DapAbort::QueryMismatch { .. } + DapError::Abort(DapAbort::QueryMismatch { .. }) ); assert_eq!(t.helper.audit_log.invocations(), 0); @@ -830,14 +826,14 @@ mod test { // Expect failure due to missing bearer token. assert_matches!( t.helper.handle_agg_job_req(&req).await, - Err(DapAbort::UnauthorizedRequest { .. }) + Err(DapError::Abort(DapAbort::UnauthorizedRequest { .. })) ); // Expect failure due to incorrect bearer token. req.sender_auth = Some(BearerToken::from("incorrect auth token!".to_string())); assert_matches!( t.helper.handle_agg_job_req(&req).await, - Err(DapAbort::UnauthorizedRequest { .. }) + Err(DapError::Abort(DapAbort::UnauthorizedRequest { .. })) ); assert_eq!(t.helper.audit_log.invocations(), 0); @@ -865,7 +861,7 @@ mod test { assert_matches!( t.leader.handle_hpke_config_req(&req).await, - Err(DapAbort::UnrecognizedTask) + Err(DapError::Abort(DapAbort::UnrecognizedTask)) ); } @@ -888,7 +884,7 @@ mod test { // used for all tasks. assert_matches!( t.leader.handle_hpke_config_req(&req).await, - Err(DapAbort::MissingTaskId) + Err(DapError::Abort(DapAbort::MissingTaskId)) ); } @@ -905,14 +901,14 @@ mod test { // Expect failure due to missing bearer token. assert_matches!( t.helper.handle_agg_job_req(&req).await, - Err(DapAbort::UnauthorizedRequest { .. }) + Err(DapError::Abort(DapAbort::UnauthorizedRequest { .. })) ); // Expect failure due to incorrect bearer token. req.sender_auth = Some(BearerToken::from("incorrect auth token!".to_string())); assert_matches!( t.helper.handle_agg_job_req(&req).await, - Err(DapAbort::UnauthorizedRequest { .. }) + Err(DapError::Abort(DapAbort::UnauthorizedRequest { .. })) ); assert_eq!(t.helper.audit_log.invocations(), 0); @@ -928,14 +924,14 @@ mod test { // Expect failure due to missing bearer token. assert_matches!( t.helper.handle_agg_share_req(&req).await, - Err(DapAbort::UnauthorizedRequest { .. }) + Err(DapError::Abort(DapAbort::UnauthorizedRequest { .. })) ); // Expect failure due to incorrect bearer token. req.sender_auth = Some(BearerToken::from("incorrect auth token!".to_string())); assert_matches!( t.helper.handle_agg_share_req(&req).await, - Err(DapAbort::UnauthorizedRequest { .. }) + Err(DapError::Abort(DapAbort::UnauthorizedRequest { .. })) ); } @@ -971,7 +967,7 @@ mod test { .await; assert_matches!( t.helper.handle_agg_share_req(&req).await.unwrap_err(), - DapAbort::QueryMismatch { .. } + DapError::Abort(DapAbort::QueryMismatch { .. }) ); // Leader sends aggregate share request for unrecognized batch ID. @@ -999,7 +995,7 @@ mod test { .await; assert_matches!( t.helper.handle_agg_share_req(&req).await.unwrap_err(), - DapAbort::BatchInvalid { .. } + DapError::Abort(DapAbort::BatchInvalid { .. }) ); } @@ -1042,14 +1038,14 @@ mod test { // Expect failure due to missing bearer token. assert_matches!( t.leader.handle_collect_job_req(&req).await, - Err(DapAbort::UnauthorizedRequest { .. }) + Err(DapError::Abort(DapAbort::UnauthorizedRequest { .. })) ); // Expect failure due to incorrect bearer token. req.sender_auth = Some(BearerToken::from("incorrect auth token!".to_string())); assert_matches!( t.leader.handle_collect_job_req(&req).await, - Err(DapAbort::UnauthorizedRequest { .. }) + Err(DapError::Abort(DapAbort::UnauthorizedRequest { .. })) ); } @@ -1222,7 +1218,7 @@ mod test { assert_eq!(t.helper.audit_log.invocations(), 1); // Expect failure due to overwriting existing helper state. - assert_matches!(err, DapAbort::BadRequest(e) => + assert_matches!(err, DapError::Abort(DapAbort::BadRequest(e)) => assert_eq!(e, "unexpected message for aggregation job (already exists)") ); } @@ -1240,7 +1236,10 @@ mod test { let err = t.helper.handle_agg_job_req(&req).await.unwrap_err(); // Expect failure due to sending continue request before initialization request. - assert_matches!(err, DapAbort::UnrecognizedAggregationJob { .. }); + assert_matches!( + err, + DapError::Abort(DapAbort::UnrecognizedAggregationJob { .. }) + ); } async_test_versions! { handle_agg_job_req_fail_send_cont_req } @@ -1265,7 +1264,7 @@ mod test { // Expect failure due to invalid task ID in report. assert_matches!( t.leader.handle_upload_req(&req).await, - Err(DapAbort::UnrecognizedTask) + Err(DapError::Abort(DapAbort::UnrecognizedTask)) ); } @@ -1290,7 +1289,7 @@ mod test { assert_matches!( t.leader.handle_upload_req(&req).await.unwrap_err(), - DapAbort::ReportTooLate + DapError::Abort(DapAbort::ReportTooLate) ); } @@ -1441,7 +1440,7 @@ mod test { let err = t.leader.handle_collect_job_req(&req).await.unwrap_err(); // Fails because the requested batch interval is too large. - assert_matches!(err, DapAbort::BadRequest(s) => assert_eq!(s, "batch interval too large".to_string())); + assert_matches!(err, DapError::Abort(DapAbort::BadRequest(s)) => assert_eq!(s, "batch interval too large".to_string())); // Collector: Create a CollectReq with a batch interval in the past. let req = t.collector_authorized_req( @@ -1467,7 +1466,7 @@ mod test { let err = t.leader.handle_collect_job_req(&req).await.unwrap_err(); // Fails because the requested batch interval is too far into the past. - assert_matches!(err, DapAbort::BadRequest(s) => assert_eq!(s, "batch interval too far into past".to_string())); + assert_matches!(err, DapError::Abort(DapAbort::BadRequest(s)) => assert_eq!(s, "batch interval too far into past".to_string())); // Collector: Create a CollectReq with a batch interval in the future. let req = t.collector_authorized_req( @@ -1493,7 +1492,7 @@ mod test { let err = t.leader.handle_collect_job_req(&req).await.unwrap_err(); // Fails because the requested batch interval is too far into the future. - assert_matches!(err, DapAbort::BadRequest(s) => assert_eq!(s, "batch interval too far into future".to_string())); + assert_matches!(err, DapError::Abort(DapAbort::BadRequest(s)) => assert_eq!(s, "batch interval too far into future".to_string())); } async_test_versions! { handle_collect_job_req_fail_invalid_batch_interval } @@ -1551,7 +1550,7 @@ mod test { // run a second collect job (expect failure due to overlapping batch). assert_matches!( t.run_col_job(task_id, &query).await.unwrap_err(), - DapAbort::BatchOverlap { .. } + DapError::Abort(DapAbort::BatchOverlap { .. }) ); } @@ -1628,7 +1627,7 @@ mod test { ); assert_matches!( t.leader.handle_collect_job_req(&req).await.unwrap_err(), - DapAbort::QueryMismatch { .. } + DapError::Abort(DapAbort::QueryMismatch { .. }) ); // Collector indicates unrecognized batch ID. @@ -1651,7 +1650,7 @@ mod test { ); assert_matches!( t.leader.handle_collect_job_req(&req).await.unwrap_err(), - DapAbort::BatchInvalid { .. } + DapError::Abort(DapAbort::BatchInvalid { .. }) ); } diff --git a/daphne/src/taskprov.rs b/daphne/src/taskprov.rs index 0853bd42d..cc360f916 100644 --- a/daphne/src/taskprov.rs +++ b/daphne/src/taskprov.rs @@ -125,8 +125,7 @@ pub fn resolve_advertised_task_config( report_metadata_advertisement: Option<&ReportMetadata>, ) -> Result, DapError> { let Some(advertised_task_config) = - get_taskprov_task_config(req, task_id, report_metadata_advertisement) - .map_err(DapError::Abort)? + get_taskprov_task_config(req, task_id, report_metadata_advertisement)? else { return Ok(None); }; @@ -147,7 +146,7 @@ fn get_taskprov_task_config( req: &'_ DapRequest, task_id: &TaskId, report_metadata_advertisement: Option<&ReportMetadata>, -) -> Result, DapAbort> { +) -> Result, DapError> { let taskprov_data = if let Some(ref taskprov_base64url) = req.taskprov { Cow::Owned(decode_base64url_vec(taskprov_base64url).ok_or_else(|| { DapAbort::BadRequest( @@ -163,9 +162,7 @@ fn get_taskprov_task_config( .draft02_extensions .as_ref() .ok_or_else(|| { - DapAbort::from(fatal_error!( - err = "draft02: encountered report metadata with no extensions" - )) + fatal_error!(err = "draft02: encountered report metadata with no extensions") })? .iter() .filter(|x| matches!(x, Extension::Taskprov { .. })) @@ -188,7 +185,7 @@ fn get_taskprov_task_config( if compute_task_id(req.version, taskprov_data.as_ref()) != *task_id { // Return unrecognizedTask following section 5.1 of the taskprov draft. - return Err(DapAbort::UnrecognizedTask); + return Err(DapAbort::UnrecognizedTask.into()); } // Return unrecognizedMessage if parsing fails following section 5.1 of the taskprov draft. diff --git a/daphne/src/testing.rs b/daphne/src/testing.rs index 917d116fc..99b73fc0b 100644 --- a/daphne/src/testing.rs +++ b/daphne/src/testing.rs @@ -280,7 +280,7 @@ impl AggregationJobTest { &self, leader_state: DapAggregationJobState, agg_job_resp: AggregationJobResp, - ) -> DapAbort { + ) -> DapError { let metrics = &self.leader_metrics; self.task_config .vdaf @@ -321,7 +321,7 @@ impl AggregationJobTest { &self, helper_state: DapAggregationJobState, agg_job_cont_req: &AggregationJobContinueReq, - ) -> DapAbort { + ) -> DapError { self.task_config .vdaf .handle_agg_job_cont_req( diff --git a/daphne/src/vdaf/mod.rs b/daphne/src/vdaf/mod.rs index 55a84582b..c1b2624cb 100644 --- a/daphne/src/vdaf/mod.rs +++ b/daphne/src/vdaf/mod.rs @@ -773,7 +773,7 @@ impl VdafConfig { part_batch_sel: &PartialBatchSelector, reports: Vec, metrics: &DaphneMetrics, - ) -> Result, DapAbort> { + ) -> Result, DapError> { let mut processed = HashSet::with_capacity(reports.len()); let mut states = Vec::with_capacity(reports.len()); let mut prep_inits = Vec::with_capacity(reports.len()); @@ -784,8 +784,7 @@ impl VdafConfig { return Err(fatal_error!( err = "tried to process report sequence with non-unique report IDs", non_unique_id = %report.report_metadata.id, - ) - .into()); + )); } processed.insert(report.report_metadata.id); @@ -896,7 +895,7 @@ impl VdafConfig { task_id: &TaskId, task_config: &DapTaskConfig, agg_job_init_req: &'req AggregationJobInitReq, - ) -> Result>, DapAbort> { + ) -> Result>, DapError> { let num_reports = agg_job_init_req.prep_inits.len(); let mut processed = HashSet::with_capacity(num_reports); let mut consumed_reports = Vec::with_capacity(num_reports); @@ -908,7 +907,8 @@ impl VdafConfig { prep_init.report_share.report_metadata.id.to_base64url() ), task_id: Some(*task_id), - }); + } + .into()); } processed.insert(prep_init.report_share.report_metadata.id); @@ -950,7 +950,7 @@ impl VdafConfig { initialized_reports: &[EarlyReportStateInitialized<'_>], agg_job_init_req: &AggregationJobInitReq, metrics: &DaphneMetrics, - ) -> Result, DapAbort> { + ) -> Result, DapError> { match task_config.version { DapVersion::Draft02 => Ok(Self::draft02_handle_agg_job_init_req( report_status, @@ -1035,7 +1035,7 @@ impl VdafConfig { initialized_reports: &[EarlyReportStateInitialized<'_>], agg_job_init_req: &AggregationJobInitReq, metrics: &DaphneMetrics, - ) -> Result, DapAbort> { + ) -> Result, DapError> { let num_reports = agg_job_init_req.prep_inits.len(); let mut agg_span = DapAggregateSpan::default(); let mut transitions = Vec::with_capacity(num_reports); @@ -1058,7 +1058,8 @@ impl VdafConfig { return Err(DapAbort::InvalidMessage { detail: "PrepareInit with missing payload".to_string(), task_id: Some(*task_id), - }); + } + .into()); }; // Decode the ping-pong "initialize" message framing. @@ -1142,16 +1143,18 @@ impl VdafConfig { state: DapAggregationJobState, agg_job_resp: AggregationJobResp, metrics: &DaphneMetrics, - ) -> Result, DapAbort> { + ) -> Result, DapError> { match task_config.version { - DapVersion::Draft02 => self.draft02_handle_agg_job_resp( - task_id, - task_config, - agg_job_id, - state, - agg_job_resp, - metrics, - ), + DapVersion::Draft02 => self + .draft02_handle_agg_job_resp( + task_id, + task_config, + agg_job_id, + state, + agg_job_resp, + metrics, + ) + .map_err(Into::into), DapVersion::Draft07 => { self.draft07_handle_agg_job_resp(task_id, task_config, state, agg_job_resp, metrics) } @@ -1278,7 +1281,7 @@ impl VdafConfig { state: DapAggregationJobState, agg_job_resp: AggregationJobResp, metrics: &DaphneMetrics, - ) -> Result, DapAbort> { + ) -> Result, DapError> { if agg_job_resp.transitions.len() != state.seq.len() { return Err(DapAbort::InvalidMessage { detail: format!( @@ -1287,7 +1290,8 @@ impl VdafConfig { state.seq.len(), ), task_id: Some(*task_id), - }); + } + .into()); } let mut agg_span = DapAggregateSpan::default(); @@ -1303,7 +1307,8 @@ impl VdafConfig { helper.report_id.to_base64url() ), task_id: Some(*task_id), - }); + } + .into()); } let prep_msg = match &helper.var { @@ -1319,7 +1324,7 @@ impl VdafConfig { return Err(DapAbort::InvalidMessage { detail: "The Helper's AggregationJobResp is invalid, but it may have already committed its state change. A batch mismatch is inevitable.".to_string(), task_id: Some(*task_id), - }); + }.into()); }; prep_msg @@ -1335,7 +1340,8 @@ impl VdafConfig { return Err(DapAbort::InvalidMessage { detail: "helper sent unexpected `Finished` message".to_string(), task_id: Some(*task_id), - }) + } + .into()) } }; @@ -1388,14 +1394,15 @@ impl VdafConfig { report_status: &HashMap, agg_job_id: &MetaAggregationJobId, agg_job_cont_req: &AggregationJobContinueReq, - ) -> Result<(DapAggregateSpan, AggregationJobResp), DapAbort> { + ) -> Result<(DapAggregateSpan, AggregationJobResp), DapError> { match agg_job_cont_req.round { Some(1) | None => {} Some(0) => { return Err(DapAbort::InvalidMessage { detail: "request shouldn't indicate round 0".into(), task_id: Some(*task_id), - }) + } + .into()) } // TODO(bhalleycf) For now, there is only ever one round, and we don't try to do // aggregation-round-skew-recovery. @@ -1404,7 +1411,8 @@ impl VdafConfig { detail: format!("The request indicates round {r}; round 1 was expected."), task_id: *task_id, agg_job_id_base64url: agg_job_id.to_base64url(), - }) + } + .into()) } } let mut processed = HashSet::with_capacity(state.seq.len()); @@ -1432,7 +1440,8 @@ impl VdafConfig { leader.report_id.to_base64url() ), task_id: Some(*task_id), - }); + } + .into()); } if processed.contains(&leader.report_id) { return Err(DapAbort::InvalidMessage { @@ -1441,7 +1450,8 @@ impl VdafConfig { leader.report_id.to_base64url() ), task_id: Some(*task_id), - }); + } + .into()); } // Find the next helper report that matches leader.report_id. @@ -1467,7 +1477,8 @@ impl VdafConfig { return Err(DapAbort::InvalidMessage { detail: "helper sent unexpected message instead of `Continued`".to_string(), task_id: Some(*task_id), - }); + } + .into()); }; let var = match report_status.get(&leader.report_id) { @@ -1519,7 +1530,7 @@ impl VdafConfig { state: DapAggregationJobUncommitted, agg_job_resp: AggregationJobResp, metrics: &DaphneMetrics, - ) -> Result, DapAbort> { + ) -> Result, DapError> { if agg_job_resp.transitions.len() != state.seq.len() { return Err(DapAbort::InvalidMessage { detail: format!( @@ -1528,7 +1539,8 @@ impl VdafConfig { agg_job_resp.transitions.len() ), task_id: None, - }); + } + .into()); } let mut agg_span = DapAggregateSpan::default(); @@ -1540,7 +1552,8 @@ impl VdafConfig { helper.report_id.to_base64url() ), task_id: None, - }); + } + .into()); } match &helper.var { @@ -1548,7 +1561,8 @@ impl VdafConfig { return Err(DapAbort::InvalidMessage { detail: "helper sent unexpected `Continued` message".to_string(), task_id: None, - }) + } + .into()) } // Skip report that can't be processed any further. @@ -1580,7 +1594,7 @@ impl VdafConfig { agg_param: &[u8], agg_share: &DapAggregateShare, version: DapVersion, - ) -> Result { + ) -> Result { produce_encrypted_agg_share( true, hpke_config, @@ -1602,7 +1616,7 @@ impl VdafConfig { agg_param: &[u8], agg_share: &DapAggregateShare, version: DapVersion, - ) -> Result { + ) -> Result { produce_encrypted_agg_share( false, hpke_config, @@ -1703,7 +1717,7 @@ fn produce_encrypted_agg_share( agg_param: &[u8], agg_share: &DapAggregateShare, version: DapVersion, -) -> Result { +) -> Result { let agg_share_data = agg_share .data .as_ref() @@ -1731,9 +1745,7 @@ fn produce_encrypted_agg_share( } batch_sel.encode(&mut aad); - let (enc, payload) = hpke_config - .encrypt(&info, &aad, &agg_share_data) - .map_err(|e| DapAbort::Internal(Box::new(e)))?; + let (enc, payload) = hpke_config.encrypt(&info, &aad, &agg_share_data)?; Ok(HpkeCiphertext { config_id: hpke_config.id, enc, @@ -2190,7 +2202,7 @@ mod test { assert_matches!( t.handle_agg_job_resp_expect_err(leader_state, agg_job_resp), - DapAbort::InvalidMessage { .. } + DapError::Abort(DapAbort::InvalidMessage { .. }) ); } @@ -2212,7 +2224,7 @@ mod test { assert_matches!( t.handle_agg_job_resp_expect_err(leader_state, agg_job_resp), - DapAbort::InvalidMessage { .. } + DapError::Abort(DapAbort::InvalidMessage { .. }) ); } @@ -2237,7 +2249,7 @@ mod test { assert_matches!( t.handle_agg_job_resp_expect_err(leader_state, agg_job_resp), - DapAbort::InvalidMessage { .. } + DapError::Abort(DapAbort::InvalidMessage { .. }) ); } @@ -2258,7 +2270,7 @@ mod test { assert_matches!( t.handle_agg_job_resp_expect_err(leader_state, agg_job_resp), - DapAbort::InvalidMessage { .. } + DapError::Abort(DapAbort::InvalidMessage { .. }) ); } @@ -2439,7 +2451,7 @@ mod test { assert_matches!( t.handle_agg_job_cont_req_expect_err(helper_state, &agg_job_cont_req), - DapAbort::InvalidMessage { .. } + DapError::Abort(DapAbort::InvalidMessage { .. }) ); } @@ -2465,7 +2477,7 @@ mod test { assert_matches!( t.handle_agg_job_cont_req_expect_err(helper_state, &agg_job_cont_req), - DapAbort::InvalidMessage { .. } + DapError::Abort(DapAbort::InvalidMessage { .. }) ); } @@ -2490,7 +2502,7 @@ mod test { assert_matches!( t.handle_agg_job_cont_req_expect_err(helper_state, &agg_job_cont_req), - DapAbort::InvalidMessage { .. } + DapError::Abort(DapAbort::InvalidMessage { .. }) ); } diff --git a/daphne_worker/src/config.rs b/daphne_worker/src/config.rs index 57dcc3130..cb87a5f12 100644 --- a/daphne_worker/src/config.rs +++ b/daphne_worker/src/config.rs @@ -520,11 +520,12 @@ impl<'srv> DaphneWorkerRequestState<'srv> { Ok(()) } - pub(crate) fn dap_abort_to_worker_response( - &self, - e: DapAbort, - ) -> Result { - let status = if matches!(e, DapAbort::Internal(..)) { + pub(crate) fn dap_abort_to_worker_response(&self, e: E) -> Result + where + E: Into, + { + let e = e.into(); + let status = if matches!(e, DapError::Fatal(..) | DapError::Transition(_)) { self.error_reporter.report_abort(&e); 500 } else { diff --git a/daphne_worker/src/error_reporting.rs b/daphne_worker/src/error_reporting.rs index 0a8229830..0927e1eaa 100644 --- a/daphne_worker/src/error_reporting.rs +++ b/daphne_worker/src/error_reporting.rs @@ -3,17 +3,17 @@ //! Daphne-Worker error reporting trait and default implementation. -use daphne::error::DapAbort; +use daphne::DapError; /// Interface for error reporting in Daphne /// Refer to `NoopErrorReporter` for implementation example. pub trait ErrorReporter { - fn report_abort(&self, error: &DapAbort); + fn report_abort(&self, error: &DapError); } /// Default implementation of the error reporting trait, which is a no-op. pub(crate) struct NoopErrorReporter {} impl ErrorReporter for NoopErrorReporter { - fn report_abort(&self, _error: &DapAbort) {} + fn report_abort(&self, _error: &DapError) {} } diff --git a/daphne_worker/src/roles/mod.rs b/daphne_worker/src/roles/mod.rs index 0ab70ffcb..dfcfff1ca 100644 --- a/daphne_worker/src/roles/mod.rs +++ b/daphne_worker/src/roles/mod.rs @@ -81,7 +81,7 @@ impl<'srv> HpkeDecrypter for DaphneWorker<'srv> { }) .await .map_err(|e| fatal_error!(err = ?e))? - .ok_or_else(|| DapError::Transition(TransitionFailure::HpkeUnknownConfigId))? + .ok_or(DapError::Transition(TransitionFailure::HpkeUnknownConfigId))? } } diff --git a/daphne_worker/src/router/aggregator.rs b/daphne_worker/src/router/aggregator.rs index 6efc280a3..c89681559 100644 --- a/daphne_worker/src/router/aggregator.rs +++ b/daphne_worker/src/router/aggregator.rs @@ -10,7 +10,7 @@ pub(super) fn add_aggregator_routes(router: DapRouter<'_>) -> DapRouter<'_> { let daph = ctx.data.handler(&ctx.env); let req = match daph.worker_request_to_dap(req, &ctx).await { Ok(req) => req, - Err(e) => return daph.state.dap_abort_to_worker_response(e.into()), + Err(e) => return daph.state.dap_abort_to_worker_response(e), }; let span = info_span_from_dap_request!("hpke_config", req); diff --git a/daphne_worker/src/router/helper.rs b/daphne_worker/src/router/helper.rs index 3b190fe56..b3f5718fb 100644 --- a/daphne_worker/src/router/helper.rs +++ b/daphne_worker/src/router/helper.rs @@ -35,7 +35,7 @@ async fn handle_agg_job( let daph = ctx.data.handler(&ctx.env); let req = match daph.worker_request_to_dap(req, &ctx).await { Ok(req) => req, - Err(e) => return daph.state.dap_abort_to_worker_response(e.into()), + Err(e) => return daph.state.dap_abort_to_worker_response(e), }; let span = match req.media_type { @@ -61,7 +61,7 @@ async fn handle_agg_share_req( let daph = ctx.data.handler(&ctx.env); let req = match daph.worker_request_to_dap(req, &ctx).await { Ok(req) => req, - Err(e) => return daph.state.dap_abort_to_worker_response(e.into()), + Err(e) => return daph.state.dap_abort_to_worker_response(e), }; let span = info_span_from_dap_request!(MeasuredSpanName::AggregateShares.as_str(), req); diff --git a/daphne_worker/src/router/leader.rs b/daphne_worker/src/router/leader.rs index 437d086ea..348986eed 100644 --- a/daphne_worker/src/router/leader.rs +++ b/daphne_worker/src/router/leader.rs @@ -22,7 +22,7 @@ pub(super) fn add_leader_routes(router: DapRouter<'_>) -> DapRouter<'_> { let daph = ctx.data.handler(&ctx.env); let req = match daph.worker_request_to_dap(req, &ctx).await { Ok(req) => req, - Err(e) => return daph.state.dap_abort_to_worker_response(e.into()), + Err(e) => return daph.state.dap_abort_to_worker_response(e), }; if req.version != DapVersion::Draft02 { return Response::error("not implemented", 404); @@ -33,7 +33,7 @@ pub(super) fn add_leader_routes(router: DapRouter<'_>) -> DapRouter<'_> { let daph = ctx.data.handler(&ctx.env); let req = match daph.worker_request_to_dap(req, &ctx).await { Ok(req) => req, - Err(e) => return daph.state.dap_abort_to_worker_response(e.into()), + Err(e) => return daph.state.dap_abort_to_worker_response(e), }; put_report_into_task(req, daph).await }) @@ -41,7 +41,7 @@ pub(super) fn add_leader_routes(router: DapRouter<'_>) -> DapRouter<'_> { let daph = ctx.data.handler(&ctx.env); let req = match daph.worker_request_to_dap(req, &ctx).await { Ok(req) => req, - Err(e) => return daph.state.dap_abort_to_worker_response(e.into()), + Err(e) => return daph.state.dap_abort_to_worker_response(e), }; if req.version != DapVersion::Draft02 { @@ -103,7 +103,7 @@ pub(super) fn add_leader_routes(router: DapRouter<'_>) -> DapRouter<'_> { "unknown collect id".into(), )) } - Err(e) => daph.state.dap_abort_to_worker_response(e.into()), + Err(e) => daph.state.dap_abort_to_worker_response(e), } }, ) // draft02 @@ -113,7 +113,7 @@ pub(super) fn add_leader_routes(router: DapRouter<'_>) -> DapRouter<'_> { let daph = ctx.data.handler(&ctx.env); let req = match daph.worker_request_to_dap(req, &ctx).await { Ok(req) => req, - Err(e) => return daph.state.dap_abort_to_worker_response(e.into()), + Err(e) => return daph.state.dap_abort_to_worker_response(e), }; let span = info_span_from_dap_request!("collect (PUT)", req); @@ -130,7 +130,7 @@ pub(super) fn add_leader_routes(router: DapRouter<'_>) -> DapRouter<'_> { let daph = ctx.data.handler(&ctx.env); let req = match daph.worker_request_to_dap(req, &ctx).await { Ok(req) => req, - Err(e) => return daph.state.dap_abort_to_worker_response(e.into()), + Err(e) => return daph.state.dap_abort_to_worker_response(e), }; let task_id = match req.task_id() { Ok(id) => id, @@ -175,7 +175,7 @@ pub(super) fn add_leader_routes(router: DapRouter<'_>) -> DapRouter<'_> { "unknown collect id".into(), )) } - Err(e) => daph.state.dap_abort_to_worker_response(e.into()), + Err(e) => daph.state.dap_abort_to_worker_response(e), } }, ) diff --git a/daphne_worker/src/router/test_routes.rs b/daphne_worker/src/router/test_routes.rs index 4d0d5e41b..84cfdef89 100644 --- a/daphne_worker/src/router/test_routes.rs +++ b/daphne_worker/src/router/test_routes.rs @@ -55,7 +55,7 @@ pub(super) fn add_internal_test_routes(router: DapRouter<'_>, role: Role) -> Dap Ok(batch_id) => { Response::from_bytes(batch_id.to_base64url().as_bytes().to_owned()) } - Err(e) => daph.state.dap_abort_to_worker_response(e.into()), + Err(e) => daph.state.dap_abort_to_worker_response(e), } }, ) @@ -71,7 +71,7 @@ pub(super) fn add_internal_test_routes(router: DapRouter<'_>, role: Role) -> Dap .await { Ok(()) => Response::empty(), - Err(e) => daph.state.dap_abort_to_worker_response(e.into()), + Err(e) => daph.state.dap_abort_to_worker_response(e), } }) // Endpoints for draft-dcook-ppm-dap-interop-test-design-02