From 52fd456f1eb8edae9d0dc48120fdf898ce4d700e Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Mon, 2 Dec 2024 17:03:23 -0800 Subject: [PATCH] Implement Leader async aggregation. This change includes unit tests, but no integration tests -- those will need to come with the Helper async aggregation implementation, as without it we do not have anything to integration test against. A few implementation notes: * I renamed the report aggregation states to better match their functionality (IMO). * If the Helper does not provide a retry-after header, the Leader will poll each "processing" aggregation job (at most) once per minute. * The retry-after header can specify either a number of seconds, or a specific date. Currently, we only support receiving a number of seconds. --- aggregator/src/aggregator.rs | 207 +- .../aggregator/aggregation_job_continue.rs | 16 +- .../src/aggregator/aggregation_job_creator.rs | 33 +- .../src/aggregator/aggregation_job_driver.rs | 589 +++-- .../aggregation_job_driver/tests.rs | 2244 ++++++++++++++++- aggregator/src/aggregator/batch_creator.rs | 2 +- .../src/aggregator/collection_job_driver.rs | 4 +- .../src/aggregator/garbage_collector.rs | 20 +- .../tests/aggregation_job_continue.rs | 30 +- aggregator/src/aggregator/taskprov_tests.rs | 2 +- aggregator_core/src/datastore.rs | 145 +- aggregator_core/src/datastore/models.rs | 110 +- aggregator_core/src/datastore/tests.rs | 125 +- aggregator_core/src/task.rs | 16 +- db/00000000000001_initial_schema.up.sql | 19 +- 15 files changed, 3116 insertions(+), 446 deletions(-) diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 2f671194e..184f44a34 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -60,7 +60,7 @@ use janus_core::vdaf::Prio3FixedPointBoundedL2VecSumBitSize; use janus_core::{ auth_tokens::AuthenticationToken, hpke::{self, HpkeApplicationInfo, Label}, - retries::retry_http_request_notify, + retries::{retry_http_request_notify, HttpResponse}, time::{Clock, DurationExt, IntervalExt, TimeExt}, vdaf::{ new_prio3_sum_vec_field64_multiproof_hmacsha256_aes128, vdaf_application_context, @@ -1752,105 +1752,106 @@ impl VdafOps { C: Clock, for<'a> A::PrepareState: ParameterizedDecode<(&'a A, usize)>, { - if let Some(existing_aggregation_job) = tx + let existing_aggregation_job = match tx .get_aggregation_job::(task_id, aggregation_job_id) .await? { - if existing_aggregation_job.state() == &AggregationJobState::Deleted { - return Err(datastore::Error::User( - Error::DeletedAggregationJob(*task_id, *aggregation_job_id).into(), - )); - } + Some(existing_aggregation_job) => existing_aggregation_job, + None => return Ok(None), + }; - if existing_aggregation_job.last_request_hash() != Some(request_hash) { - if let Some(log_forbidden_mutations) = log_forbidden_mutations { - let original_report_ids: Vec<_> = tx - .get_report_aggregations_for_aggregation_job( - vdaf, - &Role::Helper, - task_id, - aggregation_job_id, - ) - .await? - .iter() - .map(|ra| *ra.report_id()) - .collect(); - let mutating_request_report_ids: Vec<_> = req - .prepare_inits() - .iter() - .map(|pi| *pi.report_share().metadata().id()) - .collect(); - let event = AggregationJobInitForbiddenMutationEvent { - task_id: *task_id, - aggregation_job_id: *aggregation_job_id, - original_request_hash: existing_aggregation_job.last_request_hash(), - original_report_ids, - original_batch_id: format!( - "{:?}", - existing_aggregation_job.partial_batch_identifier() - ), - original_aggregation_parameter: existing_aggregation_job - .aggregation_parameter() - .get_encoded() - .map_err(|e| datastore::Error::User(e.into()))?, - mutating_request_hash: Some(request_hash), - mutating_request_report_ids, - mutating_request_batch_id: format!( - "{:?}", - req.batch_selector().batch_identifier() - ), - mutating_request_aggregation_parameter: req - .aggregation_parameter() - .to_vec(), - }; - let event_id = crate::diagnostic::write_event( - log_forbidden_mutations, - "agg-job-illegal-mutation", - event, - ) - .await - .map(|event_id| format!("{event_id:?}")) - .unwrap_or_else(|error| { - tracing::error!(?error, "failed to write hash mismatch event"); - "no event id".to_string() - }); - - tracing::info!( - ?event_id, - original_request_hash = existing_aggregation_job - .last_request_hash() - .map(hex::encode), - mutating_request_hash = hex::encode(request_hash), - "request hash mismatch on retried aggregation job request", - ); - } - return Err(datastore::Error::User( - Error::ForbiddenMutation { - resource_type: "aggregation job", - identifier: aggregation_job_id.to_string(), - } - .into(), - )); - } + if existing_aggregation_job.state() == &AggregationJobState::Deleted { + return Err(datastore::Error::User( + Error::DeletedAggregationJob(*task_id, *aggregation_job_id).into(), + )); + } - // This is a repeated request. Send the same response we computed last time. - return Ok(Some(AggregationJobResp::Finished { - prepare_resps: tx + if existing_aggregation_job.last_request_hash() != Some(request_hash) { + if let Some(log_forbidden_mutations) = log_forbidden_mutations { + let original_report_ids: Vec<_> = tx .get_report_aggregations_for_aggregation_job( vdaf, &Role::Helper, task_id, aggregation_job_id, + existing_aggregation_job.aggregation_parameter(), ) .await? .iter() - .filter_map(ReportAggregation::last_prep_resp) - .cloned() - .collect(), - })); + .map(|ra| *ra.report_id()) + .collect(); + let mutating_request_report_ids: Vec<_> = req + .prepare_inits() + .iter() + .map(|pi| *pi.report_share().metadata().id()) + .collect(); + let event = AggregationJobInitForbiddenMutationEvent { + task_id: *task_id, + aggregation_job_id: *aggregation_job_id, + original_request_hash: existing_aggregation_job.last_request_hash(), + original_report_ids, + original_batch_id: format!( + "{:?}", + existing_aggregation_job.partial_batch_identifier() + ), + original_aggregation_parameter: existing_aggregation_job + .aggregation_parameter() + .get_encoded() + .map_err(|e| datastore::Error::User(e.into()))?, + mutating_request_hash: Some(request_hash), + mutating_request_report_ids, + mutating_request_batch_id: format!( + "{:?}", + req.batch_selector().batch_identifier() + ), + mutating_request_aggregation_parameter: req.aggregation_parameter().to_vec(), + }; + let event_id = crate::diagnostic::write_event( + log_forbidden_mutations, + "agg-job-illegal-mutation", + event, + ) + .await + .map(|event_id| format!("{event_id:?}")) + .unwrap_or_else(|error| { + tracing::error!(?error, "failed to write hash mismatch event"); + "no event id".to_string() + }); + + tracing::info!( + ?event_id, + original_request_hash = existing_aggregation_job + .last_request_hash() + .map(hex::encode), + mutating_request_hash = hex::encode(request_hash), + "request hash mismatch on retried aggregation job request", + ); + } + return Err(datastore::Error::User( + Error::ForbiddenMutation { + resource_type: "aggregation job", + identifier: aggregation_job_id.to_string(), + } + .into(), + )); } - Ok(None) + // This is a repeated request. Send the same response we computed last time. + return Ok(Some(AggregationJobResp::Finished { + prepare_resps: tx + .get_report_aggregations_for_aggregation_job( + vdaf, + &Role::Helper, + task_id, + aggregation_job_id, + existing_aggregation_job.aggregation_parameter(), + ) + .await? + .iter() + .filter_map(ReportAggregation::last_prep_resp) + .cloned() + .collect(), + })); } /// Implements [helper aggregate initialization][1]. @@ -2202,7 +2203,7 @@ impl VdafOps { // Helper is not finished. Await the next message from the Leader to advance to // the next step. ( - ReportAggregationState::WaitingHelper { prepare_state }, + ReportAggregationState::HelperContinue { prepare_state }, PrepareStepResult::Continue { message: outgoing_message, }, @@ -2421,22 +2422,24 @@ impl VdafOps { Box::pin(async move { // Read existing state. - let (aggregation_job, report_aggregations) = try_join!( - tx.get_aggregation_job::(task.id(), &aggregation_job_id), - tx.get_report_aggregations_for_aggregation_job( + let aggregation_job = tx + .get_aggregation_job::(task.id(), &aggregation_job_id) + .await? + .ok_or_else(|| { + datastore::Error::User( + Error::UnrecognizedAggregationJob(*task.id(), aggregation_job_id) + .into(), + ) + })?; + let report_aggregations = tx + .get_report_aggregations_for_aggregation_job( vdaf.as_ref(), &Role::Helper, task.id(), &aggregation_job_id, + aggregation_job.aggregation_parameter(), ) - )?; - - let aggregation_job = aggregation_job.ok_or_else(|| { - datastore::Error::User( - Error::UnrecognizedAggregationJob(*task.id(), aggregation_job_id) - .into(), - ) - })?; + .await?; // Deleted aggregation jobs cannot be stepped if *aggregation_job.state() == AggregationJobState::Deleted { @@ -3201,6 +3204,12 @@ fn write_task_aggregation_counter( task_id: TaskId, counters: TaskAggregationCounter, ) { + if counters.is_zero() { + // Don't spawn a task or interact with the datastore if doing so won't change the state of + // the datastore. + return; + } + // We write task aggregation counters back in a separate tokio task & datastore transaction, // so that any slowness induced by writing the counters (e.g. due to transaction retry) does // not slow the main processing. The lack of transactionality between writing the updated @@ -3354,7 +3363,7 @@ async fn send_request_to_helper( request_body: Option, auth_token: &AuthenticationToken, http_request_duration_histogram: &Histogram, -) -> Result { +) -> Result { let (auth_header, auth_value) = auth_token.request_authentication(); let domain = Arc::from(url.domain().unwrap_or_default()); let method_str = Arc::from(method.as_str()); @@ -3383,7 +3392,7 @@ async fn send_request_to_helper( // Successful response. Ok(response) => { timer.finish_attempt("success"); - Ok(response.body().clone()) + Ok(response) } // HTTP-level error. diff --git a/aggregator/src/aggregator/aggregation_job_continue.rs b/aggregator/src/aggregator/aggregation_job_continue.rs index c9ae5089d..498dc385c 100644 --- a/aggregator/src/aggregator/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/aggregation_job_continue.rs @@ -81,7 +81,7 @@ impl VdafOps { // the report was dropped (if it's not already in an error state) and continue. if matches!( report_agg.state(), - ReportAggregationState::WaitingHelper { .. } + ReportAggregationState::HelperContinue { .. } ) { report_aggregations_to_write.push(WritableReportAggregation::new( report_agg @@ -99,11 +99,11 @@ impl VdafOps { }; let prep_state = match report_aggregation.state() { - ReportAggregationState::WaitingHelper { prepare_state } => prepare_state.clone(), - ReportAggregationState::WaitingLeader { .. } => { + ReportAggregationState::HelperContinue { prepare_state } => prepare_state.clone(), + ReportAggregationState::LeaderContinue { .. } => { return Err(datastore::Error::User( Error::Internal( - "helper encountered unexpected ReportAggregationState::WaitingLeader" + "helper encountered unexpected ReportAggregationState::LeaderContinue" .to_string(), ) .into(), @@ -128,7 +128,7 @@ impl VdafOps { // the report was dropped (if it's not already in an error state) and continue. if matches!( report_aggregation.state(), - ReportAggregationState::WaitingHelper { .. } + ReportAggregationState::HelperContinue { .. } ) { report_aggregations_to_write.push(WritableReportAggregation::new( report_aggregation @@ -189,7 +189,7 @@ impl VdafOps { // state and await the next message from // the Leader to advance preparation. PingPongState::Continued(prepare_state) => ( - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state, }, None, @@ -517,7 +517,7 @@ mod tests { *prepare_init.report_share().metadata().time(), 0, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: *transcript.helper_prepare_transitions[0] .prepare_state(), }, @@ -743,6 +743,7 @@ mod tests { &Role::Helper, &task_id, &aggregation_job_id, + &test_case.aggregation_parameter, ) .await .unwrap(); @@ -795,6 +796,7 @@ mod tests { &Role::Helper, &task_id, &aggregation_job_id, + &test_case.aggregation_parameter, ) .await .unwrap(); diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index da5cc97b2..3be377aff 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -670,7 +670,7 @@ impl AggregationJobCreator { *report.report_id(), *report.client_timestamp(), ord.try_into()?, - ReportAggregationMetadataState::Start, + ReportAggregationMetadataState::Init, )) }) .collect::>()?; @@ -812,7 +812,7 @@ impl AggregationJobCreator { *report.report_id(), *report.client_timestamp(), ord.try_into()?, - ReportAggregationMetadataState::Start, + ReportAggregationMetadataState::Init, )) }) .collect::>()?; @@ -1042,7 +1042,7 @@ mod tests { ( (*report.metadata().id(), ()), report - .as_start_leader_report_aggregation(random(), 0) + .as_leader_init_report_aggregation(random(), 0) .state() .clone(), ) @@ -1222,7 +1222,7 @@ mod tests { ( (*report.metadata().id(), ()), report - .as_start_leader_report_aggregation(random(), 0) + .as_leader_init_report_aggregation(random(), 0) .state() .clone(), ) @@ -1409,7 +1409,7 @@ mod tests { ( (*report.metadata().id(), ()), report - .as_start_leader_report_aggregation(random(), 0) + .as_leader_init_report_aggregation(random(), 0) .state() .clone(), ) @@ -1466,7 +1466,7 @@ mod tests { ( (*report.metadata().id(), ()), report - .as_start_leader_report_aggregation(random(), 0) + .as_leader_init_report_aggregation(random(), 0) .state() .clone(), ) @@ -1802,7 +1802,7 @@ mod tests { ( (*report.metadata().id(), ()), report - .as_start_leader_report_aggregation(random(), 0) + .as_leader_init_report_aggregation(random(), 0) .state() .clone(), ) @@ -1995,7 +1995,7 @@ mod tests { ( (*report.metadata().id(), ()), report - .as_start_leader_report_aggregation(random(), 0) + .as_leader_init_report_aggregation(random(), 0) .state() .clone(), ) @@ -2162,7 +2162,7 @@ mod tests { ( (*report.metadata().id(), ()), report - .as_start_leader_report_aggregation(random(), 0) + .as_leader_init_report_aggregation(random(), 0) .state() .clone(), ) @@ -2257,7 +2257,7 @@ mod tests { ( (*report.metadata().id(), ()), report - .as_start_leader_report_aggregation(random(), 0) + .as_leader_init_report_aggregation(random(), 0) .state() .clone(), ) @@ -2424,7 +2424,7 @@ mod tests { ( (*report.metadata().id(), ()), report - .as_start_leader_report_aggregation(random(), 0) + .as_leader_init_report_aggregation(random(), 0) .state() .clone(), ) @@ -2527,7 +2527,7 @@ mod tests { ( (*report.metadata().id(), ()), report - .as_start_leader_report_aggregation(random(), 0) + .as_leader_init_report_aggregation(random(), 0) .state() .clone(), ) @@ -2724,7 +2724,7 @@ mod tests { ( (*report.metadata().id(), ()), report - .as_start_leader_report_aggregation(random(), 0) + .as_leader_init_report_aggregation(random(), 0) .state() .clone(), ) @@ -2878,7 +2878,7 @@ mod tests { expected_report_aggregations.insert( (*report.metadata().id(), first_aggregation_param), report - .as_start_leader_report_aggregation(random(), 0) + .as_leader_init_report_aggregation(random(), 0) .state() .clone(), ); @@ -2897,14 +2897,14 @@ mod tests { expected_report_aggregations.insert( (*report.metadata().id(), first_aggregation_param), report - .as_start_leader_report_aggregation(random(), 0) + .as_leader_init_report_aggregation(random(), 0) .state() .clone(), ); expected_report_aggregations.insert( (*report.metadata().id(), second_aggregation_param), report - .as_start_leader_report_aggregation(random(), 0) + .as_leader_init_report_aggregation(random(), 0) .state() .clone(), ); @@ -3172,6 +3172,7 @@ mod tests { &Role::Leader, task_id, &agg_job_id, + agg_job.aggregation_parameter(), ) .await .unwrap(); diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index ff042d8ea..937bef26e 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -18,6 +18,7 @@ use backoff::backoff::Backoff; use bytes::Bytes; use educe::Educe; use futures::future::BoxFuture; +use http::{header::RETRY_AFTER, HeaderValue}; use janus_aggregator_core::{ datastore::{ self, @@ -38,8 +39,8 @@ use janus_core::{ use janus_messages::{ batch_mode::{LeaderSelected, TimeInterval}, AggregationJobContinueReq, AggregationJobInitializeReq, AggregationJobResp, - PartialBatchSelector, PrepareContinue, PrepareInit, PrepareStepResult, ReportError, - ReportMetadata, ReportShare, Role, + PartialBatchSelector, PrepareContinue, PrepareInit, PrepareResp, PrepareStepResult, + ReportError, ReportMetadata, ReportShare, Role, }; use opentelemetry::{ metrics::{Counter, Histogram, Meter}, @@ -52,7 +53,7 @@ use prio::{ }; use rayon::iter::{IndexedParallelIterator as _, IntoParallelIterator as _, ParallelIterator as _}; use reqwest::Method; -use std::{collections::HashSet, panic, sync::Arc, time::Duration}; +use std::{collections::HashSet, panic, str::FromStr, sync::Arc, time::Duration}; use tokio::{join, sync::mpsc, try_join}; use tracing::{debug, error, info, info_span, trace_span, warn, Span}; @@ -180,37 +181,27 @@ where { // Read all information about the aggregation job. let (task, aggregation_job, report_aggregations, verify_key) = datastore - .run_tx("step_aggregation_job_1", |tx| { + .run_tx("step_aggregation_job_generic", |tx| { let (lease, vdaf) = (Arc::clone(&lease), Arc::clone(&vdaf)); Box::pin(async move { - let task = tx - .get_aggregator_task(lease.leased().task_id()) - .await? - .ok_or_else(|| { - datastore::Error::User( - anyhow!("couldn't find task {}", lease.leased().task_id()).into(), - ) - })?; - let verify_key = task.vdaf_verify_key().map_err(|_| { - datastore::Error::User( - anyhow!("VDAF verification key has wrong length").into(), - ) - })?; - + let task_future = tx.get_aggregator_task(lease.leased().task_id()); let aggregation_job_future = tx.get_aggregation_job::( lease.leased().task_id(), lease.leased().aggregation_job_id(), ); - let report_aggregations_future = tx - .get_report_aggregations_for_aggregation_job( - vdaf.as_ref(), - &Role::Leader, - lease.leased().task_id(), - lease.leased().aggregation_job_id(), - ); - let (aggregation_job, report_aggregations) = - try_join!(aggregation_job_future, report_aggregations_future)?; + let (task, aggregation_job) = try_join!(task_future, aggregation_job_future,)?; + + let task = task.ok_or_else(|| { + datastore::Error::User( + anyhow!("couldn't find task {}", lease.leased().task_id()).into(), + ) + })?; + let verify_key = task.vdaf_verify_key().map_err(|_| { + datastore::Error::User( + anyhow!("VDAF verification key has wrong length").into(), + ) + })?; let aggregation_job = aggregation_job.ok_or_else(|| { datastore::Error::User( anyhow!( @@ -222,6 +213,16 @@ where ) })?; + let report_aggregations = tx + .get_report_aggregations_for_aggregation_job( + vdaf.as_ref(), + &Role::Leader, + lease.leased().task_id(), + lease.leased().aggregation_job_id(), + aggregation_job.aggregation_parameter(), + ) + .await?; + Ok(( Arc::new(task), aggregation_job, @@ -233,26 +234,30 @@ where .await?; // Figure out the next step based on the non-error report aggregation states, and dispatch accordingly. - let (mut saw_start, mut saw_waiting, mut saw_finished) = (false, false, false); + let (mut saw_init, mut saw_continue, mut saw_poll, mut saw_finished) = + (false, false, false, false); for report_aggregation in &report_aggregations { match report_aggregation.state() { - ReportAggregationState::StartLeader { .. } => saw_start = true, - ReportAggregationState::WaitingLeader { .. } => saw_waiting = true, - ReportAggregationState::WaitingHelper { .. } => { + ReportAggregationState::LeaderInit { .. } => saw_init = true, + ReportAggregationState::LeaderContinue { .. } => saw_continue = true, + ReportAggregationState::LeaderPoll { .. } => saw_poll = true, + + ReportAggregationState::HelperContinue { .. } => { return Err(Error::Internal( - "Leader encountered unexpected ReportAggregationState::WaitingHelper" + "Leader encountered unexpected ReportAggregationState::HelperContinue" .to_string(), )); } + ReportAggregationState::Finished => saw_finished = true, ReportAggregationState::Failed { .. } => (), // ignore failed aggregations } } - match (saw_start, saw_waiting, saw_finished) { - // Only saw report aggregations in state "start" (or failed or invalid). - (true, false, false) => { + match (saw_init, saw_continue, saw_poll, saw_finished) { + // Only saw report aggregations in state "init" (or failed). + (true, false, false, false) => { self.step_aggregation_job_aggregate_init( - Arc::clone(&datastore), + datastore, vdaf, lease, task, @@ -263,10 +268,23 @@ where .await } - // Only saw report aggregations in state "waiting" (or failed or invalid). - (false, true, false) => { + // Only saw report aggregations in state "continue" (or failed). + (false, true, false, false) => { self.step_aggregation_job_aggregate_continue( - Arc::clone(&datastore), + datastore, + vdaf, + lease, + task, + aggregation_job, + report_aggregations, + ) + .await + } + + // Only saw report aggregations in state "poll" (or failed). + (false, false, true, false) => { + self.step_aggregation_job_aggregate_poll( + datastore, vdaf, lease, task, @@ -277,8 +295,9 @@ where } _ => Err(Error::Internal(format!( - "unexpected combination of report aggregation states (saw_start = {saw_start}, \ - saw_waiting = {saw_waiting}, saw_finished = {saw_finished})", + "unexpected combination of report aggregation states (saw_init = {saw_init}, \ + saw_continue = {saw_continue}, saw_poll = {saw_poll}, \ + saw_finished = {saw_finished})", ))), } } @@ -317,7 +336,7 @@ where .filter(|report_aggregation| { matches!( report_aggregation.state(), - &ReportAggregationState::StartLeader { .. } + &ReportAggregationState::LeaderInit { .. } ) }) .collect(); @@ -359,7 +378,7 @@ where leader_input_share, helper_encrypted_input_share, ) = match report_aggregation.state() { - ReportAggregationState::StartLeader { + ReportAggregationState::LeaderInit { public_extensions, public_share, leader_private_extensions, @@ -374,7 +393,7 @@ where ), // Panic safety: this can't happen because we filter to only - // StartLeader-state report aggregations before this loop. + // LeaderInit-state report aggregations before this loop. _ => panic!( "Unexpected report aggregation state: {:?}", report_aggregation.state() @@ -506,7 +525,7 @@ where ); assert_eq!(prepare_inits.len(), stepped_aggregations.len()); - let resp = if !prepare_inits.is_empty() { + let (resp, retry_after) = if !prepare_inits.is_empty() { // Construct request, send it to the helper, and process the response. let request = AggregationJobInitializeReq::::new( aggregation_job @@ -517,11 +536,11 @@ where prepare_inits, ); - let resp_bytes = send_request_to_helper( + let http_response = send_request_to_helper( &self.http_client, self.backoff.clone(), Method::PUT, - task.aggregation_job_uri(aggregation_job.id())? + task.aggregation_job_uri(aggregation_job.id(), None)? .ok_or_else(|| { Error::InvalidConfiguration("task is leader and has no aggregate share URI") })?, @@ -530,25 +549,38 @@ where content_type: AggregationJobInitializeReq::::MEDIA_TYPE, body: Bytes::from(request.get_encoded().map_err(Error::MessageEncode)?), }), - // The only way a task wouldn't have an aggregator auth token in it is in the taskprov - // case, and Janus never acts as the leader with taskprov enabled. + // The only way a task wouldn't have an aggregator auth token in it is in the + // taskprov case, and Janus never acts as the leader with taskprov enabled. task.aggregator_auth_token().ok_or_else(|| { Error::InvalidConfiguration("no aggregator auth token in task") })?, &self.http_request_duration_histogram, ) .await?; - AggregationJobResp::get_decoded(&resp_bytes).map_err(Error::MessageDecode)? + + let retry_after = http_response + .headers() + .get(RETRY_AFTER) + .map(parse_retry_after) + .transpose()?; + let resp = AggregationJobResp::get_decoded(http_response.body()) + .map_err(Error::MessageDecode)?; + + (resp, retry_after) } else { // If there are no prepare inits to send (because every report aggregation was filtered // by the block above), don't send a request to the Helper at all and process an // artificial aggregation job response instead, which will finish the aggregation job. - AggregationJobResp::Finished { - prepare_resps: Vec::new(), - } + ( + AggregationJobResp::Finished { + prepare_resps: Vec::new(), + }, + None, + ) }; - let aggregation_job = Arc::unwrap_or_clone(aggregation_job); + let aggregation_job: AggregationJob = + Arc::unwrap_or_clone(aggregation_job); self.process_response_from_helper( datastore, vdaf, @@ -557,6 +589,7 @@ where aggregation_job, stepped_aggregations, report_aggregations_to_write, + retry_after.as_ref(), resp, ) .await @@ -592,7 +625,7 @@ where .filter(|report_aggregation| { matches!( report_aggregation.state(), - &ReportAggregationState::WaitingLeader { .. } + &ReportAggregationState::LeaderContinue { .. } ) }) .collect(); @@ -626,9 +659,9 @@ where let _entered = span.enter(); let transition = match report_aggregation.state() { - ReportAggregationState::WaitingLeader { transition } => transition, + ReportAggregationState::LeaderContinue { transition } => transition, // Panic safety: this can't happen because we filter to only - // WaitingLeader-state report aggregations before this loop. + // LeaderContinue-state report aggregations before this loop. _ => panic!( "Unexpected report aggregation state: {:?}", report_aggregation.state() @@ -710,11 +743,11 @@ where // Construct request, send it to the helper, and process the response. let request = AggregationJobContinueReq::new(aggregation_job.step(), prepare_continues); - let resp_bytes = send_request_to_helper( + let http_response = send_request_to_helper( &self.http_client, self.backoff.clone(), Method::POST, - task.aggregation_job_uri(aggregation_job.id())? + task.aggregation_job_uri(aggregation_job.id(), None)? .ok_or_else(|| { Error::InvalidConfiguration("task is not leader and has no aggregate share URI") })?, @@ -730,7 +763,14 @@ where &self.http_request_duration_histogram, ) .await?; - let resp = AggregationJobResp::get_decoded(&resp_bytes).map_err(Error::MessageDecode)?; + + let retry_after = http_response + .headers() + .get(RETRY_AFTER) + .map(parse_retry_after) + .transpose()?; + let resp = + AggregationJobResp::get_decoded(http_response.body()).map_err(Error::MessageDecode)?; self.process_response_from_helper( datastore, @@ -740,6 +780,91 @@ where aggregation_job, stepped_aggregations, report_aggregations_to_write, + retry_after.as_ref(), + resp, + ) + .await + } + + async fn step_aggregation_job_aggregate_poll< + const SEED_SIZE: usize, + C: Clock, + B: CollectableBatchMode, + A: vdaf::Aggregator + Send + Sync + 'static, + >( + &self, + datastore: Arc>, + vdaf: Arc, + lease: Arc>, + task: Arc, + aggregation_job: AggregationJob, + report_aggregations: Vec>, + ) -> Result<(), Error> + where + A::AggregationParam: Send + Sync + PartialEq + Eq, + A::AggregateShare: Send + Sync, + A::InputShare: Send + Sync, + A::OutputShare: Send + Sync, + A::PrepareState: Send + Sync + Encode, + A::PrepareShare: Send + Sync, + A::PrepareMessage: Send + Sync, + A::PublicShare: Send + Sync, + { + // Only process non-failed report aggregations; convert non-failed report aggregations into + // stepped aggregations to be compatible with `process_response_from_helper`. + let stepped_aggregations: Vec<_> = report_aggregations + .into_iter() + .filter_map(|report_aggregation| { + let leader_state = match report_aggregation.state() { + ReportAggregationState::LeaderPoll { leader_state } => { + Some(leader_state.clone()) + } + _ => None, + }; + + leader_state.map(|leader_state| SteppedAggregation { + report_aggregation, + leader_state, + }) + }) + .collect(); + + // Poll the Helper for completion. + let http_response = send_request_to_helper( + &self.http_client, + self.backoff.clone(), + Method::GET, + task.aggregation_job_uri(aggregation_job.id(), Some(aggregation_job.step()))? + .ok_or_else(|| { + Error::InvalidConfiguration("task is not leader and has no aggregate share URI") + })?, + AGGREGATION_JOB_ROUTE, + None, + // The only way a task wouldn't have an aggregator auth token in it is in the taskprov + // case, and Janus never acts as the leader with taskprov enabled. + task.aggregator_auth_token() + .ok_or_else(|| Error::InvalidConfiguration("no aggregator auth token in task"))?, + &self.http_request_duration_histogram, + ) + .await?; + + let retry_after = http_response + .headers() + .get(RETRY_AFTER) + .map(parse_retry_after) + .transpose()?; + let resp = + AggregationJobResp::get_decoded(http_response.body()).map_err(Error::MessageDecode)?; + + self.process_response_from_helper( + datastore, + vdaf, + lease, + task, + aggregation_job, + stepped_aggregations, + Vec::new(), + retry_after.as_ref(), resp, ) .await @@ -759,7 +884,8 @@ where task: Arc, aggregation_job: AggregationJob, stepped_aggregations: Vec>, - mut report_aggregations_to_write: Vec>, + report_aggregations_to_write: Vec>, + retry_after: Option<&Duration>, helper_resp: AggregationJobResp, ) -> Result<(), Error> where @@ -772,16 +898,155 @@ where A::PrepareState: Send + Sync + Encode, A::PublicShare: Send + Sync, { - let prepare_resps = match helper_resp { + match helper_resp { // TODO(#3436): implement asynchronous aggregation AggregationJobResp::Processing => { - return Err(Error::Internal( - "asynchronous aggregation not yet implemented".into(), - )) + self.process_response_from_helper_pending( + datastore, + vdaf, + lease, + task, + aggregation_job, + stepped_aggregations, + report_aggregations_to_write, + retry_after, + ) + .await } - AggregationJobResp::Finished { prepare_resps } => prepare_resps, - }; + AggregationJobResp::Finished { prepare_resps } => { + self.process_response_from_helper_finished( + datastore, + vdaf, + lease, + task, + aggregation_job, + stepped_aggregations, + report_aggregations_to_write, + prepare_resps, + ) + .await + } + } + } + + async fn process_response_from_helper_pending< + const SEED_SIZE: usize, + C: Clock, + B: CollectableBatchMode, + A: vdaf::Aggregator + Send + Sync + 'static, + >( + &self, + datastore: Arc>, + vdaf: Arc, + lease: Arc>, + task: Arc, + aggregation_job: AggregationJob, + stepped_aggregations: Vec>, + mut report_aggregations_to_write: Vec>, + retry_after: Option<&Duration>, + ) -> Result<(), Error> + where + A::AggregationParam: Send + Sync + Eq + PartialEq, + A::AggregateShare: Send + Sync, + A::InputShare: Send + Sync, + A::OutputShare: Send + Sync, + A::PrepareMessage: Send + Sync, + A::PrepareShare: Send + Sync, + A::PrepareState: Send + Sync + Encode, + A::PublicShare: Send + Sync, + { + // Any non-failed report aggregations are set to the Poll state, allowing them to be polled + // when the aggregation job is next picked up. + report_aggregations_to_write.extend(stepped_aggregations.into_iter().map( + |stepped_aggregation| { + WritableReportAggregation::new( + stepped_aggregation.report_aggregation.with_state( + ReportAggregationState::LeaderPoll { + leader_state: stepped_aggregation.leader_state, + }, + ), + // Even if we have recovered an output share (i.e., + // `stepped_aggregation.leader_state` is Finished), we don't include it here: we + // aren't done with aggregation until we receive a response from the Helper, so + // it would be incorrect to merge the results into the batch aggregations at + // this point. + None, + ) + }, + )); + + // Write everything back to storage. + let mut aggregation_job_writer = + AggregationJobWriter::::new( + Arc::clone(&task), + self.batch_aggregation_shard_count, + Some(AggregationJobWriterMetrics { + report_aggregation_success_counter: self.aggregation_success_counter.clone(), + aggregate_step_failure_counter: self.aggregate_step_failure_counter.clone(), + aggregated_report_share_dimension_histogram: self + .aggregated_report_share_dimension_histogram + .clone(), + }), + ); + aggregation_job_writer.put(aggregation_job, report_aggregations_to_write)?; + let aggregation_job_writer = Arc::new(aggregation_job_writer); + + let retry_after = retry_after + .copied() + .or_else(|| Some(Duration::from_secs(60))); + let counters = datastore + .run_tx("process_response_from_helper_pending", |tx| { + let vdaf = Arc::clone(&vdaf); + let aggregation_job_writer = Arc::clone(&aggregation_job_writer); + let lease = Arc::clone(&lease); + + Box::pin(async move { + let ((_, counters), _) = try_join!( + aggregation_job_writer.write(tx, Arc::clone(&vdaf)), + tx.release_aggregation_job(&lease, retry_after.as_ref()), + )?; + Ok(counters) + }) + }) + .await?; + + write_task_aggregation_counter( + datastore, + self.task_counter_shard_count, + *task.id(), + counters, + ); + + Ok(()) + } + + async fn process_response_from_helper_finished< + const SEED_SIZE: usize, + C: Clock, + B: CollectableBatchMode, + A: vdaf::Aggregator + Send + Sync + 'static, + >( + &self, + datastore: Arc>, + vdaf: Arc, + lease: Arc>, + task: Arc, + aggregation_job: AggregationJob, + stepped_aggregations: Vec>, + mut report_aggregations_to_write: Vec>, + prepare_resps: Vec, + ) -> Result<(), Error> + where + A::AggregationParam: Send + Sync + Eq + PartialEq, + A::AggregateShare: Send + Sync, + A::InputShare: Send + Sync, + A::OutputShare: Send + Sync, + A::PrepareMessage: Send + Sync, + A::PrepareShare: Send + Sync, + A::PrepareState: Send + Sync + Encode, + A::PublicShare: Send + Sync, + { // Handle response, computing the new report aggregations to be stored. let expected_report_aggregation_count = report_aggregations_to_write.len() + stepped_aggregations.len(); @@ -823,99 +1088,99 @@ where let ctx = vdaf_application_context(&task_id); stepped_aggregations.into_par_iter().zip(prepare_resps).try_for_each( - |(stepped_aggregation, helper_prep_resp)| { - let _entered = span.enter(); - - let (new_state, output_share) = match helper_prep_resp.result() { - PrepareStepResult::Continue { - message: helper_prep_msg, - } => { - let state_and_message = trace_span!("VDAF preparation (leader continuation)") - .in_scope(|| { - vdaf.leader_continued( - &ctx, - stepped_aggregation.leader_state.clone(), - aggregation_job.aggregation_parameter(), - helper_prep_msg, - ) - .map_err(|ping_pong_error| { - handle_ping_pong_error( - &task_id, - Role::Leader, - stepped_aggregation.report_aggregation.report_id(), - ping_pong_error, - &aggregate_step_failure_counter, + |(stepped_aggregation, helper_prep_resp)| { + let _entered = span.enter(); + + let (new_state, output_share) = match helper_prep_resp.result() { + PrepareStepResult::Continue { + message: helper_prep_msg, + } => { + let state_and_message = trace_span!("VDAF preparation (leader continuation)") + .in_scope(|| { + vdaf.leader_continued( + &ctx, + stepped_aggregation.leader_state.clone(), + aggregation_job.aggregation_parameter(), + helper_prep_msg, ) - }) - }); - - match state_and_message { - Ok(PingPongContinuedValue::WithMessage { transition }) => { - // Leader did not finish. Store our state and outgoing message for the - // next step. - // n.b. it's possible we finished and recovered an output share at the - // VDAF level (i.e., state may be PingPongState::Finished) but we cannot - // finish at the DAP layer and commit the output share until we get - // confirmation from the Helper that they finished, too. - (ReportAggregationState::WaitingLeader { transition }, None) + .map_err(|ping_pong_error| { + handle_ping_pong_error( + &task_id, + Role::Leader, + stepped_aggregation.report_aggregation.report_id(), + ping_pong_error, + &aggregate_step_failure_counter, + ) + }) + }); + + match state_and_message { + Ok(PingPongContinuedValue::WithMessage { transition }) => { + // Leader did not finish. Store our state and outgoing message for the + // next step. + // n.b. it's possible we finished and recovered an output share at the + // VDAF level (i.e., state may be PingPongState::Finished) but we cannot + // finish at the DAP layer and commit the output share until we get + // confirmation from the Helper that they finished, too. + (ReportAggregationState::LeaderContinue { transition }, None) + } + Ok(PingPongContinuedValue::FinishedNoMessage { output_share }) => { + // We finished and have no outgoing message, meaning the Helper was + // already finished. Commit the output share. + (ReportAggregationState::Finished, Some(output_share)) + } + Err(report_error) => { + (ReportAggregationState::Failed { report_error }, None) + } } - Ok(PingPongContinuedValue::FinishedNoMessage { output_share }) => { - // We finished and have no outgoing message, meaning the Helper was - // already finished. Commit the output share. + } + + PrepareStepResult::Finished => { + if let PingPongState::Finished(output_share) = stepped_aggregation.leader_state + { + // Helper finished and we had already finished. Commit the output share. (ReportAggregationState::Finished, Some(output_share)) - } - Err(report_error) => { - (ReportAggregationState::Failed { report_error }, None) + } else { + warn!( + report_id = %stepped_aggregation.report_aggregation.report_id(), + "Helper finished but Leader did not", + ); + aggregate_step_failure_counter + .add(1, &[KeyValue::new("type", "finish_mismatch")]); + ( + ReportAggregationState::Failed { + report_error: ReportError::VdafPrepError, + }, + None, + ) } } - } - PrepareStepResult::Finished => { - if let PingPongState::Finished(output_share) = stepped_aggregation.leader_state - { - // Helper finished and we had already finished. Commit the output share. - (ReportAggregationState::Finished, Some(output_share)) - } else { - warn!( + PrepareStepResult::Reject(err) => { + // If the helper failed, we move to FAILED immediately. + // TODO(#236): is it correct to just record the transition error that the helper reports? + info!( report_id = %stepped_aggregation.report_aggregation.report_id(), - "Helper finished but Leader did not", + helper_error = ?err, + "Helper couldn't step report aggregation", ); aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "finish_mismatch")]); + .add(1, &[KeyValue::new("type", "helper_step_failure")]); ( ReportAggregationState::Failed { - report_error: ReportError::VdafPrepError, + report_error: *err, }, None, ) } - } - - PrepareStepResult::Reject(err) => { - // If the helper failed, we move to FAILED immediately. - // TODO(#236): is it correct to just record the transition error that the helper reports? - info!( - report_id = %stepped_aggregation.report_aggregation.report_id(), - helper_error = ?err, - "Helper couldn't step report aggregation", - ); - aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "helper_step_failure")]); - ( - ReportAggregationState::Failed { - report_error: *err, - }, - None, - ) - } - }; + }; - ra_sender.send(WritableReportAggregation::new( - stepped_aggregation.report_aggregation.with_state(new_state), - output_share, - )) - } - ) + ra_sender.send(WritableReportAggregation::new( + stepped_aggregation.report_aggregation.with_state(new_state), + output_share, + )) + } + ) } }); @@ -961,7 +1226,7 @@ where let aggregation_job_writer = Arc::new(aggregation_job_writer); let counters = datastore - .run_tx("step_aggregation_job_2", |tx| { + .run_tx("process_response_from_helper_finished", |tx| { let vdaf = Arc::clone(&vdaf); let aggregation_job_writer = Arc::clone(&aggregation_job_writer); let lease = Arc::clone(&lease); @@ -969,7 +1234,7 @@ where Box::pin(async move { let ((_, counters), _) = try_join!( aggregation_job_writer.write(tx, Arc::clone(&vdaf)), - tx.release_aggregation_job(&lease), + tx.release_aggregation_job(&lease, None), )?; Ok(counters) }) @@ -1041,7 +1306,7 @@ where let vdaf = Arc::new(vdaf); let batch_aggregation_shard_count = self.batch_aggregation_shard_count; let (aggregation_job_uri, aggregator_auth_token) = datastore - .run_tx("cancel_aggregation_job", |tx| { + .run_tx("cancel_aggregation_job_generic", |tx| { let vdaf = Arc::clone(&vdaf); let lease = Arc::clone(&lease); @@ -1049,18 +1314,12 @@ where // On abandoning an aggregation job, we update the aggregation job's state field // to Abandoned, but leave all other state (e.g. report aggregations) alone to // ease debugging. - let (task, aggregation_job, report_aggregations) = try_join!( + let (task, aggregation_job) = try_join!( tx.get_aggregator_task(lease.leased().task_id()), tx.get_aggregation_job::( lease.leased().task_id(), lease.leased().aggregation_job_id() ), - tx.get_report_aggregations_for_aggregation_job( - vdaf.as_ref(), - &Role::Leader, - lease.leased().task_id(), - lease.leased().aggregation_job_id() - ), )?; let task = task.ok_or_else(|| { @@ -1081,13 +1340,21 @@ where })? .with_state(AggregationJobState::Abandoned); - let report_aggregations = report_aggregations + let report_aggregations = tx + .get_report_aggregations_for_aggregation_job( + vdaf.as_ref(), + &Role::Leader, + lease.leased().task_id(), + lease.leased().aggregation_job_id(), + aggregation_job.aggregation_parameter(), + ) + .await? .into_iter() .map(|ra| WritableReportAggregation::new(ra, None)) .collect(); let aggregation_job_uri = - task.aggregation_job_uri(lease.leased().aggregation_job_id()); + task.aggregation_job_uri(lease.leased().aggregation_job_id(), None); let aggregator_auth_token = task.aggregator_auth_token().cloned(); let mut aggregation_job_writer = @@ -1100,7 +1367,7 @@ where try_join!( aggregation_job_writer.write(tx, vdaf), - tx.release_aggregation_job(&lease), + tx.release_aggregation_job(&lease, None), )?; Ok((aggregation_job_uri, aggregator_auth_token)) @@ -1246,3 +1513,11 @@ struct SteppedAggregation, leader_state: PingPongState, } + +fn parse_retry_after(header_value: &HeaderValue) -> Result { + let val = header_value + .to_str() + .map_err(|err| Error::BadRequest(err.to_string()))?; + let val = u64::from_str(val).map_err(|err| Error::BadRequest(err.to_string()))?; + Ok(Duration::from_secs(val)) +} diff --git a/aggregator/src/aggregator/aggregation_job_driver/tests.rs b/aggregator/src/aggregator/aggregation_job_driver/tests.rs index b6e01f067..6586f4ed3 100644 --- a/aggregator/src/aggregator/aggregation_job_driver/tests.rs +++ b/aggregator/src/aggregator/aggregation_job_driver/tests.rs @@ -134,7 +134,7 @@ async fn aggregation_job_driver() { .await .unwrap(); tx.put_report_aggregation( - &report.as_start_leader_report_aggregation(aggregation_job_id, 0), + &report.as_leader_init_report_aggregation(aggregation_job_id, 0), ) .await .unwrap(); @@ -319,6 +319,7 @@ async fn aggregation_job_driver() { task.id(), &aggregation_job_id, &report_id, + &aggregation_param, ) .await .unwrap() @@ -346,7 +347,7 @@ async fn aggregation_job_driver() { } #[tokio::test] -async fn step_time_interval_aggregation_job_init_single_step() { +async fn sync_time_interval_aggregation_job_init_single_step() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; @@ -475,7 +476,7 @@ async fn step_time_interval_aggregation_job_init_single_step() { .enumerate() { tx.put_report_aggregation( - &report.as_start_leader_report_aggregation(aggregation_job_id, ord as u64), + &report.as_leader_init_report_aggregation(aggregation_job_id, ord as u64), ) .await .unwrap(); @@ -564,7 +565,7 @@ async fn step_time_interval_aggregation_job_init_single_step() { AggregationJobInitializeReq::::MEDIA_TYPE, ) .match_body(leader_request.get_encoded().unwrap()) - .with_status(200) + .with_status(201) .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) .with_body(helper_response.get_encoded().unwrap()) .create_async() @@ -694,6 +695,7 @@ async fn step_time_interval_aggregation_job_init_single_step() { task.id(), &aggregation_job_id, &report_id, + &(), ) .await .unwrap() @@ -705,6 +707,7 @@ async fn step_time_interval_aggregation_job_init_single_step() { task.id(), &aggregation_job_id, &repeated_public_extension_report_id, + &(), ) .await .unwrap() @@ -716,6 +719,7 @@ async fn step_time_interval_aggregation_job_init_single_step() { task.id(), &aggregation_job_id, &repeated_private_extension_report_id, + &(), ) .await .unwrap() @@ -727,6 +731,7 @@ async fn step_time_interval_aggregation_job_init_single_step() { task.id(), &aggregation_job_id, &repeated_public_private_extension_report_id, + &(), ) .await .unwrap() @@ -771,7 +776,7 @@ async fn step_time_interval_aggregation_job_init_single_step() { } #[tokio::test] -async fn step_time_interval_aggregation_job_init_two_steps() { +async fn sync_time_interval_aggregation_job_init_two_steps() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; @@ -841,7 +846,7 @@ async fn step_time_interval_aggregation_job_init_two_steps() { .unwrap(); tx.put_report_aggregation( - &report.as_start_leader_report_aggregation(aggregation_job_id, 0), + &report.as_leader_init_report_aggregation(aggregation_job_id, 0), ) .await .unwrap(); @@ -913,7 +918,7 @@ async fn step_time_interval_aggregation_job_init_two_steps() { AggregationJobInitializeReq::::MEDIA_TYPE, ) .match_body(leader_request.get_encoded().unwrap()) - .with_status(200) + .with_status(201) .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) .with_body(helper_response.get_encoded().unwrap()) .create_async() @@ -951,7 +956,7 @@ async fn step_time_interval_aggregation_job_init_two_steps() { *report.metadata().time(), 0, None, - ReportAggregationState::WaitingLeader { + ReportAggregationState::LeaderContinue { transition: transcript.leader_prepare_transitions[1] .transition .clone() @@ -994,6 +999,7 @@ async fn step_time_interval_aggregation_job_init_two_steps() { task.id(), &aggregation_job_id, &report_id, + &aggregation_param, ) .await .unwrap() @@ -1021,7 +1027,7 @@ async fn step_time_interval_aggregation_job_init_two_steps() { } #[tokio::test] -async fn step_time_interval_aggregation_job_init_partially_garbage_collected() { +async fn sync_time_interval_aggregation_job_init_partially_garbage_collected() { // This is a regression test for https://github.com/divviup/janus/issues/2464. const OLDEST_ALLOWED_REPORT_TIMESTAMP: Time = Time::from_seconds_since_epoch(1000); @@ -1144,12 +1150,12 @@ async fn step_time_interval_aggregation_job_init_partially_garbage_collected() { .await .unwrap(); tx.put_report_aggregation( - &gc_eligible_report.as_start_leader_report_aggregation(aggregation_job_id, 0), + &gc_eligible_report.as_leader_init_report_aggregation(aggregation_job_id, 0), ) .await .unwrap(); tx.put_report_aggregation( - &gc_ineligible_report.as_start_leader_report_aggregation(aggregation_job_id, 1), + &gc_ineligible_report.as_leader_init_report_aggregation(aggregation_job_id, 1), ) .await .unwrap(); @@ -1271,7 +1277,7 @@ async fn step_time_interval_aggregation_job_init_partially_garbage_collected() { AggregationJobInitializeReq::::MEDIA_TYPE, ) .match_body(leader_request.get_encoded().unwrap()) - .with_status(200) + .with_status(201) .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) .with_body(helper_response.get_encoded().unwrap()) .create_async() @@ -1369,6 +1375,7 @@ async fn step_time_interval_aggregation_job_init_partially_garbage_collected() { &Role::Leader, task.id(), &aggregation_job_id, + &(), ) .await .unwrap(); @@ -1392,7 +1399,7 @@ async fn step_time_interval_aggregation_job_init_partially_garbage_collected() { } #[tokio::test] -async fn step_leader_selected_aggregation_job_init_single_step() { +async fn sync_leader_selected_aggregation_job_init_single_step() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; @@ -1471,7 +1478,7 @@ async fn step_leader_selected_aggregation_job_init_single_step() { .unwrap(); tx.put_report_aggregation( - &report.as_start_leader_report_aggregation(aggregation_job_id, 0), + &report.as_leader_init_report_aggregation(aggregation_job_id, 0), ) .await .unwrap(); @@ -1559,7 +1566,7 @@ async fn step_leader_selected_aggregation_job_init_single_step() { AggregationJobInitializeReq::::MEDIA_TYPE, ) .match_body(leader_request.get_encoded().unwrap()) - .with_status(200) + .with_status(201) .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) .with_body(helper_response.get_encoded().unwrap()) .create_async() @@ -1650,6 +1657,7 @@ async fn step_leader_selected_aggregation_job_init_single_step() { task.id(), &aggregation_job_id, &report_id, + &() ) .await .unwrap() @@ -1681,7 +1689,7 @@ async fn step_leader_selected_aggregation_job_init_single_step() { } #[tokio::test] -async fn step_leader_selected_aggregation_job_init_two_steps() { +async fn sync_leader_selected_aggregation_job_init_two_steps() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; @@ -1759,7 +1767,7 @@ async fn step_leader_selected_aggregation_job_init_two_steps() { .unwrap(); tx.put_report_aggregation( - &report.as_start_leader_report_aggregation(aggregation_job_id, 0), + &report.as_leader_init_report_aggregation(aggregation_job_id, 0), ) .await .unwrap(); @@ -1831,7 +1839,7 @@ async fn step_leader_selected_aggregation_job_init_two_steps() { AggregationJobInitializeReq::::MEDIA_TYPE, ) .match_body(leader_request.get_encoded().unwrap()) - .with_status(200) + .with_status(201) .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) .with_body(helper_response.get_encoded().unwrap()) .create_async() @@ -1869,7 +1877,7 @@ async fn step_leader_selected_aggregation_job_init_two_steps() { *report.metadata().time(), 0, None, - ReportAggregationState::WaitingLeader { + ReportAggregationState::LeaderContinue { transition: transcript.leader_prepare_transitions[1] .transition .clone() @@ -1912,6 +1920,7 @@ async fn step_leader_selected_aggregation_job_init_two_steps() { task.id(), &aggregation_job_id, &report_id, + &aggregation_param, ) .await .unwrap() @@ -1939,7 +1948,7 @@ async fn step_leader_selected_aggregation_job_init_two_steps() { } #[tokio::test] -async fn step_time_interval_aggregation_job_continue() { +async fn sync_time_interval_aggregation_job_continue() { // Setup: insert a client report and add it to an aggregation job whose state has already // been stepped once. install_test_trace_subscriber(); @@ -2029,7 +2038,7 @@ async fn step_time_interval_aggregation_job_continue() { *report.metadata().time(), 0, None, - ReportAggregationState::WaitingLeader { + ReportAggregationState::LeaderContinue { transition: transcript.leader_prepare_transitions[1] .transition .clone() @@ -2126,7 +2135,7 @@ async fn step_time_interval_aggregation_job_continue() { .match_header(header, value.as_str()) .match_header(CONTENT_TYPE.as_str(), AggregationJobContinueReq::MEDIA_TYPE) .match_body(leader_request.get_encoded().unwrap()) - .with_status(200) + .with_status(202) .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) .with_body(helper_response.get_encoded().unwrap()) .create_async() @@ -2232,6 +2241,7 @@ async fn step_time_interval_aggregation_job_continue() { task.id(), &aggregation_job_id, report_metadata.id(), + &aggregation_param, ) .await .unwrap() @@ -2260,7 +2270,7 @@ async fn step_time_interval_aggregation_job_continue() { } #[tokio::test] -async fn step_leader_selected_aggregation_job_continue() { +async fn sync_leader_selected_aggregation_job_continue() { // Setup: insert a client report and add it to an aggregation job whose state has already // been stepped once. install_test_trace_subscriber(); @@ -2350,7 +2360,7 @@ async fn step_leader_selected_aggregation_job_continue() { *report.metadata().time(), 0, None, - ReportAggregationState::WaitingLeader { + ReportAggregationState::LeaderContinue { transition: transcript.leader_prepare_transitions[1] .transition .clone() @@ -2431,7 +2441,7 @@ async fn step_leader_selected_aggregation_job_continue() { .match_header(header, value.as_str()) .match_header(CONTENT_TYPE.as_str(), AggregationJobContinueReq::MEDIA_TYPE) .match_body(leader_request.get_encoded().unwrap()) - .with_status(200) + .with_status(202) .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) .with_body(helper_response.get_encoded().unwrap()) .create_async() @@ -2512,6 +2522,7 @@ async fn step_leader_selected_aggregation_job_continue() { task.id(), &aggregation_job_id, report_metadata.id(), + &aggregation_param, ) .await .unwrap() @@ -2546,6 +2557,2178 @@ async fn step_leader_selected_aggregation_job_continue() { .await; } +#[tokio::test] +async fn async_aggregation_job_init_to_pending() { + // Setup: insert a client report and add it to a new aggregation job. + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let vdaf = Arc::new(dummy::Vdaf::new(1)); + + let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); + + let leader_task = task.leader_view().unwrap(); + + let time = clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(); + let batch_identifier = TimeInterval::to_batch_identifier(&leader_task, &(), &time).unwrap(); + let report_metadata = ReportMetadata::new(random(), time, Vec::new()); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + let aggregation_param = dummy::AggregationParam(0); + + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &0, + ); + + let agg_auth_token = task.aggregator_auth_token(); + let helper_hpke_keypair = HpkeKeypair::test(); + let report = LeaderStoredReport::generate( + *task.id(), + report_metadata, + helper_hpke_keypair.config(), + Vec::new(), + &transcript, + ); + let aggregation_job_id = random(); + + let lease = ds + .run_unnamed_tx(|tx| { + let task = leader_task.clone(); + let report = report.clone(); + + Box::pin(async move { + tx.put_aggregator_task(&task).await.unwrap(); + + tx.put_client_report(&report).await.unwrap(); + tx.scrub_client_report(report.task_id(), report.metadata().id()) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation( + &report.as_leader_init_report_aggregation(aggregation_job_id, 0), + ) + .await + .unwrap(); + + tx.put_batch_aggregation(&BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + batch_identifier, + aggregation_param, + 0, + Interval::from_time(&time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )) + .await + .unwrap(); + + Ok(tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await + .unwrap() + .remove(0)) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Setup: prepare mocked HTTP response. + let leader_request = AggregationJobInitializeReq::new( + aggregation_param.get_encoded().unwrap(), + PartialBatchSelector::new_time_interval(), + Vec::from([PrepareInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded().unwrap(), + report.helper_encrypted_input_share().clone(), + ), + transcript.leader_prepare_transitions[0].message.clone(), + )]), + ); + let helper_response = AggregationJobResp::Processing; + let (header, value) = agg_auth_token.request_authentication(); + let mocked_aggregate_request = server + .mock( + "PUT", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .match_header(header, value.as_str()) + .match_header( + CONTENT_TYPE.as_str(), + AggregationJobInitializeReq::::MEDIA_TYPE, + ) + .match_body(leader_request.get_encoded().unwrap()) + .with_status(201) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded().unwrap()) + .create_async() + .await; + + // Run: create an aggregation job driver & try to step the aggregation we've created twice. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + LimitedRetryer::new(1), + &noop_meter(), + BATCH_AGGREGATION_SHARD_COUNT, + TASK_AGGREGATION_COUNTER_SHARD_COUNT, + ); + aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease)) + .await + .unwrap(); + + // Verify. + mocked_aggregate_request.assert_async().await; + + let want_aggregation_job = AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(0), + ); + + let want_report_aggregation = ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::LeaderPoll { + leader_state: transcript.leader_prepare_transitions[0].state.clone(), + }, + ); + + let want_batch_aggregations = + Vec::from([BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + batch_identifier, + aggregation_param, + 0, + Interval::from_time(&time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )]); + + let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds + .run_unnamed_tx(|tx| { + let vdaf = Arc::clone(&vdaf); + let task = task.clone(); + let report_id = *report.metadata().id(); + + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::<0, TimeInterval, dummy::Vdaf>( + task.id(), + &aggregation_job_id, + ) + .await + .unwrap() + .unwrap(); + let report_aggregation = tx + .get_report_aggregation_by_report_id( + vdaf.as_ref(), + &Role::Leader, + task.id(), + &aggregation_job_id, + &report_id, + &aggregation_param, + ) + .await + .unwrap() + .unwrap(); + let batch_aggregations = merge_batch_aggregations_by_batch( + tx.get_batch_aggregations_for_task::<0, TimeInterval, dummy::Vdaf>( + &vdaf, + task.id(), + ) + .await + .unwrap(), + ); + + Ok((aggregation_job, report_aggregation, batch_aggregations)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch_aggregations, got_batch_aggregations); + + assert_task_aggregation_counter(&ds, *task.id(), TaskAggregationCounter::new_with_values(0)) + .await; +} + +#[tokio::test] +async fn async_aggregation_job_init_to_pending_two_step() { + // Setup: insert a client report and add it to a new aggregation job. + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let vdaf = Arc::new(dummy::Vdaf::new(2)); + + let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); + + let leader_task = task.leader_view().unwrap(); + + let time = clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(); + let batch_identifier = TimeInterval::to_batch_identifier(&leader_task, &(), &time).unwrap(); + let report_metadata = ReportMetadata::new(random(), time, Vec::new()); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + let aggregation_param = dummy::AggregationParam(0); + + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &0, + ); + + let agg_auth_token = task.aggregator_auth_token(); + let helper_hpke_keypair = HpkeKeypair::test(); + let report = LeaderStoredReport::generate( + *task.id(), + report_metadata, + helper_hpke_keypair.config(), + Vec::new(), + &transcript, + ); + let aggregation_job_id = random(); + + let lease = ds + .run_unnamed_tx(|tx| { + let task = leader_task.clone(); + let report = report.clone(); + + Box::pin(async move { + tx.put_aggregator_task(&task).await.unwrap(); + + tx.put_client_report(&report).await.unwrap(); + tx.scrub_client_report(report.task_id(), report.metadata().id()) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation( + &report.as_leader_init_report_aggregation(aggregation_job_id, 0), + ) + .await + .unwrap(); + + tx.put_batch_aggregation(&BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + batch_identifier, + aggregation_param, + 0, + Interval::from_time(&time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )) + .await + .unwrap(); + + Ok(tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await + .unwrap() + .remove(0)) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Setup: prepare mocked HTTP response. + let leader_request = AggregationJobInitializeReq::new( + aggregation_param.get_encoded().unwrap(), + PartialBatchSelector::new_time_interval(), + Vec::from([PrepareInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded().unwrap(), + report.helper_encrypted_input_share().clone(), + ), + transcript.leader_prepare_transitions[0].message.clone(), + )]), + ); + let helper_response = AggregationJobResp::Processing; + let (header, value) = agg_auth_token.request_authentication(); + let mocked_aggregate_request = server + .mock( + "PUT", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .match_header(header, value.as_str()) + .match_header( + CONTENT_TYPE.as_str(), + AggregationJobInitializeReq::::MEDIA_TYPE, + ) + .match_body(leader_request.get_encoded().unwrap()) + .with_status(201) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded().unwrap()) + .create_async() + .await; + + // Run: create an aggregation job driver & try to step the aggregation we've created twice. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + LimitedRetryer::new(1), + &noop_meter(), + BATCH_AGGREGATION_SHARD_COUNT, + TASK_AGGREGATION_COUNTER_SHARD_COUNT, + ); + aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease)) + .await + .unwrap(); + + // Verify. + mocked_aggregate_request.assert_async().await; + + let want_aggregation_job = AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(0), + ); + + let want_report_aggregation = ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::LeaderPoll { + leader_state: transcript.leader_prepare_transitions[0].state.clone(), + }, + ); + + let want_batch_aggregations = + Vec::from([BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + batch_identifier, + aggregation_param, + 0, + Interval::from_time(&time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )]); + + let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds + .run_unnamed_tx(|tx| { + let vdaf = Arc::clone(&vdaf); + let task = task.clone(); + let report_id = *report.metadata().id(); + + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::<0, TimeInterval, dummy::Vdaf>( + task.id(), + &aggregation_job_id, + ) + .await + .unwrap() + .unwrap(); + let report_aggregation = tx + .get_report_aggregation_by_report_id( + vdaf.as_ref(), + &Role::Leader, + task.id(), + &aggregation_job_id, + &report_id, + &aggregation_param, + ) + .await + .unwrap() + .unwrap(); + let batch_aggregations = merge_batch_aggregations_by_batch( + tx.get_batch_aggregations_for_task::<0, TimeInterval, dummy::Vdaf>( + &vdaf, + task.id(), + ) + .await + .unwrap(), + ); + + Ok((aggregation_job, report_aggregation, batch_aggregations)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch_aggregations, got_batch_aggregations); + + assert_task_aggregation_counter(&ds, *task.id(), TaskAggregationCounter::new_with_values(0)) + .await; +} + +#[tokio::test] +async fn async_aggregation_job_continue_to_pending() { + // Setup: insert a client report and add it to a new aggregation job. + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let vdaf = Arc::new(dummy::Vdaf::new(2)); + + let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); + + let leader_task = task.leader_view().unwrap(); + + let time = clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(); + let batch_identifier = TimeInterval::to_batch_identifier(&leader_task, &(), &time).unwrap(); + let report_metadata = ReportMetadata::new(random(), time, Vec::new()); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + let aggregation_param = dummy::AggregationParam(0); + + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &0, + ); + + let agg_auth_token = task.aggregator_auth_token(); + let helper_hpke_keypair = HpkeKeypair::test(); + let report = LeaderStoredReport::generate( + *task.id(), + report_metadata, + helper_hpke_keypair.config(), + Vec::new(), + &transcript, + ); + let aggregation_job_id = random(); + + let lease = ds + .run_unnamed_tx(|tx| { + let task = leader_task.clone(); + let report = report.clone(); + let transition = transcript.leader_prepare_transitions[1] + .transition + .clone() + .unwrap(); + + Box::pin(async move { + tx.put_aggregator_task(&task).await.unwrap(); + + tx.put_client_report(&report).await.unwrap(); + tx.scrub_client_report(report.task_id(), report.metadata().id()) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(1), + )) + .await + .unwrap(); + + tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::LeaderContinue { transition }, + )) + .await + .unwrap(); + + tx.put_batch_aggregation(&BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + batch_identifier, + aggregation_param, + 0, + Interval::from_time(&time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )) + .await + .unwrap(); + + Ok(tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await + .unwrap() + .remove(0)) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Setup: prepare mocked HTTP response. + let leader_request = AggregationJobContinueReq::new( + AggregationJobStep::from(1), + Vec::from([PrepareContinue::new( + *report.metadata().id(), + transcript.leader_prepare_transitions[1].message.clone(), + )]), + ); + let helper_response = AggregationJobResp::Processing; + let (header, value) = agg_auth_token.request_authentication(); + let mocked_aggregate_request = server + .mock( + "POST", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .match_header(header, value.as_str()) + .match_header(CONTENT_TYPE.as_str(), AggregationJobContinueReq::MEDIA_TYPE) + .match_body(leader_request.get_encoded().unwrap()) + .with_status(202) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded().unwrap()) + .create_async() + .await; + + // Run: create an aggregation job driver & try to step the aggregation we've created twice. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + LimitedRetryer::new(1), + &noop_meter(), + BATCH_AGGREGATION_SHARD_COUNT, + TASK_AGGREGATION_COUNTER_SHARD_COUNT, + ); + aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease)) + .await + .unwrap(); + + // Verify. + mocked_aggregate_request.assert_async().await; + + let want_aggregation_job = AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(1), + ); + + let want_report_aggregation = ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::LeaderPoll { + leader_state: transcript.leader_prepare_transitions[1].state.clone(), + }, + ); + + let want_batch_aggregations = + Vec::from([BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + batch_identifier, + aggregation_param, + 0, + Interval::from_time(&time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )]); + + let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds + .run_unnamed_tx(|tx| { + let vdaf = Arc::clone(&vdaf); + let task = task.clone(); + let report_id = *report.metadata().id(); + + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::<0, TimeInterval, dummy::Vdaf>( + task.id(), + &aggregation_job_id, + ) + .await + .unwrap() + .unwrap(); + let report_aggregation = tx + .get_report_aggregation_by_report_id( + vdaf.as_ref(), + &Role::Leader, + task.id(), + &aggregation_job_id, + &report_id, + &aggregation_param, + ) + .await + .unwrap() + .unwrap(); + let batch_aggregations = merge_batch_aggregations_by_batch( + tx.get_batch_aggregations_for_task::<0, TimeInterval, dummy::Vdaf>( + &vdaf, + task.id(), + ) + .await + .unwrap(), + ); + + Ok((aggregation_job, report_aggregation, batch_aggregations)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch_aggregations, got_batch_aggregations); + + assert_task_aggregation_counter(&ds, *task.id(), TaskAggregationCounter::new_with_values(0)) + .await; +} + +#[tokio::test] +async fn async_aggregation_job_init_poll_to_pending() { + // Setup: insert a client report and add it to a new aggregation job. + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let vdaf = Arc::new(dummy::Vdaf::new(1)); + + let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); + + let leader_task = task.leader_view().unwrap(); + + let time = clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(); + let batch_identifier = TimeInterval::to_batch_identifier(&leader_task, &(), &time).unwrap(); + let report_metadata = ReportMetadata::new(random(), time, Vec::new()); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + let aggregation_param = dummy::AggregationParam(0); + + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &0, + ); + + let agg_auth_token = task.aggregator_auth_token(); + let helper_hpke_keypair = HpkeKeypair::test(); + let report = LeaderStoredReport::generate( + *task.id(), + report_metadata, + helper_hpke_keypair.config(), + Vec::new(), + &transcript, + ); + let aggregation_job_id = random(); + + let lease = ds + .run_unnamed_tx(|tx| { + let task = leader_task.clone(); + let report = report.clone(); + let leader_state = transcript.leader_prepare_transitions[0].state.clone(); + + Box::pin(async move { + tx.put_aggregator_task(&task).await.unwrap(); + + tx.put_client_report(&report).await.unwrap(); + tx.scrub_client_report(report.task_id(), report.metadata().id()) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::LeaderPoll { leader_state }, + )) + .await + .unwrap(); + + tx.put_batch_aggregation(&BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + batch_identifier, + aggregation_param, + 0, + Interval::from_time(&time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )) + .await + .unwrap(); + + Ok(tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await + .unwrap() + .remove(0)) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Setup: prepare mocked HTTP response. + let helper_response = AggregationJobResp::Processing; + let (header, value) = agg_auth_token.request_authentication(); + + let mocked_aggregate_request = server + .mock( + "GET", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .match_query("step=0") + .match_header(header, value.as_str()) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded().unwrap()) + .create_async() + .await; + + // Run: create an aggregation job driver & try to step the aggregation we've created twice. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + LimitedRetryer::new(1), + &noop_meter(), + BATCH_AGGREGATION_SHARD_COUNT, + TASK_AGGREGATION_COUNTER_SHARD_COUNT, + ); + aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease)) + .await + .unwrap(); + + // Verify. + mocked_aggregate_request.assert_async().await; + + let want_aggregation_job = AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(0), + ); + + let want_report_aggregation = ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::LeaderPoll { + leader_state: transcript.leader_prepare_transitions[0].state.clone(), + }, + ); + + let want_batch_aggregations = + Vec::from([BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + batch_identifier, + aggregation_param, + 0, + Interval::from_time(&time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )]); + + let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds + .run_unnamed_tx(|tx| { + let vdaf = Arc::clone(&vdaf); + let task = task.clone(); + let report_id = *report.metadata().id(); + + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::<0, TimeInterval, dummy::Vdaf>( + task.id(), + &aggregation_job_id, + ) + .await + .unwrap() + .unwrap(); + let report_aggregation = tx + .get_report_aggregation_by_report_id( + vdaf.as_ref(), + &Role::Leader, + task.id(), + &aggregation_job_id, + &report_id, + &aggregation_param, + ) + .await + .unwrap() + .unwrap(); + let batch_aggregations = merge_batch_aggregations_by_batch( + tx.get_batch_aggregations_for_task::<0, TimeInterval, dummy::Vdaf>( + &vdaf, + task.id(), + ) + .await + .unwrap(), + ); + + Ok((aggregation_job, report_aggregation, batch_aggregations)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch_aggregations, got_batch_aggregations); + + assert_task_aggregation_counter(&ds, *task.id(), TaskAggregationCounter::new_with_values(0)) + .await; +} + +#[tokio::test] +async fn async_aggregation_job_init_poll_to_pending_two_step() { + // Setup: insert a client report and add it to a new aggregation job. + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let vdaf = Arc::new(dummy::Vdaf::new(2)); + + let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); + + let leader_task = task.leader_view().unwrap(); + + let time = clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(); + let batch_identifier = TimeInterval::to_batch_identifier(&leader_task, &(), &time).unwrap(); + let report_metadata = ReportMetadata::new(random(), time, Vec::new()); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + let aggregation_param = dummy::AggregationParam(0); + + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &0, + ); + + let agg_auth_token = task.aggregator_auth_token(); + let helper_hpke_keypair = HpkeKeypair::test(); + let report = LeaderStoredReport::generate( + *task.id(), + report_metadata, + helper_hpke_keypair.config(), + Vec::new(), + &transcript, + ); + let aggregation_job_id = random(); + + let lease = ds + .run_unnamed_tx(|tx| { + let task = leader_task.clone(); + let report = report.clone(); + let leader_state = transcript.leader_prepare_transitions[0].state.clone(); + + Box::pin(async move { + tx.put_aggregator_task(&task).await.unwrap(); + + tx.put_client_report(&report).await.unwrap(); + tx.scrub_client_report(report.task_id(), report.metadata().id()) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::LeaderPoll { leader_state }, + )) + .await + .unwrap(); + + tx.put_batch_aggregation(&BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + batch_identifier, + aggregation_param, + 0, + Interval::from_time(&time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )) + .await + .unwrap(); + + Ok(tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await + .unwrap() + .remove(0)) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Setup: prepare mocked HTTP response. + let helper_response = AggregationJobResp::Processing; + let (header, value) = agg_auth_token.request_authentication(); + + let mocked_aggregate_request = server + .mock( + "GET", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .match_query("step=0") + .match_header(header, value.as_str()) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded().unwrap()) + .create_async() + .await; + + // Run: create an aggregation job driver & try to step the aggregation we've created twice. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + LimitedRetryer::new(1), + &noop_meter(), + BATCH_AGGREGATION_SHARD_COUNT, + TASK_AGGREGATION_COUNTER_SHARD_COUNT, + ); + aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease)) + .await + .unwrap(); + + // Verify. + mocked_aggregate_request.assert_async().await; + + let want_aggregation_job = AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(0), + ); + + let want_report_aggregation = ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::LeaderPoll { + leader_state: transcript.leader_prepare_transitions[0].state.clone(), + }, + ); + + let want_batch_aggregations = + Vec::from([BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + batch_identifier, + aggregation_param, + 0, + Interval::from_time(&time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )]); + + let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds + .run_unnamed_tx(|tx| { + let vdaf = Arc::clone(&vdaf); + let task = task.clone(); + let report_id = *report.metadata().id(); + + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::<0, TimeInterval, dummy::Vdaf>( + task.id(), + &aggregation_job_id, + ) + .await + .unwrap() + .unwrap(); + let report_aggregation = tx + .get_report_aggregation_by_report_id( + vdaf.as_ref(), + &Role::Leader, + task.id(), + &aggregation_job_id, + &report_id, + &aggregation_param, + ) + .await + .unwrap() + .unwrap(); + let batch_aggregations = merge_batch_aggregations_by_batch( + tx.get_batch_aggregations_for_task::<0, TimeInterval, dummy::Vdaf>( + &vdaf, + task.id(), + ) + .await + .unwrap(), + ); + + Ok((aggregation_job, report_aggregation, batch_aggregations)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch_aggregations, got_batch_aggregations); + + assert_task_aggregation_counter(&ds, *task.id(), TaskAggregationCounter::new_with_values(0)) + .await; +} + +#[tokio::test] +async fn async_aggregation_job_init_poll_to_finished() { + // Setup: insert a client report and add it to a new aggregation job. + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let vdaf = Arc::new(dummy::Vdaf::new(1)); + + let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); + + let leader_task = task.leader_view().unwrap(); + + let time = clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(); + let batch_identifier = TimeInterval::to_batch_identifier(&leader_task, &(), &time).unwrap(); + let report_metadata = ReportMetadata::new(random(), time, Vec::new()); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + let aggregation_param = dummy::AggregationParam(0); + + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &0, + ); + + let agg_auth_token = task.aggregator_auth_token(); + let helper_hpke_keypair = HpkeKeypair::test(); + let report = LeaderStoredReport::generate( + *task.id(), + report_metadata, + helper_hpke_keypair.config(), + Vec::new(), + &transcript, + ); + let aggregation_job_id = random(); + + let lease = ds + .run_unnamed_tx(|tx| { + let task = leader_task.clone(); + let report = report.clone(); + let leader_state = transcript.leader_prepare_transitions[0].state.clone(); + + Box::pin(async move { + tx.put_aggregator_task(&task).await.unwrap(); + + tx.put_client_report(&report).await.unwrap(); + tx.scrub_client_report(report.task_id(), report.metadata().id()) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::LeaderPoll { leader_state }, + )) + .await + .unwrap(); + + tx.put_batch_aggregation(&BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + batch_identifier, + aggregation_param, + 0, + Interval::from_time(&time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )) + .await + .unwrap(); + + Ok(tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await + .unwrap() + .remove(0)) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Setup: prepare mocked HTTP response. + let helper_response = AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, + )]), + }; + let (header, value) = agg_auth_token.request_authentication(); + + let mocked_aggregate_request = server + .mock( + "GET", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .match_query("step=0") + .match_header(header, value.as_str()) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded().unwrap()) + .create_async() + .await; + + // Run: create an aggregation job driver & try to step the aggregation we've created twice. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + LimitedRetryer::new(1), + &noop_meter(), + BATCH_AGGREGATION_SHARD_COUNT, + TASK_AGGREGATION_COUNTER_SHARD_COUNT, + ); + aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease)) + .await + .unwrap(); + + // Verify. + mocked_aggregate_request.assert_async().await; + + let want_aggregation_job = AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobStep::from(1), + ); + + let want_report_aggregation = ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Finished, + ); + + let want_batch_aggregations = + Vec::from([BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + batch_identifier, + aggregation_param, + 0, + Interval::from_time(&time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: Some(transcript.leader_output_share.into()), + report_count: 1, + checksum: ReportIdChecksum::for_report_id(report.metadata().id()), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 1, + }, + )]); + + let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds + .run_unnamed_tx(|tx| { + let vdaf = Arc::clone(&vdaf); + let task = task.clone(); + let report_id = *report.metadata().id(); + + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::<0, TimeInterval, dummy::Vdaf>( + task.id(), + &aggregation_job_id, + ) + .await + .unwrap() + .unwrap(); + let report_aggregation = tx + .get_report_aggregation_by_report_id( + vdaf.as_ref(), + &Role::Leader, + task.id(), + &aggregation_job_id, + &report_id, + &aggregation_param, + ) + .await + .unwrap() + .unwrap(); + let batch_aggregations = merge_batch_aggregations_by_batch( + tx.get_batch_aggregations_for_task::<0, TimeInterval, dummy::Vdaf>( + &vdaf, + task.id(), + ) + .await + .unwrap(), + ); + + Ok((aggregation_job, report_aggregation, batch_aggregations)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch_aggregations, got_batch_aggregations); + + assert_task_aggregation_counter(&ds, *task.id(), TaskAggregationCounter::new_with_values(0)) + .await; +} + +#[tokio::test] +async fn async_aggregation_job_init_poll_to_continue() { + // Setup: insert a client report and add it to a new aggregation job. + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let vdaf = Arc::new(dummy::Vdaf::new(2)); + + let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); + + let leader_task = task.leader_view().unwrap(); + + let time = clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(); + let batch_identifier = TimeInterval::to_batch_identifier(&leader_task, &(), &time).unwrap(); + let report_metadata = ReportMetadata::new(random(), time, Vec::new()); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + let aggregation_param = dummy::AggregationParam(0); + + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &0, + ); + + let agg_auth_token = task.aggregator_auth_token(); + let helper_hpke_keypair = HpkeKeypair::test(); + let report = LeaderStoredReport::generate( + *task.id(), + report_metadata, + helper_hpke_keypair.config(), + Vec::new(), + &transcript, + ); + let aggregation_job_id = random(); + + let lease = ds + .run_unnamed_tx(|tx| { + let task = leader_task.clone(); + let report = report.clone(); + let leader_state = transcript.leader_prepare_transitions[0].state.clone(); + + Box::pin(async move { + tx.put_aggregator_task(&task).await.unwrap(); + + tx.put_client_report(&report).await.unwrap(); + tx.scrub_client_report(report.task_id(), report.metadata().id()) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::LeaderPoll { leader_state }, + )) + .await + .unwrap(); + + tx.put_batch_aggregation(&BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + batch_identifier, + aggregation_param, + 0, + Interval::from_time(&time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )) + .await + .unwrap(); + + Ok(tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await + .unwrap() + .remove(0)) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Setup: prepare mocked HTTP response. + let helper_response = AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, + )]), + }; + let (header, value) = agg_auth_token.request_authentication(); + + let mocked_aggregate_request = server + .mock( + "GET", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .match_query("step=0") + .match_header(header, value.as_str()) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded().unwrap()) + .create_async() + .await; + + // Run: create an aggregation job driver & try to step the aggregation we've created twice. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + LimitedRetryer::new(1), + &noop_meter(), + BATCH_AGGREGATION_SHARD_COUNT, + TASK_AGGREGATION_COUNTER_SHARD_COUNT, + ); + aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease)) + .await + .unwrap(); + + // Verify. + mocked_aggregate_request.assert_async().await; + + let want_aggregation_job = AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(1), + ); + + let want_report_aggregation = ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::LeaderContinue { + transition: transcript.leader_prepare_transitions[1] + .transition + .clone() + .unwrap(), + }, + ); + + let want_batch_aggregations = + Vec::from([BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + batch_identifier, + aggregation_param, + 0, + Interval::from_time(&time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )]); + + let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds + .run_unnamed_tx(|tx| { + let vdaf = Arc::clone(&vdaf); + let task = task.clone(); + let report_id = *report.metadata().id(); + + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::<0, TimeInterval, dummy::Vdaf>( + task.id(), + &aggregation_job_id, + ) + .await + .unwrap() + .unwrap(); + let report_aggregation = tx + .get_report_aggregation_by_report_id( + vdaf.as_ref(), + &Role::Leader, + task.id(), + &aggregation_job_id, + &report_id, + &aggregation_param, + ) + .await + .unwrap() + .unwrap(); + let batch_aggregations = merge_batch_aggregations_by_batch( + tx.get_batch_aggregations_for_task::<0, TimeInterval, dummy::Vdaf>( + &vdaf, + task.id(), + ) + .await + .unwrap(), + ); + + Ok((aggregation_job, report_aggregation, batch_aggregations)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch_aggregations, got_batch_aggregations); + + assert_task_aggregation_counter(&ds, *task.id(), TaskAggregationCounter::new_with_values(0)) + .await; +} + +#[tokio::test] +async fn async_aggregation_job_continue_poll_to_pending() { + // Setup: insert a client report and add it to an aggregation job whose state has already + // been stepped once. + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let vdaf = Arc::new(dummy::Vdaf::new(2)); + + let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); + let leader_task = task.leader_view().unwrap(); + let time = clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(); + let active_batch_identifier = + TimeInterval::to_batch_identifier(&leader_task, &(), &time).unwrap(); + let report_metadata = ReportMetadata::new(random(), time, Vec::new()); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + + let aggregation_param = dummy::AggregationParam(7); + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &13, + ); + + let agg_auth_token = task.aggregator_auth_token(); + let helper_hpke_keypair = HpkeKeypair::test(); + let report = LeaderStoredReport::generate( + *task.id(), + report_metadata, + helper_hpke_keypair.config(), + Vec::new(), + &transcript, + ); + let aggregation_job_id = random(); + + let lease = ds + .run_unnamed_tx(|tx| { + let task = leader_task.clone(); + let report = report.clone(); + let leader_state = transcript.leader_prepare_transitions[1].state.clone(); + + Box::pin(async move { + tx.put_aggregator_task(&task).await.unwrap(); + tx.put_client_report(&report).await.unwrap(); + tx.scrub_client_report(report.task_id(), report.metadata().id()) + .await + .unwrap(); + tx.mark_report_aggregated(task.id(), report.metadata().id()) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(1), + )) + .await + .unwrap(); + + tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::LeaderPoll { leader_state }, + )) + .await + .unwrap(); + + tx.put_batch_aggregation(&BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + active_batch_identifier, + aggregation_param, + 0, + Interval::from_time(report.metadata().time()).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )) + .await + .unwrap(); + + let lease = tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await + .unwrap() + .remove(0); + + Ok(lease) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Setup: prepare mocked HTTP responses. + let helper_response = AggregationJobResp::Processing; + let (header, value) = agg_auth_token.request_authentication(); + let mocked_aggregate_success = server + .mock( + "GET", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .match_query("step=1") + .match_header(header, value.as_str()) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded().unwrap()) + .create_async() + .await; + + // Run. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + LimitedRetryer::new(0), + &noop_meter(), + BATCH_AGGREGATION_SHARD_COUNT, + TASK_AGGREGATION_COUNTER_SHARD_COUNT, + ); + aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease)) + .await + .unwrap(); + + // Verify. + mocked_aggregate_success.assert_async().await; + + let want_aggregation_job = AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(1), + ); + let want_report_aggregation = ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::LeaderPoll { + leader_state: transcript.leader_prepare_transitions[1].state.clone(), + }, + ); + + let want_batch_aggregations = + Vec::from([BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + active_batch_identifier, + aggregation_param, + 0, + Interval::from_time(report.metadata().time()).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )]); + + let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds + .run_unnamed_tx(|tx| { + let vdaf = Arc::clone(&vdaf); + let task = leader_task.clone(); + let report_metadata = report.metadata().clone(); + + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::<0, TimeInterval, dummy::Vdaf>( + task.id(), + &aggregation_job_id, + ) + .await + .unwrap() + .unwrap(); + let report_aggregation = tx + .get_report_aggregation_by_report_id( + vdaf.as_ref(), + &Role::Leader, + task.id(), + &aggregation_job_id, + report_metadata.id(), + &aggregation_param, + ) + .await + .unwrap() + .unwrap(); + let batch_aggregations = merge_batch_aggregations_by_batch( + tx.get_batch_aggregations_for_task::<0, TimeInterval, dummy::Vdaf>( + &vdaf, + task.id(), + ) + .await + .unwrap(), + ); + + Ok((aggregation_job, report_aggregation, batch_aggregations)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch_aggregations, got_batch_aggregations); + + assert_task_aggregation_counter(&ds, *task.id(), TaskAggregationCounter::new_with_values(0)) + .await; +} + +#[tokio::test] +async fn async_aggregation_job_continue_poll_to_finished() { + // Setup: insert a client report and add it to an aggregation job whose state has already + // been stepped once. + install_test_trace_subscriber(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let vdaf = Arc::new(dummy::Vdaf::new(2)); + + let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); + let leader_task = task.leader_view().unwrap(); + let time = clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(); + let active_batch_identifier = + TimeInterval::to_batch_identifier(&leader_task, &(), &time).unwrap(); + let report_metadata = ReportMetadata::new(random(), time, Vec::new()); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + + let aggregation_param = dummy::AggregationParam(7); + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &13, + ); + + let agg_auth_token = task.aggregator_auth_token(); + let helper_hpke_keypair = HpkeKeypair::test(); + let report = LeaderStoredReport::generate( + *task.id(), + report_metadata, + helper_hpke_keypair.config(), + Vec::new(), + &transcript, + ); + let aggregation_job_id = random(); + + let lease = ds + .run_unnamed_tx(|tx| { + let task = leader_task.clone(); + let report = report.clone(); + let leader_state = transcript.leader_prepare_transitions[1].state.clone(); + + Box::pin(async move { + tx.put_aggregator_task(&task).await.unwrap(); + tx.put_client_report(&report).await.unwrap(); + tx.scrub_client_report(report.task_id(), report.metadata().id()) + .await + .unwrap(); + tx.mark_report_aggregated(task.id(), report.metadata().id()) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::InProgress, + AggregationJobStep::from(1), + )) + .await + .unwrap(); + + tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::LeaderPoll { leader_state }, + )) + .await + .unwrap(); + + tx.put_batch_aggregation(&BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + active_batch_identifier, + aggregation_param, + 0, + Interval::from_time(report.metadata().time()).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )) + .await + .unwrap(); + + let lease = tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await + .unwrap() + .remove(0); + + Ok(lease) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Setup: prepare mocked HTTP responses. + let helper_response = AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Finished, + )]), + }; + let (header, value) = agg_auth_token.request_authentication(); + let mocked_aggregate_success = server + .mock( + "GET", + task.aggregation_job_uri(&aggregation_job_id) + .unwrap() + .path(), + ) + .match_query("step=1") + .match_header(header, value.as_str()) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded().unwrap()) + .create_async() + .await; + + // Run. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + LimitedRetryer::new(0), + &noop_meter(), + BATCH_AGGREGATION_SHARD_COUNT, + TASK_AGGREGATION_COUNTER_SHARD_COUNT, + ); + aggregation_job_driver + .step_aggregation_job(ds.clone(), Arc::new(lease)) + .await + .unwrap(); + + // Verify. + mocked_aggregate_success.assert_async().await; + + let want_aggregation_job = AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobStep::from(2), + ); + let want_report_aggregation = ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report.metadata().id(), + *report.metadata().time(), + 0, + None, + ReportAggregationState::Finished, + ); + + let want_batch_aggregations = + Vec::from([BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + active_batch_identifier, + aggregation_param, + 0, + Interval::from_time(report.metadata().time()).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: Some(transcript.leader_aggregate_share), + report_count: 1, + checksum: ReportIdChecksum::for_report_id(report.metadata().id()), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 1, + }, + )]); + + let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds + .run_unnamed_tx(|tx| { + let vdaf = Arc::clone(&vdaf); + let task = leader_task.clone(); + let report_metadata = report.metadata().clone(); + + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::<0, TimeInterval, dummy::Vdaf>( + task.id(), + &aggregation_job_id, + ) + .await + .unwrap() + .unwrap(); + let report_aggregation = tx + .get_report_aggregation_by_report_id( + vdaf.as_ref(), + &Role::Leader, + task.id(), + &aggregation_job_id, + report_metadata.id(), + &aggregation_param, + ) + .await + .unwrap() + .unwrap(); + let batch_aggregations = merge_batch_aggregations_by_batch( + tx.get_batch_aggregations_for_task::<0, TimeInterval, dummy::Vdaf>( + &vdaf, + task.id(), + ) + .await + .unwrap(), + ); + + Ok((aggregation_job, report_aggregation, batch_aggregations)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch_aggregations, got_batch_aggregations); + + assert_task_aggregation_counter(&ds, *task.id(), TaskAggregationCounter::new_with_values(1)) + .await; +} + struct CancelAggregationJobTestCase { task: AggregatorTask, vdaf: Arc, @@ -2608,7 +4791,7 @@ async fn setup_cancel_aggregation_job_test() -> CancelAggregationJobTestCase { AggregationJobState::InProgress, AggregationJobStep::from(0), ); - let report_aggregation = report.as_start_leader_report_aggregation(aggregation_job_id, 0); + let report_aggregation = report.as_leader_init_report_aggregation(aggregation_job_id, 0); let lease = datastore .run_unnamed_tx(|tx| { @@ -2688,7 +4871,7 @@ async fn cancel_aggregation_job() { "DELETE", test_case .task - .aggregation_job_uri(test_case.aggregation_job.id()) + .aggregation_job_uri(test_case.aggregation_job.id(), None) .unwrap() .unwrap() .path(), @@ -2758,6 +4941,7 @@ async fn cancel_aggregation_job() { task.id(), aggregation_job.id(), &report_id, + &(), ) .await .unwrap() @@ -2797,7 +4981,7 @@ async fn cancel_aggregation_job_helper_aggregation_job_deletion_fails() { "DELETE", test_case .task - .aggregation_job_uri(test_case.aggregation_job.id()) + .aggregation_job_uri(test_case.aggregation_job.id(), None) .unwrap() .unwrap() .path(), @@ -2893,7 +5077,7 @@ async fn abandon_failing_aggregation_job_with_retryable_error() { .unwrap(); tx.put_report_aggregation( - &report.as_start_leader_report_aggregation(aggregation_job_id, 0), + &report.as_leader_init_report_aggregation(aggregation_job_id, 0), ) .await .unwrap(); @@ -3136,7 +5320,7 @@ async fn abandon_failing_aggregation_job_with_fatal_error() { .unwrap(); tx.put_report_aggregation( - &report.as_start_leader_report_aggregation(aggregation_job_id, 0), + &report.as_leader_init_report_aggregation(aggregation_job_id, 0), ) .await .unwrap(); diff --git a/aggregator/src/aggregator/batch_creator.rs b/aggregator/src/aggregator/batch_creator.rs index 2cce192f8..3fda9eb26 100644 --- a/aggregator/src/aggregator/batch_creator.rs +++ b/aggregator/src/aggregator/batch_creator.rs @@ -349,7 +349,7 @@ where *report.report_id(), client_timestamp, ord, - ReportAggregationMetadataState::Start, + ReportAggregationMetadataState::Init, ) }) .collect(); diff --git a/aggregator/src/aggregator/collection_job_driver.rs b/aggregator/src/aggregator/collection_job_driver.rs index 8cb77f8dd..03baa5498 100644 --- a/aggregator/src/aggregator/collection_job_driver.rs +++ b/aggregator/src/aggregator/collection_job_driver.rs @@ -347,7 +347,7 @@ where .map_err(Error::DifferentialPrivacy)?; // Send an aggregate share request to the helper. - let resp_bytes = send_request_to_helper( + let http_response = send_request_to_helper( &self.http_client, self.backoff.clone(), Method::POST, @@ -385,7 +385,7 @@ where collection_job.with_state(CollectionJobState::Finished { report_count, client_timestamp_interval, - encrypted_helper_aggregate_share: AggregateShare::get_decoded(&resp_bytes) + encrypted_helper_aggregate_share: AggregateShare::get_decoded(http_response.body()) .map_err(Error::MessageDecode)? .encrypted_aggregate_share() .clone(), diff --git a/aggregator/src/aggregator/garbage_collector.rs b/aggregator/src/aggregator/garbage_collector.rs index 1fbe1b073..6f87a05ba 100644 --- a/aggregator/src/aggregator/garbage_collector.rs +++ b/aggregator/src/aggregator/garbage_collector.rs @@ -210,6 +210,7 @@ mod tests { let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); let vdaf = dummy::Vdaf::new(1); + let aggregation_param = dummy::AggregationParam(0); // Setup. let task = ds @@ -236,7 +237,7 @@ mod tests { tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( *task.id(), aggregation_job_id, - dummy::AggregationParam(0), + aggregation_param, (), Interval::from_time(&client_timestamp).unwrap(), AggregationJobState::InProgress, @@ -246,7 +247,7 @@ mod tests { .unwrap(); tx.put_report_aggregation( - &report.as_start_leader_report_aggregation(aggregation_job_id, 0), + &report.as_leader_init_report_aggregation(aggregation_job_id, 0), ) .await .unwrap(); @@ -329,6 +330,7 @@ mod tests { &vdaf, &Role::Leader, task.id(), + &aggregation_param, ) .await .unwrap() @@ -361,6 +363,7 @@ mod tests { let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); let vdaf = dummy::Vdaf::new(1); + let aggregation_param = dummy::AggregationParam(0); // Setup. let task = ds @@ -397,7 +400,7 @@ mod tests { tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( *task.id(), aggregation_job_id, - dummy::AggregationParam(0), + aggregation_param, (), Interval::from_time(&client_timestamp).unwrap(), AggregationJobState::InProgress, @@ -498,6 +501,7 @@ mod tests { &vdaf, &Role::Leader, task.id(), + &aggregation_param, ) .await .unwrap() @@ -533,6 +537,7 @@ mod tests { let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); let vdaf = dummy::Vdaf::new(1); + let aggregation_param = dummy::AggregationParam(0); // Setup. let task = ds @@ -566,7 +571,7 @@ mod tests { let aggregation_job = AggregationJob::<0, LeaderSelected, dummy::Vdaf>::new( *task.id(), random(), - dummy::AggregationParam(0), + aggregation_param, batch_id, Interval::from_time(&client_timestamp).unwrap(), AggregationJobState::InProgress, @@ -575,7 +580,7 @@ mod tests { tx.put_aggregation_job(&aggregation_job).await.unwrap(); let report_aggregation = - report.as_start_leader_report_aggregation(*aggregation_job.id(), 0); + report.as_leader_init_report_aggregation(*aggregation_job.id(), 0); tx.put_report_aggregation(&report_aggregation) .await .unwrap(); @@ -661,6 +666,7 @@ mod tests { &vdaf, &Role::Leader, task.id(), + &aggregation_param, ) .await .unwrap() @@ -701,6 +707,7 @@ mod tests { let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); let vdaf = dummy::Vdaf::new(1); + let aggregation_param = dummy::AggregationParam(0); // Setup. let task = ds @@ -744,7 +751,7 @@ mod tests { let aggregation_job = AggregationJob::<0, LeaderSelected, dummy::Vdaf>::new( *task.id(), random(), - dummy::AggregationParam(0), + aggregation_param, batch_id, Interval::from_time(&client_timestamp).unwrap(), AggregationJobState::InProgress, @@ -851,6 +858,7 @@ mod tests { &vdaf, &Role::Leader, task.id(), + &aggregation_param, ) .await .unwrap() diff --git a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs index 9d9fcc328..c26b18e7e 100644 --- a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs @@ -198,7 +198,7 @@ async fn aggregate_continue() { *report_metadata_0.time(), 0, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: helper_prep_state_0, }, )) @@ -211,7 +211,7 @@ async fn aggregate_continue() { *report_metadata_1.time(), 1, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: helper_prep_state_1, }, )) @@ -224,7 +224,7 @@ async fn aggregate_continue() { *report_metadata_2.time(), 2, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: helper_prep_state_2, }, )) @@ -299,6 +299,7 @@ async fn aggregate_continue() { &Role::Helper, task.id(), &aggregation_job_id, + &aggregation_param, ) .await .unwrap(); @@ -575,7 +576,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { *report_metadata_0.time(), 0, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: helper_prep_state_0, }, )) @@ -588,7 +589,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { *report_metadata_1.time(), 1, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: helper_prep_state_1, }, )) @@ -601,7 +602,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { *report_metadata_2.time(), 2, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: helper_prep_state_2, }, )) @@ -883,7 +884,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { *report_metadata_3.time(), 3, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: helper_prep_state_3, }, )) @@ -896,7 +897,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { *report_metadata_4.time(), 4, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: helper_prep_state_4, }, )) @@ -909,7 +910,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { *report_metadata_5.time(), 5, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: helper_prep_state_5, }, )) @@ -1112,7 +1113,7 @@ async fn aggregate_continue_leader_sends_non_continue_or_finish_transition() { *report_metadata.time(), 0, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: *transcript.helper_prepare_transitions[0].prepare_state(), }, )) @@ -1226,7 +1227,7 @@ async fn aggregate_continue_prep_step_fails() { *report_metadata.time(), 0, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: *transcript.helper_prepare_transitions[0].prepare_state(), }, )) @@ -1281,6 +1282,7 @@ async fn aggregate_continue_prep_step_fails() { task.id(), &aggregation_job_id, report_metadata.id(), + &aggregation_param, ) .await .unwrap() @@ -1401,7 +1403,7 @@ async fn aggregate_continue_unexpected_transition() { *report_metadata.time(), 0, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: *transcript.helper_prepare_transitions[0].prepare_state(), }, )) @@ -1559,7 +1561,7 @@ async fn aggregate_continue_out_of_order_transition() { *report_metadata_0.time(), 0, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: *transcript_0.helper_prepare_transitions[0].prepare_state(), }, )) @@ -1572,7 +1574,7 @@ async fn aggregate_continue_out_of_order_transition() { *report_metadata_1.time(), 1, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: *transcript_1.helper_prepare_transitions[0].prepare_state(), }, )) diff --git a/aggregator/src/aggregator/taskprov_tests.rs b/aggregator/src/aggregator/taskprov_tests.rs index 104a59772..1c2aa2aee 100644 --- a/aggregator/src/aggregator/taskprov_tests.rs +++ b/aggregator/src/aggregator/taskprov_tests.rs @@ -856,7 +856,7 @@ async fn taskprov_aggregate_continue() { *report_share.metadata().time(), 0, None, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: *transcript.helper_prepare_transitions[0].prepare_state(), }, )) diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index 8f56ab7c8..267dbfffb 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -40,7 +40,7 @@ use opentelemetry::{ use postgres_types::{FromSql, Json, Timestamp, ToSql}; use prio::{ codec::{decode_u16_items, encode_u16_items, CodecError, Decode, Encode, ParameterizedDecode}, - topology::ping_pong::PingPongTransition, + topology::ping_pong::{PingPongState, PingPongTransition}, vdaf, }; use rand::random; @@ -1573,7 +1573,7 @@ ON CONFLICT(task_id, report_id) DO UPDATE /// /// This method is intended for use by aggregators acting in the Leader role. Scrubbed reports /// can no longer be read, so this method should only be called once all aggregations over the - /// report have stepped past their START state. + /// report have stepped past their INIT state. #[tracing::instrument(skip(self), err(level = Level::DEBUG))] pub async fn scrub_client_report( &self, @@ -1915,11 +1915,15 @@ RETURNING tasks.task_id, tasks.batch_mode, tasks.vdaf, } /// release_aggregation_job releases an acquired (via e.g. acquire_incomplete_aggregation_jobs) - /// aggregation job. It returns an error if the aggregation job has no current lease. + /// aggregation job. If given, `reacquire_delay` determines the duration of time that must pass + /// before the aggregation job can be reacquired; this method assumes a reacquire delay + /// indicates that no progress was made, and will increment `step_attempts` accordingly. It + /// returns an error if the aggregation job has no current lease. #[tracing::instrument(skip(self), err(level = Level::DEBUG))] pub async fn release_aggregation_job( &self, lease: &Lease, + reacquire_delay: Option<&StdDuration>, ) -> Result<(), Error> { let task_info = match self.task_info_for(lease.leased().task_id()).await? { Some(task_info) => task_info, @@ -1927,26 +1931,33 @@ RETURNING tasks.task_id, tasks.batch_mode, tasks.vdaf, }; let now = self.clock.now().as_naive_date_time()?; + let lease_expiration = reacquire_delay + .map(|rd| add_naive_date_time_duration(&now, rd)) + .transpose()? + .map(Timestamp::Value) + .unwrap_or_else(|| Timestamp::NegInfinity); + let stmt = self .prepare_cached( "-- release_aggregation_job() UPDATE aggregation_jobs -SET lease_expiry = '-infinity'::TIMESTAMP, +SET lease_expiry = $1, lease_token = NULL, lease_attempts = 0, - updated_at = $1, - updated_by = $2 -WHERE aggregation_jobs.task_id = $3 - AND aggregation_jobs.aggregation_job_id = $4 - AND aggregation_jobs.lease_expiry = $5 - AND aggregation_jobs.lease_token = $6 - AND UPPER(aggregation_jobs.client_timestamp_interval) >= $7", + updated_at = $2, + updated_by = $3 +WHERE aggregation_jobs.task_id = $4 + AND aggregation_jobs.aggregation_job_id = $5 + AND aggregation_jobs.lease_expiry = $6 + AND aggregation_jobs.lease_token = $7 + AND UPPER(aggregation_jobs.client_timestamp_interval) >= $8", ) .await?; check_single_row_mutation( self.execute( &stmt, &[ + /* lease_expiry */ &lease_expiration, /* updated_at */ &now, /* updated_by */ &self.name, /* task_id */ &task_info.pkey, @@ -2082,6 +2093,7 @@ WHERE aggregation_jobs.task_id = $6 role: &Role, task_id: &TaskId, aggregation_job_id: &AggregationJobId, + aggregation_param: &A::AggregationParam, ) -> Result>, Error> where for<'a> A::PrepareState: ParameterizedDecode<(&'a A, usize)>, @@ -2098,8 +2110,8 @@ SELECT ord, client_report_id, client_timestamp, last_prep_resp, report_aggregations.state, public_extensions, public_share, leader_private_extensions, leader_input_share, - helper_encrypted_input_share, leader_prep_transition, helper_prep_state, - error_code + helper_encrypted_input_share, leader_prep_transition, leader_prep_state, + leader_output_share, helper_prep_state, error_code FROM report_aggregations JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id WHERE report_aggregations.task_id = $1 @@ -2127,6 +2139,7 @@ ORDER BY ord ASC", task_id, aggregation_job_id, &row.get_bytea_and_convert::("client_report_id")?, + aggregation_param, &row, ) }) @@ -2145,6 +2158,7 @@ ORDER BY ord ASC", task_id: &TaskId, aggregation_job_id: &AggregationJobId, report_id: &ReportId, + aggregation_param: &A::AggregationParam, ) -> Result>, Error> where for<'a> A::PrepareState: ParameterizedDecode<(&'a A, usize)>, @@ -2161,7 +2175,7 @@ SELECT ord, client_timestamp, last_prep_resp, report_aggregations.state, public_extensions, public_share, leader_private_extensions, leader_input_share, helper_encrypted_input_share, leader_prep_transition, - helper_prep_state, error_code + leader_prep_state, leader_output_share, helper_prep_state, error_code FROM report_aggregations JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id @@ -2190,6 +2204,7 @@ WHERE report_aggregations.task_id = $1 task_id, aggregation_job_id, report_id, + aggregation_param, &row, ) }) @@ -2207,6 +2222,7 @@ WHERE report_aggregations.task_id = $1 vdaf: &A, role: &Role, task_id: &TaskId, + aggregation_param: &A::AggregationParam, ) -> Result>, Error> where for<'a> A::PrepareState: ParameterizedDecode<(&'a A, usize)>, @@ -2224,7 +2240,7 @@ SELECT client_timestamp, last_prep_resp, report_aggregations.state, public_extensions, public_share, leader_private_extensions, leader_input_share, helper_encrypted_input_share, leader_prep_transition, - helper_prep_state, error_code + leader_prep_state, leader_output_share, helper_prep_state, error_code FROM report_aggregations JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id WHERE report_aggregations.task_id = $1 @@ -2249,6 +2265,7 @@ WHERE report_aggregations.task_id = $1 task_id, &row.get_bytea_and_convert::("aggregation_job_id")?, &row.get_bytea_and_convert::("client_report_id")?, + aggregation_param, &row, ) }) @@ -2261,6 +2278,7 @@ WHERE report_aggregations.task_id = $1 task_id: &TaskId, aggregation_job_id: &AggregationJobId, report_id: &ReportId, + aggregation_param: &A::AggregationParam, row: &Row, ) -> Result, Error> where @@ -2277,12 +2295,12 @@ WHERE report_aggregations.task_id = $1 .transpose()?; let agg_state = match state { - ReportAggregationStateCode::Start => { + ReportAggregationStateCode::Init => { let public_extensions_bytes = row .get::<_, Option>>("public_extensions") .ok_or_else(|| { Error::DbState( - "report aggregation in state START but public_extensions is NULL" + "report aggregation in state INIT but public_extensions is NULL" .to_string(), ) })?; @@ -2290,7 +2308,7 @@ WHERE report_aggregations.task_id = $1 row.get::<_, Option>>("public_share") .ok_or_else(|| { Error::DbState( - "report aggregation in state START but public_share is NULL" + "report aggregation in state INIT but public_share is NULL" .to_string(), ) })?; @@ -2298,7 +2316,7 @@ WHERE report_aggregations.task_id = $1 .get::<_, Option>>("leader_private_extensions") .ok_or_else(|| { Error::DbState( - "report aggregation in state START but leader_private_extensions is NULL" + "report aggregation in state INIT but leader_private_extensions is NULL" .to_string(), ) })?; @@ -2306,7 +2324,7 @@ WHERE report_aggregations.task_id = $1 .get::<_, Option>>("leader_input_share") .ok_or_else(|| { Error::DbState( - "report aggregation in state START but leader_input_share is NULL" + "report aggregation in state INIT but leader_input_share is NULL" .to_string(), ) })?; @@ -2314,7 +2332,7 @@ WHERE report_aggregations.task_id = $1 row.get::<_, Option>>("helper_encrypted_input_share") .ok_or_else(|| { Error::DbState( - "report aggregation in state START but helper_encrypted_input_share is NULL" + "report aggregation in state INIT but helper_encrypted_input_share is NULL" .to_string(), ) })?; @@ -2332,7 +2350,7 @@ WHERE report_aggregations.task_id = $1 let helper_encrypted_input_share = HpkeCiphertext::get_decoded(&helper_encrypted_input_share_bytes)?; - ReportAggregationState::StartLeader { + ReportAggregationState::LeaderInit { public_extensions, public_share, leader_private_extensions, @@ -2341,7 +2359,7 @@ WHERE report_aggregations.task_id = $1 } } - ReportAggregationStateCode::Waiting => { + ReportAggregationStateCode::Continue => { match role { Role::Leader => { let leader_prep_transition_bytes = row @@ -2357,7 +2375,7 @@ WHERE report_aggregations.task_id = $1 &leader_prep_transition_bytes, )?; - ReportAggregationState::WaitingLeader { + ReportAggregationState::LeaderContinue { transition: ping_pong_transition, } } @@ -2375,12 +2393,42 @@ WHERE report_aggregations.task_id = $1 &helper_prep_state_bytes, )?; - ReportAggregationState::WaitingHelper { prepare_state } + ReportAggregationState::HelperContinue { prepare_state } } _ => panic!("unexpected role"), } } + ReportAggregationStateCode::Poll => { + let leader_prep_state_bytes = row.get::<_, Option>>("leader_prep_state"); + let leader_output_share_bytes = + row.get::<_, Option>>("leader_output_share"); + + let leader_state = match (leader_prep_state_bytes, leader_output_share_bytes) { + (Some(leader_prep_state_bytes), None) => { + PingPongState::Continued(A::PrepareState::get_decoded_with_param( + &(vdaf, 0 /* leader */), + &leader_prep_state_bytes, + )?) + } + + (None, Some(leader_output_share_bytes)) => { + PingPongState::Finished(A::OutputShare::get_decoded_with_param( + &(vdaf, aggregation_param), + &leader_output_share_bytes, + )?) + } + + _ => return Err(Error::DbState( + "report aggregation in state POLL but both/neither of leader_prep_state \ + and leader_output_share are NULL" + .to_string(), + )), + }; + + ReportAggregationState::LeaderPoll { leader_state } + } + ReportAggregationStateCode::Finished => ReportAggregationState::Finished, ReportAggregationStateCode::Failed => { @@ -2447,11 +2495,11 @@ INSERT INTO report_aggregations (task_id, aggregation_job_id, ord, client_report_id, client_timestamp, last_prep_resp, state, public_extensions, public_share, leader_private_extensions, leader_input_share, - helper_encrypted_input_share, leader_prep_transition, helper_prep_state, - error_code, created_at, updated_at, updated_by) + helper_encrypted_input_share, leader_prep_transition, leader_prep_state, + leader_output_share, helper_prep_state, error_code, created_at, updated_at, updated_by) SELECT $1, aggregation_jobs.id, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, - $15, $16, $17, $18 + $15, $16, $17, $18, $19, $20 FROM aggregation_jobs WHERE task_id = $1 AND aggregation_job_id = $2 @@ -2460,20 +2508,21 @@ ON CONFLICT(task_id, aggregation_job_id, ord) DO UPDATE client_report_id, client_timestamp, last_prep_resp, state, public_extensions, public_share, leader_private_extensions, leader_input_share, helper_encrypted_input_share, - leader_prep_transition, helper_prep_state, error_code, created_at, - updated_at, updated_by + leader_prep_transition, leader_prep_state, leader_output_share, + helper_prep_state, error_code, created_at, updated_at, updated_by ) = ( excluded.client_report_id, excluded.client_timestamp, excluded.last_prep_resp, excluded.state, excluded.public_extensions, excluded.public_share, excluded.leader_private_extensions, excluded.leader_input_share, excluded.helper_encrypted_input_share, - excluded.leader_prep_transition, excluded.helper_prep_state, + excluded.leader_prep_transition, excluded.leader_prep_state, + excluded.leader_output_share, excluded.helper_prep_state, excluded.error_code, excluded.created_at, excluded.updated_at, excluded.updated_by ) WHERE (SELECT UPPER(client_timestamp_interval) FROM aggregation_jobs - WHERE id = report_aggregations.aggregation_job_id) >= $19", + WHERE id = report_aggregations.aggregation_job_id) >= $21", ) .await?; check_insert( @@ -2497,6 +2546,8 @@ ON CONFLICT(task_id, aggregation_job_id, ord) DO UPDATE &encoded_state_values.helper_encrypted_input_share, /* leader_prep_transition */ &encoded_state_values.leader_prep_transition, + /* leader_prep_state */ &encoded_state_values.leader_prep_state, + /* leader_output_share */ &encoded_state_values.leader_output_share, /* helper_prep_state */ &encoded_state_values.helper_prep_state, /* error_code */ &encoded_state_values.report_error, /* created_at */ &now, @@ -2509,7 +2560,7 @@ ON CONFLICT(task_id, aggregation_job_id, ord) DO UPDATE ) } - /// Creates a report aggregation in the `StartLeader` state from its metadata. + /// Creates a report aggregation in the `LeaderInit` state from its metadata. /// /// Report shares are copied directly from the `client_reports` table. #[tracing::instrument(skip(self), err(level = Level::DEBUG))] @@ -2527,7 +2578,7 @@ ON CONFLICT(task_id, aggregation_job_id, ord) DO UPDATE let now = self.clock.now().as_naive_date_time()?; match report_aggregation_metadata.state() { - ReportAggregationMetadataState::Start => { + ReportAggregationMetadataState::Init => { let stmt = self .prepare_cached( "-- put_leader_report_aggregation() @@ -2537,7 +2588,7 @@ INSERT INTO report_aggregations leader_input_share, helper_encrypted_input_share, created_at, updated_at, updated_by) SELECT - $1, aggregation_jobs.id, $3, $4, $5, 'START'::REPORT_AGGREGATION_STATE, + $1, aggregation_jobs.id, $3, $4, $5, 'INIT'::REPORT_AGGREGATION_STATE, client_reports.public_extensions, client_reports.public_share, client_reports.leader_private_extensions, client_reports.leader_input_share, @@ -2685,16 +2736,18 @@ SET last_prep_resp = $1, state = $2, public_extensions = $3, public_share = $4, leader_private_extensions = $5, leader_input_share = $6, helper_encrypted_input_share = $7, leader_prep_transition = $8, - helper_prep_state = $9, error_code = $10, updated_at = $11, updated_by = $12 + leader_prep_state = $9, leader_output_share = $10, + helper_prep_state = $11, error_code = $12, updated_at = $13, + updated_by = $14 FROM aggregation_jobs WHERE report_aggregations.aggregation_job_id = aggregation_jobs.id - AND aggregation_jobs.aggregation_job_id = $13 - AND aggregation_jobs.task_id = $14 - AND report_aggregations.task_id = $14 - AND report_aggregations.client_report_id = $15 - AND report_aggregations.client_timestamp = $16 - AND report_aggregations.ord = $17 - AND UPPER(aggregation_jobs.client_timestamp_interval) >= $18", + AND aggregation_jobs.aggregation_job_id = $15 + AND aggregation_jobs.task_id = $16 + AND report_aggregations.task_id = $16 + AND report_aggregations.client_report_id = $17 + AND report_aggregations.client_timestamp = $18 + AND report_aggregations.ord = $19 + AND UPPER(aggregation_jobs.client_timestamp_interval) >= $20", ) .await?; check_single_row_mutation( @@ -2712,6 +2765,8 @@ WHERE report_aggregations.aggregation_job_id = aggregation_jobs.id &encoded_state_values.helper_encrypted_input_share, /* leader_prep_transition */ &encoded_state_values.leader_prep_transition, + /* leader_prep_state */ &encoded_state_values.leader_prep_state, + /* leader_output_share */ &encoded_state_values.leader_output_share, /* helper_prep_state */ &encoded_state_values.helper_prep_state, /* error_code */ &encoded_state_values.report_error, /* updated_at */ &now, @@ -4487,7 +4542,7 @@ WHERE task_id = $1 // * min_size is the minimum possible number of reports included in the batch, i.e. all report // aggregations in the batch which have reached the FINISHED state. // * max_size is the maximum possible number of reports included in the batch, i.e. all report - // aggregations in the batch which are in a non-failure state (START/WAITING/FINISHED). + // aggregations in the batch which are in a non-failure state (INIT/CONTINUE/FINISHED). async fn read_batch_size( &self, task_pkey: i64, @@ -4503,7 +4558,7 @@ WITH report_aggregations_count AS ( WHERE aggregation_jobs.task_id = $1 AND report_aggregations.task_id = aggregation_jobs.task_id AND aggregation_jobs.batch_id = $2 - AND report_aggregations.state in ('START', 'WAITING') + AND report_aggregations.state in ('INIT', 'CONTINUE') ), batch_aggregation_count AS ( SELECT SUM(report_count) AS count FROM batch_aggregations diff --git a/aggregator_core/src/datastore/models.rs b/aggregator_core/src/datastore/models.rs index 46e262578..8175d1273 100644 --- a/aggregator_core/src/datastore/models.rs +++ b/aggregator_core/src/datastore/models.rs @@ -24,7 +24,7 @@ use postgres_protocol::types::{ use postgres_types::{accepts, to_sql_checked, FromSql, ToSql}; use prio::{ codec::{encode_u16_items, Encode}, - topology::ping_pong::PingPongTransition, + topology::ping_pong::{PingPongState, PingPongTransition}, vdaf::{self, Aggregatable}, }; use rand::{distributions::Standard, prelude::Distribution}; @@ -162,7 +162,7 @@ where } #[cfg(feature = "test-util")] - pub fn as_start_leader_report_aggregation( + pub fn as_leader_init_report_aggregation( &self, aggregation_job_id: AggregationJobId, ord: u64, @@ -174,7 +174,7 @@ where *self.metadata().time(), ord, None, - ReportAggregationState::StartLeader { + ReportAggregationState::LeaderInit { public_extensions: self.metadata().public_extensions().to_vec(), public_share: self.public_share().clone(), leader_private_extensions: self.leader_private_extensions().to_vec(), @@ -919,7 +919,11 @@ where #[derive(Clone, Educe)] #[educe(Debug)] pub enum ReportAggregationState> { - StartLeader { + // + // Leader-only states. + // + /// The Leader is ready to send an aggregation initialization request to the Helper. + LeaderInit { /// The sequence of public extensions from this report's metadata. #[educe(Debug(ignore))] public_extensions: Vec, @@ -936,20 +940,36 @@ pub enum ReportAggregationState, }, - WaitingHelper { + /// The Leader received a "processing" response from a previous aggregation initialization or + /// continuation request, and is ready to poll for completion. + LeaderPoll { + /// Leader's current aggregation state. + leader_state: PingPongState, + }, + + // + // Helper-only states. + // + /// The Helper is ready to receive an aggregation continuation request from the Leader. + HelperContinue { /// Helper's current preparation state #[educe(Debug(ignore))] prepare_state: A::PrepareState, }, + + // + // Common states. + // + /// Aggregation has completed successfully. Finished, - Failed { - report_error: ReportError, - }, + /// Aggregation has completed unsuccessfully. + Failed { report_error: ReportError }, } impl> @@ -957,9 +977,10 @@ impl> { pub(super) fn state_code(&self) -> ReportAggregationStateCode { match self { - ReportAggregationState::StartLeader { .. } => ReportAggregationStateCode::Start, - ReportAggregationState::WaitingLeader { .. } - | ReportAggregationState::WaitingHelper { .. } => ReportAggregationStateCode::Waiting, + ReportAggregationState::LeaderInit { .. } => ReportAggregationStateCode::Init, + ReportAggregationState::LeaderContinue { .. } + | ReportAggregationState::HelperContinue { .. } => ReportAggregationStateCode::Continue, + ReportAggregationState::LeaderPoll { .. } => ReportAggregationStateCode::Poll, ReportAggregationState::Finished => ReportAggregationStateCode::Finished, ReportAggregationState::Failed { .. } => ReportAggregationStateCode::Failed, } @@ -975,7 +996,7 @@ impl> A::PrepareState: Encode, { Ok(match self { - ReportAggregationState::StartLeader { + ReportAggregationState::LeaderInit { public_extensions, public_share, leader_private_extensions, @@ -1001,18 +1022,35 @@ impl> ..Default::default() } } - ReportAggregationState::WaitingLeader { transition } => { + ReportAggregationState::LeaderContinue { transition } => { EncodedReportAggregationStateValues { leader_prep_transition: Some(transition.get_encoded()?), ..Default::default() } } - ReportAggregationState::WaitingHelper { prepare_state } => { + ReportAggregationState::LeaderPoll { leader_state } => { + let (encoded_leader_prep_state, encoded_leader_output_share) = match leader_state { + PingPongState::Continued(prepare_state) => { + (Some(prepare_state.get_encoded()?), None) + } + PingPongState::Finished(output_share) => { + (None, Some(output_share.get_encoded()?)) + } + }; + EncodedReportAggregationStateValues { + leader_prep_state: encoded_leader_prep_state, + leader_output_share: encoded_leader_output_share, + ..Default::default() + } + } + + ReportAggregationState::HelperContinue { prepare_state } => { EncodedReportAggregationStateValues { helper_prep_state: Some(prepare_state.get_encoded()?), ..Default::default() } } + ReportAggregationState::Finished => EncodedReportAggregationStateValues::default(), ReportAggregationState::Failed { report_error } => { EncodedReportAggregationStateValues { @@ -1026,17 +1064,21 @@ impl> #[derive(Default)] pub(super) struct EncodedReportAggregationStateValues { - // State for StartLeader. + // State for LeaderInit. pub(super) public_extensions: Option>, pub(super) public_share: Option>, pub(super) leader_private_extensions: Option>, pub(super) leader_input_share: Option>, pub(super) helper_encrypted_input_share: Option>, - // State for WaitingLeader. + // State for LeaderContinue. pub(super) leader_prep_transition: Option>, - // State for WaitingHelper. + // State for LeaderPoll. + pub(super) leader_prep_state: Option>, + pub(super) leader_output_share: Option>, + + // State for HelperContinue. pub(super) helper_prep_state: Option>, // State for Failed. @@ -1049,10 +1091,12 @@ pub(super) struct EncodedReportAggregationStateValues { #[derive(Debug, Clone, Copy, PartialEq, Eq, FromSql, ToSql)] #[postgres(name = "report_aggregation_state")] pub(super) enum ReportAggregationStateCode { - #[postgres(name = "START")] - Start, - #[postgres(name = "WAITING")] - Waiting, + #[postgres(name = "INIT")] + Init, + #[postgres(name = "CONTINUE")] + Continue, + #[postgres(name = "POLL")] + Poll, #[postgres(name = "FINISHED")] Finished, #[postgres(name = "FAILED")] @@ -1074,14 +1118,14 @@ where fn eq(&self, other: &Self) -> bool { match (self, other) { ( - Self::StartLeader { + Self::LeaderInit { public_extensions: lhs_public_extensions, public_share: lhs_public_share, leader_private_extensions: lhs_leader_private_extensions, leader_input_share: lhs_leader_input_share, helper_encrypted_input_share: lhs_helper_encrypted_input_share, }, - Self::StartLeader { + Self::LeaderInit { public_extensions: rhs_public_extensions, public_share: rhs_public_share, leader_private_extensions: rhs_leader_private_extensions, @@ -1096,18 +1140,18 @@ where && lhs_helper_encrypted_input_share == rhs_helper_encrypted_input_share } ( - Self::WaitingLeader { + Self::LeaderContinue { transition: lhs_transition, }, - Self::WaitingLeader { + Self::LeaderContinue { transition: rhs_transition, }, ) => lhs_transition == rhs_transition, ( - Self::WaitingHelper { + Self::HelperContinue { prepare_state: lhs_state, }, - Self::WaitingHelper { + Self::HelperContinue { prepare_state: rhs_state, }, ) => lhs_state == rhs_state, @@ -1143,14 +1187,14 @@ where /// See also [`ReportAggregationState`]. #[derive(Clone, Debug)] pub enum ReportAggregationMetadataState { - Start, + Init, Failed { report_error: ReportError }, } /// Metadata from the state of a single client report's ongoing aggregation. This is like /// [`ReportAggregation`], but omits the report aggregation state and report shares. /// -/// This is only used with report aggregations in the `StartLeader` or `Failed` states. +/// This is only used with report aggregations in the `LeaderInit` or `Failed` states. #[derive(Clone, Debug)] pub struct ReportAggregationMetadata { task_id: TaskId, @@ -2440,4 +2484,10 @@ impl TaskAggregationCounter { pub fn increment_success(&mut self) { self.success += 1 } + + /// Returns true if and only if this task aggregation counter is "zero", i.e. it would not + /// change the state of the written task aggregation counters. + pub fn is_zero(&self) -> bool { + self == &TaskAggregationCounter::default() + } } diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index b4333a5bb..af8c20635 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -1054,7 +1054,7 @@ async fn get_unaggregated_client_report_ids_with_agg_param_for_task( )) .await?; tx.put_report_aggregation( - &aggregated_report.as_start_leader_report_aggregation(aggregation_job_id, 0), + &aggregated_report.as_leader_init_report_aggregation(aggregation_job_id, 0), ) .await }) @@ -1382,7 +1382,7 @@ async fn count_client_reports_for_batch_id(ephemeral_datastore: EphemeralDatasto AggregationJobStep::from(0), ); let expired_report_aggregation = expired_report - .as_start_leader_report_aggregation(*expired_aggregation_job.id(), 0); + .as_leader_init_report_aggregation(*expired_aggregation_job.id(), 0); let aggregation_job_0 = AggregationJob::<0, LeaderSelected, dummy::Vdaf>::new( *task.id(), @@ -1395,9 +1395,9 @@ async fn count_client_reports_for_batch_id(ephemeral_datastore: EphemeralDatasto AggregationJobStep::from(0), ); let aggregation_job_0_report_aggregation_0 = - report_0.as_start_leader_report_aggregation(*aggregation_job_0.id(), 1); + report_0.as_leader_init_report_aggregation(*aggregation_job_0.id(), 1); let aggregation_job_0_report_aggregation_1 = - report_1.as_start_leader_report_aggregation(*aggregation_job_0.id(), 2); + report_1.as_leader_init_report_aggregation(*aggregation_job_0.id(), 2); let aggregation_job_1 = AggregationJob::<0, LeaderSelected, dummy::Vdaf>::new( *task.id(), @@ -1410,9 +1410,9 @@ async fn count_client_reports_for_batch_id(ephemeral_datastore: EphemeralDatasto AggregationJobStep::from(0), ); let aggregation_job_1_report_aggregation_0 = - report_0.as_start_leader_report_aggregation(*aggregation_job_1.id(), 0); + report_0.as_leader_init_report_aggregation(*aggregation_job_1.id(), 0); let aggregation_job_1_report_aggregation_1 = - report_1.as_start_leader_report_aggregation(*aggregation_job_1.id(), 1); + report_1.as_leader_init_report_aggregation(*aggregation_job_1.id(), 1); tx.put_client_report(&expired_report).await.unwrap(); tx.put_client_report(&report_0).await.unwrap(); @@ -2021,15 +2021,18 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore assert_eq!(want_aggregation_jobs, got_aggregation_jobs); - // Run: release a few jobs, then attempt to acquire jobs again. + // Run: release a few jobs with a delay before reacquiry, then attempt to acquire jobs again. const RELEASE_COUNT: usize = 2; + const REACQUIRE_DELAY: StdDuration = StdDuration::from_secs(10); - // Sanity check constants: ensure we release fewer jobs than we're about to acquire to - // ensure we can acquire them in all in a single call, while leaving headroom to acquire - // at least one unwanted job if there is a logic bug. + // Sanity check constants: ensure we release fewer jobs than we're about to acquire to ensure we + // can acquire them in all in a single call, while leaving headroom to acquire at least one + // unwanted job if there is a logic bug. And ensure that our reacquire delay is shorter than the + // lease duration, to ensure we don't timeout the leases which are not explicitly released. #[allow(clippy::assertions_on_constants)] { assert!(RELEASE_COUNT < MAXIMUM_ACQUIRE_COUNT); + assert!(REACQUIRE_DELAY < LEASE_DURATION); } let leases_to_release: Vec<_> = got_leases.into_iter().take(RELEASE_COUNT).collect(); @@ -2042,7 +2045,9 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore let leases_to_release = leases_to_release.clone(); Box::pin(async move { for lease in leases_to_release { - tx.release_aggregation_job(&lease).await.unwrap(); + tx.release_aggregation_job(&lease, Some(&REACQUIRE_DELAY)) + .await + .unwrap(); } Ok(()) }) @@ -2050,11 +2055,32 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore .await .unwrap(); + // Verify that we can't immediately acquire the jobs again. + ds.run_unnamed_tx(|tx| { + Box::pin(async move { + assert!(tx + .acquire_incomplete_aggregation_jobs(&LEASE_DURATION, MAXIMUM_ACQUIRE_COUNT) + .await + .unwrap() + .is_empty()); + Ok(()) + }) + }) + .await + .unwrap(); + + // Advance the clock past the reacquire delay, then reacquire the leases we released with a + // reacquire delay. + clock.advance(&Duration::from_seconds(REACQUIRE_DELAY.as_secs())); + let mut got_aggregation_jobs: Vec<_> = ds .run_unnamed_tx(|tx| { Box::pin(async move { - tx.acquire_incomplete_aggregation_jobs(&LEASE_DURATION, MAXIMUM_ACQUIRE_COUNT) - .await + tx.acquire_incomplete_aggregation_jobs( + &(LEASE_DURATION - REACQUIRE_DELAY), + MAXIMUM_ACQUIRE_COUNT, + ) + .await }) }) .await @@ -2072,7 +2098,9 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore // Run: advance time by the lease duration (which implicitly releases the jobs), and attempt // to acquire aggregation jobs again. - clock.advance(&Duration::from_seconds(LEASE_DURATION.as_secs())); + clock.advance(&Duration::from_seconds( + LEASE_DURATION.as_secs() - REACQUIRE_DELAY.as_secs(), + )); let want_expiry_time = clock.now().as_naive_date_time().unwrap() + chrono::Duration::from_std(LEASE_DURATION).unwrap(); let want_aggregation_jobs: Vec<_> = aggregation_job_ids @@ -2137,7 +2165,10 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore ); ds.run_unnamed_tx(|tx| { let lease_with_random_token = lease_with_random_token.clone(); - Box::pin(async move { tx.release_aggregation_job(&lease_with_random_token).await }) + Box::pin(async move { + tx.release_aggregation_job(&lease_with_random_token, None) + .await + }) }) .await .unwrap_err(); @@ -2146,7 +2177,7 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore // place. ds.run_unnamed_tx(|tx| { let lease = lease.clone(); - Box::pin(async move { tx.release_aggregation_job(&lease).await }) + Box::pin(async move { tx.release_aggregation_job(&lease, None).await }) }) .await .unwrap(); @@ -2320,7 +2351,7 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { for (ord, (role, state)) in [ ( Role::Leader, - ReportAggregationState::StartLeader { + ReportAggregationState::LeaderInit { public_extensions: Vec::from([Extension::new( ExtensionType::Tbd, "public_extension_tbd".into(), @@ -2340,16 +2371,28 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { ), ( Role::Leader, - ReportAggregationState::WaitingLeader { + ReportAggregationState::LeaderContinue { transition: vdaf_transcript.leader_prepare_transitions[1] .transition .clone() .unwrap(), }, ), + ( + Role::Leader, + ReportAggregationState::LeaderPoll { + leader_state: vdaf_transcript.leader_prepare_transitions[0].state.clone(), + }, + ), + ( + Role::Leader, + ReportAggregationState::LeaderPoll { + leader_state: vdaf_transcript.leader_prepare_transitions[1].state.clone(), + }, + ), ( Role::Helper, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: *vdaf_transcript.helper_prepare_transitions[0].prepare_state(), }, ), @@ -2479,6 +2522,7 @@ WHERE client_report_id = $1", task.id(), &aggregation_job_id, &report_id, + &aggregation_param, ) .await }) @@ -2547,6 +2591,7 @@ SELECT updated_at, updated_by FROM report_aggregations task.id(), &aggregation_job_id, &report_id, + &aggregation_param, ) .await }) @@ -2570,6 +2615,7 @@ SELECT updated_at, updated_by FROM report_aggregations task.id(), &aggregation_job_id, &report_id, + &aggregation_param, ) .await }) @@ -2587,6 +2633,7 @@ async fn report_aggregation_not_found(ephemeral_datastore: EphemeralDatastore) { let ds = ephemeral_datastore.datastore(MockClock::default()).await; let vdaf = Arc::new(dummy::Vdaf::default()); + let aggregation_param = dummy::AggregationParam(5); let rslt = ds .run_unnamed_tx(|tx| { @@ -2599,6 +2646,7 @@ async fn report_aggregation_not_found(ephemeral_datastore: EphemeralDatastore) { &random(), &random(), &ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), + &aggregation_param, ) .await }) @@ -2682,7 +2730,7 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme let mut want_report_aggregations = Vec::new(); for (ord, state) in [ - ReportAggregationState::StartLeader { + ReportAggregationState::LeaderInit { public_extensions: Vec::new(), public_share: vdaf_transcript.public_share, leader_private_extensions: Vec::new(), @@ -2693,7 +2741,7 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme Vec::from("payload"), ), }, - ReportAggregationState::WaitingHelper { + ReportAggregationState::HelperContinue { prepare_state: *vdaf_transcript.helper_prepare_transitions[0] .prepare_state(), }, @@ -2757,6 +2805,7 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme &Role::Helper, task.id(), &aggregation_job_id, + &aggregation_param, ) .await }) @@ -2777,6 +2826,7 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme &Role::Helper, task.id(), &aggregation_job_id, + &aggregation_param, ) .await }) @@ -2872,7 +2922,7 @@ async fn create_report_aggregation_from_client_reports_table( report_id, timestamp, 0, - ReportAggregationMetadataState::Start, + ReportAggregationMetadataState::Init, ); tx.put_leader_report_aggregation(&report_aggregation_metadata) .await @@ -2885,7 +2935,7 @@ async fn create_report_aggregation_from_client_reports_table( timestamp, 0, None, - ReportAggregationState::<0, dummy::Vdaf>::StartLeader { + ReportAggregationState::<0, dummy::Vdaf>::LeaderInit { public_extensions: leader_stored_report .metadata() .public_extensions() @@ -2915,6 +2965,7 @@ async fn create_report_aggregation_from_client_reports_table( &Role::Leader, task.id(), &aggregation_job_id, + &aggregation_param, ) .await }) @@ -5242,7 +5293,7 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) { AggregationJobStep::from(1), ); let report_aggregation_0_0 = - report_1.as_start_leader_report_aggregation(*aggregation_job_0.id(), 0); + report_1.as_leader_init_report_aggregation(*aggregation_job_0.id(), 0); let report_id_0_1 = random(); let transcript = run_vdaf( @@ -5262,7 +5313,7 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) { 1, None, // Counted among max_size. - ReportAggregationState::WaitingLeader { + ReportAggregationState::LeaderContinue { transition: transcript.helper_prepare_transitions[0].transition.clone(), }, ); @@ -5329,7 +5380,7 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) { AggregationJobStep::from(1), ); let report_aggregation_2_0 = - report_2.as_start_leader_report_aggregation(*aggregation_job_2.id(), 0); + report_2.as_leader_init_report_aggregation(*aggregation_job_2.id(), 0); for aggregation_job in &[aggregation_job_0, aggregation_job_1, aggregation_job_2] { tx.put_aggregation_job(aggregation_job).await.unwrap(); @@ -5682,11 +5733,13 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData let clock = MockClock::new(OLDEST_ALLOWED_REPORT_TIMESTAMP); let ds = ephemeral_datastore.datastore(clock.clone()).await; let vdaf = dummy::Vdaf::default(); + let aggregation_param = dummy::AggregationParam(0); // Setup. async fn write_aggregation_artifacts( tx: &Transaction<'_, MockClock>, task_id: &TaskId, + aggregation_param: &dummy::AggregationParam, client_timestamps: &[Time], ) -> ( B::BatchIdentifier, @@ -5717,7 +5770,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData let aggregation_job = AggregationJob::<0, B, dummy::Vdaf>::new( *task_id, random(), - dummy::AggregationParam(0), + *aggregation_param, B::partial_batch_identifier(&batch_identifier).clone(), client_timestamp_interval, AggregationJobState::InProgress, @@ -5727,7 +5780,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData for (ord, report) in reports.iter().enumerate() { let report_aggregation = report - .as_start_leader_report_aggregation(*aggregation_job.id(), ord.try_into().unwrap()); + .as_leader_init_report_aggregation(*aggregation_job.id(), ord.try_into().unwrap()); tx.put_report_aggregation(&report_aggregation) .await .unwrap(); @@ -5809,6 +5862,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData write_aggregation_artifacts::( tx, leader_time_interval_task.id(), + &aggregation_param, &[ OLDEST_ALLOWED_REPORT_TIMESTAMP .sub(&Duration::from_seconds(20)) @@ -5825,6 +5879,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData write_aggregation_artifacts::( tx, leader_time_interval_task.id(), + &aggregation_param, &[ OLDEST_ALLOWED_REPORT_TIMESTAMP .sub(&Duration::from_seconds(5)) @@ -5843,6 +5898,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData write_aggregation_artifacts::( tx, leader_time_interval_task.id(), + &aggregation_param, &[ OLDEST_ALLOWED_REPORT_TIMESTAMP .add(&Duration::from_seconds(19)) @@ -5860,6 +5916,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData write_aggregation_artifacts::( tx, helper_time_interval_task.id(), + &aggregation_param, &[ OLDEST_ALLOWED_REPORT_TIMESTAMP .sub(&Duration::from_seconds(20)) @@ -5876,6 +5933,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData write_aggregation_artifacts::( tx, helper_time_interval_task.id(), + &aggregation_param, &[ OLDEST_ALLOWED_REPORT_TIMESTAMP .sub(&Duration::from_seconds(5)) @@ -5894,6 +5952,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData write_aggregation_artifacts::( tx, helper_time_interval_task.id(), + &aggregation_param, &[ OLDEST_ALLOWED_REPORT_TIMESTAMP .add(&Duration::from_seconds(19)) @@ -5911,6 +5970,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData write_aggregation_artifacts::( tx, leader_leader_selected_task.id(), + &aggregation_param, &[ OLDEST_ALLOWED_REPORT_TIMESTAMP .sub(&Duration::from_seconds(20)) @@ -5927,6 +5987,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData write_aggregation_artifacts::( tx, leader_leader_selected_task.id(), + &aggregation_param, &[ OLDEST_ALLOWED_REPORT_TIMESTAMP .sub(&Duration::from_seconds(5)) @@ -5945,6 +6006,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData write_aggregation_artifacts::( tx, leader_leader_selected_task.id(), + &aggregation_param, &[ OLDEST_ALLOWED_REPORT_TIMESTAMP .add(&Duration::from_seconds(19)) @@ -5962,6 +6024,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData write_aggregation_artifacts::( tx, helper_leader_selected_task.id(), + &aggregation_param, &[ OLDEST_ALLOWED_REPORT_TIMESTAMP .sub(&Duration::from_seconds(20)) @@ -5978,6 +6041,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData write_aggregation_artifacts::( tx, helper_leader_selected_task.id(), + &aggregation_param, &[ OLDEST_ALLOWED_REPORT_TIMESTAMP .sub(&Duration::from_seconds(5)) @@ -5996,6 +6060,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData write_aggregation_artifacts::( tx, helper_leader_selected_task.id(), + &aggregation_param, &[ OLDEST_ALLOWED_REPORT_TIMESTAMP .add(&Duration::from_seconds(19)) @@ -6101,6 +6166,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData &vdaf, &Role::Leader, &leader_time_interval_task_id, + &aggregation_param, ) .await .unwrap(); @@ -6109,6 +6175,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData &vdaf, &Role::Helper, &helper_time_interval_task_id, + &aggregation_param, ) .await .unwrap(); @@ -6117,6 +6184,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData &vdaf, &Role::Leader, &leader_leader_selected_task_id, + &aggregation_param, ) .await .unwrap(); @@ -6125,6 +6193,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData &vdaf, &Role::Helper, &helper_leader_selected_task_id, + &aggregation_param, ) .await .unwrap(); diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index 23690c5ee..77a59d0d4 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -8,7 +8,9 @@ use janus_core::{ time::TimeExt, vdaf::VdafInstance, }; -use janus_messages::{batch_mode, AggregationJobId, Duration, HpkeConfig, Role, TaskId, Time}; +use janus_messages::{ + batch_mode, AggregationJobId, AggregationJobStep, Duration, HpkeConfig, Role, TaskId, Time, +}; use rand::{distributions::Standard, random, thread_rng, Rng}; use serde::{de::Error as _, Deserialize, Deserializer, Serialize, Serializer}; use std::array::TryFromSliceError; @@ -378,15 +380,23 @@ impl AggregatorTask { pub fn aggregation_job_uri( &self, aggregation_job_id: &AggregationJobId, + step: Option, ) -> Result, Error> { if matches!( self.aggregator_parameters, AggregatorTaskParameters::Leader { .. } ) { - Ok(Some(self.peer_aggregator_endpoint().join(&format!( + let mut uri = self.peer_aggregator_endpoint().join(&format!( "{}/aggregation_jobs/{aggregation_job_id}", self.tasks_path() - ))?)) + ))?; + + if let Some(step) = step { + uri.query_pairs_mut() + .append_pair("step", &u16::from(step).to_string()); + } + + Ok(Some(uri)) } else { Ok(None) } diff --git a/db/00000000000001_initial_schema.up.sql b/db/00000000000001_initial_schema.up.sql index 8a9649648..e4ca5e9d9 100644 --- a/db/00000000000001_initial_schema.up.sql +++ b/db/00000000000001_initial_schema.up.sql @@ -248,10 +248,11 @@ CREATE INDEX aggregation_jobs_task_and_client_timestamp_interval ON aggregation_ -- Specifies the possible state of aggregating a single report. CREATE TYPE REPORT_AGGREGATION_STATE AS ENUM( - 'START', -- the aggregator is waiting to decrypt its input share & compute initial preparation state - 'WAITING', -- the aggregator is waiting for a message from its peer before proceeding - 'FINISHED', -- the aggregator has completed the preparation process and recovered an output share - 'FAILED' -- an error has occurred and an output share cannot be recovered + 'INIT', -- the aggregator is ready for the aggregation initialization step + 'CONTINUE', -- the aggregator is ready for an aggregation continuation step + 'POLL', -- the aggregator is polling for completion of a previous operation + 'FINISHED', -- the aggregator has completed the preparation process successfully + 'FAILED' -- the aggregator has completed the preparation process unsuccessfully ); -- An aggregation attempt for a single client report. An aggregation job logically contains a number @@ -267,17 +268,21 @@ CREATE TABLE report_aggregations( last_prep_resp BYTEA, -- the last PrepareResp message sent to the Leader, to assist in replay (opaque DAP message, populated for Helper only) state REPORT_AGGREGATION_STATE NOT NULL, -- the current state of this report aggregation - -- Additional data for state StartLeader. + -- Additional data for state LeaderInit. public_extensions BYTEA, -- encoded sequence of public Extension messages (opaque DAP messages) public_share BYTEA, -- the public share for the report (opaque VDAF message) leader_private_extensions BYTEA, -- encoded sequence of leader's private Extension messages (opaque DAP messages) leader_input_share BYTEA, -- encoded leader input share (opaque VDAF message) helper_encrypted_input_share BYTEA, -- encoded HPKE ciphertext of helper input share (opaque DAP message) - -- Additional data for state WaitingLeader. + -- Additional data for state LeaderContinue. leader_prep_transition BYTEA, -- the current VDAF prepare transition (opaque VDAF message) - -- Additional data for state WaitingHelper. + -- Additional data for state LeaderPoll. + leader_prep_state BYTEA, -- the current prepare state (opaque VDAF message) + leader_output_share BYTEA, -- the leader's recovered output share (opaque VDAF message) + + -- Additional data for state HelperContinue. helper_prep_state BYTEA, -- the current VDAF prepare state (opaque VDAF message) -- Additional data for state Failed.