diff --git a/benches/cycle_counts.rs b/benches/cycle_counts.rs index 5ab704cd4..2f3ede7a8 100644 --- a/benches/cycle_counts.rs +++ b/benches/cycle_counts.rs @@ -47,7 +47,10 @@ fn prio2_client(size: usize) -> Vec> { let prio2 = Prio2::new(size).unwrap(); let input = vec![0u32; size]; let nonce = [0; 16]; - prio2.shard(&black_box(input), &black_box(nonce)).unwrap().1 + prio2 + .shard(b"", &black_box(input), &black_box(nonce)) + .unwrap() + .1 } #[cfg(feature = "experimental")] @@ -70,9 +73,19 @@ fn prio2_shard_and_prepare(size: usize) -> Prio2PrepareShare { let prio2 = Prio2::new(size).unwrap(); let input = vec![0u32; size]; let nonce = [0; 16]; - let (public_share, input_shares) = prio2.shard(&black_box(input), &black_box(nonce)).unwrap(); + let (public_share, input_shares) = prio2 + .shard(b"", &black_box(input), &black_box(nonce)) + .unwrap(); prio2 - .prepare_init(&[0; 32], 0, &(), &nonce, &public_share, &input_shares[0]) + .prepare_init( + &[0; 32], + b"", + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) .unwrap() .1 } @@ -97,7 +110,7 @@ fn prio3_client_count() -> Vec> { let measurement = true; let nonce = [0; 16]; prio3 - .shard(&black_box(measurement), &black_box(nonce)) + .shard(b"", &black_box(measurement), &black_box(nonce)) .unwrap() .1 } @@ -107,7 +120,7 @@ fn prio3_client_histogram_10() -> Vec> { let measurement = 9; let nonce = [0; 16]; prio3 - .shard(&black_box(measurement), &black_box(nonce)) + .shard(b"", &black_box(measurement), &black_box(nonce)) .unwrap() .1 } @@ -117,7 +130,7 @@ fn prio3_client_sum_32() -> Vec> { let measurement = 1337; let nonce = [0; 16]; prio3 - .shard(&black_box(measurement), &black_box(nonce)) + .shard(b"", &black_box(measurement), &black_box(nonce)) .unwrap() .1 } @@ -128,7 +141,7 @@ fn prio3_client_count_vec_1000() -> Vec> { let measurement = vec![0; len]; let nonce = [0; 16]; prio3 - .shard(&black_box(measurement), &black_box(nonce)) + .shard(b"", &black_box(measurement), &black_box(nonce)) .unwrap() .1 } @@ -140,7 +153,7 @@ fn prio3_client_count_vec_multithreaded_1000() -> Vec>(); let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -145,10 +145,18 @@ fn prio2(c: &mut Criterion) { .collect::>(); let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 32]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { - vdaf.prepare_init(&verify_key, 0, &(), &nonce, &public_share, &input_shares[0]) - .unwrap(); + vdaf.prepare_init( + &verify_key, + b"", + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap(); }); }, ); @@ -164,7 +172,7 @@ fn prio3(c: &mut Criterion) { let vdaf = Prio3::new_count(num_shares).unwrap(); let measurement = black_box(true); let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }); c.bench_function("prio3count_prepare_init", |b| { @@ -172,10 +180,18 @@ fn prio3(c: &mut Criterion) { let measurement = black_box(true); let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { - vdaf.prepare_init(&verify_key, 0, &(), &nonce, &public_share, &input_shares[0]) - .unwrap() + vdaf.prepare_init( + &verify_key, + b"", + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap() }); }); @@ -185,7 +201,7 @@ fn prio3(c: &mut Criterion) { let vdaf = Prio3::new_sum(num_shares, *bits).unwrap(); let measurement = (1 << bits) - 1; let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }); } group.finish(); @@ -197,10 +213,18 @@ fn prio3(c: &mut Criterion) { let measurement = (1 << bits) - 1; let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { - vdaf.prepare_init(&verify_key, 0, &(), &nonce, &public_share, &input_shares[0]) - .unwrap() + vdaf.prepare_init( + &verify_key, + b"", + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap() }); }); } @@ -217,7 +241,7 @@ fn prio3(c: &mut Criterion) { .map(|i| i & 1) .collect::>(); let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -240,7 +264,7 @@ fn prio3(c: &mut Criterion) { .map(|i| i & 1) .collect::>(); let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -259,10 +283,18 @@ fn prio3(c: &mut Criterion) { .collect::>(); let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { - vdaf.prepare_init(&verify_key, 0, &(), &nonce, &public_share, &input_shares[0]) - .unwrap() + vdaf.prepare_init( + &verify_key, + b"", + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap() }); }, ); @@ -287,10 +319,12 @@ fn prio3(c: &mut Criterion) { .collect::>(); let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = + vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { vdaf.prepare_init( &verify_key, + b"", 0, &(), &nonce, @@ -323,7 +357,7 @@ fn prio3(c: &mut Criterion) { let vdaf = Prio3::new_histogram(num_shares, *input_length, *chunk_length).unwrap(); let measurement = black_box(0); let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -352,7 +386,7 @@ fn prio3(c: &mut Criterion) { .unwrap(); let measurement = black_box(0); let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -378,10 +412,18 @@ fn prio3(c: &mut Criterion) { let measurement = black_box(0); let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { - vdaf.prepare_init(&verify_key, 0, &(), &nonce, &public_share, &input_shares[0]) - .unwrap() + vdaf.prepare_init( + &verify_key, + b"", + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap() }); }, ); @@ -412,10 +454,12 @@ fn prio3(c: &mut Criterion) { let measurement = black_box(0); let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = + vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { vdaf.prepare_init( &verify_key, + b"", 0, &(), &nonce, @@ -448,7 +492,7 @@ fn prio3(c: &mut Criterion) { let mut measurement = vec![FP16_ZERO; *dimension]; measurement[0] = FP16_HALF; let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -468,7 +512,7 @@ fn prio3(c: &mut Criterion) { let mut measurement = vec![FP16_ZERO; *dimension]; measurement[0] = FP16_HALF; let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -487,10 +531,12 @@ fn prio3(c: &mut Criterion) { measurement[0] = FP16_HALF; let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = + vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { vdaf.prepare_init( &verify_key, + b"", 0, &(), &nonce, @@ -520,10 +566,11 @@ fn prio3(c: &mut Criterion) { let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); let (public_share, input_shares) = - vdaf.shard(&measurement, &nonce).unwrap(); + vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { vdaf.prepare_init( &verify_key, + b"", 0, &(), &nonce, @@ -549,7 +596,7 @@ fn prio3(c: &mut Criterion) { let mut measurement = vec![FP32_ZERO; *dimension]; measurement[0] = FP32_HALF; let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -569,7 +616,7 @@ fn prio3(c: &mut Criterion) { let mut measurement = vec![FP32_ZERO; *dimension]; measurement[0] = FP32_HALF; let nonce = black_box([0u8; 16]); - b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + b.iter(|| vdaf.shard(b"", &measurement, &nonce).unwrap()); }, ); } @@ -588,10 +635,12 @@ fn prio3(c: &mut Criterion) { measurement[0] = FP32_HALF; let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = + vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { vdaf.prepare_init( &verify_key, + b"", 0, &(), &nonce, @@ -621,10 +670,11 @@ fn prio3(c: &mut Criterion) { let nonce = black_box([0u8; 16]); let verify_key = black_box([0u8; 16]); let (public_share, input_shares) = - vdaf.shard(&measurement, &nonce).unwrap(); + vdaf.shard(b"", &measurement, &nonce).unwrap(); b.iter(|| { vdaf.prepare_init( &verify_key, + b"", 0, &(), &nonce, @@ -724,7 +774,7 @@ fn poplar1(c: &mut Criterion) { let measurement = IdpfInput::from_bools(&bits); b.iter(|| { - vdaf.shard(&measurement, &nonce).unwrap(); + vdaf.shard(b"", &measurement, &nonce).unwrap(); }); }); } @@ -753,7 +803,7 @@ fn poplar1(c: &mut Criterion) { // We are benchmarking preparation of a single report. For this test, it doesn't matter // which measurement we generate a report for, so pick the first measurement // arbitrarily. - let (public_share, input_shares) = vdaf.shard(&measurements[0], &nonce).unwrap(); + let (public_share, input_shares) = vdaf.shard(b"", &measurements[0], &nonce).unwrap(); let input_share = input_shares.into_iter().next().unwrap(); // For the aggregation paramter, we use the candidate prefixes from the prefix tree for @@ -765,6 +815,7 @@ fn poplar1(c: &mut Criterion) { b.iter(|| { vdaf.prepare_init( &verify_key, + b"", 0, &agg_param, &nonce, diff --git a/binaries/src/bin/vdaf_message_sizes.rs b/binaries/src/bin/vdaf_message_sizes.rs index dae75685a..998f15722 100644 --- a/binaries/src/bin/vdaf_message_sizes.rs +++ b/binaries/src/bin/vdaf_message_sizes.rs @@ -15,6 +15,9 @@ use prio::{ }, }; +const PRIO2_CTX_STR: &[u8] = b"prio2 ctx"; +const PRIO3_CTX_STR: &[u8] = b"prio3 ctx"; + fn main() { let num_shares = 2; let nonce = [0; 16]; @@ -23,7 +26,9 @@ fn main() { let measurement = true; println!( "prio3 count share size = {}", - vdaf_input_share_size::(prio3.shard(&measurement, &nonce).unwrap()) + vdaf_input_share_size::( + prio3.shard(PRIO3_CTX_STR, &measurement, &nonce).unwrap() + ) ); let length = 10; @@ -32,7 +37,9 @@ fn main() { println!( "prio3 histogram ({} buckets) share size = {}", length, - vdaf_input_share_size::(prio3.shard(&measurement, &nonce).unwrap()) + vdaf_input_share_size::( + prio3.shard(PRIO3_CTX_STR, &measurement, &nonce).unwrap() + ) ); let bits = 32; @@ -41,7 +48,9 @@ fn main() { println!( "prio3 sum ({} bits) share size = {}", bits, - vdaf_input_share_size::(prio3.shard(&measurement, &nonce).unwrap()) + vdaf_input_share_size::( + prio3.shard(PRIO3_CTX_STR, &measurement, &nonce).unwrap() + ) ); let len = 1000; @@ -50,7 +59,9 @@ fn main() { println!( "prio3 sumvec ({} len) share size = {}", len, - vdaf_input_share_size::(prio3.shard(&measurement, &nonce).unwrap()) + vdaf_input_share_size::( + prio3.shard(PRIO3_CTX_STR, &measurement, &nonce).unwrap() + ) ); let len = 1000; @@ -61,7 +72,7 @@ fn main() { "prio3 fixedpoint16 boundedl2 vec ({} entries) size = {}", len, vdaf_input_share_size::>, 16>( - prio3.shard(&measurement, &nonce).unwrap() + prio3.shard(PRIO3_CTX_STR, &measurement, &nonce).unwrap() ) ); @@ -74,7 +85,9 @@ fn main() { println!( "prio2 ({} entries) size = {}", size, - vdaf_input_share_size::(prio2.shard(&measurement, &nonce).unwrap()) + vdaf_input_share_size::( + prio2.shard(PRIO2_CTX_STR, &measurement, &nonce).unwrap() + ) ); // Prio3 @@ -83,7 +96,9 @@ fn main() { println!( "prio3 sumvec ({} entries) size = {}", size, - vdaf_input_share_size::(prio3.shard(&measurement, &nonce).unwrap()) + vdaf_input_share_size::( + prio3.shard(PRIO3_CTX_STR, &measurement, &nonce).unwrap() + ) ); } } diff --git a/src/topology/ping_pong.rs b/src/topology/ping_pong.rs index 646f18186..b3de2fe5d 100644 --- a/src/topology/ping_pong.rs +++ b/src/topology/ping_pong.rs @@ -206,6 +206,7 @@ impl< #[allow(clippy::type_complexity)] pub fn evaluate( &self, + ctx: &[u8], vdaf: &A, ) -> Result< ( @@ -220,6 +221,7 @@ impl< .map_err(PingPongError::CodecPrepMessage)?; vdaf.prepare_next( + ctx, self.previous_prepare_state.clone(), self.current_prepare_message.clone(), ) @@ -362,6 +364,7 @@ pub trait PingPongTopology Result<(Self::State, PingPongMessage), PingPongError> { self.prepare_init( verify_key, + ctx, /* Leader */ 0, agg_param, nonce, @@ -522,6 +532,7 @@ where fn helper_initialized( &self, verify_key: &[u8; VERIFY_KEY_SIZE], + ctx: &[u8], agg_param: &Self::AggregationParam, nonce: &[u8; NONCE_SIZE], public_share: &Self::PublicShare, @@ -531,6 +542,7 @@ where let (prep_state, prep_share) = self .prepare_init( verify_key, + ctx, /* Helper */ 1, agg_param, nonce, @@ -550,7 +562,7 @@ where }; let current_prepare_message = self - .prepare_shares_to_prepare_message(agg_param, [inbound_prep_share, prep_share]) + .prepare_shares_to_prepare_message(ctx, agg_param, [inbound_prep_share, prep_share]) .map_err(PingPongError::VdafPrepareSharesToPrepareMessage)?; Ok(PingPongTransition { @@ -561,20 +573,22 @@ where fn leader_continued( &self, + ctx: &[u8], leader_state: Self::State, agg_param: &Self::AggregationParam, inbound: &PingPongMessage, ) -> Result { - self.continued(true, leader_state, agg_param, inbound) + self.continued(ctx, true, leader_state, agg_param, inbound) } fn helper_continued( &self, + ctx: &[u8], helper_state: Self::State, agg_param: &Self::AggregationParam, inbound: &PingPongMessage, ) -> Result { - self.continued(false, helper_state, agg_param, inbound) + self.continued(ctx, false, helper_state, agg_param, inbound) } } @@ -585,6 +599,7 @@ where { fn continued( &self, + ctx: &[u8], is_leader: bool, host_state: Self::State, agg_param: &Self::AggregationParam, @@ -616,7 +631,7 @@ where let prep_msg = Self::PrepareMessage::get_decoded_with_param(&host_prep_state, prep_msg) .map_err(PingPongError::CodecPrepMessage)?; let host_prep_transition = self - .prepare_next(host_prep_state, prep_msg) + .prepare_next(ctx, host_prep_state, prep_msg) .map_err(PingPongError::VdafPrepareNext)?; match (host_prep_transition, next_peer_prep_share) { @@ -634,7 +649,7 @@ where prep_shares.reverse(); } let current_prepare_message = self - .prepare_shares_to_prepare_message(agg_param, prep_shares) + .prepare_shares_to_prepare_message(ctx, agg_param, prep_shares) .map_err(PingPongError::VdafPrepareSharesToPrepareMessage)?; Ok(PingPongContinuedValue::WithMessage { @@ -667,6 +682,8 @@ mod tests { use crate::vdaf::dummy; use assert_matches::assert_matches; + const CTX_STR: &[u8] = b"pingpong ctx"; + #[test] fn ping_pong_one_round() { let verify_key = []; @@ -683,6 +700,7 @@ mod tests { let (leader_state, leader_message) = leader .leader_initialized( &verify_key, + CTX_STR, &aggregation_param, &nonce, &public_share, @@ -694,6 +712,7 @@ mod tests { let (helper_state, helper_message) = helper .helper_initialized( &verify_key, + CTX_STR, &aggregation_param, &nonce, &public_share, @@ -701,14 +720,14 @@ mod tests { &leader_message, ) .unwrap() - .evaluate(&helper) + .evaluate(CTX_STR, &helper) .unwrap(); // 1 round VDAF: helper should finish immediately. assert_matches!(helper_state, PingPongState::Finished(_)); let leader_state = leader - .leader_continued(leader_state, &aggregation_param, &helper_message) + .leader_continued(CTX_STR, leader_state, &aggregation_param, &helper_message) .unwrap(); // 1 round VDAF: leader should finish when it gets helper message and emit no message. assert_matches!( @@ -733,6 +752,7 @@ mod tests { let (leader_state, leader_message) = leader .leader_initialized( &verify_key, + CTX_STR, &aggregation_param, &nonce, &public_share, @@ -744,6 +764,7 @@ mod tests { let (helper_state, helper_message) = helper .helper_initialized( &verify_key, + CTX_STR, &aggregation_param, &nonce, &public_share, @@ -751,26 +772,26 @@ mod tests { &leader_message, ) .unwrap() - .evaluate(&helper) + .evaluate(CTX_STR, &helper) .unwrap(); // 2 round VDAF, round 1: helper should continue. assert_matches!(helper_state, PingPongState::Continued(_)); let leader_state = leader - .leader_continued(leader_state, &aggregation_param, &helper_message) + .leader_continued(CTX_STR, leader_state, &aggregation_param, &helper_message) .unwrap(); // 2 round VDAF, round 1: leader should finish and emit a finish message. let leader_message = assert_matches!( leader_state, PingPongContinuedValue::WithMessage { transition } => { - let (state, message) = transition.evaluate(&leader).unwrap(); + let (state, message) = transition.evaluate(CTX_STR,&leader).unwrap(); assert_matches!(state, PingPongState::Finished(_)); message } ); let helper_state = helper - .helper_continued(helper_state, &aggregation_param, &leader_message) + .helper_continued(CTX_STR, helper_state, &aggregation_param, &leader_message) .unwrap(); // 2 round vdaf, round 1: helper should finish and emit no message. assert_matches!( @@ -795,6 +816,7 @@ mod tests { let (leader_state, leader_message) = leader .leader_initialized( &verify_key, + CTX_STR, &aggregation_param, &nonce, &public_share, @@ -806,6 +828,7 @@ mod tests { let (helper_state, helper_message) = helper .helper_initialized( &verify_key, + CTX_STR, &aggregation_param, &nonce, &public_share, @@ -813,38 +836,38 @@ mod tests { &leader_message, ) .unwrap() - .evaluate(&helper) + .evaluate(CTX_STR, &helper) .unwrap(); // 3 round VDAF, round 1: helper should continue. assert_matches!(helper_state, PingPongState::Continued(_)); let leader_state = leader - .leader_continued(leader_state, &aggregation_param, &helper_message) + .leader_continued(CTX_STR, leader_state, &aggregation_param, &helper_message) .unwrap(); // 3 round VDAF, round 1: leader should continue and emit a continue message. let (leader_state, leader_message) = assert_matches!( leader_state, PingPongContinuedValue::WithMessage { transition } => { - let (state, message) = transition.evaluate(&leader).unwrap(); + let (state, message) = transition.evaluate(CTX_STR,&leader).unwrap(); assert_matches!(state, PingPongState::Continued(_)); (state, message) } ); let helper_state = helper - .helper_continued(helper_state, &aggregation_param, &leader_message) + .helper_continued(CTX_STR, helper_state, &aggregation_param, &leader_message) .unwrap(); // 3 round vdaf, round 2: helper should finish and emit a finish message. let helper_message = assert_matches!( helper_state, PingPongContinuedValue::WithMessage { transition } => { - let (state, message) = transition.evaluate(&helper).unwrap(); + let (state, message) = transition.evaluate(CTX_STR,&helper).unwrap(); assert_matches!(state, PingPongState::Finished(_)); message } ); let leader_state = leader - .leader_continued(leader_state, &aggregation_param, &helper_message) + .leader_continued(CTX_STR, leader_state, &aggregation_param, &helper_message) .unwrap(); // 3 round VDAF, round 2: leader should finish and emit no message. assert_matches!( diff --git a/src/vdaf.rs b/src/vdaf.rs index 2c68f2e40..815836430 100644 --- a/src/vdaf.rs +++ b/src/vdaf.rs @@ -200,12 +200,16 @@ pub trait Vdaf: Clone + Debug { /// Generate the domain separation tag for this VDAF. The output is used for domain separation /// by the XOF. - fn domain_separation_tag(&self, usage: u16) -> [u8; 8] { - let mut dst = [0_u8; 8]; - dst[0] = VERSION; - dst[1] = 0; // algorithm class - dst[2..6].copy_from_slice(&(self.algorithm_id()).to_be_bytes()); - dst[6..8].copy_from_slice(&usage.to_be_bytes()); + fn domain_separation_tag(&self, usage: u16, ctx: &[u8]) -> Vec { + // Prefix is 8 bytes and defined by the spec. Copy these values in + let mut dst = Vec::with_capacity(ctx.len() + 8); + dst.push(VERSION); + dst.push(0); // algorithm class + dst.extend_from_slice(self.algorithm_id().to_be_bytes().as_slice()); + dst.extend_from_slice(usage.to_be_bytes().as_slice()); + // Finally, append user-chosen `ctx` + dst.extend_from_slice(ctx); + dst } } @@ -217,9 +221,10 @@ pub trait Client: Vdaf { /// /// Implements `Vdaf::shard` from [VDAF]. /// - /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.1 + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-13#section-5.1 fn shard( &self, + ctx: &[u8], measurement: &Self::Measurement, nonce: &[u8; NONCE_SIZE], ) -> Result<(Self::PublicShare, Vec), VdafError>; @@ -254,9 +259,11 @@ pub trait Aggregator: Vda /// Implements `Vdaf.prep_init` from [VDAF]. /// /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.2 + #[allow(clippy::too_many_arguments)] fn prepare_init( &self, verify_key: &[u8; VERIFY_KEY_SIZE], + ctx: &[u8], agg_id: usize, agg_param: &Self::AggregationParam, nonce: &[u8; NONCE_SIZE], @@ -271,6 +278,7 @@ pub trait Aggregator: Vda /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.2 fn prepare_shares_to_prepare_message>( &self, + ctx: &[u8], agg_param: &Self::AggregationParam, inputs: M, ) -> Result; @@ -288,6 +296,7 @@ pub trait Aggregator: Vda /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.2 fn prepare_next( &self, + ctx: &[u8], state: Self::PrepareState, input: Self::PrepareMessage, ) -> Result, VdafError>; @@ -489,6 +498,7 @@ pub mod test_utils { /// Execute the VDAF end-to-end and return the aggregate result. pub fn run_vdaf( + ctx: &[u8], vdaf: &V, agg_param: &V::AggregationParam, measurements: M, @@ -500,16 +510,17 @@ pub mod test_utils { let mut sharded_measurements = Vec::new(); for measurement in measurements.into_iter() { let nonce = random(); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce)?; + let (public_share, input_shares) = vdaf.shard(ctx, &measurement, &nonce)?; sharded_measurements.push((public_share, nonce, input_shares)); } - run_vdaf_sharded(vdaf, agg_param, sharded_measurements) + run_vdaf_sharded(ctx, vdaf, agg_param, sharded_measurements) } /// Execute the VDAF on sharded measurements and return the aggregate result. pub fn run_vdaf_sharded( + ctx: &[u8], vdaf: &V, agg_param: &V::AggregationParam, sharded_measurements: M, @@ -530,6 +541,7 @@ pub mod test_utils { let out_shares = run_vdaf_prepare( vdaf, &verify_key, + ctx, agg_param, &nonce, public_share, @@ -579,6 +591,7 @@ pub mod test_utils { pub fn run_vdaf_prepare( vdaf: &V, verify_key: &[u8; SEED_SIZE], + ctx: &[u8], agg_param: &V::AggregationParam, nonce: &[u8; 16], public_share: V::PublicShare, @@ -600,6 +613,7 @@ pub mod test_utils { for (agg_id, input_share) in input_shares.enumerate() { let (state, msg) = vdaf.prepare_init( verify_key, + ctx, agg_id, agg_param, nonce, @@ -613,6 +627,7 @@ pub mod test_utils { let mut inbound = vdaf .prepare_shares_to_prepare_message( + ctx, agg_param, outbound.iter().map(|encoded| { V::PrepareShare::get_decoded_with_param(&states[0], encoded) @@ -627,6 +642,7 @@ pub mod test_utils { let mut outbound = Vec::new(); for state in states.iter_mut() { match vdaf.prepare_next( + ctx, state.clone(), V::PrepareMessage::get_decoded_with_param(state, &inbound) .expect("failed to decode prep message"), @@ -645,6 +661,7 @@ pub mod test_utils { // Another round is required before output shares are computed. inbound = vdaf .prepare_shares_to_prepare_message( + ctx, agg_param, outbound.iter().map(|encoded| { V::PrepareShare::get_decoded_with_param(&states[0], encoded) diff --git a/src/vdaf/dummy.rs b/src/vdaf/dummy.rs index 5b969bc19..1a78e3ee7 100644 --- a/src/vdaf/dummy.rs +++ b/src/vdaf/dummy.rs @@ -123,6 +123,7 @@ impl vdaf::Aggregator<0, 16> for Vdaf { fn prepare_init( &self, _verify_key: &[u8; 0], + _ctx: &[u8], _: usize, aggregation_param: &Self::AggregationParam, _nonce: &[u8; 16], @@ -141,6 +142,7 @@ impl vdaf::Aggregator<0, 16> for Vdaf { fn prepare_shares_to_prepare_message>( &self, + _ctx: &[u8], _: &Self::AggregationParam, _: M, ) -> Result { @@ -149,6 +151,7 @@ impl vdaf::Aggregator<0, 16> for Vdaf { fn prepare_next( &self, + _ctx: &[u8], state: Self::PrepareState, _: Self::PrepareMessage, ) -> Result, VdafError> { @@ -175,6 +178,7 @@ impl vdaf::Aggregator<0, 16> for Vdaf { impl vdaf::Client<16> for Vdaf { fn shard( &self, + _ctx: &[u8], measurement: &Self::Measurement, _nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec), VdafError> { @@ -361,12 +365,14 @@ mod tests { let mut sharded_measurements = Vec::new(); for measurement in measurements { let nonce = thread_rng().gen(); - let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + let (public_share, input_shares) = + vdaf.shard(b"dummy ctx", &measurement, &nonce).unwrap(); sharded_measurements.push((public_share, nonce, input_shares)); } let result = run_vdaf_sharded( + b"dummy ctx", &vdaf, &AggregationParam(aggregation_parameter), sharded_measurements.clone(), diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs index 7b8d63424..afbac9331 100644 --- a/src/vdaf/mastic.rs +++ b/src/vdaf/mastic.rs @@ -341,6 +341,7 @@ where { fn shard( &self, + _ctx: &[u8], (attribute, weight): &(VidpfInput, T::Measurement), nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec), VdafError> { @@ -388,6 +389,7 @@ mod tests { use rand::{thread_rng, Rng}; const TEST_NONCE_SIZE: usize = 16; + const CTX_STR: &[u8] = b"mastic ctx"; #[test] fn test_mastic_shard_sum() { @@ -404,7 +406,9 @@ mod tests { let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); let mastic = Mastic::new(algorithm_id, sum_szk, sum_vidpf, 32); - let (_public, _input_shares) = mastic.shard(&(first_input, 24u128), &nonce).unwrap(); + let (_public, _input_shares) = mastic + .shard(CTX_STR, &(first_input, 24u128), &nonce) + .unwrap(); } #[test] @@ -422,7 +426,9 @@ mod tests { let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); let mastic = Mastic::new(algorithm_id, sum_szk, sum_vidpf, 32); - let (_, input_shares) = mastic.shard(&(first_input, 26u128), &nonce).unwrap(); + let (_, input_shares) = mastic + .shard(CTX_STR, &(first_input, 26u128), &nonce) + .unwrap(); let [leader_input_share, helper_input_share] = [&input_shares[0], &input_shares[1]]; assert_eq!( @@ -450,7 +456,7 @@ mod tests { let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); - let (_public, _input_shares) = mastic.shard(&(first_input, true), &nonce).unwrap(); + let (_public, _input_shares) = mastic.shard(CTX_STR, &(first_input, true), &nonce).unwrap(); } #[test] @@ -470,7 +476,9 @@ mod tests { let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); - let (_public, _input_shares) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + let (_public, _input_shares) = mastic + .shard(CTX_STR, &(first_input, measurement), &nonce) + .unwrap(); } #[test] @@ -490,7 +498,9 @@ mod tests { let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); - let (_public, input_shares) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + let (_public, input_shares) = mastic + .shard(CTX_STR, &(first_input, measurement), &nonce) + .unwrap(); let leader_input_share = &input_shares[0]; let helper_input_share = &input_shares[1]; @@ -521,7 +531,9 @@ mod tests { let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); - let (_public, input_shares) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + let (_public, input_shares) = mastic + .shard(CTX_STR, &(first_input, measurement), &nonce) + .unwrap(); let leader_input_share = &input_shares[0]; let helper_input_share = &input_shares[1]; diff --git a/src/vdaf/poplar1.rs b/src/vdaf/poplar1.rs index 1d396e4da..71bae8cc9 100644 --- a/src/vdaf/poplar1.rs +++ b/src/vdaf/poplar1.rs @@ -69,6 +69,7 @@ impl, const SEED_SIZE: usize> Poplar1 { &self, seed: &[u8; SEED_SIZE], usage: u16, + ctx: &[u8], binder_chunks: I, ) -> Prng where @@ -77,7 +78,7 @@ impl, const SEED_SIZE: usize> Poplar1 { P: Xof, F: FieldElement, { - let mut xof = P::init(seed, &self.domain_separation_tag(usage)); + let mut xof = P::init(seed, &self.domain_separation_tag(usage, ctx)); for binder_chunk in binder_chunks.into_iter() { xof.update(binder_chunk.as_ref()); } @@ -865,6 +866,7 @@ impl, const SEED_SIZE: usize> Vdaf for Poplar1 { impl, const SEED_SIZE: usize> Poplar1 { fn shard_with_random( &self, + ctx: &[u8], input: &IdpfInput, nonce: &[u8; 16], idpf_random: &[[u8; 16]; 2], @@ -879,7 +881,7 @@ impl, const SEED_SIZE: usize> Poplar1 { // Generate the authenticator for each inner level of the IDPF tree. let mut prng = - self.init_prng::<_, _, Field64>(&poplar_random[2], DST_SHARD_RANDOMNESS, [nonce]); + self.init_prng::<_, _, Field64>(&poplar_random[2], DST_SHARD_RANDOMNESS, ctx, [nonce]); let auth_inner: Vec = (0..self.bits - 1).map(|_| prng.get()).collect(); // Generate the authenticator for the last level of the IDPF tree (i.e., the leaves). @@ -912,11 +914,13 @@ impl, const SEED_SIZE: usize> Poplar1 { let mut corr_prng_0 = self.init_prng::<_, _, Field64>( corr_seed_0, DST_CORR_INNER, + ctx, [[0].as_slice(), nonce.as_slice()], ); let mut corr_prng_1 = self.init_prng::<_, _, Field64>( corr_seed_1, DST_CORR_INNER, + ctx, [[1].as_slice(), nonce.as_slice()], ); let mut corr_inner_0 = Vec::with_capacity(self.bits - 1); @@ -933,11 +937,13 @@ impl, const SEED_SIZE: usize> Poplar1 { let mut corr_prng_0 = self.init_prng::<_, _, Field255>( corr_seed_0, DST_CORR_LEAF, + ctx, [[0].as_slice(), nonce.as_slice()], ); let mut corr_prng_1 = self.init_prng::<_, _, Field255>( corr_seed_1, DST_CORR_LEAF, + ctx, [[1].as_slice(), nonce.as_slice()], ); let (corr_leaf_0, corr_leaf_1) = @@ -967,6 +973,7 @@ impl, const SEED_SIZE: usize> Poplar1 { fn eval_and_sketch( &self, verify_key: &[u8; SEED_SIZE], + ctx: &[u8], agg_id: usize, nonce: &[u8; 16], agg_param: &Poplar1AggregationParam, @@ -983,6 +990,7 @@ impl, const SEED_SIZE: usize> Poplar1 { let mut verify_prng = self.init_prng( verify_key, DST_VERIFY_RANDOMNESS, + ctx, [nonce.as_slice(), agg_param.level.to_be_bytes().as_slice()], ); @@ -1020,6 +1028,7 @@ impl, const SEED_SIZE: usize> Poplar1 { impl, const SEED_SIZE: usize> Client<16> for Poplar1 { fn shard( &self, + ctx: &[u8], input: &IdpfInput, nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec>), VdafError> { @@ -1031,7 +1040,7 @@ impl, const SEED_SIZE: usize> Client<16> for Poplar1, const SEED_SIZE: usize> Aggregator fn prepare_init( &self, verify_key: &[u8; SEED_SIZE], + ctx: &[u8], agg_id: usize, agg_param: &Poplar1AggregationParam, nonce: &[u8; 16], @@ -1066,6 +1076,7 @@ impl, const SEED_SIZE: usize> Aggregator let mut corr_prng = self.init_prng::<_, _, Field64>( input_share.corr_seed.as_ref(), DST_CORR_INNER, + ctx, [[agg_id as u8].as_slice(), nonce.as_slice()], ); // Fast-forward the correlated randomness XOF to the level of the tree that we are @@ -1076,6 +1087,7 @@ impl, const SEED_SIZE: usize> Aggregator let (output_share, sketch_share) = self.eval_and_sketch::( verify_key, + ctx, agg_id, nonce, agg_param, @@ -1099,11 +1111,13 @@ impl, const SEED_SIZE: usize> Aggregator let corr_prng = self.init_prng::<_, _, Field255>( input_share.corr_seed.as_ref(), DST_CORR_LEAF, + ctx, [[agg_id as u8].as_slice(), nonce.as_slice()], ); let (output_share, sketch_share) = self.eval_and_sketch::( verify_key, + ctx, agg_id, nonce, agg_param, @@ -1128,6 +1142,7 @@ impl, const SEED_SIZE: usize> Aggregator fn prepare_shares_to_prepare_message>( &self, + _ctx: &[u8], _: &Poplar1AggregationParam, inputs: M, ) -> Result { @@ -1167,6 +1182,7 @@ impl, const SEED_SIZE: usize> Aggregator fn prepare_next( &self, + _ctx: &[u8], state: Poplar1PrepareState, msg: Poplar1PrepareMessage, ) -> Result, VdafError> { @@ -1540,6 +1556,8 @@ mod tests { use serde::Deserialize; use std::collections::HashSet; + const CTX_STR: &[u8] = b"poplar1 ctx"; + fn test_prepare, const SEED_SIZE: usize>( vdaf: &Poplar1, verify_key: &[u8; SEED_SIZE], @@ -1552,6 +1570,7 @@ mod tests { let out_shares = run_vdaf_prepare( vdaf, verify_key, + CTX_STR, agg_param, nonce, public_share.clone(), @@ -1591,7 +1610,11 @@ mod tests { .map(|measurement| { let nonce = rng.gen(); let (public_share, input_shares) = vdaf - .shard(&IdpfInput::from_bytes(measurement.as_ref()), &nonce) + .shard( + CTX_STR, + &IdpfInput::from_bytes(measurement.as_ref()), + &nonce, + ) .unwrap(); (nonce, public_share, input_shares) }) @@ -1615,6 +1638,7 @@ mod tests { let out_shares = run_vdaf_prepare( vdaf, verify_key, + CTX_STR, &agg_param, nonce, public_share.clone(), @@ -1675,7 +1699,7 @@ mod tests { let verify_key = rng.gen(); let input = IdpfInput::from_bytes(b"12341324"); let nonce = rng.gen(); - let (public_share, input_shares) = vdaf.shard(&input, &nonce).unwrap(); + let (public_share, input_shares) = vdaf.shard(CTX_STR, &input, &nonce).unwrap(); test_prepare( &vdaf, @@ -2096,6 +2120,10 @@ mod tests { } fn check_test_vec(input: &str) { + // We need to use an empty context string for these test vectors to pass. + // TODO: update test vectors to ones that use a real context string + const CTX_STR: &[u8] = b""; + let test_vector: PoplarTestVector = serde_json::from_str(input).unwrap(); assert_eq!(test_vector.prep.len(), 1); let prep = &test_vector.prep[0]; @@ -2133,13 +2161,14 @@ mod tests { // Shard measurement. let poplar = Poplar1::new_turboshake128(test_vector.bits); let (public_share, input_shares) = poplar - .shard_with_random(&measurement, &nonce, &idpf_random, &poplar_random) + .shard_with_random(CTX_STR, &measurement, &nonce, &idpf_random, &poplar_random) .unwrap(); // Run aggregation. let (init_prep_state_0, init_prep_share_0) = poplar .prepare_init( &verify_key, + CTX_STR, 0, &agg_param, &nonce, @@ -2150,6 +2179,7 @@ mod tests { let (init_prep_state_1, init_prep_share_1) = poplar .prepare_init( &verify_key, + CTX_STR, 1, &agg_param, &nonce, @@ -2160,6 +2190,7 @@ mod tests { let r1_prep_msg = poplar .prepare_shares_to_prepare_message( + CTX_STR, &agg_param, [init_prep_share_0.clone(), init_prep_share_1.clone()], ) @@ -2167,19 +2198,20 @@ mod tests { let (r1_prep_state_0, r1_prep_share_0) = assert_matches!( poplar - .prepare_next(init_prep_state_0.clone(), r1_prep_msg.clone()) + .prepare_next(CTX_STR,init_prep_state_0.clone(), r1_prep_msg.clone()) .unwrap(), PrepareTransition::Continue(state, share) => (state, share) ); let (r1_prep_state_1, r1_prep_share_1) = assert_matches!( poplar - .prepare_next(init_prep_state_1.clone(), r1_prep_msg.clone()) + .prepare_next(CTX_STR,init_prep_state_1.clone(), r1_prep_msg.clone()) .unwrap(), PrepareTransition::Continue(state, share) => (state, share) ); let r2_prep_msg = poplar .prepare_shares_to_prepare_message( + CTX_STR, &agg_param, [r1_prep_share_0.clone(), r1_prep_share_1.clone()], ) @@ -2187,13 +2219,13 @@ mod tests { let out_share_0 = assert_matches!( poplar - .prepare_next(r1_prep_state_0.clone(), r2_prep_msg.clone()) + .prepare_next(CTX_STR, r1_prep_state_0.clone(), r2_prep_msg.clone()) .unwrap(), PrepareTransition::Finish(out) => out ); let out_share_1 = assert_matches!( poplar - .prepare_next(r1_prep_state_1, r2_prep_msg.clone()) + .prepare_next(CTX_STR,r1_prep_state_1, r2_prep_msg.clone()) .unwrap(), PrepareTransition::Finish(out) => out ); diff --git a/src/vdaf/prio2.rs b/src/vdaf/prio2.rs index 680f09ea7..96a8f5a3a 100644 --- a/src/vdaf/prio2.rs +++ b/src/vdaf/prio2.rs @@ -143,6 +143,7 @@ impl Vdaf for Prio2 { impl Client<16> for Prio2 { fn shard( &self, + _ctx: &[u8], measurement: &Vec, _nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec>), VdafError> { @@ -253,6 +254,7 @@ impl Aggregator<32, 16> for Prio2 { fn prepare_init( &self, agg_key: &[u8; 32], + _ctx: &[u8], agg_id: usize, _agg_param: &Self::AggregationParam, nonce: &[u8; 16], @@ -278,6 +280,7 @@ impl Aggregator<32, 16> for Prio2 { fn prepare_shares_to_prepare_message>( &self, + _ctx: &[u8], _: &Self::AggregationParam, inputs: M, ) -> Result<(), VdafError> { @@ -300,6 +303,7 @@ impl Aggregator<32, 16> for Prio2 { fn prepare_next( &self, + _ctx: &[u8], state: Prio2PrepareState, _input: (), ) -> Result, VdafError> { @@ -406,12 +410,17 @@ mod tests { use assert_matches::assert_matches; use rand::prelude::*; + // The value of this string doesn't matter. Prio2 is not defined to use the context string for + // any computation + pub(crate) const CTX_STR: &[u8] = b"prio2 ctx"; + #[test] fn run_prio2() { let prio2 = Prio2::new(6).unwrap(); assert_eq!( run_vdaf( + CTX_STR, &prio2, &(), [ @@ -434,11 +443,12 @@ mod tests { let nonce = rng.gen::<[u8; 16]>(); let data = vec![0, 0, 1, 1, 0]; let prio2 = Prio2::new(data.len()).unwrap(); - let (public_share, input_shares) = prio2.shard(&data, &nonce).unwrap(); + let (public_share, input_shares) = prio2.shard(CTX_STR, &data, &nonce).unwrap(); for (agg_id, input_share) in input_shares.iter().enumerate() { let (prepare_state, prepare_share) = prio2 .prepare_init( &verify_key, + CTX_STR, agg_id, &(), &[0; 16], @@ -500,17 +510,21 @@ mod tests { let input_share_1 = Share::get_decoded_with_param(&(&vdaf, 0), server_1_share).unwrap(); let input_share_2 = Share::get_decoded_with_param(&(&vdaf, 1), server_2_share).unwrap(); let (prepare_state_1, prepare_share_1) = vdaf - .prepare_init(&[0; 32], 0, &(), &[0; 16], &(), &input_share_1) + .prepare_init(&[0; 32], CTX_STR, 0, &(), &[0; 16], &(), &input_share_1) .unwrap(); let (prepare_state_2, prepare_share_2) = vdaf - .prepare_init(&[0; 32], 1, &(), &[0; 16], &(), &input_share_2) - .unwrap(); - vdaf.prepare_shares_to_prepare_message(&(), [prepare_share_1, prepare_share_2]) + .prepare_init(&[0; 32], CTX_STR, 1, &(), &[0; 16], &(), &input_share_2) .unwrap(); - let transition_1 = vdaf.prepare_next(prepare_state_1, ()).unwrap(); + vdaf.prepare_shares_to_prepare_message( + CTX_STR, + &(), + [prepare_share_1, prepare_share_2], + ) + .unwrap(); + let transition_1 = vdaf.prepare_next(CTX_STR, prepare_state_1, ()).unwrap(); let output_share_1 = assert_matches!(transition_1, PrepareTransition::Finish(out) => out); - let transition_2 = vdaf.prepare_next(prepare_state_2, ()).unwrap(); + let transition_2 = vdaf.prepare_next(CTX_STR, prepare_state_2, ()).unwrap(); let output_share_2 = assert_matches!(transition_2, PrepareTransition::Finish(out) => out); leader_output_shares.push(output_share_1); diff --git a/src/vdaf/prio2/server.rs b/src/vdaf/prio2/server.rs index 6e457e51d..26aa4a621 100644 --- a/src/vdaf/prio2/server.rs +++ b/src/vdaf/prio2/server.rs @@ -205,6 +205,7 @@ mod tests { prio2::{ client::{proof_length, unpack_proof_mut}, server::test_util::Server, + tests::CTX_STR, Prio2, }, Client, Share, ShareDecodingParameter, @@ -285,7 +286,7 @@ mod tests { } let vdaf = Prio2::new(dim).unwrap(); - let (_, shares) = vdaf.shard(&data, &[0; 16]).unwrap(); + let (_, shares) = vdaf.shard(CTX_STR, &data, &[0; 16]).unwrap(); let share1_original = shares[0].get_encoded().unwrap(); let share2 = shares[1].get_encoded().unwrap(); diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index f0d482dea..3936730ec 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -379,6 +379,7 @@ impl Prio3Average { /// use rand::prelude::*; /// /// let num_shares = 2; +/// let ctx = b"my context str"; /// let vdaf = Prio3::new_count(num_shares).unwrap(); /// /// let mut out_shares = vec![vec![]; num_shares.into()]; @@ -388,7 +389,7 @@ impl Prio3Average { /// for measurement in measurements { /// // Shard /// let nonce = rng.gen::<[u8; 16]>(); -/// let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); +/// let (public_share, input_shares) = vdaf.shard(ctx, &measurement, &nonce).unwrap(); /// /// // Prepare /// let mut prep_states = vec![]; @@ -396,6 +397,7 @@ impl Prio3Average { /// for (agg_id, input_share) in input_shares.iter().enumerate() { /// let (state, share) = vdaf.prepare_init( /// &verify_key, +/// ctx, /// agg_id, /// &(), /// &nonce, @@ -405,10 +407,10 @@ impl Prio3Average { /// prep_states.push(state); /// prep_shares.push(share); /// } -/// let prep_msg = vdaf.prepare_shares_to_prepare_message(&(), prep_shares).unwrap(); +/// let prep_msg = vdaf.prepare_shares_to_prepare_message(ctx, &(), prep_shares).unwrap(); /// /// for (agg_id, state) in prep_states.into_iter().enumerate() { -/// let out_share = match vdaf.prepare_next(state, prep_msg.clone()).unwrap() { +/// let out_share = match vdaf.prepare_next(ctx, state, prep_msg.clone()).unwrap() { /// PrepareTransition::Finish(out_share) => out_share, /// _ => panic!("unexpected transition"), /// }; @@ -481,10 +483,10 @@ where self.num_proofs.into() } - fn derive_prove_rands(&self, prove_rand_seed: &Seed) -> Vec { + fn derive_prove_rands(&self, ctx: &[u8], prove_rand_seed: &Seed) -> Vec { P::seed_stream( prove_rand_seed, - &self.domain_separation_tag(DST_PROVE_RANDOMNESS), + &self.domain_separation_tag(DST_PROVE_RANDOMNESS, ctx), &[self.num_proofs], ) .into_field_vec(self.typ.prove_rand_len() * self.num_proofs()) @@ -492,11 +494,12 @@ where fn derive_joint_rand_seed<'a>( &self, + ctx: &[u8], joint_rand_parts: impl Iterator>, ) -> Seed { let mut xof = P::init( &[0; SEED_SIZE], - &self.domain_separation_tag(DST_JOINT_RAND_SEED), + &self.domain_separation_tag(DST_JOINT_RAND_SEED, ctx), ); for part in joint_rand_parts { xof.update(part.as_ref()); @@ -506,12 +509,13 @@ where fn derive_joint_rands<'a>( &self, + ctx: &[u8], joint_rand_parts: impl Iterator>, ) -> (Seed, Vec) { - let joint_rand_seed = self.derive_joint_rand_seed(joint_rand_parts); + let joint_rand_seed = self.derive_joint_rand_seed(ctx, joint_rand_parts); let joint_rands = P::seed_stream( &joint_rand_seed, - &self.domain_separation_tag(DST_JOINT_RANDOMNESS), + &self.domain_separation_tag(DST_JOINT_RANDOMNESS, ctx), &[self.num_proofs], ) .into_field_vec(self.typ.joint_rand_len() * self.num_proofs()); @@ -521,20 +525,26 @@ where fn derive_helper_proofs_share( &self, + ctx: &[u8], proofs_share_seed: &Seed, agg_id: u8, ) -> Prng { Prng::from_seed_stream(P::seed_stream( proofs_share_seed, - &self.domain_separation_tag(DST_PROOF_SHARE), + &self.domain_separation_tag(DST_PROOF_SHARE, ctx), &[self.num_proofs, agg_id], )) } - fn derive_query_rands(&self, verify_key: &[u8; SEED_SIZE], nonce: &[u8; 16]) -> Vec { + fn derive_query_rands( + &self, + verify_key: &[u8; SEED_SIZE], + ctx: &[u8], + nonce: &[u8; 16], + ) -> Vec { let mut xof = P::init( verify_key, - &self.domain_separation_tag(DST_QUERY_RANDOMNESS), + &self.domain_separation_tag(DST_QUERY_RANDOMNESS, ctx), ); xof.update(&[self.num_proofs]); xof.update(nonce); @@ -562,6 +572,7 @@ where #[allow(clippy::type_complexity)] pub(crate) fn shard_with_random( &self, + ctx: &[u8], measurement: &T::Measurement, nonce: &[u8; N], random: &[u8], @@ -598,7 +609,7 @@ where let proof_share_seed = random_seeds.next().unwrap().try_into().unwrap(); let measurement_share_prng: Prng = Prng::from_seed_stream(P::seed_stream( &Seed(measurement_share_seed), - &self.domain_separation_tag(DST_MEASUREMENT_SHARE), + &self.domain_separation_tag(DST_MEASUREMENT_SHARE, ctx), &[agg_id], )); let joint_rand_blind = if let Some(helper_joint_rand_parts) = @@ -607,7 +618,7 @@ where let joint_rand_blind = random_seeds.next().unwrap().try_into().unwrap(); let mut joint_rand_part_xof = P::init( &joint_rand_blind, - &self.domain_separation_tag(DST_JOINT_RAND_PART), + &self.domain_separation_tag(DST_JOINT_RAND_PART, ctx), ); joint_rand_part_xof.update(&[agg_id]); // Aggregator ID joint_rand_part_xof.update(nonce); @@ -653,7 +664,7 @@ where let mut joint_rand_part_xof = P::init( leader_blind.as_ref(), - &self.domain_separation_tag(DST_JOINT_RAND_PART), + &self.domain_separation_tag(DST_JOINT_RAND_PART, ctx), ); joint_rand_part_xof.update(&[0]); // Aggregator ID joint_rand_part_xof.update(nonce); @@ -684,13 +695,14 @@ where let joint_rands = public_share .joint_rand_parts .as_ref() - .map(|joint_rand_parts| self.derive_joint_rands(joint_rand_parts.iter()).1) + .map(|joint_rand_parts| self.derive_joint_rands(ctx, joint_rand_parts.iter()).1) .unwrap_or_default(); // Generate the proofs. - let prove_rands = self.derive_prove_rands(&Seed::from_bytes( - random_seeds.next().unwrap().try_into().unwrap(), - )); + let prove_rands = self.derive_prove_rands( + ctx, + &Seed::from_bytes(random_seeds.next().unwrap().try_into().unwrap()), + ); let mut leader_proofs_share = Vec::with_capacity(self.typ.proof_len() * self.num_proofs()); for p in 0..self.num_proofs() { let prove_rand = @@ -707,14 +719,14 @@ where // Generate the proof shares and distribute the joint randomness seed hints. for (j, helper) in helper_shares.iter_mut().enumerate() { - for (x, y) in - leader_proofs_share - .iter_mut() - .zip(self.derive_helper_proofs_share( - &helper.proofs_share, - u8::try_from(j).unwrap() + 1, - )) - .take(self.typ.proof_len() * self.num_proofs()) + for (x, y) in leader_proofs_share + .iter_mut() + .zip(self.derive_helper_proofs_share( + ctx, + &helper.proofs_share, + u8::try_from(j).unwrap() + 1, + )) + .take(self.typ.proof_len() * self.num_proofs()) { *x -= y; } @@ -1083,12 +1095,13 @@ where #[allow(clippy::type_complexity)] fn shard( &self, + ctx: &[u8], measurement: &T::Measurement, nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec>), VdafError> { let mut random = vec![0u8; self.random_size()]; getrandom::getrandom(&mut random)?; - self.shard_with_random(measurement, nonce, &random) + self.shard_with_random(ctx, measurement, nonce, &random) } } @@ -1213,6 +1226,7 @@ where fn prepare_init( &self, verify_key: &[u8; SEED_SIZE], + ctx: &[u8], agg_id: usize, _agg_param: &Self::AggregationParam, nonce: &[u8; 16], @@ -1232,7 +1246,7 @@ where Share::Helper(ref seed) => Cow::Owned( P::seed_stream( seed, - &self.domain_separation_tag(DST_MEASUREMENT_SHARE), + &self.domain_separation_tag(DST_MEASUREMENT_SHARE, ctx), &[agg_id], ) .into_field_vec(self.typ.input_len()), @@ -1242,7 +1256,7 @@ where let proofs_share = match msg.proofs_share { Share::Leader(ref data) => Cow::Borrowed(data), Share::Helper(ref seed) => Cow::Owned( - self.derive_helper_proofs_share(seed, agg_id) + self.derive_helper_proofs_share(ctx, seed, agg_id) .take(self.typ.proof_len() * self.num_proofs()) .collect::>(), ), @@ -1252,7 +1266,7 @@ where let (joint_rand_seed, joint_rand_part, joint_rands) = if self.typ.joint_rand_len() > 0 { let mut joint_rand_part_xof = P::init( msg.joint_rand_blind.as_ref().unwrap().as_ref(), - &self.domain_separation_tag(DST_JOINT_RAND_PART), + &self.domain_separation_tag(DST_JOINT_RAND_PART, ctx), ); joint_rand_part_xof.update(&[agg_id]); joint_rand_part_xof.update(nonce); @@ -1288,7 +1302,7 @@ where ); let (joint_rand_seed, joint_rands) = - self.derive_joint_rands(corrected_joint_rand_parts); + self.derive_joint_rands(ctx, corrected_joint_rand_parts); ( Some(joint_rand_seed), @@ -1300,7 +1314,7 @@ where }; // Run the query-generation algorithm. - let query_rands = self.derive_query_rands(verify_key, nonce); + let query_rands = self.derive_query_rands(verify_key, ctx, nonce); let mut verifiers_share = Vec::with_capacity(self.typ.verifier_len() * self.num_proofs()); for p in 0..self.num_proofs() { let query_rand = @@ -1337,6 +1351,7 @@ where M: IntoIterator>, >( &self, + ctx: &[u8], _: &Self::AggregationParam, inputs: M, ) -> Result, VdafError> { @@ -1381,7 +1396,7 @@ where } let joint_rand_seed = if self.typ.joint_rand_len() > 0 { - Some(self.derive_joint_rand_seed(joint_rand_parts.iter())) + Some(self.derive_joint_rand_seed(ctx, joint_rand_parts.iter())) } else { None }; @@ -1391,6 +1406,7 @@ where fn prepare_next( &self, + ctx: &[u8], step: Prio3PrepareState, msg: Prio3PrepareMessage, ) -> Result, VdafError> { @@ -1413,7 +1429,7 @@ where let measurement_share = match step.measurement_share { Share::Leader(data) => data, Share::Helper(seed) => { - let dst = self.domain_separation_tag(DST_MEASUREMENT_SHARE); + let dst = self.domain_separation_tag(DST_MEASUREMENT_SHARE, ctx); P::seed_stream(&seed, &dst, &[step.agg_id]).into_field_vec(self.typ.input_len()) } }; @@ -1627,12 +1643,14 @@ mod tests { }; use rand::prelude::*; + const CTX_STR: &[u8] = b"prio3 ctx"; + #[test] fn test_prio3_count() { let prio3 = Prio3::new_count(2).unwrap(); assert_eq!( - run_vdaf(&prio3, &(), [true, false, false, true, true]).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), [true, false, false, true, true]).unwrap(), 3 ); @@ -1641,17 +1659,41 @@ mod tests { thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); - let (public_share, input_shares) = prio3.shard(&false, &nonce).unwrap(); - run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares).unwrap(); + let (public_share, input_shares) = prio3.shard(CTX_STR, &false, &nonce).unwrap(); + run_vdaf_prepare( + &prio3, + &verify_key, + CTX_STR, + &(), + &nonce, + public_share, + input_shares, + ) + .unwrap(); - let (public_share, input_shares) = prio3.shard(&true, &nonce).unwrap(); - run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares).unwrap(); + let (public_share, input_shares) = prio3.shard(CTX_STR, &true, &nonce).unwrap(); + run_vdaf_prepare( + &prio3, + &verify_key, + CTX_STR, + &(), + &nonce, + public_share, + input_shares, + ) + .unwrap(); test_serialization(&prio3, &true, &nonce).unwrap(); let prio3_extra_helper = Prio3::new_count(3).unwrap(); assert_eq!( - run_vdaf(&prio3_extra_helper, &(), [true, false, false, true, true]).unwrap(), + run_vdaf( + CTX_STR, + &prio3_extra_helper, + &(), + [true, false, false, true, true] + ) + .unwrap(), 3, ); } @@ -1661,7 +1703,7 @@ mod tests { let prio3 = Prio3::new_sum(3, 16).unwrap(); assert_eq!( - run_vdaf(&prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(), (1 << 16) + 1 ); @@ -1669,18 +1711,34 @@ mod tests { thread_rng().fill(&mut verify_key[..]); let nonce = [0; 16]; - let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap(); + let (public_share, mut input_shares) = prio3.shard(CTX_STR, &1, &nonce).unwrap(); assert_matches!(input_shares[0].measurement_share, Share::Leader(ref mut data) => { data[0] += Field128::one(); }); - let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + let result = run_vdaf_prepare( + &prio3, + &verify_key, + CTX_STR, + &(), + &nonce, + public_share, + input_shares, + ); assert_matches!(result, Err(VdafError::Uncategorized(_))); - let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap(); + let (public_share, mut input_shares) = prio3.shard(CTX_STR, &1, &nonce).unwrap(); assert_matches!(input_shares[0].proofs_share, Share::Leader(ref mut data) => { data[0] += Field128::one(); }); - let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + let result = run_vdaf_prepare( + &prio3, + &verify_key, + CTX_STR, + &(), + &nonce, + public_share, + input_shares, + ); assert_matches!(result, Err(VdafError::Uncategorized(_))); test_serialization(&prio3, &1, &nonce).unwrap(); @@ -1691,6 +1749,7 @@ mod tests { let prio3 = Prio3::new_sum_vec(2, 2, 20, 4).unwrap(); assert_eq!( run_vdaf( + CTX_STR, &prio3, &(), [ @@ -1715,6 +1774,7 @@ mod tests { assert_eq!( run_vdaf( + CTX_STR, &prio3, &(), [ @@ -1734,6 +1794,7 @@ mod tests { let prio3 = Prio3::new_sum_vec_multithreaded(2, 2, 20, 4).unwrap(); assert_eq!( run_vdaf( + CTX_STR, &prio3, &(), [ @@ -1787,7 +1848,7 @@ mod tests { let measurements = [fp_vec.clone(), fp_vec]; assert_eq!( - run_vdaf(&prio3, &(), measurements).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), measurements).unwrap(), vec![0.0; SIZE] ); } @@ -1888,21 +1949,21 @@ mod tests { // positive entries let fp_list = [fp_vec1, fp_vec2]; assert_eq!( - run_vdaf(&prio3, &(), fp_list).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), fp_list).unwrap(), vec!(0.5, 0.25, 0.125), ); // negative entries let fp_list2 = [fp_vec3, fp_vec4]; assert_eq!( - run_vdaf(&prio3, &(), fp_list2).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), fp_list2).unwrap(), vec!(-0.5, -0.25, -0.125), ); // both let fp_list3 = [fp_vec5, fp_vec6]; assert_eq!( - run_vdaf(&prio3, &(), fp_list3).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), fp_list3).unwrap(), vec!(0.5, 0.0, 0.0), ); @@ -1912,31 +1973,52 @@ mod tests { thread_rng().fill(&mut nonce); let (public_share, mut input_shares) = prio3 - .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) + .shard(CTX_STR, &vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) .unwrap(); input_shares[0].joint_rand_blind.as_mut().unwrap().0[0] ^= 255; - let result = - run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + let result = run_vdaf_prepare( + &prio3, + &verify_key, + CTX_STR, + &(), + &nonce, + public_share, + input_shares, + ); assert_matches!(result, Err(VdafError::Uncategorized(_))); let (public_share, mut input_shares) = prio3 - .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) + .shard(CTX_STR, &vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) .unwrap(); assert_matches!(input_shares[0].measurement_share, Share::Leader(ref mut data) => { data[0] += Field128::one(); }); - let result = - run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + let result = run_vdaf_prepare( + &prio3, + &verify_key, + CTX_STR, + &(), + &nonce, + public_share, + input_shares, + ); assert_matches!(result, Err(VdafError::Uncategorized(_))); let (public_share, mut input_shares) = prio3 - .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) + .shard(CTX_STR, &vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) .unwrap(); assert_matches!(input_shares[0].proofs_share, Share::Leader(ref mut data) => { data[0] += Field128::one(); }); - let result = - run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + let result = run_vdaf_prepare( + &prio3, + &verify_key, + CTX_STR, + &(), + &nonce, + public_share, + input_shares, + ); assert_matches!(result, Err(VdafError::Uncategorized(_))); test_serialization(&prio3, &vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce).unwrap(); @@ -1948,13 +2030,25 @@ mod tests { let prio3 = Prio3::new_histogram(2, 4, 2).unwrap(); assert_eq!( - run_vdaf(&prio3, &(), [0, 1, 2, 3]).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), [0, 1, 2, 3]).unwrap(), vec![1, 1, 1, 1] ); - assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]); - assert_eq!(run_vdaf(&prio3, &(), [1]).unwrap(), vec![0, 1, 0, 0]); - assert_eq!(run_vdaf(&prio3, &(), [2]).unwrap(), vec![0, 0, 1, 0]); - assert_eq!(run_vdaf(&prio3, &(), [3]).unwrap(), vec![0, 0, 0, 1]); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [0]).unwrap(), + vec![1, 0, 0, 0] + ); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [1]).unwrap(), + vec![0, 1, 0, 0] + ); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [2]).unwrap(), + vec![0, 0, 1, 0] + ); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [3]).unwrap(), + vec![0, 0, 0, 1] + ); test_serialization(&prio3, &3, &[0; 16]).unwrap(); } @@ -1964,13 +2058,25 @@ mod tests { let prio3 = Prio3::new_histogram_multithreaded(2, 4, 2).unwrap(); assert_eq!( - run_vdaf(&prio3, &(), [0, 1, 2, 3]).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), [0, 1, 2, 3]).unwrap(), vec![1, 1, 1, 1] ); - assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]); - assert_eq!(run_vdaf(&prio3, &(), [1]).unwrap(), vec![0, 1, 0, 0]); - assert_eq!(run_vdaf(&prio3, &(), [2]).unwrap(), vec![0, 0, 1, 0]); - assert_eq!(run_vdaf(&prio3, &(), [3]).unwrap(), vec![0, 0, 0, 1]); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [0]).unwrap(), + vec![1, 0, 0, 0] + ); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [1]).unwrap(), + vec![0, 1, 0, 0] + ); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [2]).unwrap(), + vec![0, 0, 1, 0] + ); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [3]).unwrap(), + vec![0, 0, 0, 1] + ); test_serialization(&prio3, &3, &[0; 16]).unwrap(); } @@ -1978,11 +2084,14 @@ mod tests { fn test_prio3_average() { let prio3 = Prio3::new_average(2, 64).unwrap(); - assert_eq!(run_vdaf(&prio3, &(), [17, 8]).unwrap(), 12.5f64); - assert_eq!(run_vdaf(&prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64); - assert_eq!(run_vdaf(&prio3, &(), [0, 0, 0, 1]).unwrap(), 0.25f64); + assert_eq!(run_vdaf(CTX_STR, &prio3, &(), [17, 8]).unwrap(), 12.5f64); + assert_eq!(run_vdaf(CTX_STR, &prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64); + assert_eq!( + run_vdaf(CTX_STR, &prio3, &(), [0, 0, 0, 1]).unwrap(), + 0.25f64 + ); assert_eq!( - run_vdaf(&prio3, &(), [1, 11, 111, 1111, 3, 8]).unwrap(), + run_vdaf(CTX_STR, &prio3, &(), [1, 11, 111, 1111, 3, 8]).unwrap(), 207.5f64 ); } @@ -1990,7 +2099,7 @@ mod tests { #[test] fn test_prio3_input_share() { let prio3 = Prio3::new_sum(5, 16).unwrap(); - let (_public_share, input_shares) = prio3.shard(&1, &[0; 16]).unwrap(); + let (_public_share, input_shares) = prio3.shard(CTX_STR, &1, &[0; 16]).unwrap(); // Check that seed shares are distinct. for (i, x) in input_shares.iter().enumerate() { @@ -2023,7 +2132,7 @@ mod tests { { let mut verify_key = [0; SEED_SIZE]; thread_rng().fill(&mut verify_key[..]); - let (public_share, input_shares) = prio3.shard(measurement, nonce)?; + let (public_share, input_shares) = prio3.shard(CTX_STR, measurement, nonce)?; let encoded_public_share = public_share.get_encoded().unwrap(); let decoded_public_share = @@ -2050,8 +2159,15 @@ mod tests { let mut prepare_shares = Vec::new(); let mut last_prepare_state = None; for (agg_id, input_share) in input_shares.iter().enumerate() { - let (prepare_state, prepare_share) = - prio3.prepare_init(&verify_key, agg_id, &(), nonce, &public_share, input_share)?; + let (prepare_state, prepare_share) = prio3.prepare_init( + &verify_key, + CTX_STR, + agg_id, + &(), + nonce, + &public_share, + input_share, + )?; let encoded_prepare_state = prepare_state.get_encoded().unwrap(); let decoded_prepare_state = @@ -2078,7 +2194,7 @@ mod tests { } let prepare_message = prio3 - .prepare_shares_to_prepare_message(&(), prepare_shares) + .prepare_shares_to_prepare_message(CTX_STR, &(), prepare_shares) .unwrap(); let encoded_prepare_message = prepare_message.get_encoded().unwrap(); diff --git a/src/vdaf/prio3_test.rs b/src/vdaf/prio3_test.rs index e7627be72..10b72c739 100644 --- a/src/vdaf/prio3_test.rs +++ b/src/vdaf/prio3_test.rs @@ -39,6 +39,7 @@ struct TPrio3Prep { #[derive(Deserialize, Serialize)] struct TPrio3 { + ctx: TEncoded, verify_key: TEncoded, shares: u8, prep: Vec>, @@ -63,6 +64,7 @@ macro_rules! err { fn check_prep_test_vec( prio3: &Prio3, verify_key: &[u8; SEED_SIZE], + ctx: &[u8], test_num: usize, t: &TPrio3Prep, ) -> Vec> @@ -74,7 +76,7 @@ where { let nonce = <[u8; 16]>::try_from(t.nonce.clone()).unwrap(); let (public_share, input_shares) = prio3 - .shard_with_random(&t.measurement.clone().into(), &nonce, &t.rand) + .shard_with_random(ctx, &t.measurement.clone().into(), &nonce, &t.rand) .expect("failed to generate input shares"); assert_eq!( @@ -100,7 +102,15 @@ where let mut prep_shares = Vec::new(); for (agg_id, input_share) in input_shares.iter().enumerate() { let (state, prep_share) = prio3 - .prepare_init(verify_key, agg_id, &(), &nonce, &public_share, input_share) + .prepare_init( + verify_key, + ctx, + agg_id, + &(), + &nonce, + &public_share, + input_share, + ) .unwrap_or_else(|e| err!(test_num, e, "prep state init")); states.push(state); prep_shares.push(prep_share); @@ -122,14 +132,17 @@ where } let inbound = prio3 - .prepare_shares_to_prepare_message(&(), prep_shares) + .prepare_shares_to_prepare_message(ctx, &(), prep_shares) .unwrap_or_else(|e| err!(test_num, e, "prep preprocess")); assert_eq!(t.prep_messages.len(), 1); assert_eq!(inbound.get_encoded().unwrap(), t.prep_messages[0].as_ref()); let mut out_shares = Vec::new(); for state in states.iter_mut() { - match prio3.prepare_next(state.clone(), inbound.clone()).unwrap() { + match prio3 + .prepare_next(ctx, state.clone(), inbound.clone()) + .unwrap() + { PrepareTransition::Finish(out_share) => { out_shares.push(out_share); } @@ -164,10 +177,11 @@ where P: Xof, { let verify_key = t.verify_key.as_ref().try_into().unwrap(); + let ctx = t.ctx.as_ref(); let mut all_output_shares = vec![Vec::new(); prio3.num_aggregators()]; for (test_num, p) in t.prep.iter().enumerate() { - let output_shares = check_prep_test_vec(prio3, verify_key, test_num, p); + let output_shares = check_prep_test_vec(prio3, verify_key, ctx, test_num, p); for (aggregator_output_shares, output_share) in all_output_shares.iter_mut().zip(output_shares.into_iter()) { @@ -250,6 +264,11 @@ mod tests { use super::{check_test_vec, check_test_vec_custom_de, Prio3CountMeasurement}; + // All the below tests are not passing. We ignore them until the rest of the repo is in a state + // where we can regenerate the JSON test vectors. + // Tracking issue https://github.com/divviup/libprio-rs/issues/1122 + + #[ignore] #[test] fn test_vec_prio3_count() { for test_vector_str in [ @@ -263,10 +282,6 @@ mod tests { } } - // All the below tests are not passing. We ignore them until the rest of the repo is in a state - // where we can regenerate the JSON test vectors. - // Tracking issue https://github.com/divviup/libprio-rs/issues/1122 - #[ignore] #[test] fn test_vec_prio3_sum() {