From 94abde303573aa051140fa3e950ef2cbdecbc2c7 Mon Sep 17 00:00:00 2001 From: Christopher Patton Date: Wed, 29 Nov 2023 15:11:52 -0800 Subject: [PATCH] taskprov: Enable Prio3SumVecField64MultiproofHmacSha256Aes128 in draft09 Enable the Prio3SumVec variant in taskprov. Don't enable it in draft02, since the version of Prio3 we need (draft-irtf-cfrg-vdaf-08) is incompatible with draft02. While at it, unify the encoding logic across draft02 and draft09 so that the same method is called to encode the query, VDAF, and DP parameters. (The length prefix is only added in draft09.) --- daphne/src/lib.rs | 2 +- daphne/src/messages/mod.rs | 25 ----- daphne/src/messages/taskprov.rs | 159 ++++++++++++++++++++++++++------ daphne/src/roles/mod.rs | 35 ++++++- daphne/src/taskprov.rs | 56 +++++++++-- daphne/src/vdaf/prio3.rs | 3 +- 6 files changed, 214 insertions(+), 66 deletions(-) diff --git a/daphne/src/lib.rs b/daphne/src/lib.rs index 6f1828f75..c7cbc627b 100644 --- a/daphne/src/lib.rs +++ b/daphne/src/lib.rs @@ -478,7 +478,7 @@ impl DapTaskParameters { task_expiration: now + 86400 * 14, // expires in two weeks vdaf_config: messages::taskprov::VdafConfig { dp_config: messages::taskprov::DpConfig::None, - var: messages::taskprov::VdafTypeVar::Prio2 { dimension: 10 }, + var: (&self.vdaf).try_into()?, }, }; diff --git a/daphne/src/messages/mod.rs b/daphne/src/messages/mod.rs index 198d190e6..ee3aaf80f 100644 --- a/daphne/src/messages/mod.rs +++ b/daphne/src/messages/mod.rs @@ -1288,31 +1288,6 @@ fn decode_u16_prefixed( Ok(decoded) } -fn encode_u16_item_for_version>( - bytes: &mut Vec, - version: DapVersion, - item: &E, -) { - match version { - DapVersion::DraftLatest => encode_u16_prefixed(version, bytes, |version, bytes| { - item.encode_with_param(&version, bytes); - }), - DapVersion::Draft02 => item.encode_with_param(&version, bytes), - } -} - -fn decode_u16_item_for_version>( - version: DapVersion, - bytes: &mut Cursor<&[u8]>, -) -> Result { - match version { - DapVersion::DraftLatest => decode_u16_prefixed(version, bytes, |version, inner| { - D::decode_with_param(&version, inner) - }), - DapVersion::Draft02 => D::decode_with_param(&version, bytes), - } -} - #[cfg(test)] mod test { use super::*; diff --git a/daphne/src/messages/taskprov.rs b/daphne/src/messages/taskprov.rs index d942580ef..41a179197 100644 --- a/daphne/src/messages/taskprov.rs +++ b/daphne/src/messages/taskprov.rs @@ -4,8 +4,8 @@ //! draft-wang-ppm-dap-taskprov: Messages for the taskrpov extension for DAP. use crate::messages::{ - decode_u16_bytes, decode_u16_item_for_version, encode_u16_bytes, encode_u16_item_for_version, - Duration, Time, QUERY_TYPE_FIXED_SIZE, QUERY_TYPE_TIME_INTERVAL, + decode_u16_bytes, encode_u16_bytes, Duration, Time, QUERY_TYPE_FIXED_SIZE, + QUERY_TYPE_TIME_INTERVAL, }; use crate::DapVersion; use prio::codec::{ @@ -15,8 +15,11 @@ use prio::codec::{ use serde::{Deserialize, Serialize}; use std::io::Cursor; +use super::{decode_u16_prefixed, encode_u16_prefixed}; + // VDAF type codes. const VDAF_TYPE_PRIO2: u32 = 0xFFFF_0000; +pub(crate) const VDAF_TYPE_PRIO3_SUM_VEC_FIELD64_MULTIPROOF_HMAC_SHA256_AES128: u32 = 0xFFFF_1003; // Differential privacy mechanism types. const DP_MECHANISM_NONE: u8 = 0x01; @@ -24,8 +27,19 @@ const DP_MECHANISM_NONE: u8 = 0x01; /// A VDAF type along with its type-specific data. #[derive(Clone, Deserialize, Serialize, Debug, PartialEq, Eq)] pub enum VdafTypeVar { - Prio2 { dimension: u32 }, - NotImplemented { typ: u32, param: Vec }, + Prio2 { + dimension: u32, + }, + Prio3SumVecField64MultiproofHmacSha256Aes128 { + bits: u8, + length: u32, + chunk_length: u32, + num_proofs: u8, + }, + NotImplemented { + typ: u32, + param: Vec, + }, } impl ParameterizedEncode for VdafTypeVar { @@ -33,14 +47,29 @@ impl ParameterizedEncode for VdafTypeVar { match self { Self::Prio2 { dimension } => { VDAF_TYPE_PRIO2.encode(bytes); - encode_u16_item_for_version(bytes, *version, dimension); + taskprov_encode_u16_prefixed(*version, bytes, |_version, inner| { + dimension.encode(inner); + }); + } + Self::Prio3SumVecField64MultiproofHmacSha256Aes128 { + bits, + length, + chunk_length, + num_proofs, + } => { + VDAF_TYPE_PRIO3_SUM_VEC_FIELD64_MULTIPROOF_HMAC_SHA256_AES128.encode(bytes); + taskprov_encode_u16_prefixed(*version, bytes, |_version, inner| { + bits.encode(inner); + length.encode(inner); + chunk_length.encode(inner); + num_proofs.encode(inner); + }); } Self::NotImplemented { typ, param } => { typ.encode(bytes); - match version { - DapVersion::DraftLatest => encode_u16_bytes(bytes, param), - DapVersion::Draft02 => bytes.extend_from_slice(param), - } + taskprov_encode_u16_prefixed(*version, bytes, |_version, inner| { + inner.extend_from_slice(param); + }); } } } @@ -53,9 +82,23 @@ impl ParameterizedDecode for VdafTypeVar { ) -> Result { let vdaf_type = u32::decode(bytes)?; match (version, vdaf_type) { - (.., VDAF_TYPE_PRIO2) => Ok(Self::Prio2 { - dimension: decode_u16_item_for_version(*version, bytes)?, - }), + (.., VDAF_TYPE_PRIO2) => { + taskprov_decode_u16_prefixed(*version, bytes, |_version, inner| { + Ok(Self::Prio2 { + dimension: u32::decode(inner)?, + }) + }) + } + (.., VDAF_TYPE_PRIO3_SUM_VEC_FIELD64_MULTIPROOF_HMAC_SHA256_AES128) => { + taskprov_decode_u16_prefixed(*version, bytes, |_version, inner| { + Ok(Self::Prio3SumVecField64MultiproofHmacSha256Aes128 { + bits: u8::decode(inner)?, + length: u32::decode(inner)?, + chunk_length: u32::decode(inner)?, + num_proofs: u8::decode(inner)?, + }) + }) + } (DapVersion::DraftLatest, ..) => Ok(Self::NotImplemented { typ: vdaf_type, param: decode_u16_bytes(bytes)?, @@ -79,15 +122,14 @@ impl ParameterizedEncode for DpConfig { match self { Self::None => { DP_MECHANISM_NONE.encode(bytes); - encode_u16_item_for_version(bytes, *version, &()); + taskprov_encode_u16_prefixed(*version, bytes, |_, _| ()); } Self::NotImplemented { typ, param } => { typ.encode(bytes); - match version { - DapVersion::DraftLatest => encode_u16_bytes(bytes, param), - DapVersion::Draft02 => bytes.extend_from_slice(param), - } + taskprov_encode_u16_prefixed(*version, bytes, |_version, inner| { + inner.extend_from_slice(param); + }); } } } @@ -101,7 +143,9 @@ impl ParameterizedDecode for DpConfig { let dp_mechanism = u8::decode(bytes)?; match (version, dp_mechanism) { (.., DP_MECHANISM_NONE) => { - decode_u16_item_for_version::<()>(*version, bytes)?; + taskprov_decode_u16_prefixed::<()>(*version, bytes, |_version, inner| { + <()>::decode(inner) + })?; Ok(Self::None) } (DapVersion::DraftLatest, ..) => Ok(Self::NotImplemented { @@ -207,18 +251,19 @@ impl ParameterizedEncode for QueryConfig { match &self.var { QueryConfigVar::TimeInterval => { QUERY_TYPE_TIME_INTERVAL.encode(bytes); - encode_u16_item_for_version(bytes, *version, &()); + taskprov_encode_u16_prefixed(*version, bytes, |_, _| ()); } QueryConfigVar::FixedSize { max_batch_size } => { QUERY_TYPE_FIXED_SIZE.encode(bytes); - encode_u16_item_for_version(bytes, *version, max_batch_size); + taskprov_encode_u16_prefixed(*version, bytes, |_version, inner| { + max_batch_size.encode(inner); + }); } QueryConfigVar::NotImplemented { typ, param } => { typ.encode(bytes); - match version { - DapVersion::DraftLatest => encode_u16_bytes(bytes, param), - DapVersion::Draft02 => bytes.extend_from_slice(param), - } + taskprov_encode_u16_prefixed(*version, bytes, |_version, inner| { + inner.extend_from_slice(param); + }); } } } @@ -239,12 +284,18 @@ impl ParameterizedDecode for QueryConfig { let query_type = query_type.unwrap_or(u8::decode(bytes)?); let var = match (version, query_type) { (.., QUERY_TYPE_TIME_INTERVAL) => { - decode_u16_item_for_version::<()>(*version, bytes)?; + taskprov_decode_u16_prefixed::<()>(*version, bytes, |_version, inner| { + <()>::decode(inner) + })?; QueryConfigVar::TimeInterval } - (.., QUERY_TYPE_FIXED_SIZE) => QueryConfigVar::FixedSize { - max_batch_size: decode_u16_item_for_version(*version, bytes)?, - }, + (.., QUERY_TYPE_FIXED_SIZE) => { + taskprov_decode_u16_prefixed(*version, bytes, |_version, inner| { + Ok(QueryConfigVar::FixedSize { + max_batch_size: u32::decode(inner)?, + }) + })? + } (DapVersion::DraftLatest, ..) => QueryConfigVar::NotImplemented { typ: query_type, param: decode_u16_bytes(bytes)?, @@ -318,6 +369,30 @@ impl ParameterizedDecode for TaskConfig { } } +fn taskprov_encode_u16_prefixed( + version: DapVersion, + bytes: &mut Vec, + e: impl Fn(DapVersion, &mut Vec), +) { + match version { + DapVersion::DraftLatest => encode_u16_prefixed(version, bytes, e), + // draft02 compatibility: No length prefix is used. + DapVersion::Draft02 => e(version, bytes), + } +} + +fn taskprov_decode_u16_prefixed( + version: DapVersion, + bytes: &mut Cursor<&[u8]>, + d: impl Fn(DapVersion, &mut Cursor<&[u8]>) -> Result, +) -> Result { + match version { + DapVersion::DraftLatest => decode_u16_prefixed(version, bytes, d), + // draft02 compatibility: No length prefix is used. + DapVersion::Draft02 => d(version, bytes), + } +} + #[cfg(test)] mod tests { use crate::test_versions; @@ -443,7 +518,7 @@ mod tests { .is_err()); } - fn roundtrip_vdaf_config(version: DapVersion) { + fn roundtrip_vdaf_config_prio2(version: DapVersion) { let vdaf_config = VdafConfig { dp_config: DpConfig::None, var: VdafTypeVar::Prio2 { dimension: 1337 }, @@ -458,7 +533,31 @@ mod tests { ); } - test_versions! { roundtrip_vdaf_config } + test_versions! { roundtrip_vdaf_config_prio2 } + + fn roundtrip_vdaf_config_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( + version: DapVersion, + ) { + let vdaf_config = VdafConfig { + dp_config: DpConfig::None, + var: VdafTypeVar::Prio3SumVecField64MultiproofHmacSha256Aes128 { + bits: 23, + length: 1337, + chunk_length: 42, + num_proofs: 99, + }, + }; + assert_eq!( + VdafConfig::get_decoded_with_param( + &version, + &vdaf_config.get_encoded_with_param(&version) + ) + .unwrap(), + vdaf_config + ); + } + + test_versions! { roundtrip_vdaf_config_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128 } #[test] fn roundtrip_vdaf_config_not_implemented_draft09() { diff --git a/daphne/src/roles/mod.rs b/daphne/src/roles/mod.rs index fdba13383..78f5345cc 100644 --- a/daphne/src/roles/mod.rs +++ b/daphne/src/roles/mod.rs @@ -1752,7 +1752,11 @@ mod test { async_test_versions! { e2e_fixed_size } - async fn e2e_taskprov(version: DapVersion) { + async fn e2e_taskprov( + version: DapVersion, + vdaf_config: VdafConfig, + test_measurement: DapMeasurement, + ) { let t = Test::new(version); let (task_config, task_id, taskprov_advertisement, taskprov_report_extension_payload) = @@ -1760,6 +1764,7 @@ mod test { version, min_batch_size: 1, query: DapQueryConfig::FixedSize { max_batch_size: 2 }, + vdaf: vdaf_config, ..Default::default() } .to_config_with_taskprov( @@ -1792,7 +1797,7 @@ mod test { &hpke_config_list, t.now, &task_id, - DapMeasurement::U32Vec(vec![1; 10]), + test_measurement.clone(), vec![Extension::Taskprov { draft02_payload: match version { DapVersion::DraftLatest => None, @@ -1850,7 +1855,31 @@ mod test { }); } - async_test_versions! { e2e_taskprov } + async fn e2e_taskprov_prio2(version: DapVersion) { + e2e_taskprov( + version, + VdafConfig::Prio2 { dimension: 10 }, + DapMeasurement::U32Vec(vec![1; 10]), + ) + .await; + } + + async_test_versions! { e2e_taskprov_prio2 } + + #[tokio::test] + async fn e2e_taskprov_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128_draft09() { + e2e_taskprov( + DapVersion::DraftLatest, + VdafConfig::Prio3(Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { + bits: 1, + length: 10, + chunk_length: 2, + num_proofs: 4, + }), + DapMeasurement::U64Vec(vec![1; 10]), + ) + .await; + } fn early_metadata_checks(version: DapVersion) { let t = Test::new(version); diff --git a/daphne/src/taskprov.rs b/daphne/src/taskprov.rs index 1b1689d62..1ed99ea46 100644 --- a/daphne/src/taskprov.rs +++ b/daphne/src/taskprov.rs @@ -15,7 +15,7 @@ use crate::{ }, vdaf::VdafVerifyKey, DapAbort, DapError, DapQueryConfig, DapRequest, DapTaskConfig, DapTaskConfigMethod, DapVersion, - VdafConfig, + Prio3Config, VdafConfig, }; use prio::codec::ParameterizedDecode; use ring::{ @@ -230,6 +230,32 @@ impl VdafConfig { task_id: *task_id, })?, }), + ( + DapVersion::DraftLatest, + VdafTypeVar::Prio3SumVecField64MultiproofHmacSha256Aes128 { + bits, + length, + chunk_length, + num_proofs, + }, + ) => Ok(VdafConfig::Prio3( + Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { + bits: bits.into(), + length: length.try_into().map_err(|_| DapAbort::InvalidTask { + detail: "length is larger than the system's word size".to_string(), + task_id: *task_id, + })?, + chunk_length: chunk_length.try_into().map_err(|_| DapAbort::InvalidTask { + detail: "chunk_length is larger than the system's word size".to_string(), + task_id: *task_id, + })?, + num_proofs, + }, + )), + (DapVersion::Draft02, var) => Err(DapAbort::InvalidTask { + detail: format!("draft02: unsupported VDAF: {var:?}"), + task_id: *task_id, + }), (.., VdafTypeVar::NotImplemented { typ, .. }) => Err(DapAbort::InvalidTask { detail: format!("unimplemented VDAF type ({typ})"), task_id: *task_id, @@ -312,12 +338,30 @@ impl TryFrom<&VdafConfig> for messages::taskprov::VdafTypeVar { fn try_from(vdaf_config: &VdafConfig) -> Result { match vdaf_config { VdafConfig::Prio2 { dimension } => Ok(Self::Prio2 { - dimension: (*dimension) - .try_into() - .map_err(|_| fatal_error!(err = "Prio2 dimension is too large for taskprov"))?, + dimension: (*dimension).try_into().map_err(|_| { + fatal_error!(err = "{vdaf_config}: dimension is too large for taskprov") + })?, + }), + VdafConfig::Prio3(Prio3Config::SumVecField64MultiproofHmacSha256Aes128 { + bits, + length, + chunk_length, + num_proofs, + }) => Ok(Self::Prio3SumVecField64MultiproofHmacSha256Aes128 { + bits: (*bits).try_into().map_err(|_| { + fatal_error!(err = format!("{vdaf_config}: bits is too large for taskprov")) + })?, + length: (*length).try_into().map_err(|_| { + fatal_error!(err = format!("{vdaf_config}: bits is too large for taskprov")) + })?, + + chunk_length: (*chunk_length).try_into().map_err(|_| { + fatal_error!(err = format!("{vdaf_config}: bits is too large for taskprov")) + })?, + num_proofs: *num_proofs, }), - VdafConfig::Prio3 { .. } => Err(fatal_error!( - err = "Prio3 is not currently supported for taskprov" + VdafConfig::Prio3(..) => Err(fatal_error!( + err = format!("{vdaf_config} is not currently supported for taskprov") )), } } diff --git a/daphne/src/vdaf/prio3.rs b/daphne/src/vdaf/prio3.rs index 5a5baba5b..0f85e934f 100644 --- a/daphne/src/vdaf/prio3.rs +++ b/daphne/src/vdaf/prio3.rs @@ -4,6 +4,7 @@ //! Parameters for the [Prio3 VDAF](https://datatracker.ietf.org/doc/draft-patton-cfrg-vdaf/). use crate::{ + messages::taskprov::VDAF_TYPE_PRIO3_SUM_VEC_FIELD64_MULTIPROOF_HMAC_SHA256_AES128, vdaf::{xof::XofHmacSha256Aes128, VdafError, VdafVerifyKey}, DapAggregateResult, DapMeasurement, Prio3Config, VdafAggregateShare, VdafPrepMessage, VdafPrepState, @@ -41,7 +42,7 @@ fn new_prio3_sum_vec_field64_multiproof_hmac_sha256_aes128( Ok(Prio3::new( 2, num_proofs, - 0xFFFF_1003, + VDAF_TYPE_PRIO3_SUM_VEC_FIELD64_MULTIPROOF_HMAC_SHA256_AES128, SumVec::new(bits, length, chunk_length).map_err(vdaf::VdafError::from)?, )?) }