From 57cf7806d3afd42afb3dfc8a59ef33fb1e91158a Mon Sep 17 00:00:00 2001 From: Aaron Feickert <66188213+AaronFeickert@users.noreply.github.com> Date: Mon, 8 Jan 2024 11:54:20 -0600 Subject: [PATCH] Batch verification --- benches/triptych.rs | 75 +++++++- src/lib.rs | 6 +- src/proof.rs | 428 ++++++++++++++++++++++++++++++-------------- 3 files changed, 372 insertions(+), 137 deletions(-) diff --git a/benches/triptych.rs b/benches/triptych.rs index a864053..c0bdb6a 100644 --- a/benches/triptych.rs +++ b/benches/triptych.rs @@ -10,7 +10,7 @@ extern crate alloc; use alloc::sync::Arc; use criterion::Criterion; -use curve25519_dalek::RistrettoPoint; +use curve25519_dalek::{RistrettoPoint, Scalar}; use rand_chacha::ChaCha12Rng; use rand_core::SeedableRng; use triptych::{ @@ -21,8 +21,9 @@ use triptych::{ }; // Parameters -static N_VALUES: [u32; 1] = [2]; -static M_VALUES: [u32; 4] = [2, 4, 8, 10]; +const N_VALUES: [u32; 1] = [2]; +const M_VALUES: [u32; 4] = [2, 4, 8, 10]; +const BATCH_SIZES: [usize; 2] = [2, 4]; #[allow(non_snake_case)] #[allow(non_upper_case_globals)] @@ -162,6 +163,72 @@ fn verify_proof(c: &mut Criterion) { group.finish(); } +#[allow(non_snake_case)] +#[allow(non_upper_case_globals)] +fn verify_batch_proof(c: &mut Criterion) { + let mut group = c.benchmark_group("verify_batch_proof"); + let mut rng = ChaCha12Rng::seed_from_u64(8675309); + + for n in N_VALUES { + for m in M_VALUES { + for batch in BATCH_SIZES { + // Generate parameters + let params = Arc::new(Parameters::new(n, m).unwrap()); + + let label = format!( + "Verify batch proof: n = {}, m = {} (N = {}), {}-batch", + n, + m, + params.get_N(), + batch + ); + group.bench_function(&label, |b| { + // Generate witnesses; for this test, we use adjacent indexes for simplicity + // This means the batch size must not exceed the input set size! + assert!(batch <= params.get_N() as usize); + let mut witnesses = Vec::with_capacity(batch); + witnesses.push(Witness::random(¶ms, &mut rng)); + for _ in 1..batch { + let r = Scalar::random(&mut rng); + let l = (witnesses.last().unwrap().get_l() + 1) % params.get_N(); + witnesses.push(Witness::new(¶ms, l, &r).unwrap()); + } + + // Generate input set from all witnesses + let mut M = (0..params.get_N()) + .map(|_| RistrettoPoint::random(&mut rng)) + .collect::>(); + for witness in &witnesses { + M[witness.get_l() as usize] = witness.compute_verification_key(); + } + let input_set = Arc::new(InputSet::new(&M)); + + // Generate statements + let mut statements = Vec::with_capacity(batch); + for witness in &witnesses { + let J = witness.compute_linking_tag(); + let message = "Proof message".as_bytes(); + statements.push(Statement::new(¶ms, &input_set, &J, Some(message)).unwrap()); + } + + // Generate proofs + let proofs = witnesses + .iter() + .zip(statements.iter()) + .map(|(w, s)| Proof::prove_vartime(w, s, &mut rng).unwrap()) + .collect::>(); + + // Start the benchmark + b.iter(|| { + assert!(Proof::verify_batch(&statements, &proofs)); + }) + }); + } + } + } + group.finish(); +} + criterion_group! { name = generate; config = Criterion::default(); @@ -171,7 +238,7 @@ criterion_group! { criterion_group! { name = verify; config = Criterion::default(); - targets = verify_proof + targets = verify_proof, verify_batch_proof } criterion_main!(generate, verify); diff --git a/src/lib.rs b/src/lib.rs index 42db61d..3663de0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,9 @@ //! the same linking tag, they were produced using the same signing key. However, it is not possible to determine the //! signing key associated to a linking tag, nor the corresponding verification key. //! +//! Triptych proofs scale nicely, with their size increasingly only logarithmically with the size of the verification +//! key set. Proofs sharing the same verification key set can also be verified efficiently in batches to save time. +//! //! More formally, let `G` and `U` be fixed independent generators of the Ristretto group. //! Let `N = n**m`, where `n, m > 1` are fixed parameters. //! The Triptych proving system protocol is a sigma protocol for the following relation, where `M` is an `N`-vector of @@ -50,7 +53,8 @@ //! //! # Example //! -//! Here's a complete example of how to generate and verify a Triptych proof. +//! Here's a complete example of how to generate and verify a Triptych proof; see the documentation for additional +//! functionality. //! //! ``` //! # extern crate alloc; diff --git a/src/proof.rs b/src/proof.rs index 9c50d79..7b8073f 100644 --- a/src/proof.rs +++ b/src/proof.rs @@ -340,179 +340,287 @@ impl Proof { /// Verify a Triptych proof. /// - /// Verification requires that the statement `statement` match that used when the - /// proof was generated. + /// Verification requires that the statement `statement` match that used when the proof was generated. /// - /// Returns a boolean that is `true` if and only if the proof is valid. - #[allow(clippy::too_many_lines, non_snake_case)] + /// Returns a boolean that is `true` if and only if the above requirement is met and the proof is valid. pub fn verify(&self, statement: &Statement) -> bool { - // Extract statement values for convenience - let M = statement.get_input_set().get_keys(); - let params = statement.get_params(); - let J = statement.get_J(); + // Verify as a trivial batch + Self::verify_batch(&[statement.clone()], &[self.clone()]) + } - // Check that the proof semantics are valid for the statement - if self.X.len() != params.get_m() as usize { + /// Verify a batch of Triptych proofs that share a common input set and parameters. + /// + /// Verification requires that the statements `statements` match those used when the proofs were generated, and that + /// they share a common input set and parameters. + /// + /// Returns a boolean that is `true` if and only if the above requirements are met and each proof is valid. + /// If the batch is empty, returns `true`. + #[allow(clippy::too_many_lines, non_snake_case)] + pub fn verify_batch(statements: &[Statement], proofs: &[Proof]) -> bool { + // Check that we have the same number of statements and proofs + if statements.len() != proofs.len() { return false; } - if self.Y.len() != params.get_m() as usize { + + // An empty batch is considered trivially valid + let first_statement = match statements.first() { + Some(statement) => statement, + None => { + return false; + }, + }; + + // Each statement must use the same input set (checked using the hash for efficiency) + if !statements + .iter() + .map(|s| s.get_input_set().get_hash()) + .all(|h| h == first_statement.get_input_set().get_hash()) + { return false; } - if self.f.len() != params.get_m() as usize { + + // Each statement must use the same parameters (checked using the hash for efficiency) + if !statements + .iter() + .map(|s| s.get_params().get_hash()) + .all(|h| h == first_statement.get_params().get_hash()) + { return false; } - for f_row in &self.f { - if f_row.len() != (params.get_n() - 1) as usize { - return false; - } - } - // Generate the verifier challenge - let mut transcript = Transcript::new("Triptych proof".as_bytes()); - transcript.append_u64("version".as_bytes(), VERSION); - if let Some(message) = statement.get_message() { - transcript.append_message("message".as_bytes(), message); - } - transcript.append_message("params".as_bytes(), params.get_hash()); - transcript.append_message("M".as_bytes(), statement.get_input_set().get_hash()); - transcript.append_message("J".as_bytes(), J.compress().as_bytes()); + // Extract common values for convenience + let M = first_statement.get_input_set().get_keys(); + let params = first_statement.get_params(); - transcript.append_message("A".as_bytes(), self.A.compress().as_bytes()); - transcript.append_message("B".as_bytes(), self.B.compress().as_bytes()); - transcript.append_message("C".as_bytes(), self.C.compress().as_bytes()); - transcript.append_message("D".as_bytes(), self.D.compress().as_bytes()); - for item in &self.X { - transcript.append_message("X".as_bytes(), item.compress().as_bytes()); - } - for item in &self.Y { - transcript.append_message("Y".as_bytes(), item.compress().as_bytes()); + // Check that all proof semantics are valid for the statement + for proof in proofs { + if proof.X.len() != params.get_m() as usize { + return false; + } + if proof.Y.len() != params.get_m() as usize { + return false; + } + if proof.f.len() != params.get_m() as usize { + return false; + } + for f_row in &proof.f { + if f_row.len() != (params.get_n() - 1) as usize { + return false; + } + } } - // Get challenge powers - let xi_powers = match xi_powers(&mut transcript, params.get_m()) { - Ok(xi_powers) => xi_powers, + // Determine the size of the final check vector, which must not overflow `usize` + let batch_size = match u32::try_from(proofs.len()) { + Ok(batch) => batch, _ => { return false; }, }; - - // Finish the transcript for pseudorandom number generation - for f_row in &self.f { - for f in f_row { - transcript.append_message("f".as_bytes(), f.as_bytes()); - } - } - transcript.append_message("z_A".as_bytes(), self.z_A.as_bytes()); - transcript.append_message("z_C".as_bytes(), self.z_C.as_bytes()); - transcript.append_message("z".as_bytes(), self.z.as_bytes()); - let mut transcript_rng = transcript.build_rng().finalize(&mut DangerousRng); - - // Reconstruct the remaining `f` terms - let f = (0..params.get_m()) - .map(|j| { - let mut f_j = Vec::with_capacity(params.get_n() as usize); - f_j.push(xi_powers[1] - self.f[j as usize].iter().sum::()); - f_j.extend(self.f[j as usize].iter()); - f_j - }) - .collect::>>(); - - // Check that `f` does not contain zero, which breaks batch inversion - for f_row in &f { - if f_row.contains(&Scalar::ZERO) { + let final_size = match usize::try_from( + 1 // G + + params.get_n() * params.get_m() // CommitmentG + + 1 // CommitmentH + + params.get_N() // M + + 1 // U + + batch_size * ( + 4 // A, B, C, D + + 1 // J + + 2 * params.get_m() // X, Y + ), + ) { + Ok(size) => size, + _ => { return false; - } - } - - // Generate weights for verification equations - // We implicitly set `w3 = 1` to avoid unnecessary constant-time multiplication - let w1 = Scalar::random(&mut transcript_rng); - let w2 = Scalar::random(&mut transcript_rng); - let w4 = Scalar::random(&mut transcript_rng); + }, + }; - // Set up the point iterator for the final check - let points = once(params.get_G()) + // Set up the point vector for the final check + let points = proofs + .iter() + .zip(statements.iter()) + .flat_map(|(p, s)| { + once(&p.A) + .chain(once(&p.B)) + .chain(once(&p.C)) + .chain(once(&p.D)) + .chain(once(s.get_J())) + .chain(p.X.iter()) + .chain(p.Y.iter()) + }) + .chain(once(params.get_G())) .chain(params.get_CommitmentG().iter()) .chain(once(params.get_CommitmentH())) - .chain(once(&self.A)) - .chain(once(&self.B)) - .chain(once(&self.C)) - .chain(once(&self.D)) - .chain(once(J)) - .chain(self.X.iter()) - .chain(self.Y.iter()) .chain(M.iter()) - .chain(once(params.get_U())); + .chain(once(params.get_U())) + .collect::>(); + + // Start the scalar vector, putting the common elements last + let mut scalars = Vec::with_capacity(final_size); - // Set up the scalar vector for the final check, matching the point iterator - let mut scalars = - Vec::with_capacity((params.get_N() + 2 * params.get_m() + params.get_n() * params.get_m() + 8) as usize); + // Set up common scalars + let mut G_scalar = Scalar::ZERO; + let mut CommitmentG_scalars = vec![Scalar::ZERO; params.get_CommitmentG().len()]; + let mut CommitmentH_scalar = Scalar::ZERO; + let mut M_scalars = vec![Scalar::ZERO; M.len()]; let mut U_scalar = Scalar::ZERO; - // G - scalars.push(-self.z); + // Set up a transcript generator for use in weighting + let mut transcript_weights = Transcript::new("Triptych verifier weights".as_bytes()); + + // Generate all verifier challenges + let mut xi_powers_all = Vec::with_capacity(proofs.len()); + for (statement, proof) in statements.iter().zip(proofs.iter()) { + // Generate the verifier challenge + let mut transcript = Transcript::new("Triptych proof".as_bytes()); + transcript.append_u64("version".as_bytes(), VERSION); + if let Some(message) = statement.get_message() { + transcript.append_message("message".as_bytes(), message); + } + transcript.append_message("params".as_bytes(), params.get_hash()); + transcript.append_message("M".as_bytes(), statement.get_input_set().get_hash()); + transcript.append_message("J".as_bytes(), statement.get_J().compress().as_bytes()); + + transcript.append_message("A".as_bytes(), proof.A.compress().as_bytes()); + transcript.append_message("B".as_bytes(), proof.B.compress().as_bytes()); + transcript.append_message("C".as_bytes(), proof.C.compress().as_bytes()); + transcript.append_message("D".as_bytes(), proof.D.compress().as_bytes()); + for item in &proof.X { + transcript.append_message("X".as_bytes(), item.compress().as_bytes()); + } + for item in &proof.Y { + transcript.append_message("Y".as_bytes(), item.compress().as_bytes()); + } - // CommitmentG - for f_row in &f { - for f_item in f_row { - scalars.push(w1 * f_item + w2 * f_item * (xi_powers[1] - f_item)); + // Get challenge powers + let xi_powers = match xi_powers(&mut transcript, params.get_m()) { + Ok(xi_powers) => xi_powers, + _ => { + return false; + }, + }; + + xi_powers_all.push(xi_powers); + + // Finish the transcript for pseudorandom number generation + for f_row in &proof.f { + for f in f_row { + transcript.append_message("f".as_bytes(), f.as_bytes()); + } } + transcript.append_message("z_A".as_bytes(), proof.z_A.as_bytes()); + transcript.append_message("z_C".as_bytes(), proof.z_C.as_bytes()); + transcript.append_message("z".as_bytes(), proof.z.as_bytes()); + let mut transcript_rng = transcript.build_rng().finalize(&mut DangerousRng); + + transcript_weights.append_u64("proof".as_bytes(), transcript_rng.as_rngcore().next_u64()); } - // CommitmentH - scalars.push(w1 * self.z_A + w2 * self.z_C); + // Finalize the weighting transcript into a pseudorandom number generator + let mut transcript_weights_rng = transcript_weights.build_rng().finalize(&mut DangerousRng); + + // Process each proof + for (proof, xi_powers) in proofs.iter().zip(xi_powers_all.iter()) { + // Reconstruct the remaining `f` terms + let f = (0..params.get_m()) + .map(|j| { + let mut f_j = Vec::with_capacity(params.get_n() as usize); + f_j.push(xi_powers[1] - proof.f[j as usize].iter().sum::()); + f_j.extend(proof.f[j as usize].iter()); + f_j + }) + .collect::>>(); + + // Check that `f` does not contain zero, which breaks batch inversion + for f_row in &f { + if f_row.contains(&Scalar::ZERO) { + return false; + } + } - // A - scalars.push(-w1); + // Generate weights for this proof's verification equations + let w1 = Scalar::random(&mut transcript_weights_rng); + let w2 = Scalar::random(&mut transcript_weights_rng); + let w3 = Scalar::random(&mut transcript_weights_rng); + let w4 = Scalar::random(&mut transcript_weights_rng); - // B - scalars.push(-w1 * xi_powers[1]); + // Get the challenge for convenience + let xi = xi_powers[1]; - // C - scalars.push(-w2 * xi_powers[1]); + // G + G_scalar -= w3 * proof.z; - // D - scalars.push(-w2); + // CommitmentG + for (CommitmentG_scalar, f_item) in CommitmentG_scalars + .iter_mut() + .zip(f.iter().flatten().map(|f| w1 * f + w2 * f * (xi - f))) + { + *CommitmentG_scalar += f_item; + } - // J - scalars.push(-w4 * self.z); + // CommitmentH + CommitmentH_scalar += w1 * proof.z_A + w2 * proof.z_C; - // X - for xi_power in &xi_powers[0..(params.get_m() as usize)] { - scalars.push(-xi_power); - } + // A + scalars.push(-w1); - // Y - for xi_power in &xi_powers[0..(params.get_m() as usize)] { - scalars.push(-w4 * xi_power); - } + // B + scalars.push(-w1 * xi_powers[1]); - // Set up the initial `f` product and Gray iterator - let mut f_product = f.iter().map(|f_row| f_row[0]).product::(); - let gray_iterator = if let Some(gray_iterator) = GrayIterator::new(params.get_n(), params.get_m()) { - gray_iterator - } else { - return false; - }; + // C + scalars.push(-w2 * xi_powers[1]); + + // D + scalars.push(-w2); + + // J + scalars.push(-w4 * proof.z); - // Invert each element of `f` for efficiency - let mut f_inverse_flat = f.iter().flatten().copied().collect::>(); - Scalar::batch_invert(&mut f_inverse_flat); - let f_inverse = f_inverse_flat - .chunks_exact(params.get_n() as usize) - .collect::>(); + // X + for xi_power in &xi_powers[0..(params.get_m() as usize)] { + scalars.push(-w3 * xi_power); + } + + // Y + for xi_power in &xi_powers[0..(params.get_m() as usize)] { + scalars.push(-w4 * xi_power); + } - // M - for (gray_index, gray_old, gray_new) in gray_iterator { - // Update the `f` product - f_product *= f_inverse[gray_index][gray_old as usize] * f[gray_index][gray_new as usize]; + // Set up the initial `f` product and Gray iterator + let mut f_product = f.iter().map(|f_row| f_row[0]).product::(); + let gray_iterator = if let Some(gray_iterator) = GrayIterator::new(params.get_n(), params.get_m()) { + gray_iterator + } else { + return false; + }; + + // Invert each element of `f` for efficiency + let mut f_inverse_flat = f.iter().flatten().copied().collect::>(); + Scalar::batch_invert(&mut f_inverse_flat); + let f_inverse = f_inverse_flat + .chunks_exact(params.get_n() as usize) + .collect::>(); + + // M + let mut U_scalar_proof = Scalar::ZERO; + for (M_scalar, (gray_index, gray_old, gray_new)) in M_scalars.iter_mut().zip(gray_iterator) { + // Update the `f` product + f_product *= f_inverse[gray_index][gray_old as usize] * f[gray_index][gray_new as usize]; + + *M_scalar += w3 * f_product; + U_scalar_proof += f_product; + } - scalars.push(f_product); - U_scalar += f_product; + // U + U_scalar += w4 * U_scalar_proof; } - // U - scalars.push(w4 * U_scalar); + // Add all common elements to the scalar vector + scalars.push(G_scalar); + scalars.extend(CommitmentG_scalars); + scalars.push(CommitmentH_scalar); + scalars.extend(M_scalars); + scalars.push(U_scalar); // Perform the final check; this can be done in variable time since it holds no secrets RistrettoPoint::vartime_multiscalar_mul(scalars.iter(), points) == RistrettoPoint::identity() @@ -523,7 +631,7 @@ impl Proof { mod test { use alloc::{sync::Arc, vec::Vec}; - use curve25519_dalek::RistrettoPoint; + use curve25519_dalek::{RistrettoPoint, Scalar}; use rand_chacha::ChaCha12Rng; use rand_core::{CryptoRngCore, SeedableRng}; @@ -563,6 +671,43 @@ mod test { (witness, statement) } + // Generate a batch of witnesses and corresponding statements + #[allow(non_snake_case)] + fn generate_batch_data(n: u32, m: u32, b: usize, rng: &mut R) -> (Vec, Vec) { + // Generate parameters + let params = Arc::new(Parameters::new(n, m).unwrap()); + + // Generate witnesses; for this test, we use adjacent indexes for simplicity + // This means the batch size must not exceed the input set size! + assert!(b <= params.get_N() as usize); + let mut witnesses = Vec::with_capacity(b); + witnesses.push(Witness::random(¶ms, rng)); + for _ in 1..b { + let r = Scalar::random(rng); + let l = (witnesses.last().unwrap().get_l() + 1) % params.get_N(); + witnesses.push(Witness::new(¶ms, l, &r).unwrap()); + } + + // Generate input set from all witnesses + let mut M = (0..params.get_N()) + .map(|_| RistrettoPoint::random(rng)) + .collect::>(); + for witness in &witnesses { + M[witness.get_l() as usize] = witness.compute_verification_key(); + } + let input_set = Arc::new(InputSet::new(&M)); + + // Generate statements + let mut statements = Vec::with_capacity(b); + for witness in &witnesses { + let J = witness.compute_linking_tag(); + let message = "Proof message".as_bytes(); + statements.push(Statement::new(¶ms, &input_set, &J, Some(message)).unwrap()); + } + + (witnesses, statements) + } + #[test] #[allow(non_snake_case, non_upper_case_globals)] fn test_prove_verify() { @@ -591,6 +736,25 @@ mod test { assert!(proof.verify(&statement)); } + #[test] + #[allow(non_snake_case, non_upper_case_globals)] + fn test_prove_verify_batch() { + // Generate data + const n: u32 = 2; + const m: u32 = 4; + const b: usize = 3; // batch size + let mut rng = ChaCha12Rng::seed_from_u64(8675309); + let (witnesses, statements) = generate_batch_data(n, m, b, &mut rng); + + // Generate the proofs and verify as a batch + let proofs = witnesses + .iter() + .zip(statements.iter()) + .map(|(w, s)| Proof::prove_vartime(w, s, &mut rng).unwrap()) + .collect::>(); + assert!(Proof::verify_batch(&statements, &proofs)); + } + #[test] #[allow(non_snake_case, non_upper_case_globals)] fn test_evil_message() {