Skip to content

Commit

Permalink
xof: Pass domain separation tag in parts (#1181)
Browse files Browse the repository at this point in the history
Modify the `Xof` trait by allowing the user to pass the domain
separation tag in parts. This saves us from allocating a `Vec` in many
cases.
  • Loading branch information
cjpatton authored Jan 3, 2025
1 parent bc74489 commit 30d4302
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 82 deletions.
12 changes: 6 additions & 6 deletions src/flp/szk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ where
fn derive_prove_rand(&self, prove_rand_seed: &Seed<SEED_SIZE>) -> Vec<T::Field> {
P::seed_stream(
prove_rand_seed,
&self.domain_separation_tag(DST_PROVE_RANDOMNESS),
&[&self.domain_separation_tag(DST_PROVE_RANDOMNESS)],
&[],
)
.into_field_vec(self.typ.prove_rand_len())
Expand All @@ -362,7 +362,7 @@ where
) -> Result<Seed<SEED_SIZE>, SzkError> {
let mut xof = P::init(
aggregator_blind.as_ref(),
&self.domain_separation_tag(DST_JOINT_RAND_PART),
&[&self.domain_separation_tag(DST_JOINT_RAND_PART)],
);
xof.update(nonce);
// Encode measurement_share (currently an array of field elements) into
Expand All @@ -383,7 +383,7 @@ where
) -> Seed<SEED_SIZE> {
let mut xof = P::init(
&[0; SEED_SIZE],
&self.domain_separation_tag(DST_JOINT_RAND_SEED),
&[&self.domain_separation_tag(DST_JOINT_RAND_SEED)],
);
xof.update(&leader_joint_rand_part.0);
xof.update(&helper_joint_rand_part.0);
Expand All @@ -399,7 +399,7 @@ where
self.derive_joint_rand_seed(leader_joint_rand_part, helper_joint_rand_part);
let joint_rand = P::seed_stream(
&joint_rand_seed,
&self.domain_separation_tag(DST_JOINT_RANDOMNESS),
&[&self.domain_separation_tag(DST_JOINT_RANDOMNESS)],
&[],
)
.into_field_vec(self.typ.joint_rand_len());
Expand All @@ -410,7 +410,7 @@ where
fn derive_helper_proof_share(&self, proof_share_seed: &Seed<SEED_SIZE>) -> Vec<T::Field> {
Prng::from_seed_stream(P::seed_stream(
proof_share_seed,
&self.domain_separation_tag(DST_PROOF_SHARE),
&[&self.domain_separation_tag(DST_PROOF_SHARE)],
&[],
))
.take(self.typ.proof_len())
Expand All @@ -420,7 +420,7 @@ where
fn derive_query_rand(&self, verify_key: &[u8; SEED_SIZE], nonce: &[u8; 16]) -> Vec<T::Field> {
let mut xof = P::init(
verify_key,
&self.domain_separation_tag(DST_QUERY_RANDOMNESS),
&[&self.domain_separation_tag(DST_QUERY_RANDOMNESS)],
);
xof.update(nonce);
xof.into_seed_stream()
Expand Down
10 changes: 2 additions & 8 deletions src/idpf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,7 @@ fn extend(seed: &[u8; 16], xof_mode: &XofMode<'_>) -> ([[u8; 16]; 2], [Choice; 2
seed_stream.fill_bytes(&mut seeds[1]);
}
XofMode::Leaf(ctx, nonce) => {
let mut dst = Vec::with_capacity(EXTEND_DOMAIN_SEP.len() + ctx.len());
dst.extend(EXTEND_DOMAIN_SEP);
dst.extend(*ctx);
let mut xof = XofTurboShake128::from_seed_slice(seed, &dst);
let mut xof = XofTurboShake128::from_seed_slice(seed, &[EXTEND_DOMAIN_SEP, ctx]);
xof.update(nonce);
let mut seed_stream = xof.into_seed_stream();
seed_stream.fill_bytes(&mut seeds[0]);
Expand Down Expand Up @@ -284,10 +281,7 @@ where
(next_seed, V::generate(&mut seed_stream, parameter))
}
XofMode::Leaf(ctx, nonce) => {
let mut dst = Vec::with_capacity(CONVERT_DOMAIN_SEP.len() + ctx.len());
dst.extend(CONVERT_DOMAIN_SEP);
dst.extend(*ctx);
let mut xof = XofTurboShake128::from_seed_slice(seed, &dst);
let mut xof = XofTurboShake128::from_seed_slice(seed, &[CONVERT_DOMAIN_SEP, ctx]);
xof.update(nonce);
let mut seed_stream = xof.into_seed_stream();
seed_stream.fill_bytes(&mut next_seed);
Expand Down
10 changes: 5 additions & 5 deletions src/prng.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,14 @@ mod tests {
.unwrap();
let expected = Field64::from(4857131209231097247);

let seed_stream = XofTurboShake128::seed_stream(&seed, b"", b"");
let seed_stream = XofTurboShake128::seed_stream(&seed, &[], &[]);
let mut prng = Prng::<Field64, _>::from_seed_stream(seed_stream);
let actual = prng.nth(13882).unwrap();
assert_eq!(actual, expected);

#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
{
let mut seed_stream = XofTurboShake128::seed_stream(&seed, b"", b"");
let mut seed_stream = XofTurboShake128::seed_stream(&seed, &[], &[]);
let mut actual = <Field64 as FieldElement>::zero();
for _ in 0..=13882 {
actual = <Field64 as crate::idpf::IdpfValue>::generate(&mut seed_stream, &());
Expand All @@ -257,11 +257,11 @@ mod tests {
let seed = Seed::generate().unwrap();

let mut prng: Prng<Field64, SeedStreamTurboShake128> =
Prng::from_seed_stream(XofTurboShake128::seed_stream(&seed, b"", b""));
Prng::from_seed_stream(XofTurboShake128::seed_stream(&seed, &[], &[]));

// Construct a `Prng` with a longer-than-usual buffer.
let mut prng_weird_buffer_size: Prng<Field64, SeedStreamTurboShake128> =
Prng::from_seed_stream(XofTurboShake128::seed_stream(&seed, b"", b""));
Prng::from_seed_stream(XofTurboShake128::seed_stream(&seed, &[], &[]));
let mut extra = [0; 7];
prng_weird_buffer_size.seed_stream.fill_bytes(&mut extra);
prng_weird_buffer_size.buffer.extend_from_slice(&extra);
Expand All @@ -278,7 +278,7 @@ mod tests {
fn into_different_field() {
let seed = Seed::generate().unwrap();
let want: Prng<Field64, SeedStreamTurboShake128> =
Prng::from_seed_stream(XofTurboShake128::seed_stream(&seed, b"", b""));
Prng::from_seed_stream(XofTurboShake128::seed_stream(&seed, &[], &[]));
let want_buffer = want.buffer.clone();

let got: Prng<Field128, _> = want.into_new_field();
Expand Down
15 changes: 6 additions & 9 deletions src/vdaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,13 @@ pub trait Vdaf: Clone + Debug {

/// Generate the domain separation tag for this VDAF. The output is used for domain separation
/// by the XOF.
fn domain_separation_tag(&self, usage: u16, ctx: &[u8]) -> Vec<u8> {
fn domain_separation_tag(&self, usage: u16) -> [u8; 8] {
// Prefix is 8 bytes and defined by the spec. Copy these values in
let mut dst = Vec::with_capacity(ctx.len() + 8);
dst.push(VERSION);
dst.push(0); // algorithm class
dst.extend_from_slice(self.algorithm_id().to_be_bytes().as_slice());
dst.extend_from_slice(usage.to_be_bytes().as_slice());
// Finally, append user-chosen `ctx`
dst.extend_from_slice(ctx);

let mut dst = [0; 8];
dst[0] = VERSION;
dst[1] = 0; // algorithm class
dst[2..6].clone_from_slice(self.algorithm_id().to_be_bytes().as_slice());
dst[6..8].clone_from_slice(usage.to_be_bytes().as_slice());
dst
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/vdaf/mastic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ where

// Onehot and payload checks
let (payload_check, onehot_proof) = {
let mut payload_check_xof = P::init(&[0; SEED_SIZE], b"");
let mut payload_check_xof = P::init(&[0; SEED_SIZE], &[]);
let mut payload_check_buf = Vec::with_capacity(T::Field::ENCODED_SIZE);
let mut onehot_proof = ONEHOT_PROOF_INIT;

Expand Down Expand Up @@ -580,7 +580,7 @@ where
};

let eval_proof = {
let mut eval_proof_xof = P::init(&[0; SEED_SIZE], b"");
let mut eval_proof_xof = P::init(&[0; SEED_SIZE], &[]);
eval_proof_xof.update(&onehot_proof);
eval_proof_xof.update(&payload_check);
eval_proof_xof.update(&counter_check);
Expand Down
2 changes: 1 addition & 1 deletion src/vdaf/poplar1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Poplar1<P, SEED_SIZE> {
P: Xof<SEED_SIZE>,
F: FieldElement,
{
let mut xof = P::init(seed, &self.domain_separation_tag(usage, ctx));
let mut xof = P::init(seed, &[&self.domain_separation_tag(usage), ctx]);
for binder_chunk in binder_chunks.into_iter() {
xof.update(binder_chunk.as_ref());
}
Expand Down
40 changes: 21 additions & 19 deletions src/vdaf/prio3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,8 @@ where
fn derive_prove_rands(&self, ctx: &[u8], prove_rand_seed: &Seed<SEED_SIZE>) -> Vec<T::Field> {
P::seed_stream(
prove_rand_seed,
&self.domain_separation_tag(DST_PROVE_RANDOMNESS, ctx),
&[self.num_proofs],
&[&self.domain_separation_tag(DST_PROVE_RANDOMNESS), ctx],
&[&[self.num_proofs]],
)
.into_field_vec(self.typ.prove_rand_len() * self.num_proofs())
}
Expand All @@ -495,7 +495,7 @@ where
) -> Seed<SEED_SIZE> {
let mut xof = P::init(
&[0; SEED_SIZE],
&self.domain_separation_tag(DST_JOINT_RAND_SEED, ctx),
&[&self.domain_separation_tag(DST_JOINT_RAND_SEED), ctx],
);
for part in joint_rand_parts {
xof.update(part.as_ref());
Expand All @@ -511,8 +511,8 @@ where
let joint_rand_seed = self.derive_joint_rand_seed(ctx, joint_rand_parts);
let joint_rands = P::seed_stream(
&joint_rand_seed,
&self.domain_separation_tag(DST_JOINT_RANDOMNESS, ctx),
&[self.num_proofs],
&[&self.domain_separation_tag(DST_JOINT_RANDOMNESS), ctx],
&[&[self.num_proofs]],
)
.into_field_vec(self.typ.joint_rand_len() * self.num_proofs());

Expand All @@ -527,8 +527,8 @@ where
) -> Prng<T::Field, P::SeedStream> {
Prng::from_seed_stream(P::seed_stream(
proofs_share_seed,
&self.domain_separation_tag(DST_PROOF_SHARE, ctx),
&[self.num_proofs, agg_id],
&[&self.domain_separation_tag(DST_PROOF_SHARE), ctx],
&[&[self.num_proofs, agg_id]],
))
}

Expand All @@ -540,7 +540,7 @@ where
) -> Vec<T::Field> {
let mut xof = P::init(
verify_key,
&self.domain_separation_tag(DST_QUERY_RANDOMNESS, ctx),
&[&self.domain_separation_tag(DST_QUERY_RANDOMNESS), ctx],
);
xof.update(&[self.num_proofs]);
xof.update(nonce);
Expand Down Expand Up @@ -605,16 +605,16 @@ where
let proof_share_seed = random_seeds.next().unwrap().try_into().unwrap();
let measurement_share_prng: Prng<T::Field, _> = Prng::from_seed_stream(P::seed_stream(
&Seed(measurement_share_seed),
&self.domain_separation_tag(DST_MEASUREMENT_SHARE, ctx),
&[agg_id],
&[&self.domain_separation_tag(DST_MEASUREMENT_SHARE), ctx],
&[&[agg_id]],
));
let joint_rand_blind = if let Some(helper_joint_rand_parts) =
helper_joint_rand_parts.as_mut()
{
let joint_rand_blind = random_seeds.next().unwrap().try_into().unwrap();
let mut joint_rand_part_xof = P::init(
&joint_rand_blind,
&self.domain_separation_tag(DST_JOINT_RAND_PART, ctx),
&[&self.domain_separation_tag(DST_JOINT_RAND_PART), ctx],
);
joint_rand_part_xof.update(&[agg_id]); // Aggregator ID
joint_rand_part_xof.update(nonce);
Expand Down Expand Up @@ -660,7 +660,7 @@ where

let mut joint_rand_part_xof = P::init(
leader_blind.as_ref(),
&self.domain_separation_tag(DST_JOINT_RAND_PART, ctx),
&[&self.domain_separation_tag(DST_JOINT_RAND_PART), ctx],
);
joint_rand_part_xof.update(&[0]); // Aggregator ID
joint_rand_part_xof.update(nonce);
Expand Down Expand Up @@ -1242,8 +1242,8 @@ where
Share::Helper(ref seed) => Cow::Owned(
P::seed_stream(
seed,
&self.domain_separation_tag(DST_MEASUREMENT_SHARE, ctx),
&[agg_id],
&[&self.domain_separation_tag(DST_MEASUREMENT_SHARE), ctx],
&[&[agg_id]],
)
.into_field_vec(self.typ.input_len()),
),
Expand All @@ -1262,7 +1262,7 @@ where
let (joint_rand_seed, joint_rand_part, joint_rands) = if self.typ.joint_rand_len() > 0 {
let mut joint_rand_part_xof = P::init(
msg.joint_rand_blind.as_ref().unwrap().as_ref(),
&self.domain_separation_tag(DST_JOINT_RAND_PART, ctx),
&[&self.domain_separation_tag(DST_JOINT_RAND_PART), ctx],
);
joint_rand_part_xof.update(&[agg_id]);
joint_rand_part_xof.update(nonce);
Expand Down Expand Up @@ -1424,10 +1424,12 @@ where
// Compute the output share.
let measurement_share = match step.measurement_share {
Share::Leader(data) => data,
Share::Helper(seed) => {
let dst = self.domain_separation_tag(DST_MEASUREMENT_SHARE, ctx);
P::seed_stream(&seed, &dst, &[step.agg_id]).into_field_vec(self.typ.input_len())
}
Share::Helper(seed) => P::seed_stream(
&seed,
&[&self.domain_separation_tag(DST_MEASUREMENT_SHARE), ctx],
&[&[step.agg_id]],
)
.into_field_vec(self.typ.input_len()),
};

let output_share = match self.typ.truncate(measurement_share) {
Expand Down
Loading

0 comments on commit 30d4302

Please sign in to comment.