Skip to content

Commit

Permalink
daphne: Add support for the Mastic VDAF
Browse files Browse the repository at this point in the history
Mastic (https://datatracker.ietf.org/doc/draft-mouris-cfrg-mastic/) is a
VDAF that enables a richer set of functionalities than the VDAFs we
support so so far. The `daphne::vdaf::mastic` module contains a "dummy"
version of Mastic intended to exercise the DAP protocol logic we would
need in order to fully support this VDAF. The `prio` crate now
implements Mastic, so upgrade to a version of the crate that supports it
and replace the dummy VDAF with the real one.

In addition, to complete aggregation of a report, it is necessary to
know the aggregation parameter, which currently is only plumbed to
report initialization. In particular,
`DapTaskConfig::produce_agg_job_resp()` needs the aggregation parameter
from the aggregation job request message. (Likewise,
`ToInitializedReportsTransition::with_initialized_reports()` needs the
aggregation parameter.)

Finally, clean up some API cruft in `daphne::vdaf`:

1. Encapsulate variants of Mastic behind a `MasticConfig` as we've done
   for other VDAFs.

2. Modify `prep_finish_from_shares()` to not take the aggregator ID.
   This is a relic of when we supported DAP-02, when this method may
   have been called by either the Leader or the Helper. Now it's always
   called by the Helper.

3. Implement state encoding for Mastic, as required by the new async
   Helper implementation.

4. Generalize the `prep_init()` function in `daphne::prio3` to be used
   for Mastic as well.
  • Loading branch information
cjpatton committed Jan 28, 2025
1 parent dac3bcc commit bfe5b34
Show file tree
Hide file tree
Showing 14 changed files with 328 additions and 389 deletions.
3 changes: 1 addition & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ matchit = "0.7.3"
p256 = { version = "0.13.2", features = ["ecdsa-core", "ecdsa", "pem"] }
paste = "1.0.15"
prio_draft09 = { package = "prio", version = "0.16.7" }
prio = { git = "https://github.com/divviup/libprio-rs.git", rev = "c50bb9a47b396ad6a08a3fec36b98bcc2d9217a1" }
# TODO Point to version `0.17.0` once release. This revision is one commit ahead of `0.17.0-alpha.0`.
prio = { git = "https://github.com/divviup/libprio-rs.git", rev = "e5e8a47ee4567f7588d0b5c8d20f75dde4061b2f" }
prometheus = "0.13.4"
rand = "0.8.5"
rayon = "1.10.0"
Expand Down
13 changes: 10 additions & 3 deletions crates/daphne-worker/src/aggregator/router/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@ use axum::{
routing::{post, put},
};
use daphne::{
error::DapAbort,
fatal_error,
hpke::HpkeProvider,
messages::AggregateShareReq,
roles::{
helper::{self, HashedAggregationJobReq},
DapAggregator, DapHelper,
},
DapError, DapResponse,
DapAggregationParam, DapError, DapResponse,
};
use daphne_service_utils::compute_offload;
use http::StatusCode;
use prio::codec::ParameterizedEncode;
use prio::codec::{ParameterizedDecode, ParameterizedEncode};
use std::{borrow::Cow, sync::Arc};

pub(super) fn add_helper_routes(router: super::Router<App>) -> super::Router<App> {
Expand Down Expand Up @@ -60,6 +61,12 @@ async fn agg_job(

let hpke_receiver_configs = app.get_hpke_receiver_configs(req.version).await?;

let agg_param = DapAggregationParam::get_decoded_with_param(
&transition.task_config.vdaf,
&req.payload.agg_param,
)
.map_err(|e| DapAbort::from_codec_error(e, req.task_id))?;

let initialized_reports: compute_offload::InitializedReports = app
.compute_offload
.compute(
Expand All @@ -77,7 +84,7 @@ async fn agg_job(
.map_err(|e| fatal_error!(err = ?e, "failed to offload report initialization"))?;

transition
.with_initialized_reports(initialized_reports.reports)
.with_initialized_reports(agg_param, initialized_reports.reports)
.finish_and_aggregate(&*app)
.await
}
Expand Down
2 changes: 1 addition & 1 deletion crates/daphne/src/hpke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ pub mod info_and_aad {
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct AggregateShare<'s> {
// info
pub version: DapVersion,
Expand Down
8 changes: 4 additions & 4 deletions crates/daphne/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ use error::FatalDapError;
use hpke::{HpkeConfig, HpkeKemId};
use messages::taskprov::TaskprovAdvertisement;
#[cfg(feature = "experimental")]
use prio::{codec::Decode, vdaf::poplar1::Poplar1AggregationParam};
use prio::{codec::Decode, vdaf::mastic::MasticAggregationParam};
use prio::{
codec::{CodecError, Encode, ParameterizedDecode},
vdaf::Aggregatable as AggregatableTrait,
Expand Down Expand Up @@ -813,11 +813,11 @@ pub enum DapMeasurement {
}

/// An aggregation parameter.
#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug, PartialEq)]
pub enum DapAggregationParam {
Empty,
#[cfg(feature = "experimental")]
Mastic(Poplar1AggregationParam),
Mastic(MasticAggregationParam),
}

#[cfg(any(test, feature = "test-utils"))]
Expand Down Expand Up @@ -877,7 +877,7 @@ impl ParameterizedDecode<VdafConfig> for DapAggregationParam {
let _ = bytes;
match vdaf_config {
#[cfg(feature = "experimental")]
VdafConfig::Mastic { .. } => Ok(Self::Mastic(Poplar1AggregationParam::decode(bytes)?)),
VdafConfig::Mastic(_) => Ok(Self::Mastic(MasticAggregationParam::decode(bytes)?)),
_ => Ok(Self::Empty),
}
}
Expand Down
16 changes: 12 additions & 4 deletions crates/daphne/src/protocol/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,13 @@ impl DapTaskConfig {
task_id: &TaskId,
agg_job_init_req: AggregationJobInitReq,
replay_protection: ReplayProtection,
) -> Result<Vec<InitializedReport<WithPeerPrepShare>>, DapError>
) -> Result<
(
DapAggregationParam,
Vec<InitializedReport<WithPeerPrepShare>>,
),
DapError,
>
where
H: HpkeDecrypter + Sync,
{
Expand All @@ -260,7 +266,7 @@ impl DapTaskConfig {
agg_job_init_req.prep_inits.len()
);

agg_job_init_req
let initialized_reports = agg_job_init_req
.prep_inits
.into_par_iter()
.map(|prep_init| {
Expand All @@ -274,7 +280,8 @@ impl DapTaskConfig {
&agg_param,
)
})
.collect()
.collect::<Result<Vec<_>, _>>()?;
Ok((agg_param, initialized_reports))
}

/// Helper -> Leader: Produce the `AggregationJobResp` message to send to the Leader and
Expand All @@ -283,6 +290,7 @@ impl DapTaskConfig {
pub(crate) fn produce_agg_job_resp(
&self,
task_id: TaskId,
agg_param: &DapAggregationParam,
report_status: &HashMap<ReportId, ReportProcessedStatus>,
part_batch_sel: &PartialBatchSelector,
initialized_reports: &[InitializedReport<WithPeerPrepShare>],
Expand All @@ -305,8 +313,8 @@ impl DapTaskConfig {
} => {
let res = self.vdaf.prep_finish_from_shares(
self.version,
1,
task_id,
agg_param,
helper_prep_state.clone(),
helper_prep_share.clone(),
leader_prep_share,
Expand Down
10 changes: 8 additions & 2 deletions crates/daphne/src/roles/helper/handle_agg_job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
metrics::ReportStatus,
protocol::aggregator::ReportProcessedStatus,
roles::{aggregator::MergeAggShareError, resolve_task_config},
DapError, DapRequest, DapTaskConfig, InitializedReport, WithPeerPrepShare,
DapAggregationParam, DapError, DapRequest, DapTaskConfig, InitializedReport, WithPeerPrepShare,
};
use std::{collections::HashMap, sync::Once};

Expand Down Expand Up @@ -48,6 +48,7 @@ pub struct ToInitializedReportsTransition {
/// The reports have been initialized and are ready for aggregation.
pub struct InitializedReports {
task_id: TaskId,
agg_param: DapAggregationParam,
part_batch_sel: PartialBatchSelector,
task_config: DapTaskConfig,
reports: Vec<InitializedReport<WithPeerPrepShare>>,
Expand Down Expand Up @@ -142,7 +143,7 @@ impl HandleAggJob<WithTaskConfig> {
} = self.state;
let task_id = request.task_id;
let part_batch_sel = request.payload.part_batch_sel.clone();
let initialized_reports = task_config.consume_agg_job_req(
let (agg_param, initialized_reports) = task_config.consume_agg_job_req(
&aggregator
.get_hpke_receiver_configs(task_config.version)
.await?,
Expand All @@ -155,6 +156,7 @@ impl HandleAggJob<WithTaskConfig> {
Ok(HandleAggJob {
state: InitializedReports {
task_id,
agg_param,
task_config,
part_batch_sel,
reports: initialized_reports,
Expand Down Expand Up @@ -207,6 +209,7 @@ impl ToInitializedReportsTransition {
/// Provide the initialized reports that should be aggregated.
pub fn with_initialized_reports(
self,
agg_param: DapAggregationParam,
reports: Vec<InitializedReport<WithPeerPrepShare>>,
) -> HandleAggJob<InitializedReports> {
let Self {
Expand All @@ -217,6 +220,7 @@ impl ToInitializedReportsTransition {
HandleAggJob {
state: InitializedReports {
task_id,
agg_param,
part_batch_sel,
task_config,
reports,
Expand All @@ -236,6 +240,7 @@ impl HandleAggJob<InitializedReports> {
state:
InitializedReports {
task_id,
agg_param,
part_batch_sel,
task_config,
reports,
Expand All @@ -257,6 +262,7 @@ impl HandleAggJob<InitializedReports> {
for _ in 0..RETRY_COUNT {
let (agg_span, agg_job_resp) = task_config.produce_agg_job_resp(
task_id,
&agg_param,
&report_status,
&part_batch_sel,
&reports,
Expand Down
25 changes: 15 additions & 10 deletions crates/daphne/src/roles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async fn resolve_task_config(
mod test {
use super::{aggregator, helper, leader, DapLeader};
#[cfg(feature = "experimental")]
use crate::vdaf::{mastic::MasticWeight, MasticWeightConfig};
use crate::vdaf::mastic::{MasticConfig, MasticWeight, MasticWeightConfig};
use crate::{
assert_metrics_include, async_test_versions,
constants::DapMediaType,
Expand All @@ -151,7 +151,7 @@ mod test {
use assert_matches::assert_matches;
use prio::codec::{Encode, ParameterizedDecode};
#[cfg(feature = "experimental")]
use prio::{idpf::IdpfInput, vdaf::poplar1::Poplar1AggregationParam};
use prio::idpf::IdpfInput;
use rand::{thread_rng, Rng};
use std::{
collections::HashMap,
Expand Down Expand Up @@ -275,10 +275,10 @@ mod test {

#[cfg(feature = "experimental")]
{
let mastic = VdafConfig::Mastic {
input_size: 1,
let mastic = VdafConfig::Mastic(MasticConfig {
bits: 8,
weight_config: MasticWeightConfig::Count,
};
});
tasks.insert(
mastic_task_id,
DapTaskConfig {
Expand Down Expand Up @@ -2018,6 +2018,8 @@ mod test {
#[cfg(feature = "experimental")]
#[tokio::test]
async fn mastic() {
use prio::vdaf::mastic::MasticAggregationParam;

let t = Test::new(DapVersion::Latest);
let task_id = &t.mastic_task_id;
let task_config = t
Expand All @@ -2043,11 +2045,14 @@ mod test {
// Collector: Request result from the Leader.
let query = task_config.query_for_current_batch_window(t.now);
let agg_param = DapAggregationParam::Mastic(
Poplar1AggregationParam::try_from_prefixes(vec![
IdpfInput::from_bytes(&[0]),
IdpfInput::from_bytes(&[1]),
IdpfInput::from_bytes(&[7]),
])
MasticAggregationParam::new(
vec![
IdpfInput::from_bytes(&[0]),
IdpfInput::from_bytes(&[1]),
IdpfInput::from_bytes(&[7]),
],
true,
)
.unwrap(),
);
leader::handle_coll_job_req(
Expand Down
26 changes: 15 additions & 11 deletions crates/daphne/src/testing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,22 +209,26 @@ impl AggregationJobTest {
&self,
agg_job_init_req: AggregationJobInitReq,
) -> (DapAggregateSpan<DapAggregateShare>, AggregationJobResp) {
let part_batch_sel = agg_job_init_req.part_batch_sel.clone();
let (agg_param, initialized_reports) = self
.task_config
.consume_agg_job_req(
&self.helper_hpke_receiver_config,
self.valid_report_time_range(),
&self.task_id,
agg_job_init_req,
self.replay_protection,
)
.unwrap();

let (span, resp) = self
.task_config
.produce_agg_job_resp(
self.task_id,
&agg_param,
&HashMap::default(),
&agg_job_init_req.part_batch_sel.clone(),
&self
.task_config
.consume_agg_job_req(
&self.helper_hpke_receiver_config,
self.valid_report_time_range(),
&self.task_id,
agg_job_init_req,
self.replay_protection,
)
.unwrap(),
&part_batch_sel,
&initialized_reports,
)
.unwrap();
(span, resp.into())
Expand Down
12 changes: 2 additions & 10 deletions crates/daphne/src/vdaf/draft09.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,26 +97,18 @@ where

pub(crate) fn prep_finish_from_shares<V, const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>(
vdaf: &V,
agg_id: usize,
host_state: V::PrepareState,
host_share: V::PrepareShare,
peer_share_data: &[u8],
) -> Result<(V::OutputShare, Vec<u8>), VdafError>
where
V: Vdaf<AggregationParam = ()> + Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
{
// Decode the Helper's inbound message.
// Decode the peer's inbound message.
let peer_share = V::PrepareShare::get_decoded_with_param(&host_state, peer_share_data)?;

// Preprocess the inbound messages.
let message = vdaf.prepare_shares_to_prepare_message(
&(),
if agg_id == 0 {
[host_share, peer_share]
} else {
[peer_share, host_share]
},
)?;
let message = vdaf.prepare_shares_to_prepare_message(&(), [peer_share, host_share])?;
let message_data = message.get_encoded()?;

// Compute the host's output share.
Expand Down
Loading

0 comments on commit bfe5b34

Please sign in to comment.