diff --git a/Cargo.lock b/Cargo.lock index cfae1f3fc..05f7cc9c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -982,6 +982,8 @@ dependencies = [ "axum", "axum-extra", "bytes", + "capnp", + "capnpc", "chrono", "constcat", "daphne", @@ -1014,6 +1016,7 @@ dependencies = [ "tracing-core", "tracing-subscriber", "url", + "wasm-bindgen", "webpki", "worker", ] diff --git a/Cargo.toml b/Cargo.toml index c7ba9a577..d6c38be41 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,8 +88,9 @@ tracing = "0.1.40" tracing-core = "0.1.32" tracing-subscriber = "0.3.18" url = { version = "2.5.4", features = ["serde"] } +wasm-bindgen = "0.2.99" webpki = "0.22.4" -worker = { version = "0.5", features = ["http"] } +worker = "0.5" x509-parser = "0.15.1" [workspace.dependencies.sentry] diff --git a/crates/daphne-server/src/roles/aggregator.rs b/crates/daphne-server/src/roles/aggregator.rs index 26b1a2d55..fa199bee3 100644 --- a/crates/daphne-server/src/roles/aggregator.rs +++ b/crates/daphne-server/src/roles/aggregator.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use std::{future::ready, num::NonZeroUsize, ops::Range, time::SystemTime}; @@ -79,6 +79,7 @@ impl DapAggregator for crate::App { #[tracing::instrument(skip(self))] async fn get_agg_share( &self, + _version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result { diff --git a/crates/daphne-server/src/roles/helper.rs b/crates/daphne-server/src/roles/helper.rs index 5257d04ce..c22b4f36d 100644 --- a/crates/daphne-server/src/roles/helper.rs +++ b/crates/daphne-server/src/roles/helper.rs @@ -3,7 +3,8 @@ use axum::async_trait; use daphne::{ - messages::{request::AggregationJobRequestHash, AggregationJobId, TaskId}, + fatal_error, + messages::{request::AggregationJobRequestHash, AggregationJobId, AggregationJobResp, TaskId}, roles::DapHelper, DapError, DapVersion, }; @@ -20,4 +21,13 @@ impl DapHelper for crate::App { // the server implementation can't check for this Ok(()) } + + async fn poll_aggregated( + &self, + _version: DapVersion, + _task_id: &TaskId, + _agg_job_id: &AggregationJobId, + ) -> Result { + Err(fatal_error!(err = "polling not implemented")) + } } diff --git a/crates/daphne-service-utils/build.rs b/crates/daphne-service-utils/build.rs index 317e18715..0f3c692f7 100644 --- a/crates/daphne-service-utils/build.rs +++ b/crates/daphne-service-utils/build.rs @@ -11,7 +11,10 @@ fn main() { #[cfg(feature = "durable_requests")] compiler .file("./src/durable_requests/durable_request.capnp") - .file("./src/durable_requests/bindings/aggregation_job_store.capnp"); + .file("./src/durable_requests/bindings/aggregation_job_store.capnp") + .file("./src/durable_requests/bindings/aggregate_store_v2.capnp") + .file("./src/durable_requests/bindings/agg_job_response_store.capnp") + .file("./src/durable_requests/bindings/replay_checker.capnp"); #[cfg(feature = "compute-offload")] compiler.file("./src/compute_offload/compute_offload.capnp"); diff --git a/crates/daphne-service-utils/src/compute_offload/compute_offload.capnp b/crates/daphne-service-utils/src/compute_offload/compute_offload.capnp index dacd261f0..cc9708bfe 100644 --- a/crates/daphne-service-utils/src/compute_offload/compute_offload.capnp +++ b/crates/daphne-service-utils/src/compute_offload/compute_offload.capnp @@ -3,8 +3,6 @@ @0xd932f3d934afce3b; -# Utilities - using Base = import "../capnproto/base.capnp"; using VdafConfig = Text; # json encoded diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/agg_job_response_store.capnp b/crates/daphne-service-utils/src/durable_requests/bindings/agg_job_response_store.capnp new file mode 100644 index 000000000..7cd79c84a --- /dev/null +++ b/crates/daphne-service-utils/src/durable_requests/bindings/agg_job_response_store.capnp @@ -0,0 +1,38 @@ +# Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +@0xd30da336463f3205; + +using Base = import "../../capnproto/base.capnp"; + +struct AggregationJobResponse { + enum ReportError { + reserved @0; + batchCollected @1; + reportReplayed @2; + reportDropped @3; + hpkeUnknownConfigId @4; + hpkeDecryptError @5; + vdafPrepError @6; + batchSaturated @7; + taskExpired @8; + invalidMessage @9; + reportTooEarly @10; + taskNotStarted @11; + } + + struct TransitionVar { + union { + continued @0 :Data; + failed @1 :ReportError; + } + } + + struct Transition { + reportId @0 :Base.ReportId; + var @1 :TransitionVar; + } + + transitions @0 :List(Transition); +} + diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/agg_job_response_store.rs b/crates/daphne-service-utils/src/durable_requests/bindings/agg_job_response_store.rs new file mode 100644 index 000000000..e5a893b45 --- /dev/null +++ b/crates/daphne-service-utils/src/durable_requests/bindings/agg_job_response_store.rs @@ -0,0 +1,159 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use daphne::{ + messages::{AggregationJobId, ReadyAggregationJobResp, TaskId, Transition, TransitionVar}, + DapVersion, +}; + +use crate::{ + agg_job_response_store_capnp::aggregation_job_response, + capnproto::{ + decode_list, encode_list, usize_to_capnp_len, CapnprotoPayloadDecode, + CapnprotoPayloadEncode, + }, + durable_requests::ObjectIdFrom, +}; + +super::define_do_binding! { + const BINDING = "AGGREGATE_JOB_RESULT_STORE"; + enum Command { + Get = "/get", + Put = "/put", + } + + fn name( + (version, task_id, agg_job_id): + (DapVersion, &'n TaskId, &'n AggregationJobId) + ) -> ObjectIdFrom { + ObjectIdFrom::Name(format!("{version}/task/{task_id}/agg_job/{agg_job_id}")) + } +} + +impl CapnprotoPayloadEncode for ReadyAggregationJobResp { + type Builder<'a> = aggregation_job_response::Builder<'a>; + + fn encode_to_builder(&self, builder: Self::Builder<'_>) { + let Self { transitions } = self; + encode_list( + transitions, + builder.init_transitions(usize_to_capnp_len(transitions.len())), + ); + } +} + +impl CapnprotoPayloadEncode for Transition { + type Builder<'a> = aggregation_job_response::transition::Builder<'a>; + + fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { + let Self { report_id, var } = self; + report_id.encode_to_builder(builder.reborrow().init_report_id()); + let mut builder = builder.init_var(); + match var { + TransitionVar::Continued(vec) => builder.set_continued(vec), + TransitionVar::Failed(report_error) => builder.set_failed((*report_error).into()), + } + } +} + +impl CapnprotoPayloadDecode for Transition { + type Reader<'a> = aggregation_job_response::transition::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result + where + Self: Sized, + { + Ok(Self { + report_id: <_>::decode_from_reader(reader.get_report_id()?)?, + var: match reader.get_var()?.which()? { + aggregation_job_response::transition_var::Which::Continued(data) => { + TransitionVar::Continued(data?.to_vec()) + } + aggregation_job_response::transition_var::Which::Failed(report_error) => { + TransitionVar::Failed(report_error?.into()) + } + }, + }) + } +} + +impl CapnprotoPayloadDecode for ReadyAggregationJobResp { + type Reader<'a> = aggregation_job_response::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result + where + Self: Sized, + { + Ok(Self { + transitions: decode_list::(reader.get_transitions()?)?, + }) + } +} + +impl From for aggregation_job_response::ReportError { + fn from(error: daphne::messages::ReportError) -> Self { + match error { + daphne::messages::ReportError::Reserved => Self::Reserved, + daphne::messages::ReportError::BatchCollected => Self::BatchCollected, + daphne::messages::ReportError::ReportReplayed => Self::ReportReplayed, + daphne::messages::ReportError::ReportDropped => Self::ReportDropped, + daphne::messages::ReportError::HpkeUnknownConfigId => Self::HpkeUnknownConfigId, + daphne::messages::ReportError::HpkeDecryptError => Self::HpkeDecryptError, + daphne::messages::ReportError::VdafPrepError => Self::VdafPrepError, + daphne::messages::ReportError::BatchSaturated => Self::BatchSaturated, + daphne::messages::ReportError::TaskExpired => Self::TaskExpired, + daphne::messages::ReportError::InvalidMessage => Self::InvalidMessage, + daphne::messages::ReportError::ReportTooEarly => Self::ReportTooEarly, + daphne::messages::ReportError::TaskNotStarted => Self::TaskNotStarted, + } + } +} + +impl From for daphne::messages::ReportError { + fn from(error: aggregation_job_response::ReportError) -> Self { + match error { + aggregation_job_response::ReportError::Reserved => Self::Reserved, + aggregation_job_response::ReportError::BatchCollected => Self::BatchCollected, + aggregation_job_response::ReportError::ReportReplayed => Self::ReportReplayed, + aggregation_job_response::ReportError::ReportDropped => Self::ReportDropped, + aggregation_job_response::ReportError::HpkeUnknownConfigId => Self::HpkeUnknownConfigId, + aggregation_job_response::ReportError::HpkeDecryptError => Self::HpkeDecryptError, + aggregation_job_response::ReportError::VdafPrepError => Self::VdafPrepError, + aggregation_job_response::ReportError::BatchSaturated => Self::BatchSaturated, + aggregation_job_response::ReportError::TaskExpired => Self::TaskExpired, + aggregation_job_response::ReportError::InvalidMessage => Self::InvalidMessage, + aggregation_job_response::ReportError::ReportTooEarly => Self::ReportTooEarly, + aggregation_job_response::ReportError::TaskNotStarted => Self::TaskNotStarted, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::capnproto::{CapnprotoPayloadDecodeExt as _, CapnprotoPayloadEncodeExt as _}; + use daphne::messages::ReportId; + use rand::{thread_rng, Rng}; + + fn gen_agg_job_resp() -> ReadyAggregationJobResp { + ReadyAggregationJobResp { + transitions: vec![ + Transition { + report_id: ReportId(thread_rng().gen()), + var: TransitionVar::Continued(vec![1, 2, 3]), + }, + Transition { + report_id: ReportId(thread_rng().gen()), + var: TransitionVar::Failed(daphne::messages::ReportError::InvalidMessage), + }, + ], + } + } + + #[test] + fn serialization_deserialization_round_trip() { + let this = gen_agg_job_resp(); + let other = ReadyAggregationJobResp::decode_from_bytes(&this.encode_to_bytes()).unwrap(); + assert_eq!(this, other); + } +} diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs index 0f5b1181d..960fbb94f 100644 --- a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs +++ b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use std::collections::HashSet; @@ -93,66 +93,7 @@ impl CapnprotoPayloadEncode for AggregateStoreMergeReq { report.set_high(high); } } - { - let mut agg_share_delta_packet = builder.reborrow().init_agg_share_delta(); - agg_share_delta_packet.set_report_count(agg_share_delta.report_count); - agg_share_delta_packet.set_min_time(agg_share_delta.min_time); - agg_share_delta_packet.set_max_time(agg_share_delta.max_time); - { - let checksum = agg_share_delta_packet - .reborrow() - .init_checksum(agg_share_delta.checksum.len().try_into().unwrap()); - checksum.copy_from_slice(&agg_share_delta.checksum); - } - { - macro_rules! make_encode { - ($func_name:ident, $agg_share_type:ident, $field_trait:ident) => { - fn $func_name<'b, F, B, const ENCODED_SIZE: usize>( - field: &$agg_share_type, - get_bytes: B, - ) where - F: $field_trait + Into<[u8; ENCODED_SIZE]>, - B: FnOnce(u32) -> &'b mut [u8], - { - let mut bytes = get_bytes( - (F::ENCODED_SIZE * field.as_ref().len()) - .try_into() - .expect("trying to encode a buffer longer than u32::MAX"), - ); - for f in field.as_ref() { - let f: [u8; ENCODED_SIZE] = (*f).into(); - bytes[..ENCODED_SIZE].copy_from_slice(&f); - bytes = &mut bytes[ENCODED_SIZE..]; - } - } - }; - } - make_encode!(encode_draft09, AggregateShareDraft09, FieldElementDraft09); - make_encode!(encode, AggregateShare, FieldElement); - let mut data = agg_share_delta_packet.init_data(); - match &self.agg_share_delta.data { - Some(VdafAggregateShare::Field64Draft09(field)) => { - encode_draft09(field, |len| data.init_field64_draft09(len)); - } - Some(VdafAggregateShare::Field128Draft09(field)) => { - encode_draft09(field, |len| data.init_field128_draft09(len)); - } - Some(VdafAggregateShare::Field32Draft09(field)) => { - encode_draft09(field, |len| data.init_field_prio2_draft09(len)); - } - Some(VdafAggregateShare::Field64(field)) => { - encode(field, |len| data.init_field64(len)); - } - Some(VdafAggregateShare::Field128(field)) => { - encode(field, |len| data.init_field128(len)); - } - Some(VdafAggregateShare::Field32(field)) => { - encode(field, |len| data.init_field_prio2(len)); - } - None => data.set_none(()), - }; - } - } + agg_share_delta.encode_to_builder(builder.reborrow().init_agg_share_delta()); { let AggregateStoreMergeOptions { skip_replay_protection, @@ -167,76 +108,8 @@ impl CapnprotoPayloadDecode for AggregateStoreMergeReq { type Reader<'a> = aggregate_store_merge_req::Reader<'a>; fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result { - let agg_share_delta = { - let agg_share_delta = reader.get_agg_share_delta()?; - let data = { - macro_rules! make_decode { - ($func_name:ident, $agg_share_type:ident, $field_trait:ident, $field_error:ident) => { - fn $func_name(fields: &[u8]) -> capnp::Result<$agg_share_type> - where - F: $field_trait + for<'s> TryFrom<&'s [u8], Error = $field_error>, - { - let iter = fields.chunks_exact(F::ENCODED_SIZE); - if let length @ 1.. = iter.remainder().len() { - return Err(capnp::Error { - kind: capnp::ErrorKind::Failed, - extra: format!( - "leftover bytes still present in buffer: {length}" - ), - }); - } - Ok($agg_share_type::from( - iter.map(|f| f.try_into().unwrap()).collect::>(), - )) - } - }; - } - make_decode!( - decode_draft09, - AggregateShareDraft09, - FieldElementDraft09, - FieldErrorDraft09 - ); - make_decode!(decode, AggregateShare, FieldElement, FieldError); - match agg_share_delta.get_data().which()? { - dap_aggregate_share::data::Which::Field64Draft09(field) => { - Some(VdafAggregateShare::Field64Draft09(decode_draft09(field?)?)) - } - dap_aggregate_share::data::Which::Field128Draft09(field) => { - Some(VdafAggregateShare::Field128Draft09(decode_draft09(field?)?)) - } - dap_aggregate_share::data::Which::FieldPrio2Draft09(field) => { - Some(VdafAggregateShare::Field32Draft09(decode_draft09(field?)?)) - } - - dap_aggregate_share::data::Which::Field64(field) => { - Some(VdafAggregateShare::Field64(decode(field?)?)) - } - dap_aggregate_share::data::Which::Field128(field) => { - Some(VdafAggregateShare::Field128(decode(field?)?)) - } - dap_aggregate_share::data::Which::FieldPrio2(field) => { - Some(VdafAggregateShare::Field32(decode(field?)?)) - } - dap_aggregate_share::data::Which::None(()) => None, - } - }; - DapAggregateShare { - report_count: agg_share_delta.get_report_count(), - min_time: agg_share_delta.get_min_time(), - max_time: agg_share_delta.get_max_time(), - checksum: agg_share_delta - .get_checksum()? - .try_into() - .map_err(|_| capnp::Error { - kind: capnp::ErrorKind::Failed, - extra: "checksum had unexpected size".into(), - })?, - data, - } - }; - let contained_reports = { - reader + Ok(Self { + contained_reports: reader .get_contained_reports()? .into_iter() .map(|report| { @@ -248,11 +121,8 @@ impl CapnprotoPayloadDecode for AggregateStoreMergeReq { buffer[8..].copy_from_slice(&high.to_le_bytes()); ReportId(buffer) }) - .collect() - }; - Ok(Self { - contained_reports, - agg_share_delta, + .collect(), + agg_share_delta: <_>::decode_from_reader(reader.get_agg_share_delta()?)?, options: AggregateStoreMergeOptions { skip_replay_protection: reader.get_options()?.get_skip_replay_protection(), }, @@ -260,6 +130,138 @@ impl CapnprotoPayloadDecode for AggregateStoreMergeReq { } } +impl CapnprotoPayloadEncode for DapAggregateShare { + type Builder<'a> = dap_aggregate_share::Builder<'a>; + + fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { + builder.set_report_count(self.report_count); + builder.set_min_time(self.min_time); + builder.set_max_time(self.max_time); + builder.set_checksum(&self.checksum); + { + macro_rules! make_encode { + ($func_name:ident, $agg_share_type:ident, $field_trait:ident) => { + fn $func_name<'b, F, B, const ENCODED_SIZE: usize>( + field: &$agg_share_type, + get_bytes: B, + ) where + F: $field_trait + Into<[u8; ENCODED_SIZE]>, + B: FnOnce(u32) -> &'b mut [u8], + { + let mut bytes = get_bytes( + (F::ENCODED_SIZE * field.as_ref().len()) + .try_into() + .expect("trying to encode a buffer longer than u32::MAX"), + ); + for f in field.as_ref() { + let f: [u8; ENCODED_SIZE] = (*f).into(); + bytes[..ENCODED_SIZE].copy_from_slice(&f); + bytes = &mut bytes[ENCODED_SIZE..]; + } + } + }; + } + make_encode!(encode_draft09, AggregateShareDraft09, FieldElementDraft09); + make_encode!(encode, AggregateShare, FieldElement); + let mut data = builder.init_data(); + match &self.data { + Some(VdafAggregateShare::Field64Draft09(field)) => { + encode_draft09(field, |len| data.init_field64_draft09(len)); + } + Some(VdafAggregateShare::Field128Draft09(field)) => { + encode_draft09(field, |len| data.init_field128_draft09(len)); + } + Some(VdafAggregateShare::Field32Draft09(field)) => { + encode_draft09(field, |len| data.init_field_prio2_draft09(len)); + } + Some(VdafAggregateShare::Field64(field)) => { + encode(field, |len| data.init_field64(len)); + } + Some(VdafAggregateShare::Field128(field)) => { + encode(field, |len| data.init_field128(len)); + } + Some(VdafAggregateShare::Field32(field)) => { + encode(field, |len| data.init_field_prio2(len)); + } + None => data.set_none(()), + }; + } + } +} + +impl CapnprotoPayloadDecode for DapAggregateShare { + type Reader<'a> = dap_aggregate_share::Reader<'a>; + + fn decode_from_reader(agg_share_delta: Self::Reader<'_>) -> capnp::Result + where + Self: Sized, + { + let data = { + macro_rules! make_decode { + ($func_name:ident, $agg_share_type:ident, $field_trait:ident, $field_error:ident) => { + fn $func_name(fields: &[u8]) -> capnp::Result<$agg_share_type> + where + F: $field_trait + for<'s> TryFrom<&'s [u8], Error = $field_error>, + { + let iter = fields.chunks_exact(F::ENCODED_SIZE); + if let length @ 1.. = iter.remainder().len() { + return Err(capnp::Error { + kind: capnp::ErrorKind::Failed, + extra: format!("leftover bytes still present in buffer: {length}"), + }); + } + Ok($agg_share_type::from( + iter.map(|f| f.try_into().unwrap()).collect::>(), + )) + } + }; + } + make_decode!( + decode_draft09, + AggregateShareDraft09, + FieldElementDraft09, + FieldErrorDraft09 + ); + make_decode!(decode, AggregateShare, FieldElement, FieldError); + match agg_share_delta.get_data().which()? { + dap_aggregate_share::data::Which::Field64Draft09(field) => { + Some(VdafAggregateShare::Field64Draft09(decode_draft09(field?)?)) + } + dap_aggregate_share::data::Which::Field128Draft09(field) => { + Some(VdafAggregateShare::Field128Draft09(decode_draft09(field?)?)) + } + dap_aggregate_share::data::Which::FieldPrio2Draft09(field) => { + Some(VdafAggregateShare::Field32Draft09(decode_draft09(field?)?)) + } + + dap_aggregate_share::data::Which::Field64(field) => { + Some(VdafAggregateShare::Field64(decode(field?)?)) + } + dap_aggregate_share::data::Which::Field128(field) => { + Some(VdafAggregateShare::Field128(decode(field?)?)) + } + dap_aggregate_share::data::Which::FieldPrio2(field) => { + Some(VdafAggregateShare::Field32(decode(field?)?)) + } + dap_aggregate_share::data::Which::None(()) => None, + } + }; + Ok(Self { + report_count: agg_share_delta.get_report_count(), + min_time: agg_share_delta.get_min_time(), + max_time: agg_share_delta.get_max_time(), + checksum: agg_share_delta + .get_checksum()? + .try_into() + .map_err(|_| capnp::Error { + kind: capnp::ErrorKind::Failed, + extra: "checksum had unexpected size".into(), + })?, + data, + }) + } +} + #[derive(Debug, Serialize, Deserialize)] pub enum AggregateStoreMergeResp { Ok, diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store_v2.capnp b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store_v2.capnp new file mode 100644 index 000000000..6cfa4fee7 --- /dev/null +++ b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store_v2.capnp @@ -0,0 +1,11 @@ +# Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +@0x822b0e344bf68531; + +using Base = import "../../capnproto/base.capnp"; + +struct PutRequest { + aggShareDelta @0 :import "../durable_request.capnp".DapAggregateShare; + aggJobId @1 :Base.AggregationJobId; +} diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store_v2.rs b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store_v2.rs new file mode 100644 index 000000000..dee4aeab4 --- /dev/null +++ b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store_v2.rs @@ -0,0 +1,194 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use daphne::{ + messages::{AggregationJobId, TaskId}, + DapAggregateShare, DapBatchBucket, DapVersion, +}; + +use crate::{ + aggregate_store_v2_capnp, + capnproto::{CapnprotoPayloadDecode, CapnprotoPayloadEncode}, + durable_requests::ObjectIdFrom, +}; + +super::define_do_binding! { + const BINDING = "AGGREGATE_STORE"; + enum Command { + Get = "/get", + Put = "/put", + MarkCollected = "/mark-collected", + CheckCollected = "/check-collected", + } + + fn name( + (version, task_id, bucket): + (DapVersion, &'n TaskId, &'n DapBatchBucket) + ) -> ObjectIdFrom { + ObjectIdFrom::Name(format!("{version}/task/{task_id}/batch_bucket/{bucket}")) + } +} + +#[derive(Debug, PartialEq, Eq)] +pub struct PutRequest { + pub agg_share_delta: DapAggregateShare, + pub agg_job_id: AggregationJobId, +} + +impl CapnprotoPayloadEncode for PutRequest { + type Builder<'a> = aggregate_store_v2_capnp::put_request::Builder<'a>; + + fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { + let Self { + agg_share_delta, + agg_job_id, + } = self; + agg_share_delta.encode_to_builder(builder.reborrow().init_agg_share_delta()); + agg_job_id.encode_to_builder(builder.reborrow().init_agg_job_id()); + } +} + +impl CapnprotoPayloadDecode for PutRequest { + type Reader<'a> = aggregate_store_v2_capnp::put_request::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result { + Ok(Self { + agg_share_delta: <_>::decode_from_reader(reader.get_agg_share_delta()?)?, + agg_job_id: <_>::decode_from_reader(reader.get_agg_job_id()?)?, + }) + } +} + +#[cfg(test)] +mod test { + use prio::{ + codec::Decode, + field::{Field128, Field64, FieldElement, FieldPrio2}, + vdaf::AggregateShare, + }; + use prio_draft09::{ + codec::Decode as DecodeDraft09, + field::{ + Field128 as Field128Draft09, Field64 as Field64Draft09, + FieldElement as FieldElementDraft09, FieldPrio2 as FieldPrio2Draft09, + }, + vdaf::AggregateShare as AggregateShareDraft09, + }; + use rand::{thread_rng, Rng}; + + use crate::capnproto::{CapnprotoPayloadDecodeExt as _, CapnprotoPayloadEncodeExt as _}; + + use super::*; + use daphne::vdaf::VdafAggregateShare; + + #[test] + fn serialization_deserialization_round_trip_draft09() { + let mut rng = thread_rng(); + for len in 0..20 { + let test_data = [ + VdafAggregateShare::Field64Draft09(AggregateShareDraft09::from( + (0..len) + .map(|_| { + Field64Draft09::get_decoded( + &rng.gen::<[_; Field64Draft09::ENCODED_SIZE]>(), + ) + .unwrap() + }) + .collect::>(), + )), + VdafAggregateShare::Field128Draft09(AggregateShareDraft09::from( + (0..len) + .map(|_| { + Field128Draft09::get_decoded( + &rng.gen::<[_; Field128Draft09::ENCODED_SIZE]>(), + ) + .unwrap() + }) + .collect::>(), + )), + VdafAggregateShare::Field32Draft09(AggregateShareDraft09::from( + (0..len) + .map(|_| { + // idk how to consistently generate a valid FieldPrio2 value, so I just + // retry until I hit a valid one. Doesn't usualy take too long. + (0..) + .find_map(|_| { + FieldPrio2Draft09::get_decoded(&rng.gen::<[_; 4]>()).ok() + }) + .unwrap() + }) + .collect::>(), + )), + ] + .map(Some) + .into_iter() + .chain([None]); + for data in test_data { + let this = PutRequest { + agg_job_id: AggregationJobId(rng.gen()), + agg_share_delta: DapAggregateShare { + report_count: rng.gen(), + min_time: rng.gen(), + max_time: rng.gen(), + checksum: rng.gen(), + data, + }, + }; + let other = PutRequest::decode_from_bytes(&this.encode_to_bytes()).unwrap(); + assert_eq!(this, other); + } + } + } + + #[test] + fn serialization_deserialization_round_trip() { + let mut rng = thread_rng(); + for len in 0..20 { + let test_data = [ + VdafAggregateShare::Field64(AggregateShare::from( + (0..len) + .map(|_| { + Field64::get_decoded(&rng.gen::<[_; Field64::ENCODED_SIZE]>()).unwrap() + }) + .collect::>(), + )), + VdafAggregateShare::Field128(AggregateShare::from( + (0..len) + .map(|_| { + Field128::get_decoded(&rng.gen::<[_; Field128::ENCODED_SIZE]>()) + .unwrap() + }) + .collect::>(), + )), + VdafAggregateShare::Field32(AggregateShare::from( + (0..len) + .map(|_| { + // idk how to consistently generate a valid FieldPrio2 value, so I just + // retry until I hit a valid one. Doesn't usualy take too long. + (0..) + .find_map(|_| FieldPrio2::get_decoded(&rng.gen::<[_; 4]>()).ok()) + .unwrap() + }) + .collect::>(), + )), + ] + .map(Some) + .into_iter() + .chain([None]); + for data in test_data { + let this = PutRequest { + agg_job_id: AggregationJobId(rng.gen()), + agg_share_delta: DapAggregateShare { + report_count: rng.gen(), + min_time: rng.gen(), + max_time: rng.gen(), + checksum: rng.gen(), + data, + }, + }; + let other = PutRequest::decode_from_bytes(&this.encode_to_bytes()).unwrap(); + assert_eq!(this, other); + } + } + } +} diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.rs b/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.rs index 9271f707d..5eaf886b9 100644 --- a/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.rs +++ b/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.rs @@ -18,6 +18,7 @@ super::define_do_binding! { enum Command { NewJob = "/new-job", + ContainsJob = "/contains", ListJobIds = "/job-ids", } diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs b/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs index 8030e9b52..20d9a8425 100644 --- a/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs +++ b/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs @@ -6,8 +6,11 @@ //! //! It also defines types that are used as the body of requests sent to these objects. +pub mod agg_job_response_store; mod aggregate_store; +pub mod aggregate_store_v2; pub mod aggregation_job_store; +pub mod replay_checker; #[cfg(feature = "test-utils")] mod test_state_cleaner; @@ -48,6 +51,8 @@ macro_rules! define_do_binding { fn name($params:tt : $params_ty:ty) -> ObjectIdFrom $name_impl:block ) => { + $(const _: () = assert!(matches!($route.as_bytes().first(), Some(b'/')));)* + #[derive( serde::Serialize, serde::Deserialize, diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/replay_checker.capnp b/crates/daphne-service-utils/src/durable_requests/bindings/replay_checker.capnp new file mode 100644 index 000000000..08c80633e --- /dev/null +++ b/crates/daphne-service-utils/src/durable_requests/bindings/replay_checker.capnp @@ -0,0 +1,11 @@ +# Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +@0xaaa529cce40f45d7; + +using Base = import "../../capnproto/base.capnp"; + +struct CheckReplaysFor { + reports @0 :List(Base.ReportId); + aggregationJobId @1 :Base.AggregationJobId; +} diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/replay_checker.rs b/crates/daphne-service-utils/src/durable_requests/bindings/replay_checker.rs new file mode 100644 index 000000000..166e05fce --- /dev/null +++ b/crates/daphne-service-utils/src/durable_requests/bindings/replay_checker.rs @@ -0,0 +1,68 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use crate::{ + capnproto::{ + decode_list, encode_list, usize_to_capnp_len, CapnprotoPayloadDecode, + CapnprotoPayloadEncode, + }, + durable_requests::ObjectIdFrom, + replay_checker_capnp::check_replays_for, +}; +use daphne::messages::{AggregationJobId, ReportId, TaskId, Time}; +use serde::{Deserialize, Serialize}; +use std::{borrow::Cow, collections::HashSet}; + +super::define_do_binding! { + const BINDING = "DAP_REPLAY_CHECK"; + + enum Command { + Check = "/check", + } + + fn name((task_id, epoch, shard): (&'n TaskId, Time, usize)) -> ObjectIdFrom { + ObjectIdFrom::Name(format!("replay-checker/{task_id}/epoch/{epoch}/shard/{shard}")) + } +} + +pub struct Request<'s> { + pub report_ids: Cow<'s, [ReportId]>, + pub aggregation_job_id: AggregationJobId, +} + +impl CapnprotoPayloadEncode for Request<'_> { + type Builder<'a> = check_replays_for::Builder<'a>; + + fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { + let Self { + report_ids, + aggregation_job_id, + } = self; + encode_list( + report_ids.iter(), + builder + .reborrow() + .init_reports(usize_to_capnp_len(report_ids.len())), + ); + aggregation_job_id.encode_to_builder(builder.init_aggregation_job_id()); + } +} + +impl CapnprotoPayloadDecode for Request<'static> { + type Reader<'a> = check_replays_for::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result + where + Self: Sized, + { + Ok(Self { + report_ids: decode_list::(reader.get_reports()?)?, + aggregation_job_id: <_>::decode_from_reader(reader.get_aggregation_job_id()?)?, + }) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Response { + pub duplicates: HashSet, +} diff --git a/crates/daphne-service-utils/src/lib.rs b/crates/daphne-service-utils/src/lib.rs index cc255864b..dcd9018d0 100644 --- a/crates/daphne-service-utils/src/lib.rs +++ b/crates/daphne-service-utils/src/lib.rs @@ -46,8 +46,42 @@ mod aggregation_job_store_capnp { )); } +#[cfg(feature = "durable_requests")] +mod agg_job_response_store_capnp { + #![allow(dead_code)] + #![allow(clippy::pedantic)] + #![allow(clippy::needless_lifetimes)] + include!(concat!( + env!("OUT_DIR"), + "/src/durable_requests/bindings/agg_job_response_store_capnp.rs" + )); +} + +#[cfg(feature = "durable_requests")] +mod aggregate_store_v2_capnp { + #![allow(dead_code)] + #![allow(clippy::pedantic)] + #![allow(clippy::needless_lifetimes)] + include!(concat!( + env!("OUT_DIR"), + "/src/durable_requests/bindings/aggregate_store_v2_capnp.rs" + )); +} + +#[cfg(feature = "durable_requests")] +mod replay_checker_capnp { + #![allow(dead_code)] + #![allow(clippy::pedantic)] + #![allow(clippy::needless_lifetimes)] + include!(concat!( + env!("OUT_DIR"), + "/src/durable_requests/bindings/replay_checker_capnp.rs" + )); +} + #[cfg(feature = "compute-offload")] -mod compute_offload_capnp { +#[doc(hidden)] +pub mod compute_offload_capnp { #![allow(dead_code)] #![allow(clippy::pedantic)] #![allow(clippy::needless_lifetimes)] diff --git a/crates/daphne-worker-test/src/durable.rs b/crates/daphne-worker-test/src/durable.rs index 6fb72a456..00fd5247a 100644 --- a/crates/daphne-worker-test/src/durable.rs +++ b/crates/daphne-worker-test/src/durable.rs @@ -18,3 +18,19 @@ instantiate_durable_object! { daphne_worker::tracing_utils::initialize_tracing(env); } } + +instantiate_durable_object! { + struct AggJobResponseStore < durable::AggJobResponseStore; + + fn init_user_data(_state: State, env: Env) { + daphne_worker::tracing_utils::initialize_tracing(env); + } +} + +instantiate_durable_object! { + struct NewAggregateStore < durable::AggregateStoreV2; + + fn init_user_data(_state: State, env: Env) { + daphne_worker::tracing_utils::initialize_tracing(env); + } +} diff --git a/crates/daphne-worker-test/src/lib.rs b/crates/daphne-worker-test/src/lib.rs index a48da5f3d..2d191ed8b 100644 --- a/crates/daphne-worker-test/src/lib.rs +++ b/crates/daphne-worker-test/src/lib.rs @@ -5,7 +5,7 @@ use daphne_worker::{aggregator::App, initialize_tracing}; use futures::stream; use std::convert::Infallible; use tracing::info; -use worker::{event, Env, HttpRequest, ResponseBody}; +use worker::{event, Env, HttpRequest, MessageBatch, ResponseBody}; mod durable; mod utils; @@ -13,6 +13,11 @@ mod utils; #[global_allocator] static CAP: cap::Cap = cap::Cap::new(std::alloc::System, 65_000_000); +fn load_compute_offload_host(env: &worker::Env) -> String { + env.var("COMPUTE_OFFLOAD_HOST") + .map_or_else(|_| "localhost:5000".into(), |t| t.to_string()) +} + #[event(fetch, respond_with_errors)] pub async fn main( req: HttpRequest, @@ -39,9 +44,7 @@ pub async fn main( daphne_worker::storage_proxy::handle_request(req, env, ®istry).await } Some("aggregator") => { - let host = env - .var("COMPUTE_OFFLOAD_HOST") - .map_or_else(|_| "localhost:5000".into(), |t| t.to_string()); + let host = load_compute_offload_host(&env); daphne_worker::aggregator::handle_dap_request( App::new(env, ®istry, None, Box::new(ComputeOffload { host })).unwrap(), @@ -96,3 +99,16 @@ impl daphne_worker::aggregator::ComputeOffload for ComputeOffload { .unwrap()) } } + +#[event(queue)] +pub async fn queue( + batch: MessageBatch<()>, + env: worker::Env, + _ctx: worker::Context, +) -> worker::Result<()> { + let registry = prometheus::Registry::new(); + let host = load_compute_offload_host(&env); + let app = App::new(env, ®istry, None, Box::new(ComputeOffload { host })).unwrap(); + daphne_worker::aggregator::queues::async_aggregate(app, batch).await; + Ok(()) +} diff --git a/crates/daphne-worker-test/wrangler.aggregator.toml b/crates/daphne-worker-test/wrangler.aggregator.toml index 98e0b37ce..fbc3269e5 100644 --- a/crates/daphne-worker-test/wrangler.aggregator.toml +++ b/crates/daphne-worker-test/wrangler.aggregator.toml @@ -68,8 +68,17 @@ bindings = [ { name = "DAP_AGGREGATE_STORE", class_name = "AggregateStore" }, { name = "DAP_TEST_STATE_CLEANER", class_name = "TestStateCleaner" }, { name = "AGGREGATION_JOB_STORE", class_name = "AggregationJobStore" }, + { name = "AGGREGATE_JOB_RESULT_STORE", class_name = "AggregationJobStore" }, + { name = "AGGREGATE_STORE", class_name = "NewAggregateStore" }, ] +[[env.helper.queues.producers]] +queue = "async-aggregation-queue" +binding = "ASYNC_AGGREGATION_QUEUE" + +[[env.helper.queues.consumers]] +queue = "async-aggregation-queue" +max_retries = 10 [[env.helper.kv_namespaces]] binding = "DAP_CONFIG" @@ -130,6 +139,8 @@ bindings = [ { name = "DAP_AGGREGATE_STORE", class_name = "AggregateStore" }, { name = "DAP_TEST_STATE_CLEANER", class_name = "TestStateCleaner" }, { name = "AGGREGATION_JOB_STORE", class_name = "AggregationJobStore" }, + { name = "AGGREGATE_JOB_RESULT_STORE", class_name = "AggregationJobStore" }, + { name = "AGGREGATE_STORE", class_name = "NewAggregateStore" }, ] [[env.leader.kv_namespaces]] @@ -157,4 +168,6 @@ new_classes = [ "AggregateStore", "GarbageCollector", "AggregationJobStore", + "AggJobResponseStore", + "NewAggregateStore", ] diff --git a/crates/daphne-worker/Cargo.toml b/crates/daphne-worker/Cargo.toml index da77cbca5..d0014848b 100644 --- a/crates/daphne-worker/Cargo.toml +++ b/crates/daphne-worker/Cargo.toml @@ -19,6 +19,7 @@ crate-type = ["cdylib", "rlib"] async-trait = { workspace = true } axum-extra = { workspace = true, features = ["typed-header"] } bytes.workspace = true +capnp = { workspace = true } chrono = { workspace = true, default-features = false, features = ["clock", "wasmbind"] } constcat.workspace = true daphne = { path = "../daphne", features = ["prometheus"] } @@ -49,7 +50,8 @@ tracing-core.workspace = true tracing-subscriber = { workspace = true, features = ["env-filter", "json"]} tracing.workspace = true url.workspace = true -worker.workspace = true +wasm-bindgen.workspace = true +worker = { workspace = true , features = ["http", "queue"] } [dependencies.axum] workspace = true @@ -67,6 +69,9 @@ reqwest.workspace = true # used in doc tests tokio.workspace = true webpki.workspace = true +[build-dependencies] +capnpc = { workspace = true } + [features] test-utils = ["daphne-service-utils/test-utils"] diff --git a/crates/daphne-worker/build.rs b/crates/daphne-worker/build.rs new file mode 100644 index 000000000..77c603c87 --- /dev/null +++ b/crates/daphne-worker/build.rs @@ -0,0 +1,10 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +fn main() { + ::capnpc::CompilerCommand::new() + .import_path("../daphne-service-utils/src") + .file("./src/aggregator/queues/queue_messages.capnp") + .run() + .expect("compiling schema"); +} diff --git a/crates/daphne-worker/src/aggregator/mod.rs b/crates/daphne-worker/src/aggregator/mod.rs index 1e0254112..063c7e682 100644 --- a/crates/daphne-worker/src/aggregator/mod.rs +++ b/crates/daphne-worker/src/aggregator/mod.rs @@ -3,6 +3,7 @@ mod config; mod metrics; +pub mod queues; mod roles; mod router; @@ -31,6 +32,7 @@ use router::DaphneService; use std::sync::{Arc, LazyLock, Mutex}; use worker::send::SendWrapper; +use queues::Queue; pub use router::handle_dap_request; #[async_trait::async_trait(?Send)] @@ -194,4 +196,12 @@ impl App { fn bearer_tokens(&self) -> BearerTokens<'_> { BearerTokens::from(Kv::new(&self.env, &self.kv_state)) } + + fn async_aggregation_queue(&self) -> Queue { + Queue::from( + self.env + .get_binding::("ASYNC_AGGREGATION_QUEUE") + .unwrap(), + ) + } } diff --git a/crates/daphne-worker/src/aggregator/queues/async_aggregator.rs b/crates/daphne-worker/src/aggregator/queues/async_aggregator.rs new file mode 100644 index 000000000..99eab38ba --- /dev/null +++ b/crates/daphne-worker/src/aggregator/queues/async_aggregator.rs @@ -0,0 +1,260 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use crate::{ + aggregator::App, + queue_messages_capnp, + storage::{self, kv, Do}, +}; +use daphne::{ + messages::{AggregationJobId, PartialBatchSelector, ReportId, ReportMetadata, TaskId, Time}, + roles::helper::handle_agg_job::ToInitializedReportsTransition, + DapVersion, +}; +use daphne_service_utils::{ + capnproto::{CapnprotoPayloadDecode, CapnprotoPayloadDecodeExt, CapnprotoPayloadEncode}, + compute_offload, + durable_requests::bindings::{ + agg_job_response_store, aggregate_store_v2, + replay_checker::{self, Command}, + }, +}; +use futures::{stream::FuturesUnordered, StreamExt, TryStreamExt}; +use std::{ + collections::{HashMap, HashSet}, + num::NonZeroUsize, +}; +use worker::{MessageBatch, MessageExt, RawMessage}; + +fn deserialize(message: &RawMessage) -> worker::Result { + let buf: worker::js_sys::Uint8Array = message.body().into(); + T::decode_from_bytes(&buf.to_vec()).map_err(|e| worker::Error::RustError(e.to_string())) +} + +pub struct AsyncAggregationMessage<'s> { + pub version: DapVersion, + pub part_batch_sel: PartialBatchSelector, + pub agg_job_id: AggregationJobId, + pub initialize_reports: compute_offload::InitializeReports<'s>, +} + +impl CapnprotoPayloadEncode for AsyncAggregationMessage<'_> { + type Builder<'a> = queue_messages_capnp::async_aggregation_message::Builder<'a>; + + fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { + builder.set_version(self.version.into()); + self.part_batch_sel + .encode_to_builder(builder.reborrow().init_partial_batch_selector()); + self.agg_job_id + .encode_to_builder(builder.reborrow().init_aggregation_job_id()); + self.initialize_reports + .encode_to_builder(builder.reborrow().init_initialize_reports()); + } +} + +impl CapnprotoPayloadDecode for AsyncAggregationMessage<'static> { + type Reader<'a> = queue_messages_capnp::async_aggregation_message::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result + where + Self: Sized, + { + Ok(Self { + version: reader.get_version()?.into(), + agg_job_id: <_>::decode_from_reader(reader.get_aggregation_job_id()?)?, + part_batch_sel: <_>::decode_from_reader(reader.get_partial_batch_selector()?)?, + initialize_reports: <_>::decode_from_reader(reader.get_initialize_reports()?)?, + }) + } +} + +async fn shard_reports<'i>( + durable: Do<'_>, + task_id: &TaskId, + time_precision: Time, + agg_job_id: AggregationJobId, + reports: impl Iterator, +) -> Result, storage::Error> { + let mut shards = HashMap::<_, Vec<_>>::new(); + for r in reports { + let epoch = r.time - (r.time % time_precision); + let shard = r.id.shard(NonZeroUsize::new(1024).unwrap()); + shards.entry((epoch, shard)).or_default().push(r.id); + } + + futures::stream::iter(shards) + .map(|((epoch, shard), report_ids)| async move { + durable + .with_retry() + .request(Command::Check, (task_id, epoch, shard)) + .encode(&replay_checker::Request { + report_ids: report_ids.into(), + aggregation_job_id: agg_job_id, + }) + .send::() + .await + .map(|r| r.duplicates) + }) + .buffer_unordered(6) + .try_fold(HashSet::new(), |mut acc, dups| async move { + acc.extend(dups); + Ok(acc) + }) + .await +} + +macro_rules! bail { + (retry $m:ident, err = $error:expr, $msg:literal) => {{ + $m.retry(); + bail!(err = $error, $msg) + }}; + (err = $error:expr, $msg:literal) => {{ + tracing::error!(error = ?$error, $msg); + return; + }} +} + +/// Perform an aggregation job. +/// +/// ## Note +/// There is a worst case scenario that this handler can't deal with. +/// +/// Messages will be replayed if they fail, but the workers runtime will eventually give up on a +/// message after it's been retried a bunch of times, in that case the helper will never respond +/// positively to the poll request, possibly leaving the leader in an infinite loop state. The +/// leader can resubmit the work as many times as it wants to get out of this situation, but +/// implementers of the leader must be made aware of this. +// +// ----- +// +// All of the IO in this function is idempotent. They can be spotted by looking at the `.await` +// expressions. The explanation is as follows: +// 1. Getting a task config. No writes performed. +// 2. Initializing the reports. Stateless. +// 3. The same aggregation job may replay it's own reports, this means replay checking is +// idempotent. See the [ReplayChecker] durable object for more details. +// 4. Storing the aggregate share does not merge with any other aggregate share and simply +// replaces the previous one, which will be identical. +// 5. Storing the aggregate response simply overwrites the previous response, which will be +// identical. +// +pub async fn async_aggregate_one(app: &App, message: RawMessage) { + let AsyncAggregationMessage { + version, + agg_job_id, + part_batch_sel, + initialize_reports: aggregate_message, + } = match deserialize(&message) { + Ok(m) => m, + Err(e) => bail!(err = e, "failed to deserialize replay queue message"), + }; + + // 1. + let task_config = match app + .kv() + .get_cloned::( + &aggregate_message.task_id, + &kv::KvGetOptions { + cache_not_found: true, + }, + ) + .await + { + Ok(Some(t)) => t, + Ok(None) => return, + Err(e) => bail!(retry message, err = e, "failed to fetch task config from kv"), + }; + + // 2. + let initialized_reports = match app + .compute_offload + .compute::<_, compute_offload::InitializedReports>( + "/compute_offload/initialize_reports", + &aggregate_message, + ) + .await + { + Ok(init) => init, + Err(e) => bail!(retry message, err = e, "failed to initialize reports"), + }; + + let time_precision = task_config.time_precision; + let state_machine = ToInitializedReportsTransition { + task_id: aggregate_message.task_id, + part_batch_sel, + task_config, + } + .with_initialized_reports(initialized_reports.reports); + + // 3. + let state_machine = match state_machine + .check_for_replays(|report_ids| { + let report_ids = report_ids.cloned().collect::>(); + shard_reports( + app.durable(), + &aggregate_message.task_id, + time_precision, + agg_job_id, + report_ids.into_iter(), + ) + }) + .await + { + Ok(st) => st, + Err(e) => bail!(retry message, err = e, "failed to check replays"), + }; + + let (span, agg_job_response) = match state_machine.finish() { + Ok(output) => output, + // this error is always caused by a bug in the code, it's not recoverable + Err(e) => bail!(err = e, "failed to finish aggregation"), + }; + + for (bucket, (share, _)) in span { + let request = aggregate_store_v2::PutRequest { + agg_job_id, + agg_share_delta: share, + }; + // 4. + let response = app + .durable() + .with_retry() + .request( + aggregate_store_v2::Command::Put, + (version, &aggregate_message.task_id, &bucket), + ) + .encode(&request) + .send::<()>() + .await; + match response { + Ok(()) => {} + Err(e) => bail!(retry message, err = e, "failed to store aggregate share"), + } + } + + // 5. + let result = app + .durable() + .with_retry() + .request( + agg_job_response_store::Command::Put, + (version, &aggregate_message.task_id, &agg_job_id), + ) + .encode(&agg_job_response) + .send::<()>() + .await; + + match result { + Ok(()) => {} + Err(e) => bail!(retry message, err = e, "failed to store aggregation response"), + } +} + +pub async fn async_aggregate(app: App, message_batch: MessageBatch<()>) { + message_batch + .raw_iter() + .map(|m| async_aggregate_one(&app, m)) + .collect::>() + .collect::<()>() + .await; +} diff --git a/crates/daphne-worker/src/aggregator/queues/mod.rs b/crates/daphne-worker/src/aggregator/queues/mod.rs new file mode 100644 index 000000000..97cb914f5 --- /dev/null +++ b/crates/daphne-worker/src/aggregator/queues/mod.rs @@ -0,0 +1,37 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +mod async_aggregator; + +pub use async_aggregator::{async_aggregate, AsyncAggregationMessage}; +use daphne_service_utils::capnproto::{CapnprotoPayloadEncode, CapnprotoPayloadEncodeExt as _}; +use std::marker::PhantomData; +use worker::RawMessageBuilder; + +pub struct Queue { + queue: worker::Queue, + _message_type: PhantomData, +} + +impl Queue { + pub async fn send(&self, message: &T) -> worker::Result<()> { + let bytes = worker::js_sys::Uint8Array::from(message.encode_to_bytes().as_slice()); + self.queue + .send_raw( + RawMessageBuilder::new(bytes.into()) + .build_with_content_type(worker::QueueContentType::V8), + ) + .await?; + + Ok(()) + } +} + +impl From for Queue { + fn from(queue: worker::Queue) -> Self { + Self { + queue, + _message_type: PhantomData, + } + } +} diff --git a/crates/daphne-worker/src/aggregator/queues/queue_messages.capnp b/crates/daphne-worker/src/aggregator/queues/queue_messages.capnp new file mode 100644 index 000000000..2553bb8a5 --- /dev/null +++ b/crates/daphne-worker/src/aggregator/queues/queue_messages.capnp @@ -0,0 +1,15 @@ +# Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +@0x8240fbeac47031a3; + +using Base = import "/capnproto/base.capnp"; +using ComputeOffload = import "/compute_offload/compute_offload.capnp"; + +struct AsyncAggregationMessage { + version @0 :Base.DapVersion; + reports @1 :List(Base.ReportId); + aggregationJobId @2 :Base.AggregationJobId; + partialBatchSelector @3 :Base.PartialBatchSelector; + initializeReports @4 :ComputeOffload.InitializeReports; +} diff --git a/crates/daphne-worker/src/aggregator/roles/aggregator.rs b/crates/daphne-worker/src/aggregator/roles/aggregator.rs index cbe3ae3bb..3e3183f6d 100644 --- a/crates/daphne-worker/src/aggregator/roles/aggregator.rs +++ b/crates/daphne-worker/src/aggregator/roles/aggregator.rs @@ -21,9 +21,10 @@ use daphne::{ DapVersion, }; use daphne_service_utils::durable_requests::bindings::{ - self, AggregateStoreMergeOptions, AggregateStoreMergeReq, AggregateStoreMergeResp, + self, aggregate_store_v2, AggregateStoreMergeOptions, AggregateStoreMergeReq, + AggregateStoreMergeResp, }; -use futures::{future::try_join_all, StreamExt as _, TryFutureExt as _, TryStreamExt as _}; +use futures::{future::try_join_all, StreamExt, TryFutureExt as _, TryStreamExt}; use mappable_rc::Marc; use std::{num::NonZeroUsize, ops::Range}; use worker::send::SendFuture; @@ -75,40 +76,18 @@ impl DapAggregator for App { .await } + // this implementation is hardcoded to for the helper #[tracing::instrument(skip(self))] async fn get_agg_share( &self, + version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result { - let task_config = self - .get_task_config_for(task_id) - .await? - .ok_or(DapError::Abort(DapAbort::UnrecognizedTask { - task_id: *task_id, - }))?; - - let durable = self.durable(); - let mut requests = Vec::new(); - for bucket in task_config.as_ref().batch_span_for_sel(batch_sel)? { - requests.push( - durable - .request( - bindings::AggregateStore::Get, - (task_config.as_ref().version, task_id, &bucket), - ) - .send(), - ); + match version { + DapVersion::Latest => self.get_agg_share_draft_latest(task_id, batch_sel).await, + DapVersion::Draft09 => self.get_agg_share_draft_09(task_id, batch_sel).await, } - let responses: Vec = try_join_all(requests) - .await - .map_err(|e| fatal_error!(err = ?e, "failed to get agg shares from durable objects"))?; - let mut agg_share = DapAggregateShare::default(); - for agg_share_delta in responses { - agg_share.merge(agg_share_delta)?; - } - - Ok(agg_share) } #[tracing::instrument(skip(self))] @@ -434,3 +413,74 @@ impl hpke::HpkeProvider for App { )) } } + +impl App { + async fn get_agg_share_draft_09( + &self, + task_id: &TaskId, + batch_sel: &BatchSelector, + ) -> Result { + let task_config = self + .get_task_config_for(task_id) + .await? + .ok_or(DapError::Abort(DapAbort::UnrecognizedTask { + task_id: *task_id, + }))?; + + let durable = self.durable(); + let mut requests = Vec::new(); + for bucket in task_config.as_ref().batch_span_for_sel(batch_sel)? { + requests.push( + durable + .request( + bindings::AggregateStore::Get, + (task_config.as_ref().version, task_id, &bucket), + ) + .send(), + ); + } + let responses: Vec = try_join_all(requests) + .await + .map_err(|e| fatal_error!(err = ?e, "failed to get agg shares from durable objects"))?; + let mut agg_share = DapAggregateShare::default(); + for agg_share_delta in responses { + agg_share.merge(agg_share_delta)?; + } + + Ok(agg_share) + } + + async fn get_agg_share_draft_latest( + &self, + task_id: &TaskId, + batch_sel: &BatchSelector, + ) -> Result { + let task_config = self + .get_task_config_for(task_id) + .await? + .ok_or(DapError::Abort(DapAbort::UnrecognizedTask { + task_id: *task_id, + }))?; + + let durable = self.durable(); + let agg_share = futures::stream::iter(task_config.as_ref().batch_span_for_sel(batch_sel)?) + .map(|bucket| { + durable + .request( + aggregate_store_v2::Command::Get, + (task_config.as_ref().version, task_id, &bucket), + ) + .send() + .map_err( + |e| fatal_error!(err = ?e, "failed to get agg shares from durable objects"), + ) + }) + .buffer_unordered(6) + .try_fold(DapAggregateShare::default(), |mut acc, share| async move { + acc.merge(share).map(|()| acc) + }) + .await?; + + Ok(agg_share) + } +} diff --git a/crates/daphne-worker/src/aggregator/roles/helper.rs b/crates/daphne-worker/src/aggregator/roles/helper.rs index 42d836d2a..8c6d19dfa 100644 --- a/crates/daphne-worker/src/aggregator/roles/helper.rs +++ b/crates/daphne-worker/src/aggregator/roles/helper.rs @@ -5,11 +5,16 @@ use crate::aggregator::App; use daphne::{ error::DapAbort, fatal_error, - messages::{request::AggregationJobRequestHash, AggregationJobId, TaskId}, + messages::{ + request::AggregationJobRequestHash, AggregationJobId, AggregationJobResp, + ReadyAggregationJobResp, TaskId, + }, roles::DapHelper, DapError, DapVersion, }; -use daphne_service_utils::durable_requests::bindings::aggregation_job_store; +use daphne_service_utils::durable_requests::bindings::{ + agg_job_response_store, aggregation_job_store, +}; use std::borrow::Cow; #[axum::async_trait] @@ -41,4 +46,46 @@ impl DapHelper for App { ), } } + + async fn poll_aggregated( + &self, + version: DapVersion, + task_id: &TaskId, + agg_job_id: &AggregationJobId, + ) -> Result { + let valid_agg_job_id = self + .durable() + .with_retry() + .request( + aggregation_job_store::Command::ContainsJob, + (version, task_id), + ) + .encode(agg_job_id) + .send::() + .await + .map_err(|e| fatal_error!(err = ?e, "failed to query the validity of the aggregation job id"))?; + + if !valid_agg_job_id { + return Err(DapError::Abort(DapAbort::UnrecognizedAggregationJob { + task_id: *task_id, + agg_job_id: *agg_job_id, + })); + } + + let response = self + .durable() + .with_retry() + .request( + agg_job_response_store::Command::Get, + (version, task_id, agg_job_id), + ) + .send::>() + .await + .map_err(|e| fatal_error!(err = ?e, "failed to poll for aggregation job response"))?; + + match response { + Some(ready) => Ok(ready.into()), + None => Ok(AggregationJobResp::Processing), + } + } } diff --git a/crates/daphne-worker/src/aggregator/router/helper.rs b/crates/daphne-worker/src/aggregator/router/helper.rs index ab74c6af1..adb9c0b1d 100644 --- a/crates/daphne-worker/src/aggregator/router/helper.rs +++ b/crates/daphne-worker/src/aggregator/router/helper.rs @@ -5,7 +5,7 @@ use super::{ super::roles::fetch_replay_protection_override, extractor::dap_sender::FROM_LEADER, App, AxumDapResponse, DapRequestExtractor, DaphneService, }; -use crate::elapsed; +use crate::{aggregator::queues, elapsed}; use axum::{ extract::State, routing::{post, put}, @@ -13,9 +13,9 @@ use axum::{ use daphne::{ fatal_error, hpke::HpkeProvider, - messages::{request::HashedAggregationJobReq, AggregateShareReq}, + messages::{request::HashedAggregationJobReq, AggregateShareReq, AggregationJobResp}, roles::{helper, DapAggregator, DapHelper}, - DapError, DapResponse, + DapError, DapResponse, DapVersion, }; use daphne_service_utils::compute_offload; use http::StatusCode; @@ -34,13 +34,78 @@ pub(super) fn add_helper_routes(router: super::Router) -> super::Router>, + req: DapRequestExtractor, +) -> AxumDapResponse { + match req.0.version { + DapVersion::Draft09 => agg_job_draft9(state, req).await, + DapVersion::Latest => agg_job_draft_latest(state, req).await, + } +} + +async fn agg_job_draft_latest( + State(app): State>, + DapRequestExtractor(req): DapRequestExtractor, +) -> AxumDapResponse { + let now = worker::Date::now(); + let version = req.version; + + let queue_result = async { + let (transition, req) = helper::handle_agg_job::start(req) + .check_aggregation_job_legality(&*app) + .await? + .resolve_task_config(&*app) + .await? + .into_parts(fetch_replay_protection_override(app.kv()).await)?; + + let hpke_receiver_configs = app.get_hpke_receiver_configs(req.version).await?; + + app.async_aggregation_queue() + .send(&queues::AsyncAggregationMessage { + version, + part_batch_sel: transition.part_batch_sel, + agg_job_id: req.resource_id, + initialize_reports: compute_offload::InitializeReports { + hpke_keys: Cow::Borrowed(hpke_receiver_configs.as_ref()), + valid_report_range: app.valid_report_time_range(), + task_id: req.task_id, + task_config: (&transition.task_config).into(), + agg_param: Cow::Borrowed(&req.payload.agg_param), + prep_inits: req.payload.prep_inits, + }, + }) + .await + .map_err(|e| fatal_error!(err = ?e, "failed to queue response")) + } + .await; + + let elapsed = elapsed(&now); + + app.server_metrics().aggregate_job_latency(elapsed); + + AxumDapResponse::from_result_with_success_code( + queue_result.and_then(|()| { + Ok(DapResponse { + version, + media_type: daphne::constants::DapMediaType::AggregationJobResp, + payload: AggregationJobResp::Processing + .get_encoded_with_param(&version) + .map_err(DapError::encoding)?, + }) + }), + app.server_metrics(), + StatusCode::CREATED, + ) +} + +async fn agg_job_draft9( State(app): State>, DapRequestExtractor(req): DapRequestExtractor, ) -> AxumDapResponse { diff --git a/crates/daphne-worker/src/durable/agg_job_response_store.rs b/crates/daphne-worker/src/durable/agg_job_response_store.rs new file mode 100644 index 000000000..9175d8828 --- /dev/null +++ b/crates/daphne-worker/src/durable/agg_job_response_store.rs @@ -0,0 +1,118 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! Durable Object for storing the result of an aggregation job. + +use super::{req_parse, GcDurableObject}; +use crate::int_err; +use daphne::messages::ReadyAggregationJobResp; +use daphne_service_utils::durable_requests::bindings::{ + self, agg_job_response_store, DurableMethod as _, +}; +use std::{sync::OnceLock, time::Duration}; +use worker::{js_sys, Env, Request, Response, Result, ScheduledTime, State}; + +const AGGREGATE_RESPONSE_CHUNK_KEY_PREFIX: &str = "dap_agg_response_chunk"; + +super::mk_durable_object! { + /// Where the aggregate share is stored. For the binding name see its + /// [`BINDING`](bindings::AggregateStore::BINDING) + struct AggJobResponseStore { + state: State, + env: Env, + agg_job_resp: Option, + } +} + +impl AggJobResponseStore { + async fn get_agg_job_response(&mut self) -> Result> { + let agg_job_resp = if let Some(agg_job_resp) = self.agg_job_resp.take() { + agg_job_resp + } else { + let Some(agg_job_resp) = self + .load_chuncked_value(AGGREGATE_RESPONSE_CHUNK_KEY_PREFIX) + .await? + else { + return Ok(None); + }; + agg_job_resp + }; + + self.agg_job_resp = Some(agg_job_resp); + + Ok(self.agg_job_resp.as_ref()) + } + + fn put_agg_job_response(&mut self, resp: ReadyAggregationJobResp) -> Result { + let obj = self.serialize_chunked_value(AGGREGATE_RESPONSE_CHUNK_KEY_PREFIX, &resp, None)?; + self.agg_job_resp = Some(resp); + Ok(obj) + } +} + +impl GcDurableObject for AggJobResponseStore { + type DurableMethod = bindings::AggregateStore; + + fn with_state_and_env(state: State, env: Env) -> Self { + Self { + state, + env, + agg_job_resp: None, + } + } + + async fn handle(&mut self, mut req: Request) -> Result { + match agg_job_response_store::Command::try_from_uri(&req.path()) { + // Store an aggregate share and aggregation job response. + // + // Idempotent + // Input: `agg_share_dellta: agg_job_result_store::FinishRequest` + // Output: `agg_job_result_store::FinishResponse` + Some(agg_job_response_store::Command::Put) => { + let response = req_parse::(&mut req).await?; + + self.state + .storage() + .put_multiple_raw(self.put_agg_job_response(response)?) + .await?; + + Response::from_json(&()) + } + + // Get the AggregationJobResp + // + // Idempotent + // Output: `Option` + Some(agg_job_response_store::Command::Get) => { + let response = self.get_agg_job_response().await?; + Response::from_json(&response) + } + + None => Err(int_err(format!( + "AggregatesStore: unexpected request: method={:?}; path={:?}", + req.method(), + req.path() + ))), + } + } + + fn should_cleanup_at(&self) -> Option { + const VAR_NAME: &str = "DAP_DURABLE_AGGREGATE_STORE_GC_AFTER_SECS"; + static SELF_DELETE_AFTER: OnceLock = OnceLock::new(); + + let duration = SELF_DELETE_AFTER.get_or_init(|| { + Duration::from_secs( + self.env + .var(VAR_NAME) + .map(|v| { + v.to_string().parse().unwrap_or_else(|e| { + panic!("{VAR_NAME} could not be parsed as a number of seconds: {e}") + }) + }) + .unwrap_or(60 * 60 * 24 * 7), // one week + ) + }); + + Some(ScheduledTime::from(*duration)) + } +} diff --git a/crates/daphne-worker/src/durable/aggregate_store_v2.rs b/crates/daphne-worker/src/durable/aggregate_store_v2.rs new file mode 100644 index 000000000..44ec4e01f --- /dev/null +++ b/crates/daphne-worker/src/durable/aggregate_store_v2.rs @@ -0,0 +1,166 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! Durable Object for storing the result of an aggregation job. + +use super::{req_parse, GcDurableObject}; +use crate::int_err; +use daphne::DapAggregateShare; +use daphne_service_utils::durable_requests::bindings::{ + self, aggregate_store_v2, DurableMethod as _, +}; +use futures::{StreamExt, TryStreamExt}; +use std::{sync::OnceLock, time::Duration}; +use worker::{js_sys, wasm_bindgen::JsValue, Env, Request, Response, Result, ScheduledTime, State}; + +const AGGREGATION_JOB_IDS_KEY: &str = "agg-job-ids"; + +super::mk_durable_object! { + /// Where the aggregate share is stored. For the binding name see its + /// [`BINDING`](bindings::AggregateStore::BINDING) + struct AggregateStoreV2 { + state: State, + env: Env, + collected: Option, + } +} + +impl AggregateStoreV2 { + async fn get_agg_share(&self, agg_job_id: &str) -> Result> { + self.load_chuncked_value(agg_job_id).await + } + + fn put_agg_share( + &mut self, + agg_job_id: &str, + share: DapAggregateShare, + obj: js_sys::Object, + ) -> Result { + self.serialize_chunked_value(agg_job_id, &share, obj) + } + + async fn is_collected(&mut self) -> Result { + Ok(if let Some(collected) = self.collected { + collected + } else { + let collected = self.get_or_default("collected").await?; + self.collected = Some(collected); + collected + }) + } +} + +impl GcDurableObject for AggregateStoreV2 { + type DurableMethod = bindings::AggregateStore; + + fn with_state_and_env(state: State, env: Env) -> Self { + Self { + state, + env, + collected: None, + } + } + + async fn handle(&mut self, mut req: Request) -> Result { + match aggregate_store_v2::Command::try_from_uri(&req.path()) { + // Store an aggregate share and aggregation job response. + // + // Idempotent + // Input: `agg_share_dellta: agg_job_result_store::FinishRequest` + // Output: `agg_job_result_store::FinishResponse` + Some(aggregate_store_v2::Command::Put) => { + let aggregate_store_v2::PutRequest { + agg_job_id, + agg_share_delta, + } = req_parse(&mut req).await?; + + let mut agg_job_ids = self + .get_or_default::>(AGGREGATION_JOB_IDS_KEY) + .await?; + + let chunks_map = js_sys::Object::default(); + + let agg_job_id = agg_job_id.to_string(); + + let chunks_map = self.put_agg_share(&agg_job_id, agg_share_delta, chunks_map)?; + + agg_job_ids.push(agg_job_id); + js_sys::Reflect::set( + &chunks_map, + &JsValue::from_str(AGGREGATION_JOB_IDS_KEY), + &serde_wasm_bindgen::to_value(&agg_job_ids)?, + )?; + + self.state.storage().put_multiple_raw(chunks_map).await?; + + Response::from_json(&()) + } + + // Get the current aggregate share. + // + // Idempotent + // Output: `DapAggregateShare` + Some(aggregate_store_v2::Command::Get) => { + let ids = self + .get_or_default::>(AGGREGATION_JOB_IDS_KEY) + .await?; + let this = &self; + let share = futures::stream::iter(ids) + .map(|id| async move { this.get_agg_share(&id).await }) + .buffer_unordered(8) + .filter_map(|share| async move { share.transpose() }) + .try_fold(DapAggregateShare::default(), |mut acc, share| async move { + acc.merge(share) + .map(|()| acc) + .map_err(|e| worker::Error::RustError(e.to_string())) + }) + .await?; + Response::from_json(&share) + } + + // Mark this bucket as collected. + // + // Idempotent + // Output: `()` + Some(aggregate_store_v2::Command::MarkCollected) => { + self.state.storage().put("collected", true).await?; + self.collected = Some(true); + Response::from_json(&()) + } + + // Get the value of the flag indicating whether this bucket has been collected. + // + // Idempotent + // Output: `bool` + Some(aggregate_store_v2::Command::CheckCollected) => { + Response::from_json(&self.is_collected().await?) + } + + None => Err(int_err(format!( + "AggregatesStore: unexpected request: method={:?}; path={:?}", + req.method(), + req.path() + ))), + } + } + + fn should_cleanup_at(&self) -> Option { + const VAR_NAME: &str = "DAP_DURABLE_AGGREGATE_STORE_GC_AFTER_SECS"; + static SELF_DELETE_AFTER: OnceLock = OnceLock::new(); + + let duration = SELF_DELETE_AFTER.get_or_init(|| { + Duration::from_secs( + self.env + .var(VAR_NAME) + .map(|v| { + v.to_string().parse().unwrap_or_else(|e| { + panic!("{VAR_NAME} could not be parsed as a number of seconds: {e}") + }) + }) + .unwrap_or(60 * 60 * 24 * 7), // one week + ) + }); + + Some(ScheduledTime::from(*duration)) + } +} diff --git a/crates/daphne-worker/src/durable/aggregation_job_store.rs b/crates/daphne-worker/src/durable/aggregation_job_store.rs index d2a0d3e54..93b255098 100644 --- a/crates/daphne-worker/src/durable/aggregation_job_store.rs +++ b/crates/daphne-worker/src/durable/aggregation_job_store.rs @@ -56,6 +56,11 @@ impl GcDurableObject for AggregationJobStore { Response::from_json(&response) } + Some(aggregation_job_store::Command::ContainsJob) => { + let agg_job_id = req_parse::(&mut req).await?; + let has = self.has(&agg_job_id.to_string()).await?; + Response::from_json(&has) + } Some(aggregation_job_store::Command::ListJobIds) => { Response::from_json(&self.load_seen_agg_job_ids().await?) } diff --git a/crates/daphne-worker/src/durable/mod.rs b/crates/daphne-worker/src/durable/mod.rs index b719e8344..a9cd1fb50 100644 --- a/crates/daphne-worker/src/durable/mod.rs +++ b/crates/daphne-worker/src/durable/mod.rs @@ -19,8 +19,11 @@ //! To know what values to provide to the `name` and `class_name` fields see each type exported by //! this module as well as the [`instantiate_durable_object`] macro, respectively. +pub(crate) mod agg_job_response_store; pub(crate) mod aggregate_store; +pub(crate) mod aggregate_store_v2; pub(crate) mod aggregation_job_store; +pub(crate) mod replay_checker; #[cfg(feature = "test-utils")] pub(crate) mod test_state_cleaner; @@ -33,8 +36,11 @@ use serde::{Deserialize, Serialize}; use tracing::info_span; use worker::{Env, Error, Request, Response, Result, ScheduledTime, State}; +pub use agg_job_response_store::AggJobResponseStore; pub use aggregate_store::AggregateStore; +pub use aggregate_store_v2::AggregateStoreV2; pub use aggregation_job_store::AggregationJobStore; +pub use replay_checker::ReportIdReplayCheck; const ERR_NO_VALUE: &str = "No such value in storage."; @@ -164,12 +170,93 @@ macro_rules! mk_durable_object { } #[allow(dead_code)] + /// Set a key/value pair unless the key already exists. If the key exists, then return the current + /// value. Otherwise return nothing. async fn put_if_not_exists(&self, key: &str, val: &T) -> ::worker::Result> where T: ::serde::de::DeserializeOwned + ::serde::Serialize, { $crate::durable::state_set_if_not_exists(&self.state, key, val).await } + + #[allow(dead_code)] + async fn has(&self, key: &str) -> ::worker::Result { + $crate::durable::state_contains_key(&self.state, key).await + } + + #[allow(dead_code)] + async fn load_chuncked_value(&self, prefix: &str) -> ::worker::Result> + where + T: daphne_service_utils::capnproto::CapnprotoPayloadDecode, + { + let Some(count) = self.get::(&format!("{prefix}_count")).await? else { + return Ok(None); + }; + + let keys = &$crate::durable::calculate_chunk_keys(count, prefix); + let map = self.state.storage().get_multiple(keys.clone()).await?; + let bytes = keys + .iter() + .map(|k| wasm_bindgen::JsValue::from_str(k.as_ref())) + .filter(|k| map.has(k)) + .map(|k| map.get(&k)) + .map(|js_v| { + serde_wasm_bindgen::from_value::>(js_v).expect("expect an array of bytes") + }) + .reduce(|mut buf, v| { + buf.extend_from_slice(&v); + buf + }); + + let Some(bytes) = bytes else { + return Ok(None); + }; + + ::decode_from_bytes(&bytes) + .map(Some) + .map_err(|e| worker::Error::RustError(e.to_string())) + } + + #[allow(dead_code)] + fn serialize_chunked_value( + &self, + prefix: &str, + value: &T, + object_to_fill: impl Into> + ) -> ::worker::Result + where + T: daphne_service_utils::capnproto::CapnprotoPayloadEncode, + { + let object_to_fill = object_to_fill.into().unwrap_or_default(); + use daphne_service_utils::capnproto::CapnprotoPayloadEncodeExt; + let bytes = value.encode_to_bytes(); + let chunk_keys = $crate::durable::chunk_keys_for(&bytes, prefix); + let mut base_idx = 0; + for key in &chunk_keys { + let end = usize::min(base_idx + $crate::durable::MAX_CHUNK_SIZE, bytes.len()); + let chunk = &bytes[base_idx..end]; + + // unwrap cannot fail because chunk len is bounded by MAX_CHUNK_SIZE which is smaller than + // u32::MAX + let value = worker::js_sys::Uint8Array::new_with_length(u32::try_from(chunk.len()).unwrap()); + value.copy_from(chunk); + + worker::js_sys::Reflect::set( + &object_to_fill, + &wasm_bindgen::JsValue::from_str(key.as_ref()), + &value.into(), + )?; + + base_idx = end; + } + worker::js_sys::Reflect::set( + &object_to_fill, + &wasm_bindgen::JsValue::from_str(&format!("{prefix}_count")), + &chunk_keys.len().into(), + )?; + + Ok(object_to_fill) + } } }; } @@ -220,6 +307,20 @@ pub(crate) async fn state_set_if_not_exists Deserialize<'a> + Seriali Ok(None) } +pub(crate) async fn state_contains_key(state: &State, key: &str) -> Result { + struct DevNull; + impl<'de> Deserialize<'de> for DevNull { + fn deserialize(_: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + Ok(Self) + } + } + + Ok(state_get::(state, key).await?.is_some()) +} + async fn req_parse(req: &mut Request) -> Result where T: CapnprotoPayloadDecode, @@ -234,6 +335,31 @@ fn create_span_from_request(req: &Request) -> tracing::Span { span } +/// The maximum chunk size as documented in +/// [the worker docs](https://developers.cloudflare.com/durable-objects/platform/limits/) +const MAX_CHUNK_SIZE: usize = 128_000; + +fn chunk_keys_for(bytes: &[u8], prefix: &str) -> Vec { + // stolen from + // https://doc.rust-lang.org/std/primitive.usize.html#method.div_ceil + // because it's nightly only + fn div_ceil(lhs: usize, rhs: usize) -> usize { + let d = lhs / rhs; + let r = lhs % rhs; + if r > 0 && rhs > 0 { + d + 1 + } else { + d + } + } + + calculate_chunk_keys(div_ceil(bytes.len(), MAX_CHUNK_SIZE), prefix) +} + +fn calculate_chunk_keys(count: usize, prefix: &str) -> Vec { + (0..count).map(|i| format!("{prefix}_{i:04}")).collect() +} + /// Instantiate a durable object. /// /// # Syntax diff --git a/crates/daphne-worker/src/durable/replay_checker.rs b/crates/daphne-worker/src/durable/replay_checker.rs new file mode 100644 index 000000000..7b3076f52 --- /dev/null +++ b/crates/daphne-worker/src/durable/replay_checker.rs @@ -0,0 +1,118 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use super::{req_parse, GcDurableObject}; +use crate::int_err; +use daphne::messages::{AggregationJobId, ReportId}; +use daphne_service_utils::durable_requests::bindings::{self, replay_checker, DurableMethod}; +use std::{ + collections::{HashMap, HashSet}, + iter::zip, + sync::OnceLock, + time::Duration, +}; +use wasm_bindgen::JsValue; +use worker::{js_sys, Env, Request, Response, Result, ScheduledTime, State}; + +super::mk_durable_object! { + /// Where report ids are stored for replay protection. + struct ReportIdReplayCheck { + state: State, + env: Env, + seen: HashMap, + } +} + +impl GcDurableObject for ReportIdReplayCheck { + type DurableMethod = bindings::AggregateStore; + + fn with_state_and_env(state: State, env: Env) -> Self { + Self { + state, + env, + seen: Default::default(), + } + } + + async fn handle(&mut self, mut req: Request) -> Result { + match replay_checker::Command::try_from_uri(&req.path()) { + Some(replay_checker::Command::Check) => { + let replay_checker::Request { + report_ids, + aggregation_job_id, + } = req_parse(&mut req).await?; + + let mut duplicates = HashSet::new(); + + let report_ids_as_string = report_ids + .iter() + .filter(|r| match self.seen.get(r) { + Some(cached_agg_job_id) => { + if *cached_agg_job_id != aggregation_job_id { + duplicates.insert(**r); + } + false // skip checking + } + None => true, // check against disk + }) + .map(ToString::to_string) + .collect::>(); + + let aggregation_job_id_as_str = aggregation_job_id.to_string(); + + let result = self + .state + .storage() + .get_multiple(report_ids_as_string.clone()) + .await?; + + let obj_to_update = js_sys::Object::new(); + for (id, as_str) in zip(report_ids.iter(), &report_ids_as_string) { + self.seen.insert(*id, aggregation_job_id); + + let v = result.get(&JsValue::from_str(as_str)); + if let Some(stored_agg_job_id) = v.as_string() { + if stored_agg_job_id != aggregation_job_id_as_str { + duplicates.insert(*id); + } + } else { + js_sys::Reflect::set( + &obj_to_update, + &JsValue::from_str(as_str), + &JsValue::from_str(aggregation_job_id_as_str.as_ref()), + )?; + } + } + + self.state.storage().put_multiple_raw(obj_to_update).await?; + + Response::from_json(&replay_checker::Response { duplicates }) + } + None => Err(int_err(format!( + "AggregatesStore: unexpected request: method={:?}; path={:?}", + req.method(), + req.path() + ))), + } + } + + fn should_cleanup_at(&self) -> Option { + const VAR_NAME: &str = "DO_REPLAY_CHECKER_GC_SECS"; + static SELF_DELETE_AFTER: OnceLock = OnceLock::new(); + + let duration = SELF_DELETE_AFTER.get_or_init(|| { + Duration::from_secs( + self.env + .var(VAR_NAME) + .map(|v| { + v.to_string().parse().unwrap_or_else(|e| { + panic!("{VAR_NAME} could not be parsed as a number of seconds: {e}") + }) + }) + .unwrap_or(60 * 60 * 24 * 7), // one week + ) + }); + + Some(ScheduledTime::from(*duration)) + } +} diff --git a/crates/daphne-worker/src/lib.rs b/crates/daphne-worker/src/lib.rs index 6b8fb25ea..5d7975245 100644 --- a/crates/daphne-worker/src/lib.rs +++ b/crates/daphne-worker/src/lib.rs @@ -30,3 +30,16 @@ pub(crate) fn int_err(s: S) -> Error { pub(crate) fn elapsed(date: &worker::Date) -> Duration { Duration::from_millis(worker::Date::now().as_millis() - date.as_millis()) } + +pub(crate) use daphne_service_utils::base_capnp; +pub(crate) use daphne_service_utils::compute_offload_capnp; + +mod queue_messages_capnp { + #![allow(dead_code)] + #![allow(clippy::pedantic)] + #![allow(clippy::needless_lifetimes)] + include!(concat!( + env!("OUT_DIR"), + "/src/aggregator/queues/queue_messages_capnp.rs" + )); +} diff --git a/crates/daphne/src/lib.rs b/crates/daphne/src/lib.rs index d7499078a..9661b62ac 100644 --- a/crates/daphne/src/lib.rs +++ b/crates/daphne/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause //! This crate implements the core protocol logic for the Distributed Aggregation Protocol @@ -354,7 +354,9 @@ impl IntoIterator for DapAggregateSpan { } impl ReportId { - fn shard(&self, num_shards: NonZeroUsize) -> usize { + /// Deterministically calculate a number between 0 and `num_shards` based on the report id. + /// Usefull for sharding datastores. + pub fn shard(&self, num_shards: NonZeroUsize) -> usize { // NOTE This sharding scheme does not evenly distribute reports across all shards. // // First, the clients are supposed to choose the report ID at random; by finding collisions diff --git a/crates/daphne/src/messages/mod.rs b/crates/daphne/src/messages/mod.rs index 959a96c1e..eac07d9e9 100644 --- a/crates/daphne/src/messages/mod.rs +++ b/crates/daphne/src/messages/mod.rs @@ -938,18 +938,57 @@ impl std::fmt::Display for ReportError { } /// An aggregate response sent from the Helper to the Leader. -#[derive(Debug, PartialEq, Eq, Default, Serialize, Deserialize)] -pub struct AggregationJobResp { +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))] +pub enum AggregationJobResp { + Ready { transitions: Vec }, + Processing, +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))] +pub struct ReadyAggregationJobResp { pub transitions: Vec, } +impl From for AggregationJobResp { + fn from(value: ReadyAggregationJobResp) -> Self { + Self::Ready { + transitions: value.transitions, + } + } +} + +impl AggregationJobResp { + #[cfg(any(test, feature = "test-utils"))] + #[track_caller] + pub fn unwrap_ready(self) -> ReadyAggregationJobResp { + match self { + Self::Ready { transitions } => ReadyAggregationJobResp { transitions }, + Self::Processing => panic!("unwraped a Processing value"), + } + } +} + impl ParameterizedEncode for AggregationJobResp { fn encode_with_param( &self, version: &DapVersion, bytes: &mut Vec, ) -> Result<(), CodecError> { - encode_u32_items(bytes, version, &self.transitions) + match (self, version) { + (Self::Ready { transitions }, DapVersion::Draft09) => { + encode_u32_items(bytes, version, transitions) + } + (Self::Ready { transitions }, DapVersion::Latest) => { + 1u8.encode(bytes)?; + encode_u32_items(bytes, version, transitions) + } + (Self::Processing, DapVersion::Draft09) => Err(CodecError::Other( + "AggregationJobResp::Processing not supported in draft-09".into(), + )), + (Self::Processing, DapVersion::Latest) => 0u8.encode(bytes), + } } } @@ -958,9 +997,18 @@ impl ParameterizedDecode for AggregationJobResp { version: &DapVersion, bytes: &mut Cursor<&[u8]>, ) -> Result { - Ok(Self { - transitions: decode_u32_items(version, bytes)?, - }) + match version { + DapVersion::Draft09 => Ok(Self::Ready { + transitions: decode_u32_items(version, bytes)?, + }), + DapVersion::Latest => match u8::decode(bytes)? { + 0 => Ok(Self::Processing), + 1 => Ok(Self::Ready { + transitions: decode_u32_items(version, bytes)?, + }), + _ => Err(CodecError::UnexpectedValue), + }, + } } } @@ -1810,7 +1858,7 @@ mod test { 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 2, 7, ]; - let want = AggregationJobResp { + let want = AggregationJobResp::Ready { transitions: vec![ Transition { report_id: ReportId([22; 16]), @@ -1828,14 +1876,18 @@ mod test { }, ], }; - println!( - "want {:?}", - want.get_encoded_with_param(&DapVersion::Latest).unwrap() - ); - let got = - AggregationJobResp::get_decoded_with_param(&DapVersion::Latest, TEST_DATA).unwrap(); + AggregationJobResp::get_decoded_with_param(&DapVersion::Draft09, TEST_DATA).unwrap(); assert_eq!(got, want); + let draft_latest_data = [&[1], TEST_DATA].concat(); + let got = + AggregationJobResp::get_decoded_with_param(&DapVersion::Latest, &draft_latest_data) + .unwrap(); + assert_eq!(got, want); + assert_eq!( + AggregationJobResp::Processing, + AggregationJobResp::get_decoded_with_param(&DapVersion::Latest, &[0]).unwrap(), + ); } #[test] diff --git a/crates/daphne/src/messages/request.rs b/crates/daphne/src/messages/request.rs index aeeca629b..45dd9b700 100644 --- a/crates/daphne/src/messages/request.rs +++ b/crates/daphne/src/messages/request.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use std::ops::Deref; @@ -14,6 +14,9 @@ pub trait RequestBody { type ResourceId; } +/// A poll request has no body, but requires a `AggregationJobId`. +pub struct PollAggregationJob; + /// A poll request has no body, but requires a `CollectionJobId`. pub struct CollectionPollReq; @@ -94,6 +97,7 @@ impl_req_body! { Report | () AggregationJobInitReq | AggregationJobId HashedAggregationJobReq | AggregationJobId + PollAggregationJob | AggregationJobId AggregateShareReq | () CollectionReq | CollectionJobId CollectionPollReq | CollectionJobId diff --git a/crates/daphne/src/protocol/aggregator.rs b/crates/daphne/src/protocol/aggregator.rs index 479d42312..a49a8dad9 100644 --- a/crates/daphne/src/protocol/aggregator.rs +++ b/crates/daphne/src/protocol/aggregator.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use super::{ @@ -12,8 +12,8 @@ use crate::{ hpke::{info_and_aad, HpkeConfig, HpkeDecrypter}, messages::{ self, encode_u32_bytes, AggregationJobInitReq, AggregationJobResp, Base64Encode, - BatchSelector, HpkeCiphertext, PartialBatchSelector, PrepareInit, Report, ReportError, - ReportId, ReportShare, TaskId, Transition, TransitionVar, + BatchSelector, HpkeCiphertext, PartialBatchSelector, PrepareInit, ReadyAggregationJobResp, + Report, ReportError, ReportId, ReportShare, TaskId, Transition, TransitionVar, }, metrics::{DaphneMetrics, ReportStatus}, protocol::{decode_ping_pong_framed, PingPongMessageType}, @@ -279,7 +279,7 @@ impl DapTaskConfig { report_status: &HashMap, part_batch_sel: &PartialBatchSelector, initialized_reports: &[InitializedReport], - ) -> Result<(DapAggregateSpan, AggregationJobResp), DapError> { + ) -> Result<(DapAggregateSpan, ReadyAggregationJobResp), DapError> { let num_reports = initialized_reports.len(); let mut agg_span = DapAggregateSpan::default(); let mut transitions = Vec::with_capacity(num_reports); @@ -355,7 +355,7 @@ impl DapTaskConfig { }); } - Ok((agg_span, AggregationJobResp { transitions })) + Ok((agg_span, ReadyAggregationJobResp { transitions })) } /// Leader: Consume the `AggregationJobResp` message sent by the Helper and compute the @@ -367,11 +367,14 @@ impl DapTaskConfig { agg_job_resp: AggregationJobResp, metrics: &dyn DaphneMetrics, ) -> Result, DapError> { - if agg_job_resp.transitions.len() != state.seq.len() { + let AggregationJobResp::Ready { transitions } = agg_job_resp else { + todo!("polling from the leader not implemented yet") + }; + if transitions.len() != state.seq.len() { return Err(DapAbort::InvalidMessage { detail: format!( "aggregation job response has {} reports; expected {}", - agg_job_resp.transitions.len(), + transitions.len(), state.seq.len(), ), task_id: *task_id, @@ -380,7 +383,7 @@ impl DapTaskConfig { } let mut agg_span = DapAggregateSpan::default(); - for (helper, leader) in zip(agg_job_resp.transitions, state.seq) { + for (helper, leader) in zip(transitions, state.seq) { if helper.report_id != leader.report_id { return Err(DapAbort::InvalidMessage { detail: format!( diff --git a/crates/daphne/src/protocol/mod.rs b/crates/daphne/src/protocol/mod.rs index 3dbb9832e..6841af1b3 100644 --- a/crates/daphne/src/protocol/mod.rs +++ b/crates/daphne/src/protocol/mod.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use prio::codec::{CodecError, Decode as _}; @@ -67,8 +67,9 @@ mod test { error::DapAbort, hpke::{HpkeAeadId, HpkeConfig, HpkeKdfId, HpkeKemId}, messages::{ - AggregationJobInitReq, BatchSelector, Extension, Interval, PartialBatchSelector, - PrepareInit, Report, ReportError, ReportId, ReportShare, Transition, TransitionVar, + AggregationJobInitReq, AggregationJobResp, BatchSelector, Extension, Interval, + PartialBatchSelector, PrepareInit, Report, ReportError, ReportId, ReportShare, + Transition, TransitionVar, }, test_versions, testing::AggregationJobTest, @@ -249,8 +250,11 @@ mod test { } let (agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let AggregationJobResp::Ready { transitions } = agg_job_resp else { + panic!("expected a ready response, got processing") + }; assert_eq!(agg_span.report_count(), 3); - assert_eq!(agg_job_resp.transitions.len(), 3); + assert_eq!(transitions.len(), 3); } test_versions! { produce_agg_job_req } @@ -376,9 +380,12 @@ mod test { t.produce_agg_job_req(&DapAggregationParam::Empty, reports.clone()); let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); - assert_eq!(agg_job_resp.transitions.len(), 1); + let AggregationJobResp::Ready { transitions } = agg_job_resp else { + panic!("expected a ready response, got processing") + }; + assert_eq!(transitions.len(), 1); assert_matches!( - agg_job_resp.transitions[0].var, + transitions[0].var, TransitionVar::Failed(ReportError::HpkeDecryptError) ); } @@ -411,9 +418,12 @@ mod test { }; let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); - assert_eq!(agg_job_resp.transitions.len(), 1); + let AggregationJobResp::Ready { transitions } = agg_job_resp else { + panic!("expected a ready response, got processing") + }; + assert_eq!(transitions.len(), 1); assert_matches!( - agg_job_resp.transitions[0].var, + transitions[0].var, TransitionVar::Failed(ReportError::ReportDropped) ); } @@ -446,9 +456,12 @@ mod test { }; let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); - assert_eq!(agg_job_resp.transitions.len(), 1); + let AggregationJobResp::Ready { transitions } = agg_job_resp else { + panic!("expected a ready response, got processing") + }; + assert_eq!(transitions.len(), 1); assert_matches!( - agg_job_resp.transitions[0].var, + transitions[0].var, TransitionVar::Failed(ReportError::ReportTooEarly) ); } @@ -466,9 +479,12 @@ mod test { t.produce_agg_job_req(&DapAggregationParam::Empty, reports.clone()); let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); - assert_eq!(agg_job_resp.transitions.len(), 1); + let AggregationJobResp::Ready { transitions } = agg_job_resp else { + panic!("expected a ready response, got processing") + }; + assert_eq!(transitions.len(), 1); assert_matches!( - agg_job_resp.transitions[0].var, + transitions[0].var, TransitionVar::Failed(ReportError::HpkeUnknownConfigId) ); } @@ -511,13 +527,17 @@ mod test { let (_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); - assert_eq!(agg_job_resp.transitions.len(), 2); + let AggregationJobResp::Ready { transitions } = agg_job_resp else { + panic!("expected a ready response, got processing") + }; + + assert_eq!(transitions.len(), 2); assert_matches!( - agg_job_resp.transitions[0].var, + transitions[0].var, TransitionVar::Failed(ReportError::VdafPrepError) ); assert_matches!( - agg_job_resp.transitions[1].var, + transitions[1].var, TransitionVar::Failed(ReportError::VdafPrepError) ); } @@ -534,10 +554,12 @@ mod test { t.produce_agg_job_req(&DapAggregationParam::Empty, reports); let (_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let AggregationJobResp::Ready { transitions } = &mut agg_job_resp else { + panic!("expected a ready response, got processing") + }; + // Helper sends transitions out of order. - let tmp = agg_job_resp.transitions[0].clone(); - agg_job_resp.transitions[0] = agg_job_resp.transitions[1].clone(); - agg_job_resp.transitions[1] = tmp; + transitions.swap(0, 1); assert_matches!( t.consume_agg_job_resp_expect_err(leader_state, agg_job_resp), @@ -557,9 +579,13 @@ mod test { t.produce_agg_job_req(&DapAggregationParam::Empty, reports); let (_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let AggregationJobResp::Ready { transitions } = &mut agg_job_resp else { + panic!("expected a ready response, got processing") + }; + // Helper sends a transition twice. - let repeated_transition = agg_job_resp.transitions[0].clone(); - agg_job_resp.transitions.push(repeated_transition); + let repeated_transition = transitions[0].clone(); + transitions.push(repeated_transition); assert_matches!( t.consume_agg_job_resp_expect_err(leader_state, agg_job_resp), @@ -580,8 +606,11 @@ mod test { t.produce_agg_job_req(&DapAggregationParam::Empty, reports); let (_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let AggregationJobResp::Ready { transitions } = &mut agg_job_resp else { + panic!("expected a ready response, got processing") + }; // Helper sent a transition with an unrecognized report ID. - agg_job_resp.transitions.push(Transition { + transitions.push(Transition { report_id: ReportId(rng.gen()), var: TransitionVar::Continued(b"whatever".to_vec()), }); @@ -601,9 +630,12 @@ mod test { t.produce_agg_job_req(&DapAggregationParam::Empty, reports); let (_helper_agg_span, mut agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let AggregationJobResp::Ready { transitions } = &mut agg_job_resp else { + panic!("expected a ready response, got processing") + }; // Helper sent a transition with an unrecognized report ID. Simulate this by flipping the // first bit of the report ID. - agg_job_resp.transitions[0].report_id.0[0] ^= 1; + transitions[0].report_id.0[0] ^= 1; assert_matches!( t.consume_agg_job_resp_expect_err(leader_state, agg_job_resp), @@ -681,9 +713,12 @@ mod test { .collect::>(); let (helper_agg_span, agg_job_resp) = t.handle_agg_job_req(agg_job_init_req); + let AggregationJobResp::Ready { transitions } = &agg_job_resp else { + panic!("expected a ready response, got processing") + }; assert_eq!(2, helper_agg_span.report_count()); - assert_eq!(3, agg_job_resp.transitions.len()); - for (transition, prep_init_id) in zip(&agg_job_resp.transitions, prep_init_ids) { + assert_eq!(3, transitions.len()); + for (transition, prep_init_id) in zip(transitions, prep_init_ids) { assert_eq!(transition.report_id, prep_init_id); } diff --git a/crates/daphne/src/protocol/report_init.rs b/crates/daphne/src/protocol/report_init.rs index c38b5fdd9..b88010810 100644 --- a/crates/daphne/src/protocol/report_init.rs +++ b/crates/daphne/src/protocol/report_init.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use crate::{ @@ -256,7 +256,7 @@ impl

InitializedReport

{ } } - pub(crate) fn metadata(&self) -> &ReportMetadata { + pub fn metadata(&self) -> &ReportMetadata { match self { Self::Ready { metadata, .. } | Self::Rejected { metadata, .. } => metadata, } diff --git a/crates/daphne/src/roles/aggregator.rs b/crates/daphne/src/roles/aggregator.rs index d9dd4bac9..c90cd56ce 100644 --- a/crates/daphne/src/roles/aggregator.rs +++ b/crates/daphne/src/roles/aggregator.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use std::{collections::HashSet, ops::Range}; @@ -113,6 +113,7 @@ pub trait DapAggregator: HpkeProvider + Sized { /// Fetch the aggregate share for the given batch. async fn get_agg_share( &self, + version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result; diff --git a/crates/daphne/src/roles/helper/handle_agg_job.rs b/crates/daphne/src/roles/helper/handle_agg_job.rs index 412e01c7d..4dc1d522e 100644 --- a/crates/daphne/src/roles/helper/handle_agg_job.rs +++ b/crates/daphne/src/roles/helper/handle_agg_job.rs @@ -6,14 +6,20 @@ use crate::{ error::DapAbort, messages::{ request::HashedAggregationJobReq, AggregationJobInitReq, AggregationJobResp, - PartialBatchSelector, ReportError, TaskId, TransitionVar, + PartialBatchSelector, ReadyAggregationJobResp, ReportError, ReportId, ReportMetadata, + TaskId, TransitionVar, }, metrics::ReportStatus, protocol::aggregator::ReportProcessedStatus, roles::{aggregator::MergeAggShareError, resolve_task_config}, - DapError, DapRequest, DapTaskConfig, InitializedReport, WithPeerPrepShare, + DapAggregateShare, DapAggregateSpan, DapError, DapRequest, DapTaskConfig, InitializedReport, + WithPeerPrepShare, +}; +use std::{ + collections::{HashMap, HashSet}, + future::Future, + sync::Once, }; -use std::{collections::HashMap, sync::Once}; /// A state machine for the handling of aggregation jobs. pub struct HandleAggJob { @@ -38,7 +44,6 @@ pub struct WithTaskConfig { /// /// This type is returned by [`HandleAggJob::into_parts`] and [`Self::with_initialized_reports`] /// can be used to return to the [`HandleAggJob`] state machine flow. -#[non_exhaustive] pub struct ToInitializedReportsTransition { pub task_id: TaskId, pub part_batch_sel: PartialBatchSelector, @@ -53,6 +58,8 @@ pub struct InitializedReports { reports: Vec>, } +pub struct UniqueInitializedReports(InitializedReports); + macro_rules! impl_from { ($($t:ty),*$(,)?) => { $(impl From<$t> for HandleAggJob<$t> { @@ -336,7 +343,7 @@ impl HandleAggJob { 0, /* vdaf step */ ); - return Ok(agg_job_resp); + return Ok(agg_job_resp.into()); } } @@ -344,4 +351,51 @@ impl HandleAggJob { // enabling an DOS attack. Err(DapAbort::BadRequest("aggregation job contained too many replays".into()).into()) } + + pub async fn check_for_replays( + mut self, + replay_check: F, + ) -> Result, E> + where + F: FnOnce(&mut dyn Iterator) -> Fut, + Fut: Future, E>>, + { + let replays = replay_check( + &mut self + .state + .reports + .iter() + .filter(|r| matches!(r, InitializedReport::Ready { .. })) + .map(|r| r.metadata()), + ) + .await?; + + for r in &mut self.state.reports { + if replays.contains(&r.metadata().id) { + *r = InitializedReport::Rejected { + metadata: r.metadata().clone(), + report_err: ReportError::ReportReplayed, + } + } + } + + Ok(HandleAggJob { + state: UniqueInitializedReports(self.state), + }) + } +} + +impl HandleAggJob { + pub fn finish( + self, + ) -> Result<(DapAggregateSpan, ReadyAggregationJobResp), DapError> { + let InitializedReports { + task_id, + part_batch_sel, + task_config, + reports, + } = self.state.0; + + task_config.produce_agg_job_resp(task_id, &Default::default(), &part_batch_sel, &reports) + } } diff --git a/crates/daphne/src/roles/helper/mod.rs b/crates/daphne/src/roles/helper/mod.rs index 052a326d1..0c086fcec 100644 --- a/crates/daphne/src/roles/helper/mod.rs +++ b/crates/daphne/src/roles/helper/mod.rs @@ -12,8 +12,9 @@ use crate::{ error::DapAbort, messages::{ constant_time_eq, - request::{AggregationJobRequestHash, HashedAggregationJobReq}, - AggregateShare, AggregateShareReq, AggregationJobId, PartialBatchSelector, TaskId, + request::{AggregationJobRequestHash, HashedAggregationJobReq, PollAggregationJob}, + AggregateShare, AggregateShareReq, AggregationJobId, AggregationJobResp, + PartialBatchSelector, TaskId, }, metrics::{DaphneRequestType, ReportStatus}, protocol::aggregator::ReplayProtection, @@ -34,6 +35,14 @@ pub trait DapHelper: DapAggregator { task_id: &TaskId, req: &AggregationJobRequestHash, ) -> Result<(), DapError>; + + /// Polls for the completion of an aggregation job. + async fn poll_aggregated( + &self, + version: DapVersion, + task_id: &TaskId, + agg_job_id: &AggregationJobId, + ) -> Result; } pub async fn handle_agg_job_init_req( @@ -65,9 +74,25 @@ pub async fn handle_agg_job_init_req( }) } +pub async fn handle_agg_job_poll_req( + aggregator: &A, + req: DapRequest, +) -> Result { + let response = aggregator + .poll_aggregated(req.version, &req.task_id, &req.resource_id) + .await?; + Ok(DapResponse { + version: req.version, + media_type: DapMediaType::AggregationJobResp, + payload: response + .get_encoded_with_param(&req.version) + .map_err(DapError::encoding)?, + }) +} + /// Handle a request for an aggregate share. This is called by the Leader to complete a /// collection job. -pub async fn handle_agg_share_req<'req, A: DapHelper>( +pub async fn handle_agg_share_req( aggregator: &A, req: DapRequest, ) -> Result { @@ -96,7 +121,7 @@ pub async fn handle_agg_share_req<'req, A: DapHelper>( .await?; let agg_share = aggregator - .get_agg_share(&task_id, &req.payload.batch_sel) + .get_agg_share(req.version, &task_id, &req.payload.batch_sel) .await?; // Check that we have aggreagted the same set of reports as the Leader. diff --git a/crates/daphne/src/roles/leader/mod.rs b/crates/daphne/src/roles/leader/mod.rs index aed51c207..fb19b88a6 100644 --- a/crates/daphne/src/roles/leader/mod.rs +++ b/crates/daphne/src/roles/leader/mod.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause pub mod in_memory_leader; @@ -402,7 +402,9 @@ async fn run_coll_job( let metrics = aggregator.metrics(); debug!("collecting id {coll_job_id}"); - let leader_agg_share = aggregator.get_agg_share(task_id, batch_sel).await?; + let leader_agg_share = aggregator + .get_agg_share(task_config.version, task_id, batch_sel) + .await?; let taskprov_advertisement = task_config.resolve_taskprove_advertisement()?; diff --git a/crates/daphne/src/roles/mod.rs b/crates/daphne/src/roles/mod.rs index 3a9deaf7d..ab644d7d0 100644 --- a/crates/daphne/src/roles/mod.rs +++ b/crates/daphne/src/roles/mod.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause //! Trait definitions for Daphne backends. @@ -762,7 +762,7 @@ mod test { .payload, ) .unwrap(); - let transition = &agg_job_resp.transitions[0]; + let transition = agg_job_resp.unwrap_ready().transitions.remove(0); // Expect failure due to invalid ciphertext. assert_matches!( @@ -795,7 +795,7 @@ mod test { .payload, ) .unwrap(); - let transition = &agg_job_resp.transitions[0]; + let transition = agg_job_resp.unwrap_ready().transitions.remove(0); // Expect success due to valid ciphertext. assert_matches!(transition.var, TransitionVar::Continued(_)); diff --git a/crates/daphne/src/testing/mod.rs b/crates/daphne/src/testing/mod.rs index f324bfda8..a3c3e09e6 100644 --- a/crates/daphne/src/testing/mod.rs +++ b/crates/daphne/src/testing/mod.rs @@ -14,7 +14,8 @@ use crate::{ messages::{ self, request::AggregationJobRequestHash, AggregationJobId, AggregationJobInitReq, AggregationJobResp, Base64Encode, BatchId, BatchSelector, Collection, CollectionJobId, - HpkeCiphertext, Interval, PartialBatchSelector, Report, ReportId, TaskId, Time, + HpkeCiphertext, Interval, PartialBatchSelector, ReadyAggregationJobResp, Report, ReportId, + TaskId, Time, }, metrics::{prometheus::DaphnePromMetrics, DaphneMetrics}, roles::{ @@ -208,7 +209,8 @@ impl AggregationJobTest { &self, agg_job_init_req: AggregationJobInitReq, ) -> (DapAggregateSpan, AggregationJobResp) { - self.task_config + let (span, resp) = self + .task_config .produce_agg_job_resp( self.task_id, &HashMap::default(), @@ -224,7 +226,8 @@ impl AggregationJobTest { ) .unwrap(), ) - .unwrap() + .unwrap(); + (span, resp.into()) } /// Leader: Handle `AggregationJobResp`, produce `AggregationJobContinueReq`. @@ -519,6 +522,7 @@ pub struct InMemoryAggregator { // Helper: aggregation jobs processed_jobs: Mutex>, + finished_jobs: Mutex>, } impl DeepSizeOf for InMemoryAggregator { @@ -535,6 +539,7 @@ impl DeepSizeOf for InMemoryAggregator { taskprov_vdaf_verify_key_init, peer, processed_jobs, + finished_jobs, } = self; global_config.deep_size_of_children(context) + tasks.deep_size_of_children(context) @@ -545,6 +550,7 @@ impl DeepSizeOf for InMemoryAggregator { + taskprov_vdaf_verify_key_init.deep_size_of_children(context) + peer.deep_size_of_children(context) + processed_jobs.deep_size_of_children(context) + + finished_jobs.deep_size_of_children(context) } } @@ -569,6 +575,7 @@ impl InMemoryAggregator { taskprov_vdaf_verify_key_init, peer: None, processed_jobs: Default::default(), + finished_jobs: Default::default(), } } @@ -593,6 +600,7 @@ impl InMemoryAggregator { taskprov_vdaf_verify_key_init, peer: peer.into(), processed_jobs: Default::default(), + finished_jobs: Default::default(), } } @@ -814,6 +822,7 @@ impl DapAggregator for InMemoryAggregator { async fn get_agg_share( &self, + _version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result { @@ -898,6 +907,26 @@ impl DapHelper for InMemoryAggregator { } } } + + async fn poll_aggregated( + &self, + _version: DapVersion, + task_id: &TaskId, + agg_job_id: &AggregationJobId, + ) -> Result { + self.finished_jobs + .lock() + .unwrap() + .get(agg_job_id) + .cloned() + .map(Into::into) + .ok_or_else(|| { + DapError::Abort(DapAbort::UnrecognizedAggregationJob { + task_id: *task_id, + agg_job_id: *agg_job_id, + }) + }) + } } #[async_trait]