diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs index 8d5cfde7..218a4458 100644 --- a/src/vdaf/mastic.rs +++ b/src/vdaf/mastic.rs @@ -301,20 +301,12 @@ where // keys for the measurement and evaluating each of them. let public_share = self.vidpf.gen_with_keys(&vidpf_keys, alpha, &beta, nonce)?; - let leader_beta_share = self.vidpf.eval_root( - VidpfServerId::S0, - &vidpf_keys[0], - &public_share, - &mut BinaryTree::default(), - nonce, - )?; - let helper_beta_share = self.vidpf.eval_root( - VidpfServerId::S1, - &vidpf_keys[1], - &public_share, - &mut BinaryTree::default(), - nonce, - )?; + let leader_beta_share = + self.vidpf + .get_beta_share(VidpfServerId::S0, &public_share, &vidpf_keys[0], nonce)?; + let helper_beta_share = + self.vidpf + .get_beta_share(VidpfServerId::S1, &public_share, &vidpf_keys[1], nonce)?; let [leader_szk_proof_share, helper_szk_proof_share] = self.szk.prove( &leader_beta_share.as_ref()[1..], @@ -393,7 +385,7 @@ pub struct MasticPrepareState { #[derive(Clone, Debug)] pub struct MasticPrepareShare { /// [`Vidpf`] evaluation proof, which guarantees one-hotness and payload consistency. - vidpf_proof: Seed, + eval_proof: Seed, /// If [`Szk`]` verification of the root weight is needed, a verification message. szk_query_share_opt: Option>, @@ -401,7 +393,7 @@ pub struct MasticPrepareShare { impl Encode for MasticPrepareShare { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { - self.vidpf_proof.encode(bytes)?; + self.eval_proof.encode(bytes)?; match &self.szk_query_share_opt { Some(query_share) => query_share.encode(bytes), None => Ok(()), @@ -410,7 +402,7 @@ impl Encode for MasticPrepareShare Option { Some( - self.vidpf_proof.encoded_len()? + self.eval_proof.encoded_len()? + match &self.szk_query_share_opt { Some(query_share) => query_share.encoded_len()?, None => 0, @@ -426,7 +418,7 @@ impl ParameterizedDecode, bytes: &mut Cursor<&[u8]>, ) -> Result { - let vidpf_proof = Seed::decode(bytes)?; + let eval_proof = Seed::decode(bytes)?; let requires_joint_rand = prep_state.szk_query_state.is_some(); let szk_query_share_opt = prep_state .verifier_len @@ -438,7 +430,7 @@ impl ParameterizedDecode>(); - if id == VidpfServerId::S1 { - for b in beta_share.iter_mut() { - *b = -*b; - } - } - beta_share - }; - // Range check. + let VidpfWeight(beta_share) = + self.vidpf + .get_beta_share(id, public_share, &input_share.vidpf_key, nonce)?; let (szk_query_share, szk_query_state) = self.szk.query( &beta_share[1..], &input_share.proof_share, @@ -628,7 +607,7 @@ where verifier_len: Some(verifier_len), }, MasticPrepareShare { - vidpf_proof: eval_proof, + eval_proof, szk_query_share_opt: Some(szk_query_share), }, ) @@ -640,7 +619,7 @@ where verifier_len: None, }, MasticPrepareShare { - vidpf_proof: eval_proof, + eval_proof, szk_query_share_opt: None, }, ) @@ -667,7 +646,7 @@ where "Received more than two prepare shares".to_string(), )); }; - if leader_share.vidpf_proof != helper_share.vidpf_proof { + if leader_share.eval_proof != helper_share.eval_proof { return Err(VdafError::Uncategorized( "Vidpf proof verification failed".to_string(), )); diff --git a/src/vidpf.rs b/src/vidpf.rs index 0ec16994..7ffc5165 100644 --- a/src/vidpf.rs +++ b/src/vidpf.rs @@ -176,7 +176,9 @@ impl Vidpf { Ok(VidpfPublicShare { cw }) } - /// Evaluate a given VIDPF (comprised of the key and public share) at a given input. + /// Evaluate a given VIDPF (comprised of the key and public share) at a given prefix. Return + /// the weight for that prefix along with a hash of the node proofs along the path from the + /// root to the prefix. pub fn eval( &self, id: VidpfServerId, @@ -185,6 +187,8 @@ impl Vidpf { input: &VidpfInput, nonce: &[u8], ) -> Result<(W, VidpfProof), VidpfError> { + use sha3::{Digest, Sha3_256}; + let mut r = VidpfEvalResult { state: VidpfEvalState::init_from_key(id, key), share: W::zero(&self.weight_parameter), // not used @@ -194,73 +198,15 @@ impl Vidpf { return Err(VidpfError::InvalidAttributeLength); } - let mut onehot_proof = ONEHOT_PROOF_INIT; + let mut hash = Sha3_256::new(); for (idx, cw) in input.index_iter()?.zip(public.cw.iter()) { r = self.eval_next(cw, idx, &r.state, nonce); - onehot_proof = xor_proof( - onehot_proof, - &Self::hash_proof(xor_proof(onehot_proof, &r.state.node_proof)), - ); + hash.update(r.state.node_proof); } let mut weight = r.share; weight.conditional_negate(Choice::from(id)); - Ok((weight, onehot_proof)) - } - - /// Evaluates the entire `input` and produces a share of the - /// input's weight. It reuses computation from previous levels available in the - /// cache. - pub(crate) fn eval_with_cache( - &self, - id: VidpfServerId, - key: &VidpfKey, - public: &VidpfPublicShare, - input: &VidpfInput, - cache_tree: &mut BinaryTree>, - nonce: &[u8], - ) -> Result<(W, VidpfProof), VidpfError> { - if input.len() > public.cw.len() { - return Err(VidpfError::InvalidAttributeLength); - } - - let mut sub_tree = cache_tree.root.get_or_insert_with(|| { - Box::new(Node::new(VidpfEvalResult { - state: VidpfEvalState::init_from_key(id, key), - share: W::zero(&self.weight_parameter), // not used - })) - }); - - let mut onehot_proof = ONEHOT_PROOF_INIT; - for (idx, cw) in input.index_iter()?.zip(public.cw.iter()) { - sub_tree = if idx.bit.unwrap_u8() == 0 { - sub_tree.left.get_or_insert_with(|| { - Box::new(Node::new(self.eval_next( - cw, - idx, - &sub_tree.value.state, - nonce, - ))) - }) - } else { - sub_tree.right.get_or_insert_with(|| { - Box::new(Node::new(self.eval_next( - cw, - idx, - &sub_tree.value.state, - nonce, - ))) - }) - }; - onehot_proof = xor_proof( - onehot_proof, - &Self::hash_proof(xor_proof(onehot_proof, &sub_tree.value.state.node_proof)), - ); - } - - let mut weight = sub_tree.value.to_share(); - weight.conditional_negate(Choice::from(id)); - Ok((weight, onehot_proof)) + Ok((weight, hash.finalize().into())) } /// Evaluates the `input` at the given level using the provided initial @@ -311,32 +257,31 @@ impl Vidpf { } } - pub(crate) fn eval_root( + pub(crate) fn get_beta_share( &self, id: VidpfServerId, + public: &VidpfPublicShare, key: &VidpfKey, - public_share: &VidpfPublicShare, - cache_tree: &mut BinaryTree>, nonce: &[u8], ) -> Result { - let (weight_share_left, _onehot_proof_left) = self.eval_with_cache( - id, - key, - public_share, - &VidpfInput::from_bools(&[false]), - cache_tree, - nonce, - )?; - - let (weight_share_right, _onehot_proof_right) = self.eval_with_cache( - id, - key, - public_share, - &VidpfInput::from_bools(&[true]), - cache_tree, - nonce, - )?; + let cw = public.cw.first().ok_or(VidpfError::InputTooLong)?; + let state = VidpfEvalState::init_from_key(id, key); + let input_left = VidpfInput::from_bools(&[false]); + let idx_left = VidpfEvalIndex::try_from_input(&input_left)?; + + let VidpfEvalResult { + state: _, + share: mut weight_share_left, + } = self.eval_next(cw, idx_left, &state, nonce); + + let VidpfEvalResult { + state: _, + share: mut weight_share_right, + } = self.eval_next(cw, idx_left.right_sibling(), &state, nonce); + + weight_share_left.conditional_negate(Choice::from(id)); + weight_share_right.conditional_negate(Choice::from(id)); Ok(weight_share_left + weight_share_right) } @@ -624,12 +569,6 @@ pub(crate) struct VidpfEvalResult { pub(crate) share: W, } -impl VidpfEvalResult { - fn to_share(&self) -> W { - self.share.clone() - } -} - const VIDPF_PROOF_SIZE: usize = 32; const VIDPF_SEED_SIZE: usize = 16; @@ -800,8 +739,14 @@ struct VidpfEvalIndex<'a> { level: u16, } -impl VidpfEvalIndex<'_> { - pub(crate) fn left_sibling(&self) -> Self { +impl<'a> VidpfEvalIndex<'a> { + fn try_from_input(input: &'a VidpfInput) -> Result { + let level = u16::try_from(input.len()).map_err(|_| VidpfError::InputTooLong)? - 1; + let bit = Choice::from(u8::from(input.get(usize::from(level)).unwrap())); + Ok(Self { bit, input, level }) + } + + fn left_sibling(&self) -> Self { Self { bit: Choice::from(0), input: self.input, @@ -809,7 +754,7 @@ impl VidpfEvalIndex<'_> { } } - pub(crate) fn right_sibling(&self) -> Self { + fn right_sibling(&self) -> Self { Self { bit: Choice::from(1), input: self.input, @@ -869,13 +814,9 @@ mod tests { mod vidpf { use crate::{ - bt::BinaryTree, codec::{Encode, ParameterizedDecode}, idpf::IdpfValue, - vidpf::{ - Vidpf, VidpfEvalResult, VidpfEvalState, VidpfInput, VidpfKey, VidpfPublicShare, - VidpfServerId, - }, + vidpf::{Vidpf, VidpfEvalState, VidpfInput, VidpfKey, VidpfPublicShare, VidpfServerId}, }; use super::{TestWeight, TEST_NONCE, TEST_NONCE_SIZE, TEST_WEIGHT_LEN}; @@ -996,94 +937,6 @@ mod tests { state_1 = r1.state; } } - - #[test] - fn caching_at_each_level() { - let input = VidpfInput::from_bytes(&[0xFF]); - let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]); - let (vidpf, public, keys, nonce) = vidpf_gen_setup(&input, &weight); - - test_equivalence_of_eval_with_caching(&vidpf, &keys, &public, &input, &nonce); - } - - /// Ensures that VIDPF outputs match regardless of whether the path to - /// each node is recomputed or cached during evaluation. - fn test_equivalence_of_eval_with_caching( - vidpf: &Vidpf, - [key_0, key_1]: &[VidpfKey; 2], - public: &VidpfPublicShare, - input: &VidpfInput, - nonce: &[u8], - ) { - let mut cache_tree_0 = BinaryTree::>::default(); - let mut cache_tree_1 = BinaryTree::>::default(); - - let n = input.len(); - for level in 0..n { - let val_share_0 = vidpf - .eval( - VidpfServerId::S0, - key_0, - public, - &input.prefix(level), - nonce, - ) - .unwrap(); - let val_share_1 = vidpf - .eval( - VidpfServerId::S1, - key_1, - public, - &input.prefix(level), - nonce, - ) - .unwrap(); - let val_share_0_cached = vidpf - .eval_with_cache( - VidpfServerId::S0, - key_0, - public, - &input.prefix(level), - &mut cache_tree_0, - nonce, - ) - .unwrap(); - let val_share_1_cached = vidpf - .eval_with_cache( - VidpfServerId::S1, - key_1, - public, - &input.prefix(level), - &mut cache_tree_1, - nonce, - ) - .unwrap(); - - assert_eq!( - val_share_0, val_share_0_cached, - "shares must be computed equally with or without caching: {:?}", - level - ); - - assert_eq!( - val_share_1, val_share_1_cached, - "shares must be computed equally with or without caching: {:?}", - level - ); - - assert_eq!( - val_share_0, val_share_0_cached, - "proofs must be equal with or without caching: {:?}", - level - ); - - assert_eq!( - val_share_1, val_share_1_cached, - "proofs must be equal with or without caching: {:?}", - level - ); - } - } } mod weight {