diff --git a/daphne/src/roles/aggregator.rs b/daphne/src/roles/aggregator.rs index 0e528883d..cb84bb1dc 100644 --- a/daphne/src/roles/aggregator.rs +++ b/daphne/src/roles/aggregator.rs @@ -38,6 +38,13 @@ pub trait DapReportInitializer { ) -> Result>, DapError>; } +#[derive(Debug)] +pub enum MergeAggShareError { + AlreadyCollected, + ReplaysDetected(HashSet), + Other(DapError), +} + /// DAP Aggregator functionality. #[async_trait(?Send)] pub trait DapAggregator: HpkeDecrypter + DapReportInitializer + Sized { @@ -125,7 +132,7 @@ pub trait DapAggregator: HpkeDecrypter + DapReportInitializer + Sized { task_id: &TaskId, task_config: &DapTaskConfig, agg_share_span: DapAggregateSpan, - ) -> DapAggregateSpan, DapError>>; + ) -> DapAggregateSpan>; /// Fetch the aggregate share for the given batch. async fn get_agg_share( diff --git a/daphne/src/roles/helper.rs b/daphne/src/roles/helper.rs index 6d39a26aa..2435eaebc 100644 --- a/daphne/src/roles/helper.rs +++ b/daphne/src/roles/helper.rs @@ -19,6 +19,7 @@ use crate::{ ReportId, TaskId, TransitionFailure, TransitionVar, }, metrics::{DaphneMetrics, DaphneRequestType}, + roles::aggregator::MergeAggShareError, vdaf::ReportProcessedStatus, DapAggregateShare, DapAggregateSpan, DapAggregationJobState, DapError, DapHelperAggregationJobTransition, DapRequest, DapResource, DapResponse, DapTaskConfig, @@ -498,7 +499,7 @@ async fn finish_agg_job_and_aggregate( for (_bucket, result) in put_shares_result { match result { // This bucket had no replays. - (Ok(replays), reports) if replays.is_empty() => { + (Ok(()), reports) => { // Every report in the bucket has been committed to aggregate storage. report_status.extend( reports.into_iter().map(|(report_id, _time)| { @@ -507,7 +508,7 @@ async fn finish_agg_job_and_aggregate( ); } // This bucket had replays. - (Ok(replays), _reports) => { + (Err(MergeAggShareError::ReplaysDetected(replays)), _reports) => { // At least one report was replayed (no change to aggregate storage). report_status.extend(replays.into_iter().map(|report_id| { ( @@ -517,6 +518,16 @@ async fn finish_agg_job_and_aggregate( })); inc_restart_metric.call_once(|| metrics.agg_job_put_span_retry_inc()); } + // This bucket belongs to a collected aggregate share. + (Err(MergeAggShareError::AlreadyCollected), reports) => { + report_status.extend(reports.into_iter().map(|(report_id, _)| { + ( + report_id, + ReportProcessedStatus::Rejected(TransitionFailure::BatchCollected), + ) + })); + inc_restart_metric.call_once(|| metrics.agg_job_put_span_retry_inc()); + } // If this happens, the leader and helper can possibly have inconsistent state. // The leader will still think all of the reports in this job have yet to be // aggregated. But we could have aggregated some and not others due to the @@ -526,7 +537,7 @@ async fn finish_agg_job_and_aggregate( // and if this error doesn't manifest itself all reports will be successfully // aggregated. Which means that no reports will be lost in a such a state that // they can never be aggregated. - (Err(e), _) => return Err(e), + (Err(MergeAggShareError::Other(other)), _) => return Err(other), } } if !inc_restart_metric.is_completed() { diff --git a/daphne/src/roles/leader.rs b/daphne/src/roles/leader.rs index de2bdff94..fe9fec314 100644 --- a/daphne/src/roles/leader.rs +++ b/daphne/src/roles/leader.rs @@ -8,7 +8,10 @@ use prio::codec::{Decode, ParameterizedDecode, ParameterizedEncode}; use tracing::{debug, error}; use url::Url; -use super::{check_batch, check_request_content_type, resolve_taskprov, DapAggregator}; +use super::{ + aggregator::MergeAggShareError, check_batch, check_request_content_type, resolve_taskprov, + DapAggregator, +}; use crate::{ constants::DapMediaType, error::DapAbort, @@ -440,8 +443,13 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { .try_put_agg_share_span(task_id, task_config, agg_span) .await .into_iter() - .map(|(_bucket, (result, _report_metadata))| { - result.map(|replayed_reports| replayed_reports.len()) + .map(|(_bucket, (result, _report_metadata))| match result { + Ok(()) => Ok(0), + Err(MergeAggShareError::AlreadyCollected) => { + panic!("aggregated to a collected agg share") + } + Err(MergeAggShareError::ReplaysDetected(replays)) => Ok(replays.len()), + Err(MergeAggShareError::Other(e)) => Err(e), }) .sum::>()?; diff --git a/daphne/src/roles/mod.rs b/daphne/src/roles/mod.rs index 6c75f3d68..9bd1153ee 100644 --- a/daphne/src/roles/mod.rs +++ b/daphne/src/roles/mod.rs @@ -3,7 +3,7 @@ //! Trait definitions for Daphne backends. -mod aggregator; +pub mod aggregator; mod helper; mod leader; diff --git a/daphne/src/testing.rs b/daphne/src/testing.rs index 868aa1b5b..4323b574e 100644 --- a/daphne/src/testing.rs +++ b/daphne/src/testing.rs @@ -16,7 +16,10 @@ use crate::{ TaskId, Time, TransitionFailure, }, metrics::DaphneMetrics, - roles::{DapAggregator, DapAuthorizedSender, DapHelper, DapLeader, DapReportInitializer}, + roles::{ + aggregator::MergeAggShareError, DapAggregator, DapAuthorizedSender, DapHelper, DapLeader, + DapReportInitializer, + }, vdaf::{EarlyReportState, EarlyReportStateConsumed, EarlyReportStateInitialized}, DapAbort, DapAggregateResult, DapAggregateShare, DapAggregateSpan, DapAggregationJobState, DapAggregationJobUncommitted, DapBatchBucket, DapCollectJob, DapError, DapGlobalConfig, @@ -1068,7 +1071,7 @@ impl DapAggregator for MockAggregator { task_id: &TaskId, _task_config: &DapTaskConfig, agg_agg_span: DapAggregateSpan, - ) -> DapAggregateSpan, DapError>> { + ) -> DapAggregateSpan> { let mut report_store_guard = self .report_store .lock() @@ -1079,7 +1082,7 @@ impl DapAggregator for MockAggregator { agg_agg_span .into_iter() - .map(|(bucket, (agg_share, report_metadatas))| { + .map(|(bucket, (agg_share_delta, report_metadatas))| { let replayed = report_metadatas .iter() .map(|(id, _)| *id) @@ -1091,14 +1094,17 @@ impl DapAggregator for MockAggregator { .processed .extend(report_metadatas.iter().map(|(id, _)| *id)); // Add to aggregate share. - agg_store - .entry(bucket.clone()) - .or_default() - .agg_share - .merge(agg_share.clone()) - .map(|()| HashSet::new()) + let agg_share = agg_store.entry(bucket.clone()).or_default(); + if agg_share.collected { + Err(MergeAggShareError::AlreadyCollected) + } else { + agg_share + .agg_share + .merge(agg_share_delta.clone()) + .map_err(MergeAggShareError::Other) + } } else { - Ok(replayed) + Err(MergeAggShareError::ReplaysDetected(replayed)) }; (bucket, (result, report_metadatas)) }) diff --git a/daphne_worker/src/durable/aggregate_store.rs b/daphne_worker/src/durable/aggregate_store.rs index ea725ad34..672024950 100644 --- a/daphne_worker/src/durable/aggregate_store.rs +++ b/daphne_worker/src/durable/aggregate_store.rs @@ -78,6 +78,9 @@ const MAX_CHUNK_SIZE: usize = 128_000; /// Key used to store metadata under. const METADATA_KEY: &str = "meta"; +/// Key used to store where this share has been collected +const COLLECTED_KEY: &str = "collected"; + #[derive(Debug, Serialize, Deserialize, Clone, Copy)] #[serde(rename_all = "snake_case")] enum VdafKind { @@ -168,7 +171,7 @@ impl AggregateStore { return Ok(DapAggregateShare::default()); } - let meta_key = JsValue::from_str("meta"); + let meta_key = JsValue::from_str(METADATA_KEY); let meta = serde_wasm_bindgen::from_value::(values.get(&meta_key)) .unwrap_or_else(|e| { @@ -266,6 +269,13 @@ fn shard_bytes_to_object( Ok(()) } +#[derive(Debug, Serialize, Deserialize)] +pub enum AggregateStoreMergeResp { + Ok, + ReplaysDetected(HashSet), + AlreadyCollected, +} + #[durable_object] impl DurableObject for AggregateStore { fn new(state: State, env: Env) -> Self { @@ -288,6 +298,16 @@ impl DurableObject for AggregateStore { } impl AggregateStore { + async fn is_collected(&mut self) -> Result { + Ok(if let Some(collected) = self.collected { + collected + } else { + let collected = state_get_or_default(&self.state, COLLECTED_KEY).await?; + self.collected = Some(collected); + collected + }) + } + async fn handle(&mut self, req: Request) -> Result { let mut req = match self .schedule_for_garbage_collection(req, BINDING_DAP_AGGREGATE_STORE) @@ -315,15 +335,22 @@ impl AggregateStore { let chunks_map = js_sys::Object::default(); + if self.is_collected().await? { + return Response::from_json(&AggregateStoreMergeResp::AlreadyCollected); + } + { // check for replays let mut merged_report_ids = self.load_aggregated_report_ids().await?; let repeat_ids = contained_reports .iter() .filter(|id| merged_report_ids.contains(id)) - .collect::>(); + .copied() + .collect::>(); if !repeat_ids.is_empty() { - return Response::from_json(&repeat_ids); + return Response::from_json(&AggregateStoreMergeResp::ReplaysDetected( + repeat_ids, + )); } merged_report_ids.extend(contained_reports); let mut as_bytes = @@ -353,7 +380,7 @@ impl AggregateStore { self.state.storage().put_multiple_raw(chunks_map).await?; - Response::from_json::<[ReportId; 0]>(&[]) + Response::from_json(&AggregateStoreMergeResp::Ok) } // Get the current aggregate share. @@ -370,7 +397,7 @@ impl AggregateStore { // Non-idempotent (do not retry) // Output: `()` (DURABLE_AGGREGATE_STORE_MARK_COLLECTED, Method::Post) => { - self.state.storage().put("collected", true).await?; + self.state.storage().put(COLLECTED_KEY, true).await?; self.collected = Some(true); Response::from_json(&()) } @@ -380,14 +407,7 @@ impl AggregateStore { // Idempotent // Output: `bool` (DURABLE_AGGREGATE_STORE_CHECK_COLLECTED, Method::Get) => { - let collected = if let Some(collected) = self.collected { - collected - } else { - let collected = state_get_or_default(&self.state, "collected").await?; - self.collected = Some(collected); - collected - }; - Response::from_json(&collected) + Response::from_json(&self.is_collected().await?) } _ => Err(int_err(format!( diff --git a/daphne_worker/src/roles/aggregator.rs b/daphne_worker/src/roles/aggregator.rs index 516ab90f7..57a43aba7 100644 --- a/daphne_worker/src/roles/aggregator.rs +++ b/daphne_worker/src/roles/aggregator.rs @@ -8,9 +8,9 @@ use crate::{ config::{DapTaskConfigKvPair, DaphneWorker}, durable::{ aggregate_store::{ - AggregateStoreMergeReq, DURABLE_AGGREGATE_STORE_CHECK_COLLECTED, - DURABLE_AGGREGATE_STORE_GET, DURABLE_AGGREGATE_STORE_MARK_COLLECTED, - DURABLE_AGGREGATE_STORE_MERGE, + AggregateStoreMergeReq, AggregateStoreMergeResp, + DURABLE_AGGREGATE_STORE_CHECK_COLLECTED, DURABLE_AGGREGATE_STORE_GET, + DURABLE_AGGREGATE_STORE_MARK_COLLECTED, DURABLE_AGGREGATE_STORE_MERGE, }, durable_name_agg_store, BINDING_DAP_AGGREGATE_STORE, }, @@ -22,52 +22,25 @@ use daphne::{ auth::BearerTokenProvider, fatal_error, hpke::HpkeConfig, - messages::{BatchId, BatchSelector, PartialBatchSelector, ReportId, TaskId, TransitionFailure}, + messages::{BatchId, BatchSelector, PartialBatchSelector, TaskId, TransitionFailure}, metrics::DaphneMetrics, - roles::{DapAggregator, DapReportInitializer}, + roles::{aggregator::MergeAggShareError, DapAggregator, DapReportInitializer}, vdaf::{EarlyReportState, EarlyReportStateConsumed, EarlyReportStateInitialized}, DapAggregateShare, DapAggregateSpan, DapBatchBucket, DapError, DapGlobalConfig, DapRequest, DapSender, DapTaskConfig, }; -use futures::{future::try_join_all, StreamExt, TryFutureExt, TryStreamExt}; -use std::collections::{HashMap, HashSet}; +use futures::{future::try_join_all, StreamExt}; #[async_trait(?Send)] impl DapReportInitializer for DaphneWorker<'_> { async fn initialize_reports<'req>( &self, is_leader: bool, - task_id: &TaskId, + _task_id: &TaskId, task_config: &DapTaskConfig, - part_batch_sel: &PartialBatchSelector, + _part_batch_sel: &PartialBatchSelector, consumed_reports: Vec>, ) -> Result>, DapError> { - let durable = self.durable(); - let span = task_config - .as_ref() - .batch_span_for_meta(part_batch_sel, consumed_reports.iter())?; - let collected_reports = { - let task_id_hex = task_id.to_hex(); - - // Send AggregateStore requests. - futures::stream::iter(span.iter()) - .map(|(bucket, _)| { - let durable_name = - durable_name_agg_store(task_config.version, &task_id_hex, bucket); - durable - .get( - BINDING_DAP_AGGREGATE_STORE, - DURABLE_AGGREGATE_STORE_CHECK_COLLECTED, - durable_name, - ) - .map_ok(move |collected| (bucket, collected)) - }) - .buffer_unordered(usize::MAX) - .try_collect::>() - .map_err(|e| fatal_error!(err = ?e, "failed to check collected")) - .await? - }; - let min_time = self.least_valid_report_time(self.get_current_time()); let max_time = self.greatest_valid_report_time(self.get_current_time()); @@ -82,11 +55,6 @@ impl DapReportInitializer for DaphneWorker<'_> { } else if metadata.time > max_time { consumed_report .into_initialized_rejected_due_to(TransitionFailure::ReportTooEarly) - } else if collected_reports - [&task_config.bucket_for(part_batch_sel, &consumed_report)] - { - consumed_report - .into_initialized_rejected_due_to(TransitionFailure::BatchCollected) } else { EarlyReportStateInitialized::initialize( is_leader, @@ -345,7 +313,7 @@ impl<'srv> DapAggregator for DaphneWorker<'srv> { task_id: &TaskId, task_config: &DapTaskConfig, agg_share_span: DapAggregateSpan, - ) -> DapAggregateSpan, DapError>> { + ) -> DapAggregateSpan> { let task_id_hex = task_id.to_hex(); let durable = self.durable().with_retry(); @@ -354,7 +322,7 @@ impl<'srv> DapAggregator for DaphneWorker<'srv> { let agg_store_name = durable_name_agg_store(task_config.version, &task_id_hex, &bucket); let result = durable - .post::<_, HashSet>( + .post::<_, AggregateStoreMergeResp>( BINDING_DAP_AGGREGATE_STORE, DURABLE_AGGREGATE_STORE_MERGE, agg_store_name, @@ -365,6 +333,16 @@ impl<'srv> DapAggregator for DaphneWorker<'srv> { ) .await .map_err(|e| fatal_error!(err = ?e)); + let result = match result { + Ok(AggregateStoreMergeResp::Ok) => Ok(()), + Ok(AggregateStoreMergeResp::AlreadyCollected) => { + Err(MergeAggShareError::AlreadyCollected) + } + Ok(AggregateStoreMergeResp::ReplaysDetected(replays)) => { + Err(MergeAggShareError::ReplaysDetected(replays)) + } + Err(e) => Err(MergeAggShareError::Other(e)), + }; (bucket, (result, report_metadatas)) }) .buffer_unordered(usize::MAX) diff --git a/daphne_worker_test/wrangler.toml b/daphne_worker_test/wrangler.toml index 46508d056..8f2ed019f 100644 --- a/daphne_worker_test/wrangler.toml +++ b/daphne_worker_test/wrangler.toml @@ -64,7 +64,6 @@ bindings = [ { name = "DAP_LEADER_COL_JOB_QUEUE", class_name = "LeaderCollectionJobQueue" }, { name = "DAP_GARBAGE_COLLECTOR", class_name = "GarbageCollector" }, { name = "DAP_REPORTS_PENDING", class_name = "ReportsPending" }, - { name = "DAP_REPORTS_PROCESSED", class_name = "ReportsProcessed" }, ] @@ -163,5 +162,4 @@ new_classes = [ "LeaderCollectionJobQueue", "GarbageCollector", "ReportsPending", - "ReportsProcessed", ] diff --git a/docker/wrangler.toml b/docker/wrangler.toml index eaa039f9a..feedde4c3 100644 --- a/docker/wrangler.toml +++ b/docker/wrangler.toml @@ -64,7 +64,6 @@ bindings = [ { name = "DAP_LEADER_COL_JOB_QUEUE", class_name = "LeaderCollectionJobQueue" }, { name = "DAP_GARBAGE_COLLECTOR", class_name = "GarbageCollector" }, { name = "DAP_REPORTS_PENDING", class_name = "ReportsPending" }, - { name = "DAP_REPORTS_PROCESSED", class_name = "ReportsProcessed" }, ] @@ -163,5 +162,4 @@ new_classes = [ "LeaderCollectionJobQueue", "GarbageCollector", "ReportsPending", - "ReportsProcessed", ]