diff --git a/crates/dapf/src/acceptance/mod.rs b/crates/dapf/src/acceptance/mod.rs index 06f983822..fb108b6a8 100644 --- a/crates/dapf/src/acceptance/mod.rs +++ b/crates/dapf/src/acceptance/mod.rs @@ -533,7 +533,7 @@ impl Test { let agg_share_span = task_config.consume_agg_job_resp( task_id, agg_job_state, - agg_job_resp, + agg_job_resp.unwrap_ready(), // TODO: implement polling self.metrics(), )?; diff --git a/crates/daphne/src/messages/mod.rs b/crates/daphne/src/messages/mod.rs index 8b86e2cb6..28ac198b6 100644 --- a/crates/daphne/src/messages/mod.rs +++ b/crates/daphne/src/messages/mod.rs @@ -951,19 +951,61 @@ impl std::fmt::Display for ReportError { } } +const AGG_JOB_RESP_PROCESSING: u8 = 0; +const AGG_JOB_RESP_READY: u8 = 1; + /// An aggregate response sent from the Helper to the Leader. -#[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize)] -pub struct AggregationJobResp { +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))] +pub enum AggregationJobResp { + Ready { transitions: Vec }, + Processing, +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))] +pub struct ReadyAggregationJobResp { pub transitions: Vec, } +impl From for AggregationJobResp { + fn from(value: ReadyAggregationJobResp) -> Self { + Self::Ready { + transitions: value.transitions, + } + } +} + +impl AggregationJobResp { + #[cfg(any(test, feature = "test-utils"))] + #[track_caller] + pub fn unwrap_ready(self) -> ReadyAggregationJobResp { + match self { + Self::Ready { transitions } => ReadyAggregationJobResp { transitions }, + Self::Processing => panic!("unwraped a Processing value"), + } + } +} + impl ParameterizedEncode for AggregationJobResp { fn encode_with_param( &self, version: &DapVersion, bytes: &mut Vec, ) -> Result<(), CodecError> { - encode_u32_items(bytes, version, &self.transitions) + match (self, version) { + (Self::Ready { transitions }, DapVersion::Draft09) => { + encode_u32_items(bytes, version, transitions) + } + (Self::Ready { transitions }, DapVersion::Latest) => { + AGG_JOB_RESP_READY.encode(bytes)?; + encode_u32_items(bytes, version, transitions) + } + (Self::Processing, DapVersion::Draft09) => Err(CodecError::Other( + "draft09: can't represent an agg job resp that's being processed".into(), + )), + (Self::Processing, DapVersion::Latest) => AGG_JOB_RESP_PROCESSING.encode(bytes), + } } } @@ -972,9 +1014,18 @@ impl ParameterizedDecode for AggregationJobResp { version: &DapVersion, bytes: &mut Cursor<&[u8]>, ) -> Result { - Ok(Self { - transitions: decode_u32_items(version, bytes)?, - }) + match version { + DapVersion::Draft09 => Ok(Self::Ready { + transitions: decode_u32_items(version, bytes)?, + }), + DapVersion::Latest => match u8::decode(bytes)? { + AGG_JOB_RESP_PROCESSING => Ok(Self::Processing), + AGG_JOB_RESP_READY => Ok(Self::Ready { + transitions: decode_u32_items(version, bytes)?, + }), + _ => Err(CodecError::UnexpectedValue), + }, + } } } @@ -1898,7 +1949,7 @@ mod test { 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 2, 7, ]; - let want = AggregationJobResp { + let want = AggregationJobResp::Ready { transitions: vec![ Transition { report_id: ReportId([22; 16]), @@ -1916,14 +1967,18 @@ mod test { }, ], }; - println!( - "want {:?}", - want.get_encoded_with_param(&DapVersion::Latest).unwrap() - ); - let got = - AggregationJobResp::get_decoded_with_param(&DapVersion::Latest, TEST_DATA).unwrap(); + AggregationJobResp::get_decoded_with_param(&DapVersion::Draft09, TEST_DATA).unwrap(); assert_eq!(got, want); + let draft_latest_data = [&[1], TEST_DATA].concat(); + let got = + AggregationJobResp::get_decoded_with_param(&DapVersion::Latest, &draft_latest_data) + .unwrap(); + assert_eq!(got, want); + assert_eq!( + AggregationJobResp::Processing, + AggregationJobResp::get_decoded_with_param(&DapVersion::Latest, &[0]).unwrap(), + ); } #[test] diff --git a/crates/daphne/src/protocol/aggregator.rs b/crates/daphne/src/protocol/aggregator.rs index b169f3664..1883e4412 100644 --- a/crates/daphne/src/protocol/aggregator.rs +++ b/crates/daphne/src/protocol/aggregator.rs @@ -11,9 +11,9 @@ use crate::{ fatal_error, hpke::{info_and_aad, HpkeConfig, HpkeDecrypter}, messages::{ - self, encode_u32_bytes, AggregationJobInitReq, AggregationJobResp, Base64Encode, - BatchSelector, HpkeCiphertext, PartialBatchSelector, PrepareInit, Report, ReportError, - ReportId, ReportShare, TaskId, Transition, TransitionVar, + self, encode_u32_bytes, AggregationJobInitReq, Base64Encode, BatchSelector, HpkeCiphertext, + PartialBatchSelector, PrepareInit, ReadyAggregationJobResp, Report, ReportError, ReportId, + ReportShare, TaskId, Transition, TransitionVar, }, metrics::{DaphneMetrics, ReportStatus}, protocol::{decode_ping_pong_framed, PingPongMessageType}, @@ -279,7 +279,7 @@ impl DapTaskConfig { report_status: &HashMap, part_batch_sel: &PartialBatchSelector, initialized_reports: &[InitializedReport], - ) -> Result<(DapAggregateSpan, AggregationJobResp), DapError> { + ) -> Result<(DapAggregateSpan, ReadyAggregationJobResp), DapError> { let num_reports = initialized_reports.len(); let mut agg_span = DapAggregateSpan::default(); let mut transitions = Vec::with_capacity(num_reports); @@ -355,7 +355,7 @@ impl DapTaskConfig { }); } - Ok((agg_span, AggregationJobResp { transitions })) + Ok((agg_span, ReadyAggregationJobResp { transitions })) } /// Leader: Consume the `AggregationJobResp` message sent by the Helper and compute the @@ -364,7 +364,7 @@ impl DapTaskConfig { &self, task_id: &TaskId, state: DapAggregationJobState, - agg_job_resp: AggregationJobResp, + agg_job_resp: ReadyAggregationJobResp, metrics: &dyn DaphneMetrics, ) -> Result, DapError> { if agg_job_resp.transitions.len() != state.seq.len() { diff --git a/crates/daphne/src/protocol/mod.rs b/crates/daphne/src/protocol/mod.rs index 2f39c3efa..034a4fece 100644 --- a/crates/daphne/src/protocol/mod.rs +++ b/crates/daphne/src/protocol/mod.rs @@ -250,7 +250,7 @@ mod test { let (agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); assert_eq!(agg_span.report_count(), 3); - assert_eq!(agg_job_resp.transitions.len(), 3); + assert_eq!(agg_job_resp.unwrap_ready().transitions.len(), 3); } test_versions! { produce_agg_job_req } @@ -376,6 +376,7 @@ mod test { t.produce_agg_job_req(&DapAggregationParam::Empty, reports.clone()); let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let agg_job_resp = agg_job_resp.unwrap_ready(); assert_eq!(agg_job_resp.transitions.len(), 1); assert_matches!( agg_job_resp.transitions[0].var, @@ -411,6 +412,7 @@ mod test { }; let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let agg_job_resp = agg_job_resp.unwrap_ready(); assert_eq!(agg_job_resp.transitions.len(), 1); assert_matches!( agg_job_resp.transitions[0].var, @@ -446,6 +448,7 @@ mod test { }; let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let agg_job_resp = agg_job_resp.unwrap_ready(); assert_eq!(agg_job_resp.transitions.len(), 1); assert_matches!( agg_job_resp.transitions[0].var, @@ -466,6 +469,7 @@ mod test { t.produce_agg_job_req(&DapAggregationParam::Empty, reports.clone()); let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let agg_job_resp = agg_job_resp.unwrap_ready(); assert_eq!(agg_job_resp.transitions.len(), 1); assert_matches!( agg_job_resp.transitions[0].var, @@ -511,6 +515,7 @@ mod test { let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let agg_job_resp = agg_job_resp.unwrap_ready(); assert_eq!(agg_job_resp.transitions.len(), 2); assert_matches!( agg_job_resp.transitions[0].var, @@ -532,8 +537,9 @@ mod test { ]); let (leader_state, agg_job_init_req) = t.produce_agg_job_req(&DapAggregationParam::Empty, reports); - let (_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let mut agg_job_resp = agg_job_resp.unwrap_ready(); // Helper sends transitions out of order. let tmp = agg_job_resp.transitions[0].clone(); agg_job_resp.transitions[0] = agg_job_resp.transitions[1].clone(); @@ -555,8 +561,9 @@ mod test { ]); let (leader_state, agg_job_init_req) = t.produce_agg_job_req(&DapAggregationParam::Empty, reports); - let (_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let mut agg_job_resp = agg_job_resp.unwrap_ready(); // Helper sends a transition twice. let repeated_transition = agg_job_resp.transitions[0].clone(); agg_job_resp.transitions.push(repeated_transition); @@ -578,8 +585,9 @@ mod test { ]); let (leader_state, agg_job_init_req) = t.produce_agg_job_req(&DapAggregationParam::Empty, reports); - let (_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let mut agg_job_resp = agg_job_resp.unwrap_ready(); // Helper sent a transition with an unrecognized report ID. agg_job_resp.transitions.push(Transition { report_id: ReportId(rng.gen()), @@ -599,8 +607,9 @@ mod test { let reports = t.produce_reports(vec![DapMeasurement::U32Vec(vec![1; 10])]); let (leader_state, agg_job_init_req) = t.produce_agg_job_req(&DapAggregationParam::Empty, reports); - let (_helper_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let (_helper_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let mut agg_job_resp = agg_job_resp.unwrap_ready(); // Helper sent a transition with an unrecognized report ID. Simulate this by flipping the // first bit of the report ID. agg_job_resp.transitions[0].report_id.0[0] ^= 1; @@ -628,6 +637,7 @@ mod test { let (leader_agg_span, helper_agg_span) = { let (helper_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let agg_job_resp = agg_job_resp.unwrap_ready(); let leader_agg_span = t.consume_agg_job_resp(leader_state, agg_job_resp); (leader_agg_span, helper_agg_span) @@ -680,6 +690,7 @@ mod test { .map(|r| r.report_share.report_metadata.id) .collect::>(); let (helper_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let agg_job_resp = agg_job_resp.unwrap_ready(); assert_eq!(2, helper_agg_span.report_count()); assert_eq!(3, agg_job_resp.transitions.len()); diff --git a/crates/daphne/src/roles/helper/handle_agg_job.rs b/crates/daphne/src/roles/helper/handle_agg_job.rs index 026a399b5..89d99044d 100644 --- a/crates/daphne/src/roles/helper/handle_agg_job.rs +++ b/crates/daphne/src/roles/helper/handle_agg_job.rs @@ -336,7 +336,7 @@ impl HandleAggJob { 0, /* vdaf step */ ); - return Ok(agg_job_resp); + return Ok(agg_job_resp.into()); } } diff --git a/crates/daphne/src/roles/leader/mod.rs b/crates/daphne/src/roles/leader/mod.rs index c65c0a8f8..9be463a41 100644 --- a/crates/daphne/src/roles/leader/mod.rs +++ b/crates/daphne/src/roles/leader/mod.rs @@ -373,6 +373,13 @@ async fn run_agg_job( AggregationJobResp::get_decoded_with_param(&task_config.version, &resp.payload) .map_err(|e| DapAbort::from_codec_error(e, *task_id))?; + let agg_job_resp = match agg_job_resp { + AggregationJobResp::Ready { transitions } => { + crate::messages::ReadyAggregationJobResp { transitions } + } + AggregationJobResp::Processing => todo!("polling not implemented yet"), + }; + // Handle AggregationJobResp. let agg_span = task_config.consume_agg_job_resp(task_id, agg_job_state, agg_job_resp, metrics)?; diff --git a/crates/daphne/src/roles/mod.rs b/crates/daphne/src/roles/mod.rs index 4a74078a1..a6ecb73c3 100644 --- a/crates/daphne/src/roles/mod.rs +++ b/crates/daphne/src/roles/mod.rs @@ -764,7 +764,7 @@ mod test { .payload, ) .unwrap(); - let transition = &agg_job_resp.transitions[0]; + let transition = agg_job_resp.unwrap_ready().transitions.remove(0); // Expect failure due to invalid ciphertext. assert_matches!( @@ -878,7 +878,7 @@ mod test { .payload, ) .unwrap(); - let transition = &agg_job_resp.transitions[0]; + let transition = agg_job_resp.unwrap_ready().transitions.remove(0); // Expect success due to valid ciphertext. assert_matches!(transition.var, TransitionVar::Continued(_)); diff --git a/crates/daphne/src/testing/mod.rs b/crates/daphne/src/testing/mod.rs index dd6d3df60..67dd2fbe1 100644 --- a/crates/daphne/src/testing/mod.rs +++ b/crates/daphne/src/testing/mod.rs @@ -14,7 +14,7 @@ use crate::{ messages::{ self, AggregationJobId, AggregationJobInitReq, AggregationJobResp, Base64Encode, BatchId, BatchSelector, Collection, CollectionJobId, HpkeCiphertext, Interval, PartialBatchSelector, - Report, ReportId, TaskId, Time, + ReadyAggregationJobResp, Report, ReportId, TaskId, Time, }, metrics::{prometheus::DaphnePromMetrics, DaphneMetrics}, roles::{ @@ -208,7 +208,8 @@ impl AggregationJobTest { &self, agg_job_init_req: AggregationJobInitReq, ) -> (DapAggregateSpan, AggregationJobResp) { - self.task_config + let (span, resp) = self + .task_config .produce_agg_job_resp( self.task_id, &HashMap::default(), @@ -224,7 +225,8 @@ impl AggregationJobTest { ) .unwrap(), ) - .unwrap() + .unwrap(); + (span, resp.into()) } /// Leader: Handle `AggregationJobResp`, produce `AggregationJobContinueReq`. @@ -233,7 +235,7 @@ impl AggregationJobTest { pub fn consume_agg_job_resp( &self, leader_state: DapAggregationJobState, - agg_job_resp: AggregationJobResp, + agg_job_resp: ReadyAggregationJobResp, ) -> DapAggregateSpan { self.task_config .consume_agg_job_resp( @@ -249,7 +251,7 @@ impl AggregationJobTest { pub fn consume_agg_job_resp_expect_err( &self, leader_state: DapAggregationJobState, - agg_job_resp: AggregationJobResp, + agg_job_resp: ReadyAggregationJobResp, ) -> DapError { let metrics = &self.leader_metrics; self.task_config @@ -337,8 +339,13 @@ impl AggregationJobTest { let (leader_state, agg_job_init_req) = self.produce_agg_job_req(&agg_param, reports); let (leader_agg_span, helper_agg_span) = { - let (helper_agg_span, agg_job_resp) = self.handle_agg_job_req(agg_job_init_req); - let leader_agg_span = self.consume_agg_job_resp(leader_state, agg_job_resp); + let (helper_agg_span, AggregationJobResp::Ready { transitions }) = + self.handle_agg_job_req(agg_job_init_req) + else { + panic!("testing should not be async") + }; + let leader_agg_span = + self.consume_agg_job_resp(leader_state, ReadyAggregationJobResp { transitions }); (leader_agg_span, helper_agg_span) };