Skip to content

Commit

Permalink
Add replay checking back to the leader
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Nov 24, 2023
1 parent 93e1e94 commit 30f35ab
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 28 deletions.
33 changes: 30 additions & 3 deletions daphne/src/roles/leader.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) 2023 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use async_trait::async_trait;
use prio::codec::{Decode, ParameterizedDecode, ParameterizedEncode};
Expand All @@ -15,7 +15,8 @@ use crate::{
fatal_error,
messages::{
AggregateShare, AggregateShareReq, AggregationJobResp, BatchSelector, Collection,
CollectionJobId, CollectionReq, Interval, PartialBatchSelector, Query, Report, TaskId,
CollectionJobId, CollectionReq, Interval, PartialBatchSelector, Query, Report, ReportId,
TaskId, Time,
},
metrics::DaphneRequestType,
DapCollectJob, DapError, DapLeaderAggregationJobTransition, DapLeaderProcessTelemetry,
Expand Down Expand Up @@ -98,6 +99,17 @@ pub trait DapLeader<S>: DapAuthorizedSender<S> + DapAggregator<S> {
/// Data type used to guide selection of a set of reports for aggregation.
type ReportSelector;

/// Register these report ids as received by the leader. Returning all ids that have been
/// registered before.
async fn register_received_reports<I>(
&self,
task_config: &DapTaskConfig,
task_id: &TaskId,
reports: I,
) -> Result<HashSet<ReportId>, DapError>
where
I: Iterator<Item = (ReportId, Time)>;

/// Store a report for use later on.
async fn put_report(&self, report: &Report, task_id: &TaskId) -> Result<(), DapError>;

Expand Down Expand Up @@ -327,7 +339,8 @@ pub trait DapLeader<S>: DapAuthorizedSender<S> + DapAggregator<S> {
metrics,
)
.await?;
let (state, agg_job_init_req) = match transition {

let (mut state, mut agg_job_init_req) = match transition {
DapLeaderAggregationJobTransition::Continued(state, agg_job_init_req) => {
(state, agg_job_init_req)
}
Expand All @@ -341,6 +354,20 @@ pub trait DapLeader<S>: DapAuthorizedSender<S> + DapAggregator<S> {
return Err(fatal_error!(err = "unexpected state transition (uncommitted)").into())
}
};
{
// registered received reports and reject replays
let replays = self
.register_received_reports(
task_config,
task_id,
state.seq.iter().map(|r| (r.report_id, r.time)),
)
.await?;
state.seq.retain(|r| !replays.contains(&r.report_id));
agg_job_init_req
.prep_inits
.retain(|r| !replays.contains(&r.report_share.report_metadata.id));
}
let method = if task_config.version != DapVersion::Draft02 {
LeaderHttpRequestMethod::Put
} else {
Expand Down
54 changes: 32 additions & 22 deletions daphne/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -710,11 +710,10 @@ impl MockAggregator {
/// Conducts checks on a received report to see whether:
/// 1) the report falls into a batch that has been already collected, or
/// 2) the report has been submitted by the client in the past.
fn check_report_early_fail(
fn check_report_has_been_collected(
&self,
task_id: &TaskId,
bucket: &DapBatchBucket,
id: &ReportId,
) -> Option<TransitionFailure> {
// Check AggStateStore to see whether the report is part of a batch that has already
// been collected.
Expand All @@ -724,16 +723,6 @@ impl MockAggregator {
return Some(TransitionFailure::BatchCollected);
}

// Check whether the same report has been submitted in the past.
let mut guard = self
.report_store
.lock()
.expect("report_store: failed to lock");
let report_store = guard.entry(*task_id).or_default();
if report_store.processed.contains(id) {
return Some(TransitionFailure::ReportReplayed);
}

None
}

Expand Down Expand Up @@ -937,7 +926,8 @@ impl DapReportInitializer for MockAggregator {
for (bucket, ((), report_ids_and_time)) in span.iter() {
for (id, _) in report_ids_and_time {
// Check whether Report has been collected or replayed.
if let Some(transition_failure) = self.check_report_early_fail(task_id, bucket, id)
if let Some(transition_failure) =
self.check_report_has_been_collected(task_id, bucket)
{
early_fails.insert(*id, transition_failure);
};
Expand Down Expand Up @@ -1223,27 +1213,47 @@ impl DapHelper<BearerToken> for MockAggregator {
impl DapLeader<BearerToken> for MockAggregator {
type ReportSelector = MockAggregatorReportSelector;

async fn register_received_reports<I>(
&self,
_task_config: &DapTaskConfig,
task_id: &TaskId,
reports: I,
) -> Result<HashSet<ReportId>, DapError>
where
I: Iterator<Item = (ReportId, Time)>,
{
let mut report_store = self
.report_store
.lock()
.expect("report_store: failed to lock");
Ok(reports
.into_iter()
.filter_map(|(id, _)| {
// Check whether the same report has been submitted in the past.
report_store
.entry(*task_id)
.or_default()
.processed
.contains(&id)
.then_some(id)
})
.collect())
}

async fn put_report(&self, report: &Report, task_id: &TaskId) -> Result<(), DapError> {
let bucket = self
.assign_report_to_bucket(report, task_id)
.await
.expect("could not determine batch for report");

// Check whether Report has been collected or replayed.
if let Some(transition_failure) =
self.check_report_early_fail(task_id, &bucket, &report.report_metadata.id)
{
return Err(DapError::Transition(transition_failure));
};

// Store Report for future processing.
let mut guard = self
.report_store
.lock()
.expect("report_store: failed to lock");
let queue = guard
.get_mut(task_id)
.expect("report_store: unrecognized task")
.entry(*task_id)
.or_default()
.pending
.entry(bucket)
.or_default();
Expand Down
53 changes: 50 additions & 3 deletions daphne_worker/src/roles/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ use crate::{
PendingReport, ReportsPendingResult, DURABLE_REPORTS_PENDING_GET,
DURABLE_REPORTS_PENDING_PUT,
},
reports_processed::DURABLE_REPORTS_INITIALIZED_REGISTER,
BINDING_DAP_LEADER_AGG_JOB_QUEUE, BINDING_DAP_LEADER_BATCH_QUEUE,
BINDING_DAP_LEADER_COL_JOB_QUEUE, BINDING_DAP_REPORTS_PENDING,
BINDING_DAP_REPORTS_PROCESSED,
},
DaphneWorkerReportSelector,
};
Expand All @@ -33,14 +35,15 @@ use daphne::{
error::DapAbort,
fatal_error,
messages::{
Collection, CollectionJobId, CollectionReq, PartialBatchSelector, Report, TaskId,
TransitionFailure,
Collection, CollectionJobId, CollectionReq, PartialBatchSelector, Report, ReportId, TaskId,
Time, TransitionFailure,
},
roles::{DapAuthorizedSender, DapLeader},
DapCollectJob, DapError, DapQueryConfig, DapRequest, DapResponse, DapTaskConfig,
};
use futures::{StreamExt, TryStreamExt};
use prio::codec::{ParameterizedDecode, ParameterizedEncode};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use tracing::debug;

#[async_trait(?Send)]
Expand Down Expand Up @@ -70,6 +73,50 @@ impl DapAuthorizedSender<DaphneWorkerAuth> for DaphneWorker<'_> {
impl<'srv> DapLeader<DaphneWorkerAuth> for DaphneWorker<'srv> {
type ReportSelector = DaphneWorkerReportSelector;

async fn register_received_reports<I>(
&self,
task_config: &DapTaskConfig,
task_id: &TaskId,
reports: I,
) -> Result<HashSet<ReportId>, DapError>
where
I: Iterator<Item = (ReportId, Time)>,
{
let task_id_hex = task_id.to_hex();
let durable = self.durable();
let mut reports_processed_request_data = HashMap::new();
for (id, time) in reports {
let durable_name = self.config().durable_name_report_store(
task_config.as_ref(),
&task_id_hex,
&id,
time,
);
reports_processed_request_data
.entry(durable_name)
.or_insert_with(Vec::new)
.push(id);
}
futures::stream::iter(reports_processed_request_data)
.map(|(durable_name, reports)| async {
durable
.post::<_, HashSet<ReportId>>(
BINDING_DAP_REPORTS_PROCESSED,
DURABLE_REPORTS_INITIALIZED_REGISTER,
durable_name,
reports,
)
.await
})
.buffer_unordered(usize::MAX)
.try_fold(HashSet::new(), |mut acc, replays| async {
acc.extend(replays);
Ok(acc)
})
.await
.map_err(|e| fatal_error!(err = ?e, "checking for replayed reports"))
}

async fn put_report(
&self,
report: &Report,
Expand Down

0 comments on commit 30f35ab

Please sign in to comment.