Skip to content

Commit

Permalink
Move collected check to the agg share merge request
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Nov 29, 2023
1 parent 58d8ab7 commit 15f5930
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 77 deletions.
9 changes: 8 additions & 1 deletion daphne/src/roles/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ pub trait DapReportInitializer {
) -> Result<Vec<EarlyReportStateInitialized<'req>>, DapError>;
}

#[derive(Debug)]
pub enum MergeAggShareError {
AlreadyCollected,
ReplaysDetected(HashSet<ReportId>),
Other(DapError),
}

/// DAP Aggregator functionality.
#[async_trait(?Send)]
pub trait DapAggregator<S>: HpkeDecrypter + DapReportInitializer + Sized {
Expand Down Expand Up @@ -125,7 +132,7 @@ pub trait DapAggregator<S>: HpkeDecrypter + DapReportInitializer + Sized {
task_id: &TaskId,
task_config: &DapTaskConfig,
agg_share_span: DapAggregateSpan<DapAggregateShare>,
) -> DapAggregateSpan<Result<HashSet<ReportId>, DapError>>;
) -> DapAggregateSpan<Result<(), MergeAggShareError>>;

/// Fetch the aggregate share for the given batch.
async fn get_agg_share(
Expand Down
17 changes: 14 additions & 3 deletions daphne/src/roles/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::{
ReportId, TaskId, TransitionFailure, TransitionVar,
},
metrics::{DaphneMetrics, DaphneRequestType},
roles::aggregator::MergeAggShareError,
vdaf::ReportProcessedStatus,
DapAggregateShare, DapAggregateSpan, DapAggregationJobState, DapError,
DapHelperAggregationJobTransition, DapRequest, DapResource, DapResponse, DapTaskConfig,
Expand Down Expand Up @@ -498,7 +499,7 @@ async fn finish_agg_job_and_aggregate<S>(
for (_bucket, result) in put_shares_result {
match result {
// This bucket had no replays.
(Ok(replays), reports) if replays.is_empty() => {
(Ok(()), reports) => {
// Every report in the bucket has been committed to aggregate storage.
report_status.extend(
reports.into_iter().map(|(report_id, _time)| {
Expand All @@ -507,7 +508,7 @@ async fn finish_agg_job_and_aggregate<S>(
);
}
// This bucket had replays.
(Ok(replays), _reports) => {
(Err(MergeAggShareError::ReplaysDetected(replays)), _reports) => {
// At least one report was replayed (no change to aggregate storage).
report_status.extend(replays.into_iter().map(|report_id| {
(
Expand All @@ -517,6 +518,16 @@ async fn finish_agg_job_and_aggregate<S>(
}));
inc_restart_metric.call_once(|| metrics.agg_job_put_span_retry_inc());
}
// This bucket belongs to a collected aggregate share.
(Err(MergeAggShareError::AlreadyCollected), reports) => {
report_status.extend(reports.into_iter().map(|(report_id, _)| {
(
report_id,
ReportProcessedStatus::Rejected(TransitionFailure::BatchCollected),
)
}));
inc_restart_metric.call_once(|| metrics.agg_job_put_span_retry_inc());
}
// If this happens, the leader and helper can possibly have inconsistent state.
// The leader will still think all of the reports in this job have yet to be
// aggregated. But we could have aggregated some and not others due to the
Expand All @@ -526,7 +537,7 @@ async fn finish_agg_job_and_aggregate<S>(
// and if this error doesn't manifest itself all reports will be successfully
// aggregated. Which means that no reports will be lost in a such a state that
// they can never be aggregated.
(Err(e), _) => return Err(e),
(Err(MergeAggShareError::Other(other)), _) => return Err(other),
}
}
if !inc_restart_metric.is_completed() {
Expand Down
14 changes: 11 additions & 3 deletions daphne/src/roles/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ use prio::codec::{Decode, ParameterizedDecode, ParameterizedEncode};
use tracing::{debug, error};
use url::Url;

use super::{check_batch, check_request_content_type, resolve_taskprov, DapAggregator};
use super::{
aggregator::MergeAggShareError, check_batch, check_request_content_type, resolve_taskprov,
DapAggregator,
};
use crate::{
constants::DapMediaType,
error::DapAbort,
Expand Down Expand Up @@ -440,8 +443,13 @@ pub trait DapLeader<S>: DapAuthorizedSender<S> + DapAggregator<S> {
.try_put_agg_share_span(task_id, task_config, agg_span)
.await
.into_iter()
.map(|(_bucket, (result, _report_metadata))| {
result.map(|replayed_reports| replayed_reports.len())
.map(|(_bucket, (result, _report_metadata))| match result {
Ok(()) => Ok(0),
Err(MergeAggShareError::AlreadyCollected) => {
panic!("aggregated to a collected agg share")
}
Err(MergeAggShareError::ReplaysDetected(replays)) => Ok(replays.len()),
Err(MergeAggShareError::Other(e)) => Err(e),
})
.sum::<Result<usize, _>>()?;

Expand Down
2 changes: 1 addition & 1 deletion daphne/src/roles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

//! Trait definitions for Daphne backends.
mod aggregator;
pub mod aggregator;
mod helper;
mod leader;

Expand Down
26 changes: 16 additions & 10 deletions daphne/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ use crate::{
TaskId, Time, TransitionFailure,
},
metrics::DaphneMetrics,
roles::{DapAggregator, DapAuthorizedSender, DapHelper, DapLeader, DapReportInitializer},
roles::{
aggregator::MergeAggShareError, DapAggregator, DapAuthorizedSender, DapHelper, DapLeader,
DapReportInitializer,
},
vdaf::{EarlyReportState, EarlyReportStateConsumed, EarlyReportStateInitialized},
DapAbort, DapAggregateResult, DapAggregateShare, DapAggregateSpan, DapAggregationJobState,
DapAggregationJobUncommitted, DapBatchBucket, DapCollectJob, DapError, DapGlobalConfig,
Expand Down Expand Up @@ -1068,7 +1071,7 @@ impl DapAggregator<BearerToken> for MockAggregator {
task_id: &TaskId,
_task_config: &DapTaskConfig,
agg_agg_span: DapAggregateSpan<DapAggregateShare>,
) -> DapAggregateSpan<Result<HashSet<ReportId>, DapError>> {
) -> DapAggregateSpan<Result<(), MergeAggShareError>> {
let mut report_store_guard = self
.report_store
.lock()
Expand All @@ -1079,7 +1082,7 @@ impl DapAggregator<BearerToken> for MockAggregator {

agg_agg_span
.into_iter()
.map(|(bucket, (agg_share, report_metadatas))| {
.map(|(bucket, (agg_share_delta, report_metadatas))| {
let replayed = report_metadatas
.iter()
.map(|(id, _)| *id)
Expand All @@ -1091,14 +1094,17 @@ impl DapAggregator<BearerToken> for MockAggregator {
.processed
.extend(report_metadatas.iter().map(|(id, _)| *id));
// Add to aggregate share.
agg_store
.entry(bucket.clone())
.or_default()
.agg_share
.merge(agg_share.clone())
.map(|()| HashSet::new())
let agg_share = agg_store.entry(bucket.clone()).or_default();
if agg_share.collected {
Err(MergeAggShareError::AlreadyCollected)
} else {
agg_share
.agg_share
.merge(agg_share_delta.clone())
.map_err(MergeAggShareError::Other)
}
} else {
Ok(replayed)
Err(MergeAggShareError::ReplaysDetected(replayed))
};
(bucket, (result, report_metadatas))
})
Expand Down
46 changes: 33 additions & 13 deletions daphne_worker/src/durable/aggregate_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ const MAX_CHUNK_SIZE: usize = 128_000;
/// Key used to store metadata under.
const METADATA_KEY: &str = "meta";

/// Key used to store where this share has been collected
const COLLECTED_KEY: &str = "collected";

#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
#[serde(rename_all = "snake_case")]
enum VdafKind {
Expand Down Expand Up @@ -168,7 +171,7 @@ impl AggregateStore {
return Ok(DapAggregateShare::default());
}

let meta_key = JsValue::from_str("meta");
let meta_key = JsValue::from_str(METADATA_KEY);
let meta =
serde_wasm_bindgen::from_value::<DapAggregateShareMetadata>(values.get(&meta_key))
.unwrap_or_else(|e| {
Expand Down Expand Up @@ -266,6 +269,13 @@ fn shard_bytes_to_object(
Ok(())
}

#[derive(Debug, Serialize, Deserialize)]
pub enum AggregateStoreMergeResp {
Ok,
ReplaysDetected(HashSet<ReportId>),
AlreadyCollected,
}

#[durable_object]
impl DurableObject for AggregateStore {
fn new(state: State, env: Env) -> Self {
Expand All @@ -288,6 +298,16 @@ impl DurableObject for AggregateStore {
}

impl AggregateStore {
async fn is_collected(&mut self) -> Result<bool> {
Ok(if let Some(collected) = self.collected {
collected
} else {
let collected = state_get_or_default(&self.state, COLLECTED_KEY).await?;
self.collected = Some(collected);
collected
})
}

async fn handle(&mut self, req: Request) -> Result<Response> {
let mut req = match self
.schedule_for_garbage_collection(req, BINDING_DAP_AGGREGATE_STORE)
Expand Down Expand Up @@ -315,15 +335,22 @@ impl AggregateStore {

let chunks_map = js_sys::Object::default();

if self.is_collected().await? {
return Response::from_json(&AggregateStoreMergeResp::AlreadyCollected);
}

{
// check for replays
let mut merged_report_ids = self.load_aggregated_report_ids().await?;
let repeat_ids = contained_reports
.iter()
.filter(|id| merged_report_ids.contains(id))
.collect::<Vec<_>>();
.copied()
.collect::<HashSet<_>>();
if !repeat_ids.is_empty() {
return Response::from_json(&repeat_ids);
return Response::from_json(&AggregateStoreMergeResp::ReplaysDetected(
repeat_ids,
));
}
merged_report_ids.extend(contained_reports);
let mut as_bytes =
Expand Down Expand Up @@ -353,7 +380,7 @@ impl AggregateStore {

self.state.storage().put_multiple_raw(chunks_map).await?;

Response::from_json::<[ReportId; 0]>(&[])
Response::from_json(&AggregateStoreMergeResp::Ok)
}

// Get the current aggregate share.
Expand All @@ -370,7 +397,7 @@ impl AggregateStore {
// Non-idempotent (do not retry)
// Output: `()`
(DURABLE_AGGREGATE_STORE_MARK_COLLECTED, Method::Post) => {
self.state.storage().put("collected", true).await?;
self.state.storage().put(COLLECTED_KEY, true).await?;
self.collected = Some(true);
Response::from_json(&())
}
Expand All @@ -380,14 +407,7 @@ impl AggregateStore {
// Idempotent
// Output: `bool`
(DURABLE_AGGREGATE_STORE_CHECK_COLLECTED, Method::Get) => {
let collected = if let Some(collected) = self.collected {
collected
} else {
let collected = state_get_or_default(&self.state, "collected").await?;
self.collected = Some(collected);
collected
};
Response::from_json(&collected)
Response::from_json(&self.is_collected().await?)
}

_ => Err(int_err(format!(
Expand Down
62 changes: 20 additions & 42 deletions daphne_worker/src/roles/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use crate::{
config::{DapTaskConfigKvPair, DaphneWorker},
durable::{
aggregate_store::{
AggregateStoreMergeReq, DURABLE_AGGREGATE_STORE_CHECK_COLLECTED,
DURABLE_AGGREGATE_STORE_GET, DURABLE_AGGREGATE_STORE_MARK_COLLECTED,
DURABLE_AGGREGATE_STORE_MERGE,
AggregateStoreMergeReq, AggregateStoreMergeResp,
DURABLE_AGGREGATE_STORE_CHECK_COLLECTED, DURABLE_AGGREGATE_STORE_GET,
DURABLE_AGGREGATE_STORE_MARK_COLLECTED, DURABLE_AGGREGATE_STORE_MERGE,
},
durable_name_agg_store, BINDING_DAP_AGGREGATE_STORE,
},
Expand All @@ -22,52 +22,25 @@ use daphne::{
auth::BearerTokenProvider,
fatal_error,
hpke::HpkeConfig,
messages::{BatchId, BatchSelector, PartialBatchSelector, ReportId, TaskId, TransitionFailure},
messages::{BatchId, BatchSelector, PartialBatchSelector, TaskId, TransitionFailure},
metrics::DaphneMetrics,
roles::{DapAggregator, DapReportInitializer},
roles::{aggregator::MergeAggShareError, DapAggregator, DapReportInitializer},
vdaf::{EarlyReportState, EarlyReportStateConsumed, EarlyReportStateInitialized},
DapAggregateShare, DapAggregateSpan, DapBatchBucket, DapError, DapGlobalConfig, DapRequest,
DapSender, DapTaskConfig,
};
use futures::{future::try_join_all, StreamExt, TryFutureExt, TryStreamExt};
use std::collections::{HashMap, HashSet};
use futures::{future::try_join_all, StreamExt};

#[async_trait(?Send)]
impl DapReportInitializer for DaphneWorker<'_> {
async fn initialize_reports<'req>(
&self,
is_leader: bool,
task_id: &TaskId,
_task_id: &TaskId,
task_config: &DapTaskConfig,
part_batch_sel: &PartialBatchSelector,
_part_batch_sel: &PartialBatchSelector,
consumed_reports: Vec<EarlyReportStateConsumed<'req>>,
) -> Result<Vec<EarlyReportStateInitialized<'req>>, DapError> {
let durable = self.durable();
let span = task_config
.as_ref()
.batch_span_for_meta(part_batch_sel, consumed_reports.iter())?;
let collected_reports = {
let task_id_hex = task_id.to_hex();

// Send AggregateStore requests.
futures::stream::iter(span.iter())
.map(|(bucket, _)| {
let durable_name =
durable_name_agg_store(task_config.version, &task_id_hex, bucket);
durable
.get(
BINDING_DAP_AGGREGATE_STORE,
DURABLE_AGGREGATE_STORE_CHECK_COLLECTED,
durable_name,
)
.map_ok(move |collected| (bucket, collected))
})
.buffer_unordered(usize::MAX)
.try_collect::<HashMap<&DapBatchBucket, bool>>()
.map_err(|e| fatal_error!(err = ?e, "failed to check collected"))
.await?
};

let min_time = self.least_valid_report_time(self.get_current_time());
let max_time = self.greatest_valid_report_time(self.get_current_time());

Expand All @@ -82,11 +55,6 @@ impl DapReportInitializer for DaphneWorker<'_> {
} else if metadata.time > max_time {
consumed_report
.into_initialized_rejected_due_to(TransitionFailure::ReportTooEarly)
} else if collected_reports
[&task_config.bucket_for(part_batch_sel, &consumed_report)]
{
consumed_report
.into_initialized_rejected_due_to(TransitionFailure::BatchCollected)
} else {
EarlyReportStateInitialized::initialize(
is_leader,
Expand Down Expand Up @@ -345,7 +313,7 @@ impl<'srv> DapAggregator<DaphneWorkerAuth> for DaphneWorker<'srv> {
task_id: &TaskId,
task_config: &DapTaskConfig,
agg_share_span: DapAggregateSpan<DapAggregateShare>,
) -> DapAggregateSpan<Result<HashSet<ReportId>, DapError>> {
) -> DapAggregateSpan<Result<(), MergeAggShareError>> {
let task_id_hex = task_id.to_hex();
let durable = self.durable().with_retry();

Expand All @@ -354,7 +322,7 @@ impl<'srv> DapAggregator<DaphneWorkerAuth> for DaphneWorker<'srv> {
let agg_store_name =
durable_name_agg_store(task_config.version, &task_id_hex, &bucket);
let result = durable
.post::<_, HashSet<ReportId>>(
.post::<_, AggregateStoreMergeResp>(
BINDING_DAP_AGGREGATE_STORE,
DURABLE_AGGREGATE_STORE_MERGE,
agg_store_name,
Expand All @@ -365,6 +333,16 @@ impl<'srv> DapAggregator<DaphneWorkerAuth> for DaphneWorker<'srv> {
)
.await
.map_err(|e| fatal_error!(err = ?e));
let result = match result {
Ok(AggregateStoreMergeResp::Ok) => Ok(()),
Ok(AggregateStoreMergeResp::AlreadyCollected) => {
Err(MergeAggShareError::AlreadyCollected)
}
Ok(AggregateStoreMergeResp::ReplaysDetected(replays)) => {
Err(MergeAggShareError::ReplaysDetected(replays))
}
Err(e) => Err(MergeAggShareError::Other(e)),
};
(bucket, (result, report_metadatas))
})
.buffer_unordered(usize::MAX)
Expand Down
Loading

0 comments on commit 15f5930

Please sign in to comment.