Skip to content

Commit

Permalink
Rework report replayed test
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Nov 30, 2023
1 parent 15f5930 commit 8231f9e
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 64 deletions.
12 changes: 7 additions & 5 deletions daphne/src/roles/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ use crate::{
/// job.
#[async_trait(?Send)]
pub trait DapReportInitializer {
/// Initialize a sequence of reports that are in the "consumed" state by performing the early
/// validation steps (belongs to a batch that has been collected) and initializing VDAF
/// Initialize a sequence of reports that are in the "consumed" state by initializing VDAF
/// preparation.
async fn initialize_reports<'req>(
&self,
Expand Down Expand Up @@ -124,9 +123,12 @@ pub trait DapAggregator<S>: HpkeDecrypter + DapReportInitializer + Sized {
///
/// A span with the same buckets as the input `agg_share_span` where the value is one of 3
/// possible sets of values:
/// - `Ok(None)` if all went well and no reports were replays.
/// - `Ok(Some(set))` if at least one report was a replay. This also means no aggregate shares where merged.
/// - `Err(err)` if an error occurred.
/// - `Ok(())` if all went well and no reports were replays.
/// - `Err(MergeAggShareError::ReplaysDetected)` if at least one report was a replay. This also
/// means no aggregate shares where merged.
/// - `Err(MergeAggShareError::AlreadyCollected)` This span belong to an aggregate share that
/// has been collected.
/// - `Err(MergeAggShareError::Other)` if another unrecoverable error occurred.
async fn try_put_agg_share_span(
&self,
task_id: &TaskId,
Expand Down
16 changes: 8 additions & 8 deletions daphne/src/roles/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,6 @@ pub trait DapHelper<S>: DapAggregator<S> {
})
.await?;

for transition in &agg_job_resp.transitions {
if let TransitionVar::Failed(failure) = &transition.var {
metrics.report_inc_by(&format!("rejected_{failure}"), 1);
}
}

let out_shares_count = agg_job_resp
.transitions
.iter()
Expand Down Expand Up @@ -518,7 +512,7 @@ async fn finish_agg_job_and_aggregate<S>(
}));
inc_restart_metric.call_once(|| metrics.agg_job_put_span_retry_inc());
}
// This bucket belongs to a collected aggregate share.
// This bucket is contained by an aggregate share span that has been collected.
(Err(MergeAggShareError::AlreadyCollected), reports) => {
report_status.extend(reports.into_iter().map(|(report_id, _)| {
(
Expand Down Expand Up @@ -550,6 +544,12 @@ async fn finish_agg_job_and_aggregate<S>(
.expect("usize to fit in u64");
metrics.report_inc_by("aggregated", out_shares_count);

for transition in &agg_job_resp.transitions {
if let TransitionVar::Failed(failure) = &transition.var {
metrics.report_inc_by(&format!("rejected_{failure}"), 1);
}
}

return Ok(agg_job_resp);
}
}
Expand Down Expand Up @@ -592,7 +592,7 @@ mod tests {
.map(|r| r.report_metadata.id)
.collect::<Vec<_>>();

let req = test
let (_, req) = test
.gen_test_agg_job_init_req(&task_id, DapVersion::Draft02, reports)
.await;

Expand Down
32 changes: 20 additions & 12 deletions daphne/src/roles/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,31 +435,39 @@ pub trait DapLeader<S>: DapAuthorizedSender<S> + DapAggregator<S> {

let out_shares_count = agg_span.report_count() as u64;

// At this point we're committed to aggregating the reports: if we do detect a report was
// replayed at this stage, then we may end up with a batch mismatch. However, this should
// only happen if there are multiple aggregation jobs in-flight that include the same
// report.
let replayed = self
// At this point we're committed to aggregating the reports: if we do detect an error (a
// report was replayed at this stage or the span overlaps with a collected batch), then we
// may end up with a batch mismatch. However, this should only happen if there are multiple
// aggregation jobs in-flight that include the same report.
let (replayed, collected) = self
.try_put_agg_share_span(task_id, task_config, agg_span)
.await
.into_iter()
.map(|(_bucket, (result, _report_metadata))| match result {
Ok(()) => Ok(0),
Err(MergeAggShareError::AlreadyCollected) => {
panic!("aggregated to a collected agg share")
}
Err(MergeAggShareError::ReplaysDetected(replays)) => Ok(replays.len()),
Ok(()) => Ok((0, 0)),
Err(MergeAggShareError::AlreadyCollected) => Ok((0, 1)),
Err(MergeAggShareError::ReplaysDetected(replays)) => Ok((replays.len(), 0)),
Err(MergeAggShareError::Other(e)) => Err(e),
})
.sum::<Result<usize, _>>()?;
.try_fold((0, 0), |(replayed, collected), rc| {
let (r, c) = rc?;
Ok::<_, DapError>((replayed + r, collected + c))
})?;

if replayed > 0 {
tracing::warn!(
tracing::error!(
replay_count = replayed,
"tried to aggregate replayed reports"
);
}

if collected > 0 {
tracing::error!(
collected_count = collected,
"tried to aggregate reports belonging to collected spans"
);
}

metrics.report_inc_by("aggregated", out_shares_count);
Ok(out_shares_count)
}
Expand Down
125 changes: 106 additions & 19 deletions daphne/src/roles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,14 @@ mod test {
test_versions,
testing::{AggStore, MockAggregator, MockAggregatorReportSelector},
vdaf::VdafVerifyKey,
DapAbort, DapAggregateShare, DapBatchBucket, DapCollectJob, DapError, DapGlobalConfig,
DapLeaderAggregationJobTransition, DapMeasurement, DapQueryConfig, DapRequest, DapResource,
DapTaskConfig, DapVersion, MetaAggregationJobId, Prio3Config, VdafConfig,
DapAbort, DapAggregateShare, DapAggregationJobState, DapBatchBucket, DapCollectJob,
DapError, DapGlobalConfig, DapLeaderAggregationJobTransition, DapMeasurement,
DapQueryConfig, DapRequest, DapResource, DapTaskConfig, DapVersion, MetaAggregationJobId,
Prio3Config, VdafConfig,
};
use assert_matches::assert_matches;
use matchit::Router;
use prio::codec::{Decode, ParameterizedEncode};
use prio::codec::{Decode, ParameterizedDecode, ParameterizedEncode};
use rand::{thread_rng, Rng};
use std::{collections::HashMap, sync::Arc, time::SystemTime, vec};
use url::Url;
Expand Down Expand Up @@ -497,7 +498,7 @@ mod test {
task_id: &TaskId,
version: DapVersion,
reports: Vec<Report>,
) -> DapRequest<BearerToken> {
) -> (DapAggregationJobState, DapRequest<BearerToken>) {
let mut rng = thread_rng();
let task_config = self.leader.unchecked_get_task_config(task_id).await;
let part_batch_sel = match task_config.query {
Expand All @@ -509,7 +510,7 @@ mod test {

let agg_job_id = MetaAggregationJobId::gen_for_version(version);

let DapLeaderAggregationJobTransition::Continued(_leader_state, agg_job_init_req) =
let DapLeaderAggregationJobTransition::Continued(leader_state, agg_job_init_req) =
task_config
.vdaf
.produce_agg_job_init_req(
Expand All @@ -528,15 +529,18 @@ mod test {
panic!("unexpected transition");
};

self.leader_authorized_req(
task_id,
&task_config,
Some(&agg_job_id),
DapMediaType::AggregationJobInitReq,
agg_job_init_req,
task_config.helper_url.join("aggregate").unwrap(),
(
leader_state,
self.leader_authorized_req(
task_id,
&task_config,
Some(&agg_job_id),
DapMediaType::AggregationJobInitReq,
agg_job_init_req,
task_config.helper_url.join("aggregate").unwrap(),
)
.await,
)
.await
}

pub async fn gen_test_agg_job_cont_req_with_round(
Expand Down Expand Up @@ -818,7 +822,7 @@ mod test {
async fn handle_agg_job_init_req_unauthorized_request(version: DapVersion) {
let t = Test::new(version);
let report = t.gen_test_report(&t.time_interval_task_id).await;
let mut req = t
let (_, mut req) = t
.gen_test_agg_job_init_req(&t.time_interval_task_id, version, vec![report])
.await;
req.sender_auth = None;
Expand Down Expand Up @@ -1057,7 +1061,7 @@ mod test {

let mut report = t.gen_test_report(task_id).await;
report.encrypted_input_shares[1].payload[0] ^= 0xff; // Cause decryption to fail
let req = t
let (_, req) = t
.gen_test_agg_job_init_req(task_id, version, vec![report])
.await;

Expand All @@ -1082,7 +1086,7 @@ mod test {
let task_id = &t.time_interval_task_id;

let report = t.gen_test_report(task_id).await;
let req = t
let (_, req) = t
.gen_test_agg_job_init_req(task_id, version, vec![report])
.await;

Expand All @@ -1099,13 +1103,96 @@ mod test {

async_test_versions! { handle_agg_job_req_transition_continue }

async fn handle_agg_job_req_failure_report_replayed(version: DapVersion) {
let t = Test::new(version);
let task_id = &t.time_interval_task_id;

let report = t.gen_test_report(task_id).await;
let (leader_state, req) = t
.gen_test_agg_job_init_req(task_id, version, vec![report.clone()])
.await;

// Add dummy data to report store backend. This is done in a new scope so that the lock on the
// report store is released before running the test.
{
let mut guard = t
.helper
.report_store
.lock()
.expect("report_store: failed to lock");
let report_store = guard.entry(*task_id).or_default();
report_store.processed.insert(report.report_metadata.id);
}

// Get AggregationJobResp and then extract the transition data from inside.
let agg_job_resp = AggregationJobResp::get_decoded(
&t.helper.handle_agg_job_req(&req).await.unwrap().payload,
)
.unwrap();
let transitions = if version == DapVersion::Draft02 {
// in version 2 replays are only detected later.
let agg_job_id =
AggregationJobInitReq::get_decoded_with_param(&DapVersion::Draft02, &req.payload)
.unwrap()
.draft02_agg_job_id
.unwrap();
let task_config = t.leader.unchecked_get_task_config(task_id).await;
let transition = task_config
.vdaf
.handle_agg_job_resp(
task_id,
&task_config,
&MetaAggregationJobId::Draft02(agg_job_id),
leader_state,
agg_job_resp,
t.leader.metrics(),
)
.unwrap();
let DapLeaderAggregationJobTransition::Uncommitted(
_,
AggregationJobContinueReq { transitions, .. },
) = transition
else {
panic!("expected uncommitted transition, was {transition:?}");
};
let req = t
.gen_test_agg_job_cont_req(
&MetaAggregationJobId::Draft02(agg_job_id),
transitions,
version,
)
.await;
AggregationJobResp::get_decoded(
&t.helper.handle_agg_job_req(&req).await.unwrap().payload,
)
.unwrap()
.transitions
} else {
agg_job_resp.transitions
};

// Expect failure due to report store marked as collected.
assert_matches!(
transitions[0].var,
TransitionVar::Failed(TransitionFailure::ReportReplayed)
);

assert_metrics_include!(t.helper_registry, {
r#"report_counter{env="test_helper",host="helper.org",status="rejected_report_replayed"}"#: 1,
r#"inbound_request_counter{env="test_helper",host="helper.org",type="aggregate"}"#: if version == DapVersion::Draft02 { 2 } else { 1 },
r#"aggregation_job_counter{env="test_helper",host="helper.org",status="started"}"#: 1,
});
}

async_test_versions! { handle_agg_job_req_failure_report_replayed }

async fn handle_agg_job_req_failure_batch_collected(version: DapVersion) {
let t = Test::new(version);
let task_id = &t.time_interval_task_id;
let task_config = t.helper.unchecked_get_task_config(task_id).await;

let report = t.gen_test_report(task_id).await;
let req = t
let (_, req) = t
.gen_test_agg_job_init_req(task_id, version, vec![report])
.await;

Expand Down Expand Up @@ -1160,7 +1247,7 @@ mod test {
let task_id = &t.time_interval_task_id;

let report = t.gen_test_report(task_id).await;
let req = t
let (_, req) = t
.gen_test_agg_job_init_req(task_id, DapVersion::Draft02, vec![report])
.await;

Expand Down
1 change: 0 additions & 1 deletion daphne/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,6 @@ impl MockAggregator {

/// Conducts checks on a received report to see whether:
/// 1) the report falls into a batch that has been already collected, or
/// 2) the report has been submitted by the client in the past.
fn check_report_has_been_collected(
&self,
task_id: &TaskId,
Expand Down
20 changes: 1 addition & 19 deletions daphne/src/vdaf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,6 @@ impl VdafConfig {
report_status,
initialized_reports,
agg_job_init_req,
metrics,
),
}
}
Expand Down Expand Up @@ -1034,7 +1033,6 @@ impl VdafConfig {
report_status: &HashMap<ReportId, ReportProcessedStatus>,
initialized_reports: &[EarlyReportStateInitialized<'_>],
agg_job_init_req: &AggregationJobInitReq,
metrics: &DaphneMetrics,
) -> Result<DapHelperAggregationJobTransition<AggregationJobResp>, DapError> {
let num_reports = agg_job_init_req.prep_inits.len();
let mut agg_span = DapAggregateSpan::default();
Expand Down Expand Up @@ -1106,7 +1104,6 @@ impl VdafConfig {

Err(VdafError::Codec(..) | VdafError::Vdaf(..)) => {
let failure = TransitionFailure::VdafPrepError;
metrics.report_inc_by(&format!("rejected_{failure}"), 1);
TransitionVar::Failed(failure)
}
}
Expand All @@ -1115,10 +1112,7 @@ impl VdafConfig {
EarlyReportStateInitialized::Rejected {
metadata: _,
failure,
} => {
metrics.report_inc_by(&format!("rejected_{failure}"), 1);
TransitionVar::Failed(*failure)
}
} => TransitionVar::Failed(*failure),
},
};

Expand Down Expand Up @@ -2094,10 +2088,6 @@ mod test {
agg_job_resp.transitions[0].var,
TransitionVar::Failed(TransitionFailure::HpkeDecryptError)
);

assert_metrics_include!(t.helper_registry, {
r#"report_counter{env="test_helper",host="helper.org",status="rejected_hpke_decrypt_error"}"#: 1,
});
}

async_test_versions! { handle_agg_job_init_req_hpke_decrypt_err }
Expand All @@ -2123,10 +2113,6 @@ mod test {
agg_job_resp.transitions[0].var,
TransitionVar::Failed(TransitionFailure::HpkeUnknownConfigId)
);

assert_metrics_include!(t.helper_registry, {
r#"report_counter{env="test_helper",host="helper.org",status="rejected_hpke_unknown_config_id"}"#: 1,
});
}

async_test_versions! { handle_agg_job_init_req_hpke_unknown_config_id }
Expand Down Expand Up @@ -2177,10 +2163,6 @@ mod test {
agg_job_resp.transitions[1].var,
TransitionVar::Failed(TransitionFailure::VdafPrepError)
);

assert_metrics_include!(t.helper_registry, {
r#"report_counter{env="test_helper",host="helper.org",status="rejected_vdaf_prep_error"}"#: 2,
});
}

async_test_versions! { handle_agg_job_init_req_vdaf_prep_error }
Expand Down

0 comments on commit 8231f9e

Please sign in to comment.