Skip to content

Commit

Permalink
Align AggregationJobResp with spec
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess authored and cjpatton committed Jan 28, 2025
1 parent a7a1533 commit 3df9a4d
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 35 deletions.
2 changes: 1 addition & 1 deletion crates/dapf/src/acceptance/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ impl Test {
let agg_share_span = task_config.consume_agg_job_resp(
task_id,
agg_job_state,
agg_job_resp,
agg_job_resp.unwrap_ready(), // TODO: implement polling
self.metrics(),
)?;

Expand Down
69 changes: 55 additions & 14 deletions crates/daphne/src/messages/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -951,10 +951,26 @@ impl std::fmt::Display for ReportError {
}
}

const AGG_JOB_RESP_PROCESSING: u8 = 0;
const AGG_JOB_RESP_READY: u8 = 1;

/// An aggregate response sent from the Helper to the Leader.
#[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct AggregationJobResp {
pub transitions: Vec<Transition>,
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))]
pub enum AggregationJobResp {
Ready { transitions: Vec<Transition> },
Processing,
}

impl AggregationJobResp {
#[cfg(any(test, feature = "test-utils"))]
#[track_caller]
pub fn unwrap_ready(self) -> crate::protocol::ReadyAggregationJobResp {
match self {
Self::Ready { transitions } => crate::protocol::ReadyAggregationJobResp { transitions },
Self::Processing => panic!("unwrapped a Processing value"),
}
}
}

impl ParameterizedEncode<DapVersion> for AggregationJobResp {
Expand All @@ -963,7 +979,19 @@ impl ParameterizedEncode<DapVersion> for AggregationJobResp {
version: &DapVersion,
bytes: &mut Vec<u8>,
) -> Result<(), CodecError> {
encode_u32_items(bytes, version, &self.transitions)
match (self, version) {
(Self::Ready { transitions }, DapVersion::Draft09) => {
encode_u32_items(bytes, version, transitions)
}
(Self::Ready { transitions }, DapVersion::Latest) => {
AGG_JOB_RESP_READY.encode(bytes)?;
encode_u32_items(bytes, version, transitions)
}
(Self::Processing, DapVersion::Draft09) => Err(CodecError::Other(
"draft09: can't represent an agg job resp that's being processed".into(),
)),
(Self::Processing, DapVersion::Latest) => AGG_JOB_RESP_PROCESSING.encode(bytes),
}
}
}

Expand All @@ -972,9 +1000,18 @@ impl ParameterizedDecode<DapVersion> for AggregationJobResp {
version: &DapVersion,
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
Ok(Self {
transitions: decode_u32_items(version, bytes)?,
})
match version {
DapVersion::Draft09 => Ok(Self::Ready {
transitions: decode_u32_items(version, bytes)?,
}),
DapVersion::Latest => match u8::decode(bytes)? {
AGG_JOB_RESP_PROCESSING => Ok(Self::Processing),
AGG_JOB_RESP_READY => Ok(Self::Ready {
transitions: decode_u32_items(version, bytes)?,
}),
_ => Err(CodecError::UnexpectedValue),
},
}
}
}

Expand Down Expand Up @@ -1898,7 +1935,7 @@ mod test {
17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 2, 7,
];

let want = AggregationJobResp {
let want = AggregationJobResp::Ready {
transitions: vec![
Transition {
report_id: ReportId([22; 16]),
Expand All @@ -1916,14 +1953,18 @@ mod test {
},
],
};
println!(
"want {:?}",
want.get_encoded_with_param(&DapVersion::Latest).unwrap()
);

let got =
AggregationJobResp::get_decoded_with_param(&DapVersion::Latest, TEST_DATA).unwrap();
AggregationJobResp::get_decoded_with_param(&DapVersion::Draft09, TEST_DATA).unwrap();
assert_eq!(got, want);
let draft_latest_data = [&[1], TEST_DATA].concat();
let got =
AggregationJobResp::get_decoded_with_param(&DapVersion::Latest, &draft_latest_data)
.unwrap();
assert_eq!(got, want);
assert_eq!(
AggregationJobResp::Processing,
AggregationJobResp::get_decoded_with_param(&DapVersion::Latest, &[0]).unwrap(),
);
}

#[test]
Expand Down
13 changes: 7 additions & 6 deletions crates/daphne/src/protocol/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
use super::{
check_no_duplicates,
report_init::{InitializedReport, WithPeerPrepShare},
ReadyAggregationJobResp,
};
use crate::{
constants::DapAggregatorRole,
error::DapAbort,
fatal_error,
hpke::{info_and_aad, HpkeConfig, HpkeDecrypter},
messages::{
self, encode_u32_bytes, AggregationJobInitReq, AggregationJobResp, Base64Encode,
BatchSelector, HpkeCiphertext, PartialBatchSelector, PrepareInit, Report, ReportError,
ReportId, ReportShare, TaskId, Transition, TransitionVar,
self, encode_u32_bytes, AggregationJobInitReq, Base64Encode, BatchSelector, HpkeCiphertext,
PartialBatchSelector, PrepareInit, Report, ReportError, ReportId, ReportShare, TaskId,
Transition, TransitionVar,
},
metrics::{DaphneMetrics, ReportStatus},
protocol::{decode_ping_pong_framed, PingPongMessageType},
Expand Down Expand Up @@ -285,7 +286,7 @@ impl DapTaskConfig {
report_status: &HashMap<ReportId, ReportProcessedStatus>,
part_batch_sel: &PartialBatchSelector,
initialized_reports: &[InitializedReport<WithPeerPrepShare>],
) -> Result<(DapAggregateSpan<DapAggregateShare>, AggregationJobResp), DapError> {
) -> Result<(DapAggregateSpan<DapAggregateShare>, ReadyAggregationJobResp), DapError> {
let num_reports = initialized_reports.len();
let mut agg_span = DapAggregateSpan::default();
let mut transitions = Vec::with_capacity(num_reports);
Expand Down Expand Up @@ -361,7 +362,7 @@ impl DapTaskConfig {
});
}

Ok((agg_span, AggregationJobResp { transitions }))
Ok((agg_span, ReadyAggregationJobResp { transitions }))
}

/// Leader: Consume the `AggregationJobResp` message sent by the Helper and compute the
Expand All @@ -370,7 +371,7 @@ impl DapTaskConfig {
&self,
task_id: &TaskId,
state: DapAggregationJobState,
agg_job_resp: AggregationJobResp,
agg_job_resp: ReadyAggregationJobResp,
metrics: &dyn DaphneMetrics,
) -> Result<DapAggregateSpan<DapAggregateShare>, DapError> {
if agg_job_resp.transitions.len() != state.seq.len() {
Expand Down
37 changes: 32 additions & 5 deletions crates/daphne/src/protocol/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

use crate::messages;
use prio::codec::{CodecError, Decode as _};
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, io::Cursor};

pub(crate) mod aggregator;
Expand Down Expand Up @@ -59,6 +61,20 @@ fn decode_ping_pong_framed(
Ok(&bytes[message_start..])
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))]
pub struct ReadyAggregationJobResp {
pub transitions: Vec<messages::Transition>,
}

impl From<ReadyAggregationJobResp> for messages::AggregationJobResp {
fn from(value: ReadyAggregationJobResp) -> Self {
Self::Ready {
transitions: value.transitions,
}
}
}

#[cfg(test)]
mod test {
use super::{report_init::InitializedReport, PingPongMessageType};
Expand Down Expand Up @@ -250,7 +266,7 @@ mod test {

let (agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);
assert_eq!(agg_span.report_count(), 3);
assert_eq!(agg_job_resp.transitions.len(), 3);
assert_eq!(agg_job_resp.unwrap_ready().transitions.len(), 3);
}

test_versions! { produce_agg_job_req }
Expand Down Expand Up @@ -376,6 +392,7 @@ mod test {
t.produce_agg_job_req(&DapAggregationParam::Empty, reports.clone());
let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);

let agg_job_resp = agg_job_resp.unwrap_ready();
assert_eq!(agg_job_resp.transitions.len(), 1);
assert_matches!(
agg_job_resp.transitions[0].var,
Expand Down Expand Up @@ -414,6 +431,7 @@ mod test {
};
let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);

let agg_job_resp = agg_job_resp.unwrap_ready();
assert_eq!(agg_job_resp.transitions.len(), 1);
assert_matches!(
agg_job_resp.transitions[0].var,
Expand Down Expand Up @@ -449,6 +467,7 @@ mod test {
};
let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);

let agg_job_resp = agg_job_resp.unwrap_ready();
assert_eq!(agg_job_resp.transitions.len(), 1);
assert_matches!(
agg_job_resp.transitions[0].var,
Expand All @@ -469,6 +488,7 @@ mod test {
t.produce_agg_job_req(&DapAggregationParam::Empty, reports.clone());
let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);

let agg_job_resp = agg_job_resp.unwrap_ready();
assert_eq!(agg_job_resp.transitions.len(), 1);
assert_matches!(
agg_job_resp.transitions[0].var,
Expand Down Expand Up @@ -514,6 +534,7 @@ mod test {

let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);

let agg_job_resp = agg_job_resp.unwrap_ready();
assert_eq!(agg_job_resp.transitions.len(), 2);
assert_matches!(
agg_job_resp.transitions[0].var,
Expand All @@ -535,8 +556,9 @@ mod test {
]);
let (leader_state, agg_job_init_req) =
t.produce_agg_job_req(&DapAggregationParam::Empty, reports);
let (_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);
let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);

let mut agg_job_resp = agg_job_resp.unwrap_ready();
// Helper sends transitions out of order.
let tmp = agg_job_resp.transitions[0].clone();
agg_job_resp.transitions[0] = agg_job_resp.transitions[1].clone();
Expand All @@ -558,8 +580,9 @@ mod test {
]);
let (leader_state, agg_job_init_req) =
t.produce_agg_job_req(&DapAggregationParam::Empty, reports);
let (_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);
let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);

let mut agg_job_resp = agg_job_resp.unwrap_ready();
// Helper sends a transition twice.
let repeated_transition = agg_job_resp.transitions[0].clone();
agg_job_resp.transitions.push(repeated_transition);
Expand All @@ -581,8 +604,9 @@ mod test {
]);
let (leader_state, agg_job_init_req) =
t.produce_agg_job_req(&DapAggregationParam::Empty, reports);
let (_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);
let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);

let mut agg_job_resp = agg_job_resp.unwrap_ready();
// Helper sent a transition with an unrecognized report ID.
agg_job_resp.transitions.push(Transition {
report_id: ReportId(rng.gen()),
Expand All @@ -602,8 +626,9 @@ mod test {
let reports = t.produce_reports(vec![DapMeasurement::U32Vec(vec![1; 10])]);
let (leader_state, agg_job_init_req) =
t.produce_agg_job_req(&DapAggregationParam::Empty, reports);
let (_helper_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);
let (_helper_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);

let mut agg_job_resp = agg_job_resp.unwrap_ready();
// Helper sent a transition with an unrecognized report ID. Simulate this by flipping the
// first bit of the report ID.
agg_job_resp.transitions[0].report_id.0[0] ^= 1;
Expand Down Expand Up @@ -631,6 +656,7 @@ mod test {

let (leader_agg_span, helper_agg_span) = {
let (helper_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);
let agg_job_resp = agg_job_resp.unwrap_ready();
let leader_agg_span = t.consume_agg_job_resp(leader_state, agg_job_resp);

(leader_agg_span, helper_agg_span)
Expand Down Expand Up @@ -683,6 +709,7 @@ mod test {
.map(|r| r.report_share.report_metadata.id)
.collect::<Vec<_>>();
let (helper_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req);
let agg_job_resp = agg_job_resp.unwrap_ready();

assert_eq!(2, helper_agg_span.report_count());
assert_eq!(3, agg_job_resp.transitions.len());
Expand Down
2 changes: 1 addition & 1 deletion crates/daphne/src/roles/helper/handle_agg_job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ impl HandleAggJob<InitializedReports> {
0, /* vdaf step */
);

return Ok(agg_job_resp);
return Ok(agg_job_resp.into());
}
}

Expand Down
7 changes: 7 additions & 0 deletions crates/daphne/src/roles/leader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,13 @@ async fn run_agg_job<A: DapLeader>(
AggregationJobResp::get_decoded_with_param(&task_config.version, &resp.payload)
.map_err(|e| DapAbort::from_codec_error(e, *task_id))?;

let agg_job_resp = match agg_job_resp {
AggregationJobResp::Ready { transitions } => {
crate::protocol::ReadyAggregationJobResp { transitions }
}
AggregationJobResp::Processing => todo!("polling not implemented yet"),
};

// Handle AggregationJobResp.
let agg_span =
task_config.consume_agg_job_resp(task_id, agg_job_state, agg_job_resp, metrics)?;
Expand Down
4 changes: 2 additions & 2 deletions crates/daphne/src/roles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ mod test {
.payload,
)
.unwrap();
let transition = &agg_job_resp.transitions[0];
let transition = agg_job_resp.unwrap_ready().transitions.remove(0);

// Expect failure due to invalid ciphertext.
assert_matches!(
Expand Down Expand Up @@ -878,7 +878,7 @@ mod test {
.payload,
)
.unwrap();
let transition = &agg_job_resp.transitions[0];
let transition = agg_job_resp.unwrap_ready().transitions.remove(0);

// Expect success due to valid ciphertext.
assert_matches!(transition.var, TransitionVar::Continued(_));
Expand Down
Loading

0 comments on commit 3df9a4d

Please sign in to comment.