diff --git a/daphne/src/roles/leader.rs b/daphne/src/roles/leader.rs index bae42fded..a67560555 100644 --- a/daphne/src/roles/leader.rs +++ b/daphne/src/roles/leader.rs @@ -1,7 +1,7 @@ // Copyright (c) 2023 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use async_trait::async_trait; use prio::codec::{Decode, ParameterizedDecode, ParameterizedEncode}; @@ -15,7 +15,8 @@ use crate::{ fatal_error, messages::{ AggregateShare, AggregateShareReq, AggregationJobResp, BatchSelector, Collection, - CollectionJobId, CollectionReq, Interval, PartialBatchSelector, Query, Report, TaskId, + CollectionJobId, CollectionReq, Interval, PartialBatchSelector, Query, Report, ReportId, + TaskId, Time, }, metrics::DaphneRequestType, DapCollectJob, DapError, DapLeaderAggregationJobTransition, DapLeaderProcessTelemetry, @@ -98,6 +99,17 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { /// Data type used to guide selection of a set of reports for aggregation. type ReportSelector; + /// Register these report ids as received by the leader. Returning all ids that have been + /// registered before. + async fn register_received_reports( + &self, + task_config: &DapTaskConfig, + task_id: &TaskId, + reports: I, + ) -> Result, DapError> + where + I: Iterator; + /// Store a report for use later on. async fn put_report(&self, report: &Report, task_id: &TaskId) -> Result<(), DapError>; @@ -327,7 +339,8 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { metrics, ) .await?; - let (state, agg_job_init_req) = match transition { + + let (mut state, mut agg_job_init_req) = match transition { DapLeaderAggregationJobTransition::Continued(state, agg_job_init_req) => { (state, agg_job_init_req) } @@ -341,6 +354,20 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { return Err(fatal_error!(err = "unexpected state transition (uncommitted)").into()) } }; + { + // registered received reports and reject replays + let replays = self + .register_received_reports( + task_config, + task_id, + state.seq.iter().map(|r| (r.report_id, r.time)), + ) + .await?; + state.seq.retain(|r| !replays.contains(&r.report_id)); + agg_job_init_req + .prep_inits + .retain(|r| !replays.contains(&r.report_share.report_metadata.id)); + } let method = if task_config.version != DapVersion::Draft02 { LeaderHttpRequestMethod::Put } else { diff --git a/daphne/src/testing.rs b/daphne/src/testing.rs index 560f31e5d..9a2b3f171 100644 --- a/daphne/src/testing.rs +++ b/daphne/src/testing.rs @@ -710,11 +710,10 @@ impl MockAggregator { /// Conducts checks on a received report to see whether: /// 1) the report falls into a batch that has been already collected, or /// 2) the report has been submitted by the client in the past. - fn check_report_early_fail( + fn check_report_has_been_collected( &self, task_id: &TaskId, bucket: &DapBatchBucket, - id: &ReportId, ) -> Option { // Check AggStateStore to see whether the report is part of a batch that has already // been collected. @@ -724,16 +723,6 @@ impl MockAggregator { return Some(TransitionFailure::BatchCollected); } - // Check whether the same report has been submitted in the past. - let mut guard = self - .report_store - .lock() - .expect("report_store: failed to lock"); - let report_store = guard.entry(*task_id).or_default(); - if report_store.processed.contains(id) { - return Some(TransitionFailure::ReportReplayed); - } - None } @@ -937,7 +926,8 @@ impl DapReportInitializer for MockAggregator { for (bucket, ((), report_ids_and_time)) in span.iter() { for (id, _) in report_ids_and_time { // Check whether Report has been collected or replayed. - if let Some(transition_failure) = self.check_report_early_fail(task_id, bucket, id) + if let Some(transition_failure) = + self.check_report_has_been_collected(task_id, bucket) { early_fails.insert(*id, transition_failure); }; @@ -1223,27 +1213,47 @@ impl DapHelper for MockAggregator { impl DapLeader for MockAggregator { type ReportSelector = MockAggregatorReportSelector; + async fn register_received_reports( + &self, + _task_config: &DapTaskConfig, + task_id: &TaskId, + reports: I, + ) -> Result, DapError> + where + I: Iterator, + { + let mut report_store = self + .report_store + .lock() + .expect("report_store: failed to lock"); + Ok(reports + .into_iter() + .filter_map(|(id, _)| { + // Check whether the same report has been submitted in the past. + report_store + .entry(*task_id) + .or_default() + .processed + .contains(&id) + .then_some(id) + }) + .collect()) + } + async fn put_report(&self, report: &Report, task_id: &TaskId) -> Result<(), DapError> { let bucket = self .assign_report_to_bucket(report, task_id) .await .expect("could not determine batch for report"); - // Check whether Report has been collected or replayed. - if let Some(transition_failure) = - self.check_report_early_fail(task_id, &bucket, &report.report_metadata.id) - { - return Err(DapError::Transition(transition_failure)); - }; - // Store Report for future processing. let mut guard = self .report_store .lock() .expect("report_store: failed to lock"); let queue = guard - .get_mut(task_id) - .expect("report_store: unrecognized task") + .entry(*task_id) + .or_default() .pending .entry(bucket) .or_default(); diff --git a/daphne_worker/src/roles/leader.rs b/daphne_worker/src/roles/leader.rs index def60fa14..65f9c272b 100644 --- a/daphne_worker/src/roles/leader.rs +++ b/daphne_worker/src/roles/leader.rs @@ -21,8 +21,10 @@ use crate::{ PendingReport, ReportsPendingResult, DURABLE_REPORTS_PENDING_GET, DURABLE_REPORTS_PENDING_PUT, }, + reports_processed::DURABLE_REPORTS_INITIALIZED_REGISTER, BINDING_DAP_LEADER_AGG_JOB_QUEUE, BINDING_DAP_LEADER_BATCH_QUEUE, BINDING_DAP_LEADER_COL_JOB_QUEUE, BINDING_DAP_REPORTS_PENDING, + BINDING_DAP_REPORTS_PROCESSED, }, DaphneWorkerReportSelector, }; @@ -33,14 +35,15 @@ use daphne::{ error::DapAbort, fatal_error, messages::{ - Collection, CollectionJobId, CollectionReq, PartialBatchSelector, Report, TaskId, - TransitionFailure, + Collection, CollectionJobId, CollectionReq, PartialBatchSelector, Report, ReportId, TaskId, + Time, TransitionFailure, }, roles::{DapAuthorizedSender, DapLeader}, DapCollectJob, DapError, DapQueryConfig, DapRequest, DapResponse, DapTaskConfig, }; +use futures::{StreamExt, TryStreamExt}; use prio::codec::{ParameterizedDecode, ParameterizedEncode}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use tracing::debug; #[async_trait(?Send)] @@ -70,6 +73,50 @@ impl DapAuthorizedSender for DaphneWorker<'_> { impl<'srv> DapLeader for DaphneWorker<'srv> { type ReportSelector = DaphneWorkerReportSelector; + async fn register_received_reports( + &self, + task_config: &DapTaskConfig, + task_id: &TaskId, + reports: I, + ) -> Result, DapError> + where + I: Iterator, + { + let task_id_hex = task_id.to_hex(); + let durable = self.durable(); + let mut reports_processed_request_data = HashMap::new(); + for (id, time) in reports { + let durable_name = self.config().durable_name_report_store( + task_config.as_ref(), + &task_id_hex, + &id, + time, + ); + reports_processed_request_data + .entry(durable_name) + .or_insert_with(Vec::new) + .push(id); + } + futures::stream::iter(reports_processed_request_data) + .map(|(durable_name, reports)| async { + durable + .post::<_, HashSet>( + BINDING_DAP_REPORTS_PROCESSED, + DURABLE_REPORTS_INITIALIZED_REGISTER, + durable_name, + reports, + ) + .await + }) + .buffer_unordered(usize::MAX) + .try_fold(HashSet::new(), |mut acc, replays| async { + acc.extend(replays); + Ok(acc) + }) + .await + .map_err(|e| fatal_error!(err = ?e, "checking for replayed reports")) + } + async fn put_report( &self, report: &Report,