From 6a7a91672ac80c5d7b345cc2307bd3c332873c53 Mon Sep 17 00:00:00 2001 From: Christopher Patton Date: Fri, 3 Jan 2025 08:08:29 -0800 Subject: [PATCH] mastic: Align XOF computations with the spec (#1182) --- benches/speed_tests.rs | 21 +++- src/flp/szk.rs | 159 ++++++++++++++++------------- src/vdaf/mastic.rs | 146 ++++++++++++++++++--------- src/vidpf.rs | 220 +++++++++++++++++++++++------------------ 4 files changed, 330 insertions(+), 216 deletions(-) diff --git a/benches/speed_tests.rs b/benches/speed_tests.rs index b2356f56..f36e6e7e 100644 --- a/benches/speed_tests.rs +++ b/benches/speed_tests.rs @@ -859,10 +859,12 @@ fn vidpf(c: &mut Criterion) { let input = VidpfInput::from_bools(&bits); let weight = VidpfWeight::from(vec![Field255::one(), Field255::one()]); - let vidpf = Vidpf::>::new(2); + let vidpf = Vidpf::>::new(bits.len(), 2).unwrap(); b.iter(|| { - let _ = vidpf.gen(&input, &weight, NONCE).unwrap(); + let _ = vidpf + .gen(b"some application", &input, &weight, NONCE) + .unwrap(); }); }); } @@ -875,13 +877,22 @@ fn vidpf(c: &mut Criterion) { let bits = iter::repeat_with(random).take(size).collect::>(); let input = VidpfInput::from_bools(&bits); let weight = VidpfWeight::from(vec![Field255::one(), Field255::one()]); - let vidpf = Vidpf::>::new(2); + let vidpf = Vidpf::>::new(bits.len(), 2).unwrap(); - let (public, keys) = vidpf.gen(&input, &weight, NONCE).unwrap(); + let (public, keys) = vidpf + .gen(b"some application", &input, &weight, NONCE) + .unwrap(); b.iter(|| { let _ = vidpf - .eval(VidpfServerId::S0, &keys[0], &public, &input, NONCE) + .eval( + b"some application", + VidpfServerId::S0, + &keys[0], + &public, + &input, + NONCE, + ) .unwrap(); }); }); diff --git a/src/flp/szk.rs b/src/flp/szk.rs index ee3633d3..9c49f5a5 100644 --- a/src/flp/szk.rs +++ b/src/flp/szk.rs @@ -16,7 +16,10 @@ use crate::{ field::{decode_fieldvec, encode_fieldvec, FieldElement}, flp::{FlpError, Type}, prng::{Prng, PrngError}, - vdaf::xof::{IntoFieldVec, Seed, Xof, XofTurboShake128}, + vdaf::{ + mastic::{self, USAGE_PROOF_SHARE}, + xof::{IntoFieldVec, Seed, Xof, XofTurboShake128}, + }, }; use std::borrow::Cow; use std::ops::BitAnd; @@ -24,14 +27,6 @@ use std::{io::Cursor, marker::PhantomData}; use subtle::{Choice, ConstantTimeEq}; // Domain separation tags -const DST_PROVE_RANDOMNESS: u16 = 0; -const DST_PROOF_SHARE: u16 = 1; -const DST_QUERY_RANDOMNESS: u16 = 2; -const DST_JOINT_RAND_SEED: u16 = 3; -const DST_JOINT_RAND_PART: u16 = 4; -const DST_JOINT_RANDOMNESS: u16 = 5; - -const MASTIC_VERSION: u8 = 0; /// Errors propagated by methods in this module. #[derive(Debug, thiserror::Error)] @@ -308,14 +303,13 @@ where { /// The Type representing the specific FLP system used to prove validity of an input. pub(crate) typ: T, - algorithm_id: u32, phantom: PhantomData

, } impl Szk { /// Create an instance of [`Szk`] using [`XofTurboShake128`]. - pub fn new_turboshake128(typ: T, algorithm_id: u32) -> Self { - Szk::new(typ, algorithm_id) + pub fn new_turboshake128(typ: T) -> Self { + Szk::new(typ) } } @@ -326,29 +320,19 @@ where { /// Construct an instance of this sharedZK proof system with the underlying /// FLP. - pub fn new(typ: T, algorithm_id: u32) -> Self { + pub fn new(typ: T) -> Self { Self { typ, - algorithm_id, phantom: PhantomData, } } - fn domain_separation_tag(&self, usage: u16) -> [u8; 8] { - let mut dst = [0u8; 8]; - dst[0] = MASTIC_VERSION; - dst[1] = 0; // algorithm class - dst[2..6].copy_from_slice(&(self.algorithm_id).to_be_bytes()); - dst[6..8].copy_from_slice(&usage.to_be_bytes()); - dst - } - /// Derive a vector of random field elements for consumption by the FLP /// prover. - fn derive_prove_rand(&self, prove_rand_seed: &Seed) -> Vec { + fn derive_prove_rand(&self, prove_rand_seed: &Seed, ctx: &[u8]) -> Vec { P::seed_stream( prove_rand_seed, - &[&self.domain_separation_tag(DST_PROVE_RANDOMNESS)], + &[&mastic::dst_usage(mastic::USAGE_PROVE_RAND), ctx], &[], ) .into_field_vec(self.typ.prove_rand_len()) @@ -359,10 +343,11 @@ where aggregator_blind: &Seed, measurement_share: &[T::Field], nonce: &[u8; 16], + ctx: &[u8], ) -> Result, SzkError> { let mut xof = P::init( aggregator_blind.as_ref(), - &[&self.domain_separation_tag(DST_JOINT_RAND_PART)], + &[&mastic::dst_usage(mastic::USAGE_JOINT_RAND_PART), ctx], ); xof.update(nonce); // Encode measurement_share (currently an array of field elements) into @@ -380,10 +365,11 @@ where &self, leader_joint_rand_part: &Seed, helper_joint_rand_part: &Seed, + ctx: &[u8], ) -> Seed { let mut xof = P::init( &[0; SEED_SIZE], - &[&self.domain_separation_tag(DST_JOINT_RAND_SEED)], + &[&mastic::dst_usage(mastic::USAGE_JOINT_RAND_SEED), ctx], ); xof.update(&leader_joint_rand_part.0); xof.update(&helper_joint_rand_part.0); @@ -394,12 +380,13 @@ where &self, leader_joint_rand_part: &Seed, helper_joint_rand_part: &Seed, + ctx: &[u8], ) -> (Seed, Vec) { let joint_rand_seed = - self.derive_joint_rand_seed(leader_joint_rand_part, helper_joint_rand_part); + self.derive_joint_rand_seed(leader_joint_rand_part, helper_joint_rand_part, ctx); let joint_rand = P::seed_stream( &joint_rand_seed, - &[&self.domain_separation_tag(DST_JOINT_RANDOMNESS)], + &[&mastic::dst_usage(mastic::USAGE_JOINT_RAND), ctx], &[], ) .into_field_vec(self.typ.joint_rand_len()); @@ -407,22 +394,33 @@ where (joint_rand_seed, joint_rand) } - fn derive_helper_proof_share(&self, proof_share_seed: &Seed) -> Vec { + fn derive_helper_proof_share( + &self, + proof_share_seed: &Seed, + ctx: &[u8], + ) -> Vec { Prng::from_seed_stream(P::seed_stream( proof_share_seed, - &[&self.domain_separation_tag(DST_PROOF_SHARE)], + &[&mastic::dst_usage(USAGE_PROOF_SHARE), ctx], &[], )) .take(self.typ.proof_len()) .collect() } - fn derive_query_rand(&self, verify_key: &[u8; SEED_SIZE], nonce: &[u8; 16]) -> Vec { + fn derive_query_rand( + &self, + verify_key: &[u8; SEED_SIZE], + nonce: &[u8; 16], + level: u16, + ctx: &[u8], + ) -> Vec { let mut xof = P::init( verify_key, - &[&self.domain_separation_tag(DST_QUERY_RANDOMNESS)], + &[&mastic::dst_usage(mastic::USAGE_QUERY_RAND), ctx], ); xof.update(nonce); + xof.update(&level.to_le_bytes()); xof.into_seed_stream() .into_field_vec(self.typ.query_rand_len()) } @@ -439,8 +437,10 @@ where /// joint randomness. /// In this case, the helper uses the same seed to derive its proof share and /// joint randomness. + #[allow(clippy::too_many_arguments)] pub(crate) fn prove( &self, + ctx: &[u8], leader_input_share: &[T::Field], helper_input_share: &[T::Field], encoded_measurement: &[T::Field], @@ -457,11 +457,14 @@ where let (leader_blind_and_helper_joint_rand_part_opt, leader_joint_rand_part_opt, joint_rand) = if let Some(leader_seed) = leader_seed_opt { let leader_joint_rand_part = - self.derive_joint_rand_part(&leader_seed, leader_input_share, nonce)?; + self.derive_joint_rand_part(&leader_seed, leader_input_share, nonce, ctx)?; let helper_joint_rand_part = - self.derive_joint_rand_part(&helper_seed, helper_input_share, nonce)?; - let (_joint_rand_seed, joint_rand) = self - .derive_joint_rand_and_seed(&leader_joint_rand_part, &helper_joint_rand_part); + self.derive_joint_rand_part(&helper_seed, helper_input_share, nonce, ctx)?; + let (_joint_rand_seed, joint_rand) = self.derive_joint_rand_and_seed( + &leader_joint_rand_part, + &helper_joint_rand_part, + ctx, + ); ( Some((leader_seed, helper_joint_rand_part)), Some(leader_joint_rand_part), @@ -471,7 +474,7 @@ where (None, None, Vec::new()) }; - let prove_rand = self.derive_prove_rand(&prove_rand_seed); + let prove_rand = self.derive_prove_rand(&prove_rand_seed, ctx); let mut leader_proof_share = self.typ .prove(encoded_measurement, &prove_rand, &joint_rand)?; @@ -479,7 +482,7 @@ where // Generate the proof shares. for (x, y) in leader_proof_share .iter_mut() - .zip(self.derive_helper_proof_share(&helper_seed)) + .zip(self.derive_helper_proof_share(&helper_seed, ctx)) { *x -= y; } @@ -498,12 +501,14 @@ where pub(crate) fn query( &self, + ctx: &[u8], + level: u16, // level of the prefix tree input_share: &[T::Field], proof_share: &SzkProofShare, verify_key: &[u8; SEED_SIZE], nonce: &[u8; 16], ) -> Result<(SzkQueryShare, SzkQueryState), SzkError> { - let query_rand = self.derive_query_rand(verify_key, nonce); + let query_rand = self.derive_query_rand(verify_key, nonce, level, ctx); let flp_proof_share = match proof_share { SzkProofShare::Leader { ref uncompressed_proof_share, @@ -512,7 +517,7 @@ where SzkProofShare::Helper { ref proof_share_seed_and_blind, .. - } => Cow::Owned(self.derive_helper_proof_share(proof_share_seed_and_blind)), + } => Cow::Owned(self.derive_helper_proof_share(proof_share_seed_and_blind, ctx)), }; let (joint_rand, joint_rand_seed, joint_rand_part) = if self.requires_joint_rand() { @@ -522,11 +527,12 @@ where leader_blind_and_helper_joint_rand_part_opt, } => match leader_blind_and_helper_joint_rand_part_opt { Some((seed, helper_joint_rand_part)) => { - match self.derive_joint_rand_part(seed, input_share, nonce) { + match self.derive_joint_rand_part(seed, input_share, nonce, ctx) { Ok(leader_joint_rand_part) => ( self.derive_joint_rand_and_seed( &leader_joint_rand_part, helper_joint_rand_part, + ctx, ), leader_joint_rand_part, ), @@ -547,11 +553,13 @@ where proof_share_seed_and_blind, input_share, nonce, + ctx, ) { Ok(helper_joint_rand_part) => ( self.derive_joint_rand_and_seed( leader_joint_rand_part, &helper_joint_rand_part, + ctx, ), helper_joint_rand_part, ), @@ -590,6 +598,7 @@ where pub(crate) fn merge_query_shares( &self, + ctx: &[u8], mut leader_share: SzkQueryShare, helper_share: SzkQueryShare, ) -> Result, SzkError> { @@ -606,7 +615,7 @@ where helper_share.joint_rand_part_opt, ) { (Some(ref leader_part), Some(ref helper_part)) => Ok(SzkJointShare(Some( - self.derive_joint_rand_seed(leader_part, helper_part), + self.derive_joint_rand_seed(leader_part, helper_part, ctx), ))), (None, None) => Ok(SzkJointShare(None)), _ => Err(SzkError::Decide( @@ -689,10 +698,10 @@ mod tests { use rand::{thread_rng, Rng}; fn generic_szk_test(typ: T, encoded_measurement: &[T::Field], valid: bool) { + let ctx = b"some application context"; let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 32]; - let algorithm_id = 5; - let szk_typ = Szk::new_turboshake128(typ.clone(), algorithm_id); + let szk_typ = Szk::new_turboshake128(typ.clone()); thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); let prove_rand_seed = Seed::generate().unwrap(); @@ -707,6 +716,7 @@ mod tests { } let proof_shares = szk_typ.prove( + ctx, &leader_input_share, &helper_input_share, encoded_measurement, @@ -718,6 +728,8 @@ mod tests { let [leader_proof_share, helper_proof_share] = proof_shares.unwrap(); let (leader_query_share, leader_query_state) = szk_typ .query( + ctx, + 0, &leader_input_share, &leader_proof_share, &verify_key, @@ -726,6 +738,8 @@ mod tests { .unwrap(); let (helper_query_share, helper_query_state) = szk_typ .query( + ctx, + 0, &helper_input_share, &helper_proof_share, &verify_key, @@ -734,7 +748,7 @@ mod tests { .unwrap(); let joint_share_result = - szk_typ.merge_query_shares(leader_query_share.clone(), helper_query_share.clone()); + szk_typ.merge_query_shares(ctx, leader_query_share.clone(), helper_query_share.clone()); let joint_share = match joint_share_result { Ok(joint_share) => { let leader_decision = szk_typ @@ -776,7 +790,7 @@ mod tests { } let joint_share_res = - szk_typ.merge_query_shares(mutated_query_share, helper_query_share.clone()); + szk_typ.merge_query_shares(ctx, mutated_query_share, helper_query_share.clone()); let leader_decision = match joint_share_res { Ok(joint_share) => szk_typ .decide(leader_query_state.clone(), joint_share) @@ -790,11 +804,18 @@ mod tests { mutated_input[0] *= T::Field::from(::Integer::try_from(23).unwrap()); let (mutated_query_share, mutated_query_state) = szk_typ - .query(&mutated_input, &leader_proof_share, &verify_key, &nonce) + .query( + ctx, + 0, + &mutated_input, + &leader_proof_share, + &verify_key, + &nonce, + ) .unwrap(); let joint_share_res = - szk_typ.merge_query_shares(mutated_query_share, helper_query_share.clone()); + szk_typ.merge_query_shares(ctx, mutated_query_share, helper_query_share.clone()); let leader_decision = match joint_share_res { Ok(joint_share) => szk_typ.decide(mutated_query_state, joint_share).is_ok(), @@ -822,6 +843,8 @@ mod tests { }; let (leader_query_share, leader_query_state) = szk_typ .query( + ctx, + 0, &leader_input_share, &mutated_proof_share, &verify_key, @@ -829,7 +852,7 @@ mod tests { ) .unwrap(); let joint_share_res = - szk_typ.merge_query_shares(leader_query_share, helper_query_share.clone()); + szk_typ.merge_query_shares(ctx, leader_query_share, helper_query_share.clone()); let leader_decision = match joint_share_res { Ok(joint_share) => szk_typ @@ -847,8 +870,7 @@ mod tests { thread_rng().fill(&mut nonce[..]); let sum = Sum::::new(max_measurement).unwrap(); let encoded_measurement = sum.encode_measurement(&9).unwrap(); - let algorithm_id = 5; - let szk_typ = Szk::new_turboshake128(sum, algorithm_id); + let szk_typ = Szk::new_turboshake128(sum); let prove_rand_seed = Seed::generate().unwrap(); let helper_seed = Seed::generate().unwrap(); let leader_seed_opt = Some(Seed::generate().unwrap()); @@ -860,6 +882,7 @@ mod tests { let [leader_proof_share, _] = szk_typ .prove( + b"some application", &leader_input_share, &helper_input_share, &encoded_measurement[..], @@ -882,8 +905,7 @@ mod tests { let sumvec = SumVec::>>::new(5, 3, 3).unwrap(); let encoded_measurement = sumvec.encode_measurement(&vec![1, 16, 0]).unwrap(); - let algorithm_id = 5; - let szk_typ = Szk::new_turboshake128(sumvec, algorithm_id); + let szk_typ = Szk::new_turboshake128(sumvec); let prove_rand_seed = Seed::generate().unwrap(); let helper_seed = Seed::generate().unwrap(); let leader_seed_opt = Some(Seed::generate().unwrap()); @@ -895,6 +917,7 @@ mod tests { let [l_proof_share, _] = szk_typ .prove( + b"some application", &leader_input_share, &helper_input_share, &encoded_measurement[..], @@ -916,8 +939,7 @@ mod tests { thread_rng().fill(&mut nonce[..]); let count = Count::::new(); let encoded_measurement = count.encode_measurement(&true).unwrap(); - let algorithm_id = 5; - let szk_typ = Szk::new_turboshake128(count, algorithm_id); + let szk_typ = Szk::new_turboshake128(count); let prove_rand_seed = Seed::generate().unwrap(); let helper_seed = Seed::generate().unwrap(); let leader_seed_opt = Some(Seed::generate().unwrap()); @@ -929,6 +951,7 @@ mod tests { let [l_proof_share, _] = szk_typ .prove( + b"some application", &leader_input_share, &helper_input_share, &encoded_measurement[..], @@ -951,8 +974,7 @@ mod tests { thread_rng().fill(&mut nonce[..]); let sum = Sum::::new(max_measurement).unwrap(); let encoded_measurement = sum.encode_measurement(&9).unwrap(); - let algorithm_id = 5; - let szk_typ = Szk::new_turboshake128(sum, algorithm_id); + let szk_typ = Szk::new_turboshake128(sum); let prove_rand_seed = Seed::generate().unwrap(); let helper_seed = Seed::generate().unwrap(); let leader_seed_opt = None; @@ -964,6 +986,7 @@ mod tests { let [l_proof_share, _] = szk_typ .prove( + b"some application", &leader_input_share, &helper_input_share, &encoded_measurement[..], @@ -992,8 +1015,7 @@ mod tests { thread_rng().fill(&mut nonce[..]); let sum = Sum::::new(max_measurement).unwrap(); let encoded_measurement = sum.encode_measurement(&9).unwrap(); - let algorithm_id = 5; - let szk_typ = Szk::new_turboshake128(sum, algorithm_id); + let szk_typ = Szk::new_turboshake128(sum); let prove_rand_seed = Seed::generate().unwrap(); let helper_seed = Seed::generate().unwrap(); let leader_seed_opt = None; @@ -1005,6 +1027,7 @@ mod tests { let [_, h_proof_share] = szk_typ .prove( + b"some application", &leader_input_share, &helper_input_share, &encoded_measurement[..], @@ -1032,8 +1055,7 @@ mod tests { thread_rng().fill(&mut nonce[..]); let count = Count::::new(); let encoded_measurement = count.encode_measurement(&true).unwrap(); - let algorithm_id = 5; - let szk_typ = Szk::new_turboshake128(count, algorithm_id); + let szk_typ = Szk::new_turboshake128(count); let prove_rand_seed = Seed::generate().unwrap(); let helper_seed = Seed::generate().unwrap(); let leader_seed_opt = None; @@ -1045,6 +1067,7 @@ mod tests { let [l_proof_share, _] = szk_typ .prove( + b"some application", &leader_input_share, &helper_input_share, &encoded_measurement[..], @@ -1072,8 +1095,7 @@ mod tests { thread_rng().fill(&mut nonce[..]); let count = Count::::new(); let encoded_measurement = count.encode_measurement(&true).unwrap(); - let algorithm_id = 5; - let szk_typ = Szk::new_turboshake128(count, algorithm_id); + let szk_typ = Szk::new_turboshake128(count); let prove_rand_seed = Seed::generate().unwrap(); let helper_seed = Seed::generate().unwrap(); let leader_seed_opt = None; @@ -1085,6 +1107,7 @@ mod tests { let [_, h_proof_share] = szk_typ .prove( + b"some application", &leader_input_share, &helper_input_share, &encoded_measurement[..], @@ -1113,8 +1136,7 @@ mod tests { let sumvec = SumVec::>>::new(5, 3, 3).unwrap(); let encoded_measurement = sumvec.encode_measurement(&vec![1, 16, 0]).unwrap(); - let algorithm_id = 5; - let szk_typ = Szk::new_turboshake128(sumvec, algorithm_id); + let szk_typ = Szk::new_turboshake128(sumvec); let prove_rand_seed = Seed::generate().unwrap(); let helper_seed = Seed::generate().unwrap(); let leader_seed_opt = Some(Seed::generate().unwrap()); @@ -1126,6 +1148,7 @@ mod tests { let [l_proof_share, _] = szk_typ .prove( + b"some application", &leader_input_share, &helper_input_share, &encoded_measurement[..], @@ -1154,8 +1177,7 @@ mod tests { let sumvec = SumVec::>>::new(5, 3, 3).unwrap(); let encoded_measurement = sumvec.encode_measurement(&vec![1, 16, 0]).unwrap(); - let algorithm_id = 5; - let szk_typ = Szk::new_turboshake128(sumvec, algorithm_id); + let szk_typ = Szk::new_turboshake128(sumvec); let prove_rand_seed = Seed::generate().unwrap(); let helper_seed = Seed::generate().unwrap(); let leader_seed_opt = Some(Seed::generate().unwrap()); @@ -1167,6 +1189,7 @@ mod tests { let [_, h_proof_share] = szk_typ .prove( + b"some applicqation", &leader_input_share, &helper_input_share, &encoded_measurement[..], diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs index fd0fca91..d16fbb09 100644 --- a/src/vdaf/mastic.rs +++ b/src/vdaf/mastic.rs @@ -19,19 +19,48 @@ use crate::{ PrepareTransition, Vdaf, VdafError, }, vidpf::{ - xor_proof, Vidpf, VidpfError, VidpfInput, VidpfKey, VidpfPublicShare, VidpfServerId, - VidpfWeight, ONEHOT_PROOF_INIT, + xor_proof, Vidpf, VidpfError, VidpfInput, VidpfKey, VidpfProof, VidpfPublicShare, + VidpfServerId, VidpfWeight, }, }; +use rand::RngCore; use std::io::{Cursor, Read}; use std::ops::BitAnd; use std::slice::from_ref; use std::{collections::VecDeque, fmt::Debug}; use subtle::{Choice, ConstantTimeEq}; +use super::xof::XofTurboShake128; + const NONCE_SIZE: usize = 16; +// draft-jimouris-cfrg-mastic: +// +// ONEHOT_PROOF_INIT = XofTurboShake128(zeros(XofTurboShake128.SEED_SIZE), +// dst(b'', USAGE_ONEHOT_PROOF_INIT), +// b'').next(PROOF_SIZE) +pub(crate) const ONEHOT_PROOF_INIT: [u8; 32] = [ + 186, 76, 128, 104, 116, 50, 149, 133, 2, 164, 82, 118, 128, 155, 163, 239, 117, 95, 162, 196, + 173, 31, 244, 180, 171, 86, 176, 209, 12, 221, 28, 204, +]; + +pub(crate) const USAGE_PROVE_RAND: u8 = 0; +pub(crate) const USAGE_PROOF_SHARE: u8 = 1; +pub(crate) const USAGE_QUERY_RAND: u8 = 2; +pub(crate) const USAGE_JOINT_RAND_SEED: u8 = 3; +pub(crate) const USAGE_JOINT_RAND_PART: u8 = 4; +pub(crate) const USAGE_JOINT_RAND: u8 = 5; +pub(crate) const USAGE_ONEHOT_PROOF_HASH: u8 = 7; +pub(crate) const USAGE_NODE_PROOF: u8 = 8; +pub(crate) const USAGE_EVAL_PROOF: u8 = 9; +pub(crate) const USAGE_EXTEND: u8 = 10; +pub(crate) const USAGE_CONVERT: u8 = 11; + +pub(crate) fn dst_usage(usage: u8) -> [u8; 11] { + [b'm', b'a', b's', b't', b'i', b'c', 0, 0, 0, 0, usage] +} + /// The main struct implementing the Mastic VDAF. /// Composed of a shared zero knowledge proof system and a verifiable incremental /// distributed point function. @@ -54,15 +83,15 @@ where P: Xof, { /// Creates a new instance of Mastic, with a specific attribute length and weight type. - pub fn new(algorithm_id: u32, typ: T, bits: usize) -> Self { - let vidpf = Vidpf::new(typ.input_len() + 1); - let szk = Szk::new(typ, algorithm_id); - Self { + pub fn new(algorithm_id: u32, typ: T, bits: usize) -> Result { + let vidpf = Vidpf::new(bits, typ.input_len() + 1)?; + let szk = Szk::new(typ); + Ok(Self { algorithm_id, szk, vidpf, bits, - } + }) } } @@ -286,29 +315,44 @@ where { fn shard_with_random( &self, - alpha: &VidpfInput, - weight: &T::Measurement, + ctx: &[u8], + (alpha, weight): &(VidpfInput, T::Measurement), nonce: &[u8; 16], vidpf_keys: [VidpfKey; 2], szk_random: [Seed; 2], joint_random_opt: Option>, ) -> Result<(::PublicShare, Vec<::InputShare>), VdafError> { + if alpha.len() != self.bits { + return Err(VdafError::Vidpf(VidpfError::InvalidInputLength)); + } + // The output with which we program the VIDPF is a counter and the encoded measurement. let mut beta = VidpfWeight(self.szk.typ.encode_measurement(weight)?); beta.0.insert(0, T::Field::one()); // Compute the measurement shares for each aggregator by generating VIDPF // keys for the measurement and evaluating each of them. - let public_share = self.vidpf.gen_with_keys(&vidpf_keys, alpha, &beta, nonce)?; + let public_share = self + .vidpf + .gen_with_keys(ctx, &vidpf_keys, alpha, &beta, nonce)?; - let leader_beta_share = - self.vidpf - .get_beta_share(VidpfServerId::S0, &public_share, &vidpf_keys[0], nonce)?; - let helper_beta_share = - self.vidpf - .get_beta_share(VidpfServerId::S1, &public_share, &vidpf_keys[1], nonce)?; + let leader_beta_share = self.vidpf.get_beta_share( + ctx, + VidpfServerId::S0, + &public_share, + &vidpf_keys[0], + nonce, + )?; + let helper_beta_share = self.vidpf.get_beta_share( + ctx, + VidpfServerId::S1, + &public_share, + &vidpf_keys[1], + nonce, + )?; let [leader_szk_proof_share, helper_szk_proof_share] = self.szk.prove( + ctx, &leader_beta_share.as_ref()[1..], &helper_beta_share.as_ref()[1..], &beta.as_ref()[1..], @@ -336,14 +380,10 @@ where { fn shard( &self, - _ctx: &[u8], - (input, weight): &(VidpfInput, T::Measurement), + ctx: &[u8], + measurement: &(VidpfInput, T::Measurement), nonce: &[u8; 16], ) -> Result<(Self::PublicShare, Vec), VdafError> { - if input.len() != self.bits { - return Err(VdafError::Vidpf(VidpfError::InvalidAttributeLength)); - } - let vidpf_keys = [VidpfKey::generate()?, VidpfKey::generate()?]; let joint_random_opt = if self.szk.requires_joint_rand() { Some(Seed::::generate()?) @@ -353,8 +393,8 @@ where let szk_random = [Seed::generate()?, Seed::generate()?]; self.shard_with_random( - input, - weight, + ctx, + measurement, nonce, vidpf_keys, szk_random, @@ -483,7 +523,7 @@ where fn prepare_init( &self, verify_key: &[u8; SEED_SIZE], - _ctx: &[u8], + ctx: &[u8], agg_id: usize, agg_param: &MasticAggregationParam, nonce: &[u8; NONCE_SIZE], @@ -507,6 +547,7 @@ where let mut prefix_tree = BinaryTree::default(); let out_shares = self.vidpf.eval_prefix_tree_with_siblings( + ctx, id, public_share, &input_share.vidpf_key, @@ -525,9 +566,8 @@ where // Traverse the prefix tree breadth-first. // - // TODO spec: Adjust the onehot proof computation accordingly so that we always - // traverse the left node then the right node. Currently we visit the on-path child - // then its sibling. + // TODO spec: Adjust the onehot and payload checks accordingly. For the onehot check, + // we need to make sure to always visit the left node before the right. let mut q = VecDeque::with_capacity(100); q.push_back(root.left.as_ref().unwrap()); q.push_back(root.right.as_ref().unwrap()); @@ -535,10 +575,7 @@ where // Update onehot proof. onehot_proof = xor_proof( onehot_proof, - &Vidpf::>::hash_proof(xor_proof( - onehot_proof, - &node.value.state.node_proof, - )), + &hash_proof(xor_proof(onehot_proof, &node.value.state.node_proof), ctx), ); // Update payload check. @@ -580,7 +617,8 @@ where }; let eval_proof = { - let mut eval_proof_xof = P::init(&[0; SEED_SIZE], &[]); + // TODO spec: Use a zero seed. + let mut eval_proof_xof = P::init(&[0; SEED_SIZE], &[&dst_usage(USAGE_EVAL_PROOF), ctx]); eval_proof_xof.update(&onehot_proof); eval_proof_xof.update(&payload_check); eval_proof_xof.update(&counter_check); @@ -591,8 +629,14 @@ where // Range check. let VidpfWeight(beta_share) = self.vidpf - .get_beta_share(id, public_share, &input_share.vidpf_key, nonce)?; + .get_beta_share(ctx, id, public_share, &input_share.vidpf_key, nonce)?; let (szk_query_share, szk_query_state) = self.szk.query( + ctx, + agg_param + .level_and_prefixes + .level() + .try_into() + .map_err(|_| VdafError::Vidpf(VidpfError::InvalidInputLength))?, &beta_share[1..], &input_share.proof_share, verify_key, @@ -630,7 +674,7 @@ where M: IntoIterator>, >( &self, - _ctx: &[u8], + ctx: &[u8], _agg_param: &MasticAggregationParam, inputs: M, ) -> Result, VdafError> { @@ -658,7 +702,7 @@ where // The SZK is only used once, during the first round of aggregation. (Some(leader_query_share), Some(helper_query_share)) => Ok(self .szk - .merge_query_shares(leader_query_share, helper_query_share)?), + .merge_query_shares(ctx, leader_query_share, helper_query_share)?), (None, None) => Ok(SzkJointShare::none()), (_, _) => Err(VdafError::Uncategorized( "Only one of leader and helper query shares is present".to_string(), @@ -749,6 +793,14 @@ where } } +fn hash_proof(mut proof: VidpfProof, ctx: &[u8]) -> VidpfProof { + let mut xof = + XofTurboShake128::from_seed_slice(&[], &[&dst_usage(USAGE_ONEHOT_PROOF_HASH), ctx]); + xof.update(&proof); + xof.into_seed_stream().fill_bytes(&mut proof); + proof +} + #[cfg(test)] mod tests { use super::*; @@ -766,7 +818,7 @@ mod tests { let algorithm_id = 6; let max_measurement = 29; let sum_typ = Sum::::new(max_measurement).unwrap(); - let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sum_typ, 32); + let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sum_typ, 32).unwrap(); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; @@ -847,7 +899,7 @@ mod tests { let algorithm_id = 6; let max_measurement = 29; let sum_typ = Sum::::new(max_measurement).unwrap(); - let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sum_typ, 32); + let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sum_typ, 32).unwrap(); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; @@ -900,7 +952,7 @@ mod tests { let algorithm_id = 6; let max_measurement = 29; let sum_typ = Sum::::new(max_measurement).unwrap(); - let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sum_typ, 32); + let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sum_typ, 32).unwrap(); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; @@ -923,7 +975,7 @@ mod tests { fn test_mastic_count() { let algorithm_id = 6; let count = Count::::new(); - let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, count, 32); + let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, count, 32).unwrap(); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; @@ -1002,7 +1054,7 @@ mod tests { fn test_public_share_encoded_len() { let algorithm_id = 6; let count = Count::::new(); - let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, count, 32); + let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, count, 32).unwrap(); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; @@ -1022,7 +1074,7 @@ mod tests { fn test_public_share_roundtrip_count() { let algorithm_id = 6; let count = Count::::new(); - let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, count, 32); + let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, count, 32).unwrap(); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; @@ -1044,7 +1096,7 @@ mod tests { let algorithm_id = 6; let sumvec = SumVec::>>::new(5, 3, 3).unwrap(); - let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sumvec, 32); + let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sumvec, 32).unwrap(); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; @@ -1134,7 +1186,7 @@ mod tests { let sumvec = SumVec::>>::new(5, 3, 3).unwrap(); let measurement = vec![1, 16, 0]; - let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sumvec, 32); + let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sumvec, 32).unwrap(); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; @@ -1165,7 +1217,7 @@ mod tests { let sumvec = SumVec::>>::new(5, 3, 3).unwrap(); let measurement = vec![1, 16, 0]; - let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sumvec, 32); + let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sumvec, 32).unwrap(); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; @@ -1198,7 +1250,7 @@ mod tests { let sumvec = SumVec::>>::new(5, 3, 3).unwrap(); let measurement = vec![1, 16, 0]; - let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sumvec, 32); + let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sumvec, 32).unwrap(); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; @@ -1223,7 +1275,7 @@ mod tests { let sumvec = SumVec::>>::new(5, 3, 3).unwrap(); let measurement = vec![1, 16, 0]; - let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sumvec, 32); + let mastic = Mastic::<_, XofTurboShake128, 32>::new(algorithm_id, sumvec, 32).unwrap(); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; diff --git a/src/vidpf.rs b/src/vidpf.rs index 500bb855..9693387f 100644 --- a/src/vidpf.rs +++ b/src/vidpf.rs @@ -26,28 +26,25 @@ use crate::{ codec::{CodecError, Decode, Encode, ParameterizedDecode}, field::FieldElement, idpf::{conditional_swap_seed, conditional_xor_seeds, xor_seeds, IdpfInput, IdpfValue}, - vdaf::xof::{Seed, Xof, XofFixedKeyAes128, XofTurboShake128}, + vdaf::{ + mastic, + xof::{Seed, Xof, XofFixedKeyAes128, XofTurboShake128}, + }, }; -pub(crate) const ONEHOT_PROOF_INIT: [u8; VIDPF_PROOF_SIZE] = [ - 186, 76, 128, 104, 116, 50, 149, 133, 2, 164, 82, 118, 128, 155, 163, 239, 117, 95, 162, 196, - 173, 31, 244, 180, 171, 86, 176, 209, 12, 221, 28, 204, -]; - /// VIDPF errors. #[derive(Debug, thiserror::Error)] #[non_exhaustive] pub enum VidpfError { /// Input is too long to be represented. - #[error("input too long")] - InputTooLong, + #[error("bit length too long")] + BitLengthTooLong, - /// Error when input attribute has too few or many bits to be a path in an initialized - /// VIDPF tree. - #[error("invalid attribute length")] - InvalidAttributeLength, + /// Error when an input has an unexpected bit length. + #[error("invalid input length")] + InvalidInputLength, - /// Error when weight's length mismatches the length in weight's parameter. + /// Error when a weight has an unexpected length. #[error("invalid weight length")] InvalidWeightLength, @@ -65,7 +62,7 @@ pub trait VidpfValue: IdpfValue + Clone + Debug + PartialEq + ConstantTimeEq {} #[derive(Clone, Debug)] /// An instance of the VIDPF. pub struct Vidpf { - /// Any parameters required to instantiate a weight value. + pub(crate) bits: u16, pub(crate) weight_parameter: W::ValueParameter, } @@ -74,9 +71,14 @@ impl Vidpf { /// /// # Arguments /// - /// * `weight_parameter`, any parameters required to instantiate a weight value. - pub const fn new(weight_parameter: W::ValueParameter) -> Self { - Self { weight_parameter } + /// * `bits`, the length of the input in bits. + /// * `weight_parameter`, the length of the weight in number of field elements. + pub fn new(bits: usize, weight_parameter: W::ValueParameter) -> Result { + let bits = u16::try_from(bits).map_err(|_| VidpfError::BitLengthTooLong)?; + Ok(Self { + bits, + weight_parameter, + }) } /// Splits an incremental point function `F` into two private keys @@ -99,18 +101,20 @@ impl Vidpf { /// APIs. pub fn gen( &self, + ctx: &[u8], input: &VidpfInput, weight: &W, nonce: &[u8], ) -> Result<(VidpfPublicShare, [VidpfKey; 2]), VidpfError> { let keys = [VidpfKey::generate()?, VidpfKey::generate()?]; - let public = self.gen_with_keys(&keys, input, weight, nonce)?; + let public = self.gen_with_keys(ctx, &keys, input, weight, nonce)?; Ok((public, keys)) } /// Produce the public share for the given keys, input, and weight. pub(crate) fn gen_with_keys( &self, + ctx: &[u8], keys: &[VidpfKey; 2], input: &VidpfInput, weight: &W, @@ -123,11 +127,14 @@ impl Vidpf { ]; let mut cw = Vec::with_capacity(input.len()); - for idx in input.index_iter()? { + for idx in self.index_iter(input)? { let bit = idx.bit; // Extend. - let e = [Self::extend(&seed[0], nonce), Self::extend(&seed[1], nonce)]; + let e = [ + Self::extend(seed[0], ctx, nonce), + Self::extend(seed[1], ctx, nonce), + ]; // Select the seed and control bit. let (seed_keep_0, seed_lose_0) = &mut (e[0].seed_right, e[0].seed_left); @@ -152,8 +159,8 @@ impl Vidpf { // Convert. let weight_0; let weight_1; - (seed[0], weight_0) = self.convert(seed_keep_0, nonce); - (seed[1], weight_1) = self.convert(seed_keep_1, nonce); + (seed[0], weight_0) = self.convert(seed_keep_0, ctx, nonce); + (seed[1], weight_1) = self.convert(seed_keep_1, ctx, nonce); ctrl[0] = ctrl_keep_0; ctrl[1] = ctrl_keep_1; @@ -162,7 +169,10 @@ impl Vidpf { cw_weight.conditional_negate(ctrl[1]); // Compute the correction word node proof. - let cw_proof = xor_proof(idx.node_proof(&seed[0]), &idx.node_proof(&seed[1])); + let cw_proof = xor_proof( + idx.node_proof(&seed[0], ctx), + &idx.node_proof(&seed[1], ctx), + ); cw.push(VidpfCorrectionWord { seed: cw_seed, @@ -181,6 +191,7 @@ impl Vidpf { /// root to the prefix. pub fn eval( &self, + ctx: &[u8], id: VidpfServerId, key: &VidpfKey, public: &VidpfPublicShare, @@ -195,12 +206,12 @@ impl Vidpf { }; if input.len() > public.cw.len() { - return Err(VidpfError::InvalidAttributeLength); + return Err(VidpfError::InvalidInputLength); } let mut hash = Sha3_256::new(); - for (idx, cw) in input.index_iter()?.zip(public.cw.iter()) { - r = self.eval_next(cw, idx, &r.state, nonce); + for (idx, cw) in self.index_iter(input)?.zip(public.cw.iter()) { + r = self.eval_next(ctx, cw, idx, &r.state, nonce); hash.update(r.state.node_proof); } @@ -213,6 +224,7 @@ impl Vidpf { /// state, and returns a new state and a share of the input's weight at that level. fn eval_next( &self, + ctx: &[u8], cw: &VidpfCorrectionWord, idx: VidpfEvalIndex<'_>, state: &VidpfEvalState, @@ -221,7 +233,7 @@ impl Vidpf { let bit = idx.bit; // Extend. - let e = Self::extend(&state.seed, nonce); + let e = Self::extend(state.seed, ctx, nonce); // Select the seed and control bit. let (seed_keep, seed_lose) = &mut (e.seed_right, e.seed_left); @@ -234,7 +246,7 @@ impl Vidpf { let next_ctrl = ctrl_keep ^ (state.control_bit & cw_ctrl_keep); // Convert and correct the payload. - let (next_seed, w) = self.convert(seed_keep, nonce); + let (next_seed, w) = self.convert(seed_keep, ctx, nonce); let mut weight = ::conditional_select( &::zero(&self.weight_parameter), &cw.weight, @@ -243,7 +255,8 @@ impl Vidpf { weight += w; // Compute and correct the node proof. - let node_proof = conditional_xor_proof(idx.node_proof(&next_seed), &cw.proof, next_ctrl); + let node_proof = + conditional_xor_proof(idx.node_proof(&next_seed, ctx), &cw.proof, next_ctrl); let next_state = VidpfEvalState { seed: next_seed, @@ -259,35 +272,39 @@ impl Vidpf { pub(crate) fn get_beta_share( &self, + ctx: &[u8], id: VidpfServerId, public: &VidpfPublicShare, key: &VidpfKey, nonce: &[u8], ) -> Result { - let cw = public.cw.first().ok_or(VidpfError::InputTooLong)?; + let cw = public.cw.first().ok_or(VidpfError::InvalidInputLength)?; let state = VidpfEvalState::init_from_key(id, key); let input_left = VidpfInput::from_bools(&[false]); - let idx_left = VidpfEvalIndex::try_from_input(&input_left)?; + let idx_left = self.index(&input_left)?; let VidpfEvalResult { state: _, share: mut weight_share_left, - } = self.eval_next(cw, idx_left, &state, nonce); + } = self.eval_next(ctx, cw, idx_left, &state, nonce); let VidpfEvalResult { state: _, share: mut weight_share_right, - } = self.eval_next(cw, idx_left.right_sibling(), &state, nonce); + } = self.eval_next(ctx, cw, idx_left.right_sibling(), &state, nonce); weight_share_left.conditional_negate(Choice::from(id)); weight_share_right.conditional_negate(Choice::from(id)); Ok(weight_share_left + weight_share_right) } - fn extend(seed: &VidpfSeed, nonce: &[u8]) -> ExtendedSeed { - let mut rng = - XofFixedKeyAes128::seed_stream(&Seed(*seed), &[VidpfDomainSepTag::PRG], &[nonce]); + fn extend(seed: VidpfSeed, ctx: &[u8], nonce: &[u8]) -> ExtendedSeed { + let mut rng = XofFixedKeyAes128::seed_stream( + &Seed(seed), + &[&mastic::dst_usage(mastic::USAGE_EXTEND), ctx], + &[nonce], + ); let mut seed_left = VidpfSeed::default(); let mut seed_right = VidpfSeed::default(); @@ -309,9 +326,12 @@ impl Vidpf { } } - fn convert(&self, seed: VidpfSeed, nonce: &[u8]) -> (VidpfSeed, W) { - let mut rng = - XofFixedKeyAes128::seed_stream(&Seed(seed), &[VidpfDomainSepTag::CONVERT], &[nonce]); + fn convert(&self, seed: VidpfSeed, ctx: &[u8], nonce: &[u8]) -> (VidpfSeed, W) { + let mut rng = XofFixedKeyAes128::seed_stream( + &Seed(seed), + &[&mastic::dst_usage(mastic::USAGE_CONVERT), ctx], + &[nonce], + ); let mut out_seed = VidpfSeed::default(); rng.fill_bytes(&mut out_seed); @@ -320,15 +340,36 @@ impl Vidpf { (out_seed, value) } - pub(crate) fn hash_proof(mut proof: VidpfProof) -> VidpfProof { - let mut rng = XofTurboShake128::seed_stream( - &Seed(Default::default()), - &[VidpfDomainSepTag::NODE_PROOF_ADJUST], - &[&proof], - ); - rng.fill_bytes(&mut proof); + fn index_iter<'a>( + &'a self, + input: &'a VidpfInput, + ) -> Result>, VidpfError> { + let n = u16::try_from(input.len()).map_err(|_| VidpfError::InvalidInputLength)?; + if n > self.bits { + return Err(VidpfError::InvalidInputLength); + } + Ok(Box::new((0..n).zip(input.iter()).map( + move |(level, bit)| VidpfEvalIndex { + bit: Choice::from(u8::from(bit)), + input, + level, + bits: self.bits, + }, + ))) + } - proof + fn index<'a>(&self, input: &'a VidpfInput) -> Result, VidpfError> { + let level = u16::try_from(input.len()).map_err(|_| VidpfError::InvalidInputLength)? - 1; + if level >= self.bits { + return Err(VidpfError::InvalidInputLength); + } + let bit = Choice::from(u8::from(input.get(usize::from(level)).unwrap())); + Ok(VidpfEvalIndex { + bit, + input, + level, + bits: self.bits, + }) } } @@ -336,8 +377,10 @@ impl Vidpf> { /// Ensure `prefix_tree` contains the prefix tree for `prefixes`, as well as the sibling of /// each node in the prefix tree. The return value is the weights for the prefixes /// concatenated together. + #[allow(clippy::too_many_arguments)] pub(crate) fn eval_prefix_tree_with_siblings( &self, + ctx: &[u8], id: VidpfServerId, public: &VidpfPublicShare>, key: &VidpfKey, @@ -349,7 +392,7 @@ impl Vidpf> { for prefix in prefixes { if prefix.len() > public.cw.len() { - return Err(VidpfError::InvalidAttributeLength); + return Err(VidpfError::InvalidInputLength); } let mut sub_tree = prefix_tree.root.get_or_insert_with(|| { @@ -359,9 +402,10 @@ impl Vidpf> { })) }); - for (idx, cw) in prefix.index_iter()?.zip(public.cw.iter()) { + for (idx, cw) in self.index_iter(prefix)?.zip(public.cw.iter()) { let left = sub_tree.left.get_or_insert_with(|| { Box::new(Node::new(self.eval_next( + ctx, cw, idx.left_sibling(), &sub_tree.value.state, @@ -370,6 +414,7 @@ impl Vidpf> { }); let right = sub_tree.right.get_or_insert_with(|| { Box::new(Node::new(self.eval_next( + ctx, cw, idx.right_sibling(), &sub_tree.value.state, @@ -396,17 +441,6 @@ impl Vidpf> { } } -/// VIDPF domain separation tag. -/// -/// Contains the domain separation tags for invoking different oracles. -struct VidpfDomainSepTag; -impl VidpfDomainSepTag { - const PRG: &'static [u8] = b"Prg"; - const CONVERT: &'static [u8] = b"Convert"; - const NODE_PROOF: &'static [u8] = b"NodeProof"; - const NODE_PROOF_ADJUST: &'static [u8] = b"NodeProofAdjust"; -} - /// VIDPF key. /// /// Private key of an aggregation server. @@ -574,7 +608,7 @@ const VIDPF_PROOF_SIZE: usize = 32; const VIDPF_SEED_SIZE: usize = 16; /// Allows to validate user input and shares after evaluation. -type VidpfProof = [u8; VIDPF_PROOF_SIZE]; +pub(crate) type VidpfProof = [u8; VIDPF_PROOF_SIZE]; pub(crate) fn xor_proof(mut lhs: VidpfProof, rhs: &VidpfProof) -> VidpfProof { zip(&mut lhs, rhs).for_each(|(a, b)| a.bitxor_assign(b)); @@ -738,20 +772,16 @@ struct VidpfEvalIndex<'a> { bit: Choice, input: &'a VidpfInput, level: u16, + bits: u16, } -impl<'a> VidpfEvalIndex<'a> { - fn try_from_input(input: &'a VidpfInput) -> Result { - let level = u16::try_from(input.len()).map_err(|_| VidpfError::InputTooLong)? - 1; - let bit = Choice::from(u8::from(input.get(usize::from(level)).unwrap())); - Ok(Self { bit, input, level }) - } - +impl VidpfEvalIndex<'_> { fn left_sibling(&self) -> Self { Self { bit: Choice::from(0), input: self.input, level: self.level, + bits: self.bits, } } @@ -760,12 +790,16 @@ impl<'a> VidpfEvalIndex<'a> { bit: Choice::from(1), input: self.input, level: self.level, + bits: self.bits, } } - fn node_proof(&self, seed: &VidpfSeed) -> VidpfProof { - let mut xof = - XofTurboShake128::from_seed_slice(&seed[..], &[VidpfDomainSepTag::NODE_PROOF]); + fn node_proof(&self, seed: &VidpfSeed, ctx: &[u8]) -> VidpfProof { + let mut xof = XofTurboShake128::from_seed_slice( + &seed[..], + &[&mastic::dst_usage(mastic::USAGE_NODE_PROOF), ctx], + ); + xof.update(&self.bits.to_le_bytes()); xof.update(&self.level.to_le_bytes()); for byte in self @@ -791,17 +825,6 @@ impl<'a> VidpfEvalIndex<'a> { } } -impl VidpfInput { - fn index_iter(&self) -> Result>, VidpfError> { - let n = u16::try_from(self.len()).map_err(|_| VidpfError::InputTooLong)?; - Ok((0..n).zip(self.iter()).map(|(level, bit)| VidpfEvalIndex { - bit: Choice::from(u8::from(bit)), - input: self, - level, - })) - } -} - #[cfg(test)] mod tests { @@ -825,9 +848,10 @@ mod tests { #[test] fn roundtrip_codec() { + let ctx = b"appliction context"; let input = VidpfInput::from_bytes(&[0xFF]); let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]); - let (_, public, _, _) = vidpf_gen_setup(&input, &weight); + let (_, public, _, _) = vidpf_gen_setup(ctx, &input, &weight); let bytes = public.get_encoded().unwrap(); assert_eq!(public.encoded_len().unwrap(), bytes.len()); @@ -841,6 +865,7 @@ mod tests { } fn vidpf_gen_setup( + ctx: &[u8], input: &VidpfInput, weight: &TestWeight, ) -> ( @@ -849,22 +874,23 @@ mod tests { [VidpfKey; 2], [u8; TEST_NONCE_SIZE], ) { - let vidpf = Vidpf::new(TEST_WEIGHT_LEN); - let (public, keys) = vidpf.gen(input, weight, TEST_NONCE).unwrap(); + let vidpf = Vidpf::new(input.len(), TEST_WEIGHT_LEN).unwrap(); + let (public, keys) = vidpf.gen(ctx, input, weight, TEST_NONCE).unwrap(); (vidpf, public, keys, *TEST_NONCE) } #[test] fn correctness_at_last_level() { + let ctx = b"some application"; let input = VidpfInput::from_bytes(&[0xFF]); let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]); - let (vidpf, public, [key_0, key_1], nonce) = vidpf_gen_setup(&input, &weight); + let (vidpf, public, [key_0, key_1], nonce) = vidpf_gen_setup(ctx, &input, &weight); let (value_share_0, onehot_proof_0) = vidpf - .eval(VidpfServerId::S0, &key_0, &public, &input, &nonce) + .eval(ctx, VidpfServerId::S0, &key_0, &public, &input, &nonce) .unwrap(); let (value_share_1, onehot_proof_1) = vidpf - .eval(VidpfServerId::S1, &key_1, &public, &input, &nonce) + .eval(ctx, VidpfServerId::S1, &key_1, &public, &input, &nonce) .unwrap(); assert_eq!( @@ -878,10 +904,10 @@ mod tests { let bad_input = VidpfInput::from_bytes(&[0x00]); let zero = TestWeight::zero(&TEST_WEIGHT_LEN); let (value_share_0, onehot_proof_0) = vidpf - .eval(VidpfServerId::S0, &key_0, &public, &bad_input, &nonce) + .eval(ctx, VidpfServerId::S0, &key_0, &public, &bad_input, &nonce) .unwrap(); let (value_share_1, onehot_proof_1) = vidpf - .eval(VidpfServerId::S1, &key_1, &public, &bad_input, &nonce) + .eval(ctx, VidpfServerId::S1, &key_1, &public, &bad_input, &nonce) .unwrap(); assert_eq!( @@ -895,20 +921,22 @@ mod tests { #[test] fn correctness_at_each_level() { + let ctx = b"application context"; let input = VidpfInput::from_bytes(&[0xFF]); let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]); - let (vidpf, public, keys, nonce) = vidpf_gen_setup(&input, &weight); + let (vidpf, public, keys, nonce) = vidpf_gen_setup(ctx, &input, &weight); - assert_eval_at_each_level(&vidpf, &keys, &public, &input, &weight, &nonce); + assert_eval_at_each_level(&vidpf, ctx, &keys, &public, &input, &weight, &nonce); let bad_input = VidpfInput::from_bytes(&[0x00]); let zero = TestWeight::zero(&TEST_WEIGHT_LEN); - assert_eval_at_each_level(&vidpf, &keys, &public, &bad_input, &zero, &nonce); + assert_eval_at_each_level(&vidpf, ctx, &keys, &public, &bad_input, &zero, &nonce); } fn assert_eval_at_each_level( vidpf: &Vidpf, + ctx: &[u8], [key_0, key_1]: &[VidpfKey; 2], public: &VidpfPublicShare, input: &VidpfInput, @@ -918,9 +946,9 @@ mod tests { let mut state_0 = VidpfEvalState::init_from_key(VidpfServerId::S0, key_0); let mut state_1 = VidpfEvalState::init_from_key(VidpfServerId::S1, key_1); - for (idx, cw) in input.index_iter().unwrap().zip(public.cw.iter()) { - let r0 = vidpf.eval_next(cw, idx, &state_0, nonce); - let r1 = vidpf.eval_next(cw, idx, &state_1, nonce); + for (idx, cw) in vidpf.index_iter(input).unwrap().zip(public.cw.iter()) { + let r0 = vidpf.eval_next(ctx, cw, idx, &state_0, nonce); + let r1 = vidpf.eval_next(ctx, cw, idx, &state_1, nonce); assert_eq!( r0.share - r1.share,