Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mastic: Add encoding for prep state #1185

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 112 additions & 2 deletions src/vdaf/mastic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,13 +396,55 @@ where
/// parameters of Mastic used for encoding.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MasticPrepareState<F: FieldElement, const SEED_SIZE: usize> {
/// Includes output shares for eventual aggregation.
/// The counter and truncated weight for each candidate prefix.
output_shares: MasticOutputShare<F>,
/// If [`Szk`]` verification is being performed, we also store the relevant state for that operation.
szk_query_state: SzkQueryState<SEED_SIZE>,
verifier_len: Option<usize>,
}

impl<F: FieldElement, const SEED_SIZE: usize> Encode for MasticPrepareState<F, SEED_SIZE> {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.output_shares.encode(bytes)?;
if let Some(joint_rand_seed) = &self.szk_query_state {
joint_rand_seed.encode(bytes)?;
}
Ok(())
}

fn encoded_len(&self) -> Option<usize> {
Some(
self.output_shares.as_ref().len() * F::ENCODED_SIZE
+ self.szk_query_state.as_ref().map_or(0, |_| SEED_SIZE),
)
}
}

impl<'a, T: Type, P: Xof<SEED_SIZE>, const SEED_SIZE: usize>
ParameterizedDecode<(&'a Mastic<T, P, SEED_SIZE>, &'a MasticAggregationParam)>
for MasticPrepareState<T::Field, SEED_SIZE>
{
fn decode_with_param(
(mastic, agg_param): &(&Mastic<T, P, SEED_SIZE>, &MasticAggregationParam),
bytes: &mut Cursor<&[u8]>,
) -> Result<Self, CodecError> {
let output_shares = MasticOutputShare::decode_with_param(&(*mastic, *agg_param), bytes)?;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be a lot of data to store, since the length of output_shares is linear in the number of candidate prefixes. In most cases, it would be more compact to store the VIDPF seed and correction words along with the candidate prefixes. However, re-inflating would cost CPU.

let szk_query_state = (mastic.szk.typ.joint_rand_len() > 0
&& agg_param.require_weight_check)
.then(|| Seed::decode(bytes))
.transpose()?;
let verifier_len = agg_param
.require_weight_check
.then_some(mastic.szk.typ.verifier_len());

Ok(Self {
output_shares,
szk_query_state,
verifier_len,
})
}
}

/// Mastic preparation share.
///
/// Broadcast message from an aggregator preparing Mastic output shares. Includes the
Expand Down Expand Up @@ -809,7 +851,7 @@ mod tests {
use super::*;
use crate::field::{Field128, Field64};
use crate::flp::gadgets::{Mul, ParallelSum};
use crate::flp::types::{Count, Sum, SumVec};
use crate::flp::types::{Count, Histogram, Sum, SumVec};
use crate::vdaf::test_utils::run_vdaf;
use crate::vdaf::xof::XofTurboShake128;
use rand::{thread_rng, Rng};
Expand Down Expand Up @@ -1297,6 +1339,74 @@ mod tests {
assert_eq!(public, decoded_public_share);
}

mod prep_state {
use super::*;

fn test_prep_state_roundtrip<T: Type>(
typ: T,
weight: T::Measurement,
require_weight_check: bool,
) {
let mastic: Mastic<T, XofTurboShake128, 32> = Mastic::new(0, typ, 256).unwrap();
let ctx = b"some application";
let verify_key = [0u8; 32];
let nonce = [0u8; 16];
let alpha = VidpfInput::from_bools(&[false; 256][..]);
let (public_share, input_shares) =
mastic.shard(ctx, &(alpha.clone(), weight), &nonce).unwrap();
let agg_param = MasticAggregationParam::new(
vec![alpha, VidpfInput::from_bools(&[true; 256][..])],
require_weight_check,
)
.unwrap();

// Test both aggregators.
for agg_id in [0, 1] {
let (prep_state, _prep_share) = mastic
.prepare_init(
&verify_key,
ctx,
agg_id,
&agg_param,
&nonce,
&public_share,
&input_shares[agg_id],
)
.unwrap();

let encoded = prep_state.get_encoded().unwrap();
assert_eq!(Some(encoded.len()), prep_state.encoded_len());
assert_eq!(
MasticPrepareState::get_decoded_with_param(&(&mastic, &agg_param), &encoded)
.unwrap(),
prep_state
);
}
}

#[test]
fn without_joint_rand() {
// The Count type doesn't use joint randomness, which means the prep share won't carry the
// aggregator's joint randomness part in the weight check.
test_prep_state_roundtrip(Count::<Field64>::new(), true, true);
}

#[test]
fn without_weight_check() {
let histogram: Histogram<Field128, ParallelSum<_, Mul<_>>> =
Histogram::new(10, 3).unwrap();
// The agg param doesn't request a weight check, so the prep share won't include it.
test_prep_state_roundtrip(histogram, 0, false);
}

#[test]
fn with_weight_check_and_joint_rand() {
let histogram: Histogram<Field128, ParallelSum<_, Mul<_>>> =
Histogram::new(10, 3).unwrap();
test_prep_state_roundtrip(histogram, 0, true);
}
}

mod test_vec {
use serde::Deserialize;
use std::collections::HashMap;
Expand Down
Loading