Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ReportsProcessed DO #433

Merged
merged 3 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions daphne/src/roles/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ use crate::{
/// job.
#[async_trait(?Send)]
pub trait DapReportInitializer {
/// Initialize a sequence of reports that are in the "consumed" state by performing the early
/// validation steps (check if the report was replayed, belongs to a batch that has been
/// collected) and initializing VDAF preparation.
/// Initialize a sequence of reports that are in the "consumed" state by initializing VDAF
/// preparation.
async fn initialize_reports<'req>(
&self,
is_leader: bool,
Expand All @@ -38,6 +37,13 @@ pub trait DapReportInitializer {
) -> Result<Vec<EarlyReportStateInitialized<'req>>, DapError>;
}

#[derive(Debug)]
pub enum MergeAggShareError {
AlreadyCollected,
ReplaysDetected(HashSet<ReportId>),
Other(DapError),
}

/// DAP Aggregator functionality.
#[async_trait(?Send)]
pub trait DapAggregator<S>: HpkeDecrypter + DapReportInitializer + Sized {
Expand Down Expand Up @@ -117,15 +123,18 @@ pub trait DapAggregator<S>: HpkeDecrypter + DapReportInitializer + Sized {
///
/// A span with the same buckets as the input `agg_share_span` where the value is one of 3
/// possible sets of values:
/// - `Ok(None)` if all went well and no reports were replays.
/// - `Ok(Some(set))` if at least one report was a replay. This also means no aggregate shares where merged.
/// - `Err(err)` if an error occurred.
/// - `Ok(())` if all went well and no reports were replays.
/// - `Err(MergeAggShareError::ReplaysDetected)` if at least one report was a replay. This also
/// means no aggregate shares where merged.
/// - `Err(MergeAggShareError::AlreadyCollected)` This span belong to an aggregate share that
/// has been collected.
/// - `Err(MergeAggShareError::Other)` if another unrecoverable error occurred.
async fn try_put_agg_share_span(
&self,
task_id: &TaskId,
task_config: &DapTaskConfig,
agg_share_span: DapAggregateSpan<DapAggregateShare>,
) -> DapAggregateSpan<Result<HashSet<ReportId>, DapError>>;
) -> DapAggregateSpan<Result<(), MergeAggShareError>>;
mendess marked this conversation as resolved.
Show resolved Hide resolved

/// Fetch the aggregate share for the given batch.
async fn get_agg_share(
Expand Down
31 changes: 21 additions & 10 deletions daphne/src/roles/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::{
ReportId, TaskId, TransitionFailure, TransitionVar,
},
metrics::{DaphneMetrics, DaphneRequestType},
roles::aggregator::MergeAggShareError,
vdaf::ReportProcessedStatus,
DapAggregateShare, DapAggregateSpan, DapAggregationJobState, DapError,
DapHelperAggregationJobTransition, DapRequest, DapResource, DapResponse, DapTaskConfig,
Expand Down Expand Up @@ -266,12 +267,6 @@ pub trait DapHelper<S>: DapAggregator<S> {
})
.await?;

for transition in &agg_job_resp.transitions {
if let TransitionVar::Failed(failure) = &transition.var {
metrics.report_inc_by(&format!("rejected_{failure}"), 1);
}
}

let out_shares_count = agg_job_resp
.transitions
.iter()
Expand Down Expand Up @@ -503,7 +498,7 @@ async fn finish_agg_job_and_aggregate<S>(
for (_bucket, result) in put_shares_result {
match result {
// This bucket had no replays.
(Ok(replays), reports) if replays.is_empty() => {
(Ok(()), reports) => {
// Every report in the bucket has been committed to aggregate storage.
report_status.extend(
reports.into_iter().map(|(report_id, _time)| {
Expand All @@ -512,7 +507,7 @@ async fn finish_agg_job_and_aggregate<S>(
);
}
// This bucket had replays.
(Ok(replays), _reports) => {
(Err(MergeAggShareError::ReplaysDetected(replays)), _reports) => {
// At least one report was replayed (no change to aggregate storage).
report_status.extend(replays.into_iter().map(|report_id| {
(
Expand All @@ -522,6 +517,16 @@ async fn finish_agg_job_and_aggregate<S>(
}));
inc_restart_metric.call_once(|| metrics.agg_job_put_span_retry_inc());
}
// This bucket is contained by an aggregate share span that has been collected.
(Err(MergeAggShareError::AlreadyCollected), reports) => {
report_status.extend(reports.into_iter().map(|(report_id, _)| {
mendess marked this conversation as resolved.
Show resolved Hide resolved
(
report_id,
ReportProcessedStatus::Rejected(TransitionFailure::BatchCollected),
)
}));
inc_restart_metric.call_once(|| metrics.agg_job_put_span_retry_inc());
}
// If this happens, the leader and helper can possibly have inconsistent state.
// The leader will still think all of the reports in this job have yet to be
// aggregated. But we could have aggregated some and not others due to the
Expand All @@ -531,7 +536,7 @@ async fn finish_agg_job_and_aggregate<S>(
// and if this error doesn't manifest itself all reports will be successfully
// aggregated. Which means that no reports will be lost in a such a state that
// they can never be aggregated.
(Err(e), _) => return Err(e),
(Err(MergeAggShareError::Other(other)), _) => return Err(other),
}
}
if !inc_restart_metric.is_completed() {
Expand All @@ -544,6 +549,12 @@ async fn finish_agg_job_and_aggregate<S>(
.expect("usize to fit in u64");
metrics.report_inc_by("aggregated", out_shares_count);

for transition in &agg_job_resp.transitions {
if let TransitionVar::Failed(failure) = &transition.var {
metrics.report_inc_by(&format!("rejected_{failure}"), 1);
}
}

return Ok(agg_job_resp);
}
}
Expand Down Expand Up @@ -586,7 +597,7 @@ mod tests {
.map(|r| r.report_metadata.id)
.collect::<Vec<_>>();

let req = test
let (_, req) = test
.gen_test_agg_job_init_req(&task_id, DapVersion::Draft02, reports)
.await;

Expand Down
37 changes: 27 additions & 10 deletions daphne/src/roles/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ use prio::codec::{Decode, ParameterizedDecode, ParameterizedEncode};
use tracing::{debug, error};
use url::Url;

use super::{check_batch, check_request_content_type, resolve_taskprov, DapAggregator};
use super::{
aggregator::MergeAggShareError, check_batch, check_request_content_type, resolve_taskprov,
DapAggregator,
};
use crate::{
constants::DapMediaType,
error::DapAbort,
Expand Down Expand Up @@ -333,6 +336,7 @@ pub trait DapLeader<S>: DapAuthorizedSender<S> + DapAggregator<S> {
metrics,
)
.await?;

let (state, agg_job_init_req) = match transition {
DapLeaderAggregationJobTransition::Continued(state, agg_job_init_req) => {
(state, agg_job_init_req)
Expand Down Expand Up @@ -437,26 +441,39 @@ pub trait DapLeader<S>: DapAuthorizedSender<S> + DapAggregator<S> {

let out_shares_count = agg_span.report_count() as u64;

// At this point we're committed to aggregating the reports: if we do detect a report was
// replayed at this stage, then we may end up with a batch mismatch. However, this should
// only happen if there are multiple aggregation jobs in-flight that include the same
// report.
let replayed = self
// At this point we're committed to aggregating the reports: if we do detect an error (a
// report was replayed at this stage or the span overlaps with a collected batch), then we
// may end up with a batch mismatch. However, this should only happen if there are multiple
// aggregation jobs in-flight that include the same report.
let (replayed, collected) = self
.try_put_agg_share_span(task_id, task_config, agg_span)
mendess marked this conversation as resolved.
Show resolved Hide resolved
.await
.into_iter()
.map(|(_bucket, (result, _report_metadata))| {
result.map(|replayed_reports| replayed_reports.len())
.map(|(_bucket, (result, _report_metadata))| match result {
Ok(()) => Ok((0, 0)),
Err(MergeAggShareError::AlreadyCollected) => Ok((0, 1)),
Err(MergeAggShareError::ReplaysDetected(replays)) => Ok((replays.len(), 0)),
Err(MergeAggShareError::Other(e)) => Err(e),
})
.sum::<Result<usize, _>>()?;
.try_fold((0, 0), |(replayed, collected), rc| {
let (r, c) = rc?;
Ok::<_, DapError>((replayed + r, collected + c))
})?;

mendess marked this conversation as resolved.
Show resolved Hide resolved
if replayed > 0 {
tracing::warn!(
tracing::error!(
replay_count = replayed,
"tried to aggregate replayed reports"
);
}

if collected > 0 {
tracing::error!(
collected_count = collected,
"tried to aggregate reports belonging to collected spans"
);
}

metrics.report_inc_by("aggregated", out_shares_count);
Ok(out_shares_count)
}
Expand Down
93 changes: 68 additions & 25 deletions daphne/src/roles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

//! Trait definitions for Daphne backends.

mod aggregator;
pub mod aggregator;
mod helper;
mod leader;

Expand Down Expand Up @@ -170,14 +170,14 @@ mod test {
test_versions,
testing::{AggStore, MockAggregator, MockAggregatorReportSelector},
vdaf::VdafVerifyKey,
DapAbort, DapAggregateShare, DapBatchBucket, DapCollectJob, DapError, DapGlobalConfig,
DapLeaderAggregationJobTransition, DapMeasurement, DapQueryConfig, DapRequest, DapResource,
DapTaskConfig, DapTaskParameters, DapVersion, MetaAggregationJobId, Prio3Config,
VdafConfig,
DapAbort, DapAggregateShare, DapAggregationJobState, DapBatchBucket, DapCollectJob,
DapError, DapGlobalConfig, DapLeaderAggregationJobTransition, DapMeasurement,
DapQueryConfig, DapRequest, DapResource, DapTaskConfig, DapTaskParameters, DapVersion,
MetaAggregationJobId, Prio3Config, VdafConfig,
};
use assert_matches::assert_matches;
use matchit::Router;
use prio::codec::{Decode, ParameterizedEncode};
use prio::codec::{Decode, ParameterizedDecode, ParameterizedEncode};
use rand::{thread_rng, Rng};
use std::{collections::HashMap, sync::Arc, time::SystemTime, vec};
use url::Url;
Expand Down Expand Up @@ -498,7 +498,7 @@ mod test {
task_id: &TaskId,
version: DapVersion,
reports: Vec<Report>,
) -> DapRequest<BearerToken> {
) -> (DapAggregationJobState, DapRequest<BearerToken>) {
let mut rng = thread_rng();
let task_config = self.leader.unchecked_get_task_config(task_id).await;
let part_batch_sel = match task_config.query {
Expand All @@ -510,7 +510,7 @@ mod test {

let agg_job_id = MetaAggregationJobId::gen_for_version(version);

let DapLeaderAggregationJobTransition::Continued(_leader_state, agg_job_init_req) =
let DapLeaderAggregationJobTransition::Continued(leader_state, agg_job_init_req) =
task_config
.vdaf
.produce_agg_job_init_req(
Expand All @@ -529,15 +529,18 @@ mod test {
panic!("unexpected transition");
};

self.leader_authorized_req(
task_id,
&task_config,
Some(&agg_job_id),
DapMediaType::AggregationJobInitReq,
agg_job_init_req,
task_config.helper_url.join("aggregate").unwrap(),
(
leader_state,
self.leader_authorized_req(
task_id,
&task_config,
Some(&agg_job_id),
DapMediaType::AggregationJobInitReq,
agg_job_init_req,
task_config.helper_url.join("aggregate").unwrap(),
)
.await,
)
.await
}

pub async fn gen_test_agg_job_cont_req_with_round(
Expand Down Expand Up @@ -819,7 +822,7 @@ mod test {
async fn handle_agg_job_init_req_unauthorized_request(version: DapVersion) {
let t = Test::new(version);
let report = t.gen_test_report(&t.time_interval_task_id).await;
let mut req = t
let (_, mut req) = t
.gen_test_agg_job_init_req(&t.time_interval_task_id, version, vec![report])
.await;
req.sender_auth = None;
Expand Down Expand Up @@ -1058,7 +1061,7 @@ mod test {

let mut report = t.gen_test_report(task_id).await;
report.encrypted_input_shares[1].payload[0] ^= 0xff; // Cause decryption to fail
let req = t
let (_, req) = t
.gen_test_agg_job_init_req(task_id, version, vec![report])
.await;

Expand All @@ -1083,7 +1086,7 @@ mod test {
let task_id = &t.time_interval_task_id;

let report = t.gen_test_report(task_id).await;
let req = t
let (_, req) = t
.gen_test_agg_job_init_req(task_id, version, vec![report])
.await;

Expand All @@ -1105,7 +1108,7 @@ mod test {
let task_id = &t.time_interval_task_id;

let report = t.gen_test_report(task_id).await;
let req = t
let (leader_state, req) = t
.gen_test_agg_job_init_req(task_id, version, vec![report.clone()])
.await;

Expand All @@ -1126,17 +1129,57 @@ mod test {
&t.helper.handle_agg_job_req(&req).await.unwrap().payload,
)
.unwrap();
let transition = &agg_job_resp.transitions[0];
let transitions = if version == DapVersion::Draft02 {
// in version 2 replays are only detected later.
let agg_job_id =
AggregationJobInitReq::get_decoded_with_param(&DapVersion::Draft02, &req.payload)
.unwrap()
.draft02_agg_job_id
.unwrap();
let task_config = t.leader.unchecked_get_task_config(task_id).await;
let transition = task_config
.vdaf
.handle_agg_job_resp(
task_id,
&task_config,
&MetaAggregationJobId::Draft02(agg_job_id),
leader_state,
agg_job_resp,
t.leader.metrics(),
)
.unwrap();
let DapLeaderAggregationJobTransition::Uncommitted(
_,
AggregationJobContinueReq { transitions, .. },
) = transition
else {
panic!("expected uncommitted transition, was {transition:?}");
};
let req = t
.gen_test_agg_job_cont_req(
&MetaAggregationJobId::Draft02(agg_job_id),
transitions,
version,
)
.await;
AggregationJobResp::get_decoded(
&t.helper.handle_agg_job_req(&req).await.unwrap().payload,
)
.unwrap()
.transitions
} else {
agg_job_resp.transitions
};

// Expect failure due to report store marked as collected.
assert_matches!(
transition.var,
transitions[0].var,
TransitionVar::Failed(TransitionFailure::ReportReplayed)
);

assert_metrics_include!(t.helper_registry, {
r#"report_counter{env="test_helper",host="helper.org",status="rejected_report_replayed"}"#: 1,
r#"inbound_request_counter{env="test_helper",host="helper.org",type="aggregate"}"#: 1,
r#"inbound_request_counter{env="test_helper",host="helper.org",type="aggregate"}"#: if version == DapVersion::Draft02 { 2 } else { 1 },
r#"aggregation_job_counter{env="test_helper",host="helper.org",status="started"}"#: 1,
});
}
Expand All @@ -1149,7 +1192,7 @@ mod test {
let task_config = t.helper.unchecked_get_task_config(task_id).await;

let report = t.gen_test_report(task_id).await;
let req = t
let (_, req) = t
.gen_test_agg_job_init_req(task_id, version, vec![report])
.await;

Expand Down Expand Up @@ -1204,7 +1247,7 @@ mod test {
let task_id = &t.time_interval_task_id;

let report = t.gen_test_report(task_id).await;
let req = t
let (_, req) = t
.gen_test_agg_job_init_req(task_id, DapVersion::Draft02, vec![report])
.await;

Expand Down
Loading
Loading