From da251428b89aa166aff3bcf1e541b7da8ada4276 Mon Sep 17 00:00:00 2001 From: Christopher Patton Date: Sun, 29 Dec 2024 12:46:35 -0500 Subject: [PATCH] vidpf: Remove `eval_with_cache()`, modify `eval()` Mastic uses `eval_prefix_tree_with_siblings()`. This method caches the prefix tree just like `eval_with_cache()` does, but it doesn't try to compute the onehot proof. It also concatenates the weight shares into the output shares for us. The only other use case for `eval_with_cache()` is for computing the shares of beta during sharding. Replace this code with a simpler implementation and remove `eval_with_cache()`. Finally, `eval()` can't be used to correctly compute the onehot check for Mastic. Instead, simply hash the node proofs together so that the user can check that the DPF invariant holds. This is useful primarily for testing. --- src/vdaf/mastic.rs | 55 ++++------- src/vidpf.rs | 221 ++++++++------------------------------------- 2 files changed, 54 insertions(+), 222 deletions(-) diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs index 8d5cfde7..218a4458 100644 --- a/src/vdaf/mastic.rs +++ b/src/vdaf/mastic.rs @@ -301,20 +301,12 @@ where // keys for the measurement and evaluating each of them. let public_share = self.vidpf.gen_with_keys(&vidpf_keys, alpha, &beta, nonce)?; - let leader_beta_share = self.vidpf.eval_root( - VidpfServerId::S0, - &vidpf_keys[0], - &public_share, - &mut BinaryTree::default(), - nonce, - )?; - let helper_beta_share = self.vidpf.eval_root( - VidpfServerId::S1, - &vidpf_keys[1], - &public_share, - &mut BinaryTree::default(), - nonce, - )?; + let leader_beta_share = + self.vidpf + .get_beta_share(VidpfServerId::S0, &public_share, &vidpf_keys[0], nonce)?; + let helper_beta_share = + self.vidpf + .get_beta_share(VidpfServerId::S1, &public_share, &vidpf_keys[1], nonce)?; let [leader_szk_proof_share, helper_szk_proof_share] = self.szk.prove( &leader_beta_share.as_ref()[1..], @@ -393,7 +385,7 @@ pub struct MasticPrepareState { #[derive(Clone, Debug)] pub struct MasticPrepareShare { /// [`Vidpf`] evaluation proof, which guarantees one-hotness and payload consistency. - vidpf_proof: Seed, + eval_proof: Seed, /// If [`Szk`]` verification of the root weight is needed, a verification message. szk_query_share_opt: Option>, @@ -401,7 +393,7 @@ pub struct MasticPrepareShare { impl Encode for MasticPrepareShare { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { - self.vidpf_proof.encode(bytes)?; + self.eval_proof.encode(bytes)?; match &self.szk_query_share_opt { Some(query_share) => query_share.encode(bytes), None => Ok(()), @@ -410,7 +402,7 @@ impl Encode for MasticPrepareShare Option { Some( - self.vidpf_proof.encoded_len()? + self.eval_proof.encoded_len()? + match &self.szk_query_share_opt { Some(query_share) => query_share.encoded_len()?, None => 0, @@ -426,7 +418,7 @@ impl ParameterizedDecode, bytes: &mut Cursor<&[u8]>, ) -> Result { - let vidpf_proof = Seed::decode(bytes)?; + let eval_proof = Seed::decode(bytes)?; let requires_joint_rand = prep_state.szk_query_state.is_some(); let szk_query_share_opt = prep_state .verifier_len @@ -438,7 +430,7 @@ impl ParameterizedDecode>(); - if id == VidpfServerId::S1 { - for b in beta_share.iter_mut() { - *b = -*b; - } - } - beta_share - }; - // Range check. + let VidpfWeight(beta_share) = + self.vidpf + .get_beta_share(id, public_share, &input_share.vidpf_key, nonce)?; let (szk_query_share, szk_query_state) = self.szk.query( &beta_share[1..], &input_share.proof_share, @@ -628,7 +607,7 @@ where verifier_len: Some(verifier_len), }, MasticPrepareShare { - vidpf_proof: eval_proof, + eval_proof, szk_query_share_opt: Some(szk_query_share), }, ) @@ -640,7 +619,7 @@ where verifier_len: None, }, MasticPrepareShare { - vidpf_proof: eval_proof, + eval_proof, szk_query_share_opt: None, }, ) @@ -667,7 +646,7 @@ where "Received more than two prepare shares".to_string(), )); }; - if leader_share.vidpf_proof != helper_share.vidpf_proof { + if leader_share.eval_proof != helper_share.eval_proof { return Err(VdafError::Uncategorized( "Vidpf proof verification failed".to_string(), )); diff --git a/src/vidpf.rs b/src/vidpf.rs index 0ec16994..7ffc5165 100644 --- a/src/vidpf.rs +++ b/src/vidpf.rs @@ -176,7 +176,9 @@ impl Vidpf { Ok(VidpfPublicShare { cw }) } - /// Evaluate a given VIDPF (comprised of the key and public share) at a given input. + /// Evaluate a given VIDPF (comprised of the key and public share) at a given prefix. Return + /// the weight for that prefix along with a hash of the node proofs along the path from the + /// root to the prefix. pub fn eval( &self, id: VidpfServerId, @@ -185,6 +187,8 @@ impl Vidpf { input: &VidpfInput, nonce: &[u8], ) -> Result<(W, VidpfProof), VidpfError> { + use sha3::{Digest, Sha3_256}; + let mut r = VidpfEvalResult { state: VidpfEvalState::init_from_key(id, key), share: W::zero(&self.weight_parameter), // not used @@ -194,73 +198,15 @@ impl Vidpf { return Err(VidpfError::InvalidAttributeLength); } - let mut onehot_proof = ONEHOT_PROOF_INIT; + let mut hash = Sha3_256::new(); for (idx, cw) in input.index_iter()?.zip(public.cw.iter()) { r = self.eval_next(cw, idx, &r.state, nonce); - onehot_proof = xor_proof( - onehot_proof, - &Self::hash_proof(xor_proof(onehot_proof, &r.state.node_proof)), - ); + hash.update(r.state.node_proof); } let mut weight = r.share; weight.conditional_negate(Choice::from(id)); - Ok((weight, onehot_proof)) - } - - /// Evaluates the entire `input` and produces a share of the - /// input's weight. It reuses computation from previous levels available in the - /// cache. - pub(crate) fn eval_with_cache( - &self, - id: VidpfServerId, - key: &VidpfKey, - public: &VidpfPublicShare, - input: &VidpfInput, - cache_tree: &mut BinaryTree>, - nonce: &[u8], - ) -> Result<(W, VidpfProof), VidpfError> { - if input.len() > public.cw.len() { - return Err(VidpfError::InvalidAttributeLength); - } - - let mut sub_tree = cache_tree.root.get_or_insert_with(|| { - Box::new(Node::new(VidpfEvalResult { - state: VidpfEvalState::init_from_key(id, key), - share: W::zero(&self.weight_parameter), // not used - })) - }); - - let mut onehot_proof = ONEHOT_PROOF_INIT; - for (idx, cw) in input.index_iter()?.zip(public.cw.iter()) { - sub_tree = if idx.bit.unwrap_u8() == 0 { - sub_tree.left.get_or_insert_with(|| { - Box::new(Node::new(self.eval_next( - cw, - idx, - &sub_tree.value.state, - nonce, - ))) - }) - } else { - sub_tree.right.get_or_insert_with(|| { - Box::new(Node::new(self.eval_next( - cw, - idx, - &sub_tree.value.state, - nonce, - ))) - }) - }; - onehot_proof = xor_proof( - onehot_proof, - &Self::hash_proof(xor_proof(onehot_proof, &sub_tree.value.state.node_proof)), - ); - } - - let mut weight = sub_tree.value.to_share(); - weight.conditional_negate(Choice::from(id)); - Ok((weight, onehot_proof)) + Ok((weight, hash.finalize().into())) } /// Evaluates the `input` at the given level using the provided initial @@ -311,32 +257,31 @@ impl Vidpf { } } - pub(crate) fn eval_root( + pub(crate) fn get_beta_share( &self, id: VidpfServerId, + public: &VidpfPublicShare, key: &VidpfKey, - public_share: &VidpfPublicShare, - cache_tree: &mut BinaryTree>, nonce: &[u8], ) -> Result { - let (weight_share_left, _onehot_proof_left) = self.eval_with_cache( - id, - key, - public_share, - &VidpfInput::from_bools(&[false]), - cache_tree, - nonce, - )?; - - let (weight_share_right, _onehot_proof_right) = self.eval_with_cache( - id, - key, - public_share, - &VidpfInput::from_bools(&[true]), - cache_tree, - nonce, - )?; + let cw = public.cw.first().ok_or(VidpfError::InputTooLong)?; + let state = VidpfEvalState::init_from_key(id, key); + let input_left = VidpfInput::from_bools(&[false]); + let idx_left = VidpfEvalIndex::try_from_input(&input_left)?; + + let VidpfEvalResult { + state: _, + share: mut weight_share_left, + } = self.eval_next(cw, idx_left, &state, nonce); + + let VidpfEvalResult { + state: _, + share: mut weight_share_right, + } = self.eval_next(cw, idx_left.right_sibling(), &state, nonce); + + weight_share_left.conditional_negate(Choice::from(id)); + weight_share_right.conditional_negate(Choice::from(id)); Ok(weight_share_left + weight_share_right) } @@ -624,12 +569,6 @@ pub(crate) struct VidpfEvalResult { pub(crate) share: W, } -impl VidpfEvalResult { - fn to_share(&self) -> W { - self.share.clone() - } -} - const VIDPF_PROOF_SIZE: usize = 32; const VIDPF_SEED_SIZE: usize = 16; @@ -800,8 +739,14 @@ struct VidpfEvalIndex<'a> { level: u16, } -impl VidpfEvalIndex<'_> { - pub(crate) fn left_sibling(&self) -> Self { +impl<'a> VidpfEvalIndex<'a> { + fn try_from_input(input: &'a VidpfInput) -> Result { + let level = u16::try_from(input.len()).map_err(|_| VidpfError::InputTooLong)? - 1; + let bit = Choice::from(u8::from(input.get(usize::from(level)).unwrap())); + Ok(Self { bit, input, level }) + } + + fn left_sibling(&self) -> Self { Self { bit: Choice::from(0), input: self.input, @@ -809,7 +754,7 @@ impl VidpfEvalIndex<'_> { } } - pub(crate) fn right_sibling(&self) -> Self { + fn right_sibling(&self) -> Self { Self { bit: Choice::from(1), input: self.input, @@ -869,13 +814,9 @@ mod tests { mod vidpf { use crate::{ - bt::BinaryTree, codec::{Encode, ParameterizedDecode}, idpf::IdpfValue, - vidpf::{ - Vidpf, VidpfEvalResult, VidpfEvalState, VidpfInput, VidpfKey, VidpfPublicShare, - VidpfServerId, - }, + vidpf::{Vidpf, VidpfEvalState, VidpfInput, VidpfKey, VidpfPublicShare, VidpfServerId}, }; use super::{TestWeight, TEST_NONCE, TEST_NONCE_SIZE, TEST_WEIGHT_LEN}; @@ -996,94 +937,6 @@ mod tests { state_1 = r1.state; } } - - #[test] - fn caching_at_each_level() { - let input = VidpfInput::from_bytes(&[0xFF]); - let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]); - let (vidpf, public, keys, nonce) = vidpf_gen_setup(&input, &weight); - - test_equivalence_of_eval_with_caching(&vidpf, &keys, &public, &input, &nonce); - } - - /// Ensures that VIDPF outputs match regardless of whether the path to - /// each node is recomputed or cached during evaluation. - fn test_equivalence_of_eval_with_caching( - vidpf: &Vidpf, - [key_0, key_1]: &[VidpfKey; 2], - public: &VidpfPublicShare, - input: &VidpfInput, - nonce: &[u8], - ) { - let mut cache_tree_0 = BinaryTree::>::default(); - let mut cache_tree_1 = BinaryTree::>::default(); - - let n = input.len(); - for level in 0..n { - let val_share_0 = vidpf - .eval( - VidpfServerId::S0, - key_0, - public, - &input.prefix(level), - nonce, - ) - .unwrap(); - let val_share_1 = vidpf - .eval( - VidpfServerId::S1, - key_1, - public, - &input.prefix(level), - nonce, - ) - .unwrap(); - let val_share_0_cached = vidpf - .eval_with_cache( - VidpfServerId::S0, - key_0, - public, - &input.prefix(level), - &mut cache_tree_0, - nonce, - ) - .unwrap(); - let val_share_1_cached = vidpf - .eval_with_cache( - VidpfServerId::S1, - key_1, - public, - &input.prefix(level), - &mut cache_tree_1, - nonce, - ) - .unwrap(); - - assert_eq!( - val_share_0, val_share_0_cached, - "shares must be computed equally with or without caching: {:?}", - level - ); - - assert_eq!( - val_share_1, val_share_1_cached, - "shares must be computed equally with or without caching: {:?}", - level - ); - - assert_eq!( - val_share_0, val_share_0_cached, - "proofs must be equal with or without caching: {:?}", - level - ); - - assert_eq!( - val_share_1, val_share_1_cached, - "proofs must be equal with or without caching: {:?}", - level - ); - } - } } mod weight {