diff --git a/daphne/src/roles/aggregator.rs b/daphne/src/roles/aggregator.rs index cb84bb1dc..3fe391c05 100644 --- a/daphne/src/roles/aggregator.rs +++ b/daphne/src/roles/aggregator.rs @@ -25,8 +25,7 @@ use crate::{ /// job. #[async_trait(?Send)] pub trait DapReportInitializer { - /// Initialize a sequence of reports that are in the "consumed" state by performing the early - /// validation steps (belongs to a batch that has been collected) and initializing VDAF + /// Initialize a sequence of reports that are in the "consumed" state by initializing VDAF /// preparation. async fn initialize_reports<'req>( &self, @@ -124,9 +123,12 @@ pub trait DapAggregator: HpkeDecrypter + DapReportInitializer + Sized { /// /// A span with the same buckets as the input `agg_share_span` where the value is one of 3 /// possible sets of values: - /// - `Ok(None)` if all went well and no reports were replays. - /// - `Ok(Some(set))` if at least one report was a replay. This also means no aggregate shares where merged. - /// - `Err(err)` if an error occurred. + /// - `Ok(())` if all went well and no reports were replays. + /// - `Err(MergeAggShareError::ReplaysDetected)` if at least one report was a replay. This also + /// means no aggregate shares where merged. + /// - `Err(MergeAggShareError::AlreadyCollected)` This span belong to an aggregate share that + /// has been collected. + /// - `Err(MergeAggShareError::Other)` if another unrecoverable error occurred. async fn try_put_agg_share_span( &self, task_id: &TaskId, diff --git a/daphne/src/roles/helper.rs b/daphne/src/roles/helper.rs index 2435eaebc..4487e61db 100644 --- a/daphne/src/roles/helper.rs +++ b/daphne/src/roles/helper.rs @@ -184,6 +184,12 @@ pub trait DapHelper: DapAggregator { ) .await?; + for transition in &agg_job_resp.transitions { + if let TransitionVar::Failed(failure) = &transition.var { + metrics.report_inc_by(&format!("rejected_{failure}"), 1); + } + } + metrics.agg_job_started_inc(); metrics.agg_job_completed_inc(); agg_job_resp @@ -518,7 +524,7 @@ async fn finish_agg_job_and_aggregate( })); inc_restart_metric.call_once(|| metrics.agg_job_put_span_retry_inc()); } - // This bucket belongs to a collected aggregate share. + // This bucket is contained by an aggregate share span that has been collected. (Err(MergeAggShareError::AlreadyCollected), reports) => { report_status.extend(reports.into_iter().map(|(report_id, _)| { ( @@ -592,7 +598,7 @@ mod tests { .map(|r| r.report_metadata.id) .collect::>(); - let req = test + let (_, req) = test .gen_test_agg_job_init_req(&task_id, DapVersion::Draft02, reports) .await; diff --git a/daphne/src/roles/leader.rs b/daphne/src/roles/leader.rs index fe9fec314..608cc0d3e 100644 --- a/daphne/src/roles/leader.rs +++ b/daphne/src/roles/leader.rs @@ -435,31 +435,39 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { let out_shares_count = agg_span.report_count() as u64; - // At this point we're committed to aggregating the reports: if we do detect a report was - // replayed at this stage, then we may end up with a batch mismatch. However, this should - // only happen if there are multiple aggregation jobs in-flight that include the same - // report. - let replayed = self + // At this point we're committed to aggregating the reports: if we do detect an error (a + // report was replayed at this stage or the span overlaps with a collected batch), then we + // may end up with a batch mismatch. However, this should only happen if there are multiple + // aggregation jobs in-flight that include the same report. + let (replayed, collected) = self .try_put_agg_share_span(task_id, task_config, agg_span) .await .into_iter() .map(|(_bucket, (result, _report_metadata))| match result { - Ok(()) => Ok(0), - Err(MergeAggShareError::AlreadyCollected) => { - panic!("aggregated to a collected agg share") - } - Err(MergeAggShareError::ReplaysDetected(replays)) => Ok(replays.len()), + Ok(()) => Ok((0, 0)), + Err(MergeAggShareError::AlreadyCollected) => Ok((0, 1)), + Err(MergeAggShareError::ReplaysDetected(replays)) => Ok((replays.len(), 0)), Err(MergeAggShareError::Other(e)) => Err(e), }) - .sum::>()?; + .try_fold((0, 0), |(replayed, collected), rc| { + let (r, c) = rc?; + Ok::<_, DapError>((replayed + r, collected + c)) + })?; if replayed > 0 { - tracing::warn!( + tracing::error!( replay_count = replayed, "tried to aggregate replayed reports" ); } + if collected > 0 { + tracing::error!( + collected_count = collected, + "tried to aggregate reports belonging to collected spans" + ); + } + metrics.report_inc_by("aggregated", out_shares_count); Ok(out_shares_count) } diff --git a/daphne/src/roles/mod.rs b/daphne/src/roles/mod.rs index 9bd1153ee..638383236 100644 --- a/daphne/src/roles/mod.rs +++ b/daphne/src/roles/mod.rs @@ -170,13 +170,14 @@ mod test { test_versions, testing::{AggStore, MockAggregator, MockAggregatorReportSelector}, vdaf::VdafVerifyKey, - DapAbort, DapAggregateShare, DapBatchBucket, DapCollectJob, DapError, DapGlobalConfig, - DapLeaderAggregationJobTransition, DapMeasurement, DapQueryConfig, DapRequest, DapResource, - DapTaskConfig, DapVersion, MetaAggregationJobId, Prio3Config, VdafConfig, + DapAbort, DapAggregateShare, DapAggregationJobState, DapBatchBucket, DapCollectJob, + DapError, DapGlobalConfig, DapLeaderAggregationJobTransition, DapMeasurement, + DapQueryConfig, DapRequest, DapResource, DapTaskConfig, DapVersion, MetaAggregationJobId, + Prio3Config, VdafConfig, }; use assert_matches::assert_matches; use matchit::Router; - use prio::codec::{Decode, ParameterizedEncode}; + use prio::codec::{Decode, ParameterizedDecode, ParameterizedEncode}; use rand::{thread_rng, Rng}; use std::{collections::HashMap, sync::Arc, time::SystemTime, vec}; use url::Url; @@ -497,7 +498,7 @@ mod test { task_id: &TaskId, version: DapVersion, reports: Vec, - ) -> DapRequest { + ) -> (DapAggregationJobState, DapRequest) { let mut rng = thread_rng(); let task_config = self.leader.unchecked_get_task_config(task_id).await; let part_batch_sel = match task_config.query { @@ -509,7 +510,7 @@ mod test { let agg_job_id = MetaAggregationJobId::gen_for_version(version); - let DapLeaderAggregationJobTransition::Continued(_leader_state, agg_job_init_req) = + let DapLeaderAggregationJobTransition::Continued(leader_state, agg_job_init_req) = task_config .vdaf .produce_agg_job_init_req( @@ -528,15 +529,18 @@ mod test { panic!("unexpected transition"); }; - self.leader_authorized_req( - task_id, - &task_config, - Some(&agg_job_id), - DapMediaType::AggregationJobInitReq, - agg_job_init_req, - task_config.helper_url.join("aggregate").unwrap(), + ( + leader_state, + self.leader_authorized_req( + task_id, + &task_config, + Some(&agg_job_id), + DapMediaType::AggregationJobInitReq, + agg_job_init_req, + task_config.helper_url.join("aggregate").unwrap(), + ) + .await, ) - .await } pub async fn gen_test_agg_job_cont_req_with_round( @@ -818,7 +822,7 @@ mod test { async fn handle_agg_job_init_req_unauthorized_request(version: DapVersion) { let t = Test::new(version); let report = t.gen_test_report(&t.time_interval_task_id).await; - let mut req = t + let (_, mut req) = t .gen_test_agg_job_init_req(&t.time_interval_task_id, version, vec![report]) .await; req.sender_auth = None; @@ -1057,7 +1061,7 @@ mod test { let mut report = t.gen_test_report(task_id).await; report.encrypted_input_shares[1].payload[0] ^= 0xff; // Cause decryption to fail - let req = t + let (_, req) = t .gen_test_agg_job_init_req(task_id, version, vec![report]) .await; @@ -1082,7 +1086,7 @@ mod test { let task_id = &t.time_interval_task_id; let report = t.gen_test_report(task_id).await; - let req = t + let (_, req) = t .gen_test_agg_job_init_req(task_id, version, vec![report]) .await; @@ -1099,13 +1103,96 @@ mod test { async_test_versions! { handle_agg_job_req_transition_continue } + async fn handle_agg_job_req_failure_report_replayed(version: DapVersion) { + let t = Test::new(version); + let task_id = &t.time_interval_task_id; + + let report = t.gen_test_report(task_id).await; + let (leader_state, req) = t + .gen_test_agg_job_init_req(task_id, version, vec![report.clone()]) + .await; + + // Add dummy data to report store backend. This is done in a new scope so that the lock on the + // report store is released before running the test. + { + let mut guard = t + .helper + .report_store + .lock() + .expect("report_store: failed to lock"); + let report_store = guard.entry(*task_id).or_default(); + report_store.processed.insert(report.report_metadata.id); + } + + // Get AggregationJobResp and then extract the transition data from inside. + let agg_job_resp = AggregationJobResp::get_decoded( + &t.helper.handle_agg_job_req(&req).await.unwrap().payload, + ) + .unwrap(); + let transitions = if version == DapVersion::Draft02 { + // in version 2 replays are only detected later. + let agg_job_id = + AggregationJobInitReq::get_decoded_with_param(&DapVersion::Draft02, &req.payload) + .unwrap() + .draft02_agg_job_id + .unwrap(); + let task_config = t.leader.unchecked_get_task_config(task_id).await; + let transition = task_config + .vdaf + .handle_agg_job_resp( + task_id, + &task_config, + &MetaAggregationJobId::Draft02(agg_job_id), + leader_state, + agg_job_resp, + t.leader.metrics(), + ) + .unwrap(); + let DapLeaderAggregationJobTransition::Uncommitted( + _, + AggregationJobContinueReq { transitions, .. }, + ) = transition + else { + panic!("expected uncommitted transition, was {transition:?}"); + }; + let req = t + .gen_test_agg_job_cont_req( + &MetaAggregationJobId::Draft02(agg_job_id), + transitions, + version, + ) + .await; + AggregationJobResp::get_decoded( + &t.helper.handle_agg_job_req(&req).await.unwrap().payload, + ) + .unwrap() + .transitions + } else { + agg_job_resp.transitions + }; + + // Expect failure due to report store marked as collected. + assert_matches!( + transitions[0].var, + TransitionVar::Failed(TransitionFailure::ReportReplayed) + ); + + assert_metrics_include!(t.helper_registry, { + r#"report_counter{env="test_helper",host="helper.org",status="rejected_report_replayed"}"#: 1, + r#"inbound_request_counter{env="test_helper",host="helper.org",type="aggregate"}"#: if version == DapVersion::Draft02 { 2 } else { 1 }, + r#"aggregation_job_counter{env="test_helper",host="helper.org",status="started"}"#: 1, + }); + } + + async_test_versions! { handle_agg_job_req_failure_report_replayed } + async fn handle_agg_job_req_failure_batch_collected(version: DapVersion) { let t = Test::new(version); let task_id = &t.time_interval_task_id; let task_config = t.helper.unchecked_get_task_config(task_id).await; let report = t.gen_test_report(task_id).await; - let req = t + let (_, req) = t .gen_test_agg_job_init_req(task_id, version, vec![report]) .await; @@ -1160,7 +1247,7 @@ mod test { let task_id = &t.time_interval_task_id; let report = t.gen_test_report(task_id).await; - let req = t + let (_, req) = t .gen_test_agg_job_init_req(task_id, DapVersion::Draft02, vec![report]) .await; diff --git a/daphne/src/testing.rs b/daphne/src/testing.rs index 4323b574e..4e30028cd 100644 --- a/daphne/src/testing.rs +++ b/daphne/src/testing.rs @@ -723,7 +723,6 @@ impl MockAggregator { /// Conducts checks on a received report to see whether: /// 1) the report falls into a batch that has been already collected, or - /// 2) the report has been submitted by the client in the past. fn check_report_has_been_collected( &self, task_id: &TaskId, diff --git a/daphne/src/vdaf/mod.rs b/daphne/src/vdaf/mod.rs index c1b2624cb..9616fbf41 100644 --- a/daphne/src/vdaf/mod.rs +++ b/daphne/src/vdaf/mod.rs @@ -964,7 +964,6 @@ impl VdafConfig { report_status, initialized_reports, agg_job_init_req, - metrics, ), } } @@ -1034,7 +1033,6 @@ impl VdafConfig { report_status: &HashMap, initialized_reports: &[EarlyReportStateInitialized<'_>], agg_job_init_req: &AggregationJobInitReq, - metrics: &DaphneMetrics, ) -> Result, DapError> { let num_reports = agg_job_init_req.prep_inits.len(); let mut agg_span = DapAggregateSpan::default(); @@ -1106,7 +1104,6 @@ impl VdafConfig { Err(VdafError::Codec(..) | VdafError::Vdaf(..)) => { let failure = TransitionFailure::VdafPrepError; - metrics.report_inc_by(&format!("rejected_{failure}"), 1); TransitionVar::Failed(failure) } } @@ -1116,7 +1113,6 @@ impl VdafConfig { metadata: _, failure, } => { - metrics.report_inc_by(&format!("rejected_{failure}"), 1); TransitionVar::Failed(*failure) } }, @@ -2094,10 +2090,6 @@ mod test { agg_job_resp.transitions[0].var, TransitionVar::Failed(TransitionFailure::HpkeDecryptError) ); - - assert_metrics_include!(t.helper_registry, { - r#"report_counter{env="test_helper",host="helper.org",status="rejected_hpke_decrypt_error"}"#: 1, - }); } async_test_versions! { handle_agg_job_init_req_hpke_decrypt_err } @@ -2123,10 +2115,6 @@ mod test { agg_job_resp.transitions[0].var, TransitionVar::Failed(TransitionFailure::HpkeUnknownConfigId) ); - - assert_metrics_include!(t.helper_registry, { - r#"report_counter{env="test_helper",host="helper.org",status="rejected_hpke_unknown_config_id"}"#: 1, - }); } async_test_versions! { handle_agg_job_init_req_hpke_unknown_config_id } @@ -2177,10 +2165,6 @@ mod test { agg_job_resp.transitions[1].var, TransitionVar::Failed(TransitionFailure::VdafPrepError) ); - - assert_metrics_include!(t.helper_registry, { - r#"report_counter{env="test_helper",host="helper.org",status="rejected_vdaf_prep_error"}"#: 2, - }); } async_test_versions! { handle_agg_job_init_req_vdaf_prep_error }