Skip to content

Commit

Permalink
Make abort->error conversion one way
Browse files Browse the repository at this point in the history
Having DapError convertible into DapAbort means that an Error can be
converted into an abort and an abort converted into an error an infinite
number of times, meaning the type can grow in an unbounded way at
runtime. This is very confusing to reason about when reading the
signature of a method that returns a `DapAbort`. Now a method that
returns a `DapAbort` can only return due to a documented protocol error
and not another system related error (such as mis configuration or IO).
  • Loading branch information
mendess committed Nov 23, 2023
1 parent 7989c48 commit 21bc5c6
Show file tree
Hide file tree
Showing 16 changed files with 223 additions and 195 deletions.
33 changes: 10 additions & 23 deletions daphne/src/error/aborts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ pub enum DapAbort {
#[error("batchOverlap")]
BatchOverlap { detail: String, task_id: TaskId },

/// Internal error.
#[error("internal error")]
Internal(#[source] Box<dyn std::error::Error + 'static + Send + Sync>),

/// Invalid batch size (either too small or too large). Sent in response to a CollectReq or
/// AggregateShareReq.
#[error("invalidBatchSize")]
Expand Down Expand Up @@ -143,7 +139,7 @@ impl DapAbort {
Some(agg_job_id_base64url),
),
Self::InvalidMessage { detail, task_id } => (task_id, Some(detail), None),
Self::ReportTooLate | Self::UnrecognizedTask | Self::Internal(_) => (None, None, None),
Self::ReportTooLate | Self::UnrecognizedTask => (None, None, None),
};

ProblemDetails {
Expand Down Expand Up @@ -207,23 +203,25 @@ impl DapAbort {
}

#[inline]
pub fn report_rejected(failure_reason: TransitionFailure) -> Self {
pub fn report_rejected(failure_reason: TransitionFailure) -> Result<Self, DapError> {
let detail = match failure_reason {
TransitionFailure::BatchCollected => {
"The report pertains to a batch that has already been collected."
}
TransitionFailure::ReportReplayed => {
"A report with the same ID was uploaded previously."
}
_ => return fatal_error!(
err = "Attempted to construct a \"reportRejected\" abort with unexpected transition failure",
unexpected_transition_failure = ?failure_reason,
).into(),
_ => {
return Err(fatal_error!(
err = "Attempted to construct a \"reportRejected\" abort with unexpected transition failure",
unexpected_transition_failure = ?failure_reason,
))
}
};

Self::ReportRejected {
Ok(Self::ReportRejected {
detail: detail.into(),
}
})
}

fn title_and_type(&self) -> (String, Option<String>) {
Expand Down Expand Up @@ -267,7 +265,6 @@ impl DapAbort {
Some(self.to_string()),
),
Self::BadRequest(..) => ("Bad request", None),
Self::Internal(..) => ("Internal server error", None),
};

(
Expand All @@ -277,16 +274,6 @@ impl DapAbort {
}
}

impl From<DapError> for DapAbort {
fn from(e: DapError) -> Self {
match e {
e @ DapError::Fatal(..) => Self::Internal(Box::new(e)),
DapError::Abort(abort) => abort,
DapError::Transition(failure_reason) => Self::report_rejected(failure_reason),
}
}
}

impl DapAbort {
pub fn from_codec_error<Id: Into<Option<TaskId>>>(e: CodecError, task_id: Id) -> Self {
Self::InvalidMessage {
Expand Down
19 changes: 19 additions & 0 deletions daphne/src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use std::fmt::{Debug, Display};
use crate::{messages::TransitionFailure, vdaf::VdafError};
pub use aborts::DapAbort;

use self::aborts::ProblemDetails;

/// DAP errors.
#[derive(Debug, thiserror::Error)]
pub enum DapError {
Expand All @@ -28,6 +30,23 @@ pub enum DapError {
Transition(#[from] TransitionFailure),
}

impl DapError {
pub fn into_problem_details(self) -> ProblemDetails {
if let Self::Abort(a) = self {
return a.into_problem_details();
}

ProblemDetails {
typ: None,
title: "Internal server error".into(),
agg_job_id: None,
task_id: None,
instance: None,
detail: None,
}
}
}

impl FatalDapError {
#[doc(hidden)]
pub fn __use_the_macro(s: String) -> Self {
Expand Down
11 changes: 5 additions & 6 deletions daphne/src/roles/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,14 @@ pub trait DapAggregator<S>: HpkeDecrypter + DapReportInitializer + Sized {
) -> Result<(), DapError>;

/// Handle request for the Aggregator's HPKE configuration.
async fn handle_hpke_config_req(&self, req: &DapRequest<S>) -> Result<DapResponse, DapAbort> {
async fn handle_hpke_config_req(&self, req: &DapRequest<S>) -> Result<DapResponse, DapError> {
let metrics = self.metrics();

// Parse the task ID from the query string, ensuring that it is the only query parameter.
let mut id = None;
for (k, v) in req.url.query_pairs() {
if k != "task_id" {
return Err(DapAbort::BadRequest("unexpected query parameter".into()));
return Err(DapAbort::BadRequest("unexpected query parameter".into()).into());
}

let bytes = decode_base64url(v.as_bytes()).ok_or(DapAbort::BadRequest(
Expand All @@ -169,10 +169,9 @@ pub trait DapAggregator<S>: HpkeDecrypter + DapReportInitializer + Sized {

// Check whether the DAP version in the request matches the task config.
if task_config.as_ref().version != req.version {
return Err(DapAbort::version_mismatch(
req.version,
task_config.as_ref().version,
));
return Err(
DapAbort::version_mismatch(req.version, task_config.as_ref().version).into(),
);
}
}

Expand Down
52 changes: 27 additions & 25 deletions daphne/src/roles/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub trait DapHelper<S>: DapAggregator<S> {
req: &'req DapRequest<S>,
metrics: &DaphneMetrics,
task_id: &TaskId,
) -> Result<DapResponse, DapAbort> {
) -> Result<DapResponse, DapError> {
let agg_job_init_req =
AggregationJobInitReq::get_decoded_with_param(&req.version, &req.payload)
.map_err(|e| DapAbort::from_codec_error(e, *task_id))?;
Expand Down Expand Up @@ -88,7 +88,8 @@ pub trait DapHelper<S>: DapAggregator<S> {
detail: "some reports include the taskprov extensions and some do not"
.to_string(),
task_id: Some(*task_id),
});
}
.into());
}
};
resolve_taskprov(self, task_id, req, first_metadata).await?;
Expand All @@ -105,14 +106,15 @@ pub trait DapHelper<S>: DapAggregator<S> {
return Err(DapAbort::UnauthorizedRequest {
detail: reason,
task_id: *task_id,
});
}
.into());
}

let agg_job_id = resolve_agg_job_id(req, agg_job_init_req.draft02_agg_job_id.as_ref())?;

// Check whether the DAP version in the request matches the task config.
if task_config.version != req.version {
return Err(DapAbort::version_mismatch(req.version, task_config.version));
return Err(DapAbort::version_mismatch(req.version, task_config.version).into());
}

// Ensure we know which batch the request pertains to.
Expand Down Expand Up @@ -140,7 +142,7 @@ pub trait DapHelper<S>: DapAggregator<S> {
metrics,
)?
else {
return Err(DapAbort::from(fatal_error!(err = "unexpected transition")));
return Err(fatal_error!(err = "unexpected transition"));
};

if !self
Expand All @@ -150,7 +152,8 @@ pub trait DapHelper<S>: DapAggregator<S> {
// TODO spec: Consider an explicit abort for this case.
return Err(DapAbort::BadRequest(
"unexpected message for aggregation job (already exists)".into(),
));
)
.into());
}
metrics.agg_job_started_inc();
agg_job_resp
Expand All @@ -173,9 +176,7 @@ pub trait DapHelper<S>: DapAggregator<S> {
metrics,
)?
else {
return Err(DapAbort::from(fatal_error!(
err = "unexpected transition"
)));
return Err(fatal_error!(err = "unexpected transition"));
};
Ok((agg_span, agg_job_resp))
},
Expand Down Expand Up @@ -209,7 +210,7 @@ pub trait DapHelper<S>: DapAggregator<S> {
req: &'req DapRequest<S>,
metrics: &DaphneMetrics,
task_id: &TaskId,
) -> Result<DapResponse, DapAbort> {
) -> Result<DapResponse, DapError> {
if self.get_global_config().allow_taskprov {
resolve_taskprov(self, task_id, req, None).await?;
}
Expand All @@ -224,12 +225,13 @@ pub trait DapHelper<S>: DapAggregator<S> {
return Err(DapAbort::UnauthorizedRequest {
detail: reason,
task_id: *task_id,
});
}
.into());
}

// Check whether the DAP version in the request matches the task config.
if task_config.version != req.version {
return Err(DapAbort::version_mismatch(req.version, task_config.version));
return Err(DapAbort::version_mismatch(req.version, task_config.version).into());
}

let agg_job_cont_req =
Expand Down Expand Up @@ -290,7 +292,7 @@ pub trait DapHelper<S>: DapAggregator<S> {
}

/// Handle a request pertaining to an aggregation job.
async fn handle_agg_job_req(&self, req: &DapRequest<S>) -> Result<DapResponse, DapAbort> {
async fn handle_agg_job_req(&self, req: &DapRequest<S>) -> Result<DapResponse, DapError> {
let metrics = self.metrics();
let task_id = req.task_id()?;

Expand All @@ -302,13 +304,13 @@ pub trait DapHelper<S>: DapAggregator<S> {
self.handle_agg_job_cont_req(req, metrics, task_id).await
}
//TODO spec: Specify this behavior.
_ => Err(DapAbort::BadRequest("unexpected media type".into())),
_ => Err(DapAbort::BadRequest("unexpected media type".into()).into()),
}
}

/// Handle a request for an aggregate share. This is called by the Leader to complete a
/// collection job.
async fn handle_agg_share_req(&self, req: &DapRequest<S>) -> Result<DapResponse, DapAbort> {
async fn handle_agg_share_req(&self, req: &DapRequest<S>) -> Result<DapResponse, DapError> {
let now = self.get_current_time();
let metrics = self.metrics();
let task_id = req.task_id()?;
Expand All @@ -330,12 +332,13 @@ pub trait DapHelper<S>: DapAggregator<S> {
return Err(DapAbort::UnauthorizedRequest {
detail: reason,
task_id: *task_id,
});
}
.into());
}

// Check whether the DAP version in the request matches the task config.
if task_config.version != req.version {
return Err(DapAbort::version_mismatch(req.version, task_config.version));
return Err(DapAbort::version_mismatch(req.version, task_config.version).into());
}

let agg_share_req = AggregateShareReq::get_decoded_with_param(&req.version, &req.payload)
Expand Down Expand Up @@ -368,7 +371,7 @@ pub trait DapHelper<S>: DapAggregator<S> {
agg_share.report_count,
hex::encode(agg_share.checksum)),
task_id: *task_id,
});
}.into());
}

// Check the batch size.
Expand All @@ -382,7 +385,8 @@ pub trait DapHelper<S>: DapAggregator<S> {
agg_share.report_count, task_config.min_batch_size
),
task_id: *task_id,
});
}
.into());
}

// Mark each aggregated report as collected.
Expand Down Expand Up @@ -468,9 +472,9 @@ async fn finish_agg_job_and_aggregate<S>(
&HashMap<ReportId, ReportProcessedStatus>,
) -> Result<
(DapAggregateSpan<DapAggregateShare>, AggregationJobResp),
DapAbort,
DapError,
>,
) -> Result<AggregationJobResp, DapAbort> {
) -> Result<AggregationJobResp, DapError> {
// This loop is intended to run at most once on the "happy path". The intent is as follows:
//
// - try to aggregate the output shares into an `DapAggregateShareSpan`
Expand Down Expand Up @@ -522,7 +526,7 @@ async fn finish_agg_job_and_aggregate<S>(
// and if this error doesn't manifest itself all reports will be successfully
// aggregated. Which means that no reports will be lost in a such a state that
// they can never be aggregated.
(Err(e), _) => return Err(e.into()),
(Err(e), _) => return Err(e),
}
}
if !inc_restart_metric.is_completed() {
Expand All @@ -541,9 +545,7 @@ async fn finish_agg_job_and_aggregate<S>(

// We need to prevent an attacker from keeping this loop running for too long, potentially
// enabling an DOS attack.
Err(DapAbort::BadRequest(
"AggregationJobContinueReq contained too many replays".into(),
))
Err(DapAbort::BadRequest("AggregationJobContinueReq contained too many replays".into()).into())
}

#[cfg(test)]
Expand Down
Loading

0 comments on commit 21bc5c6

Please sign in to comment.