From 82a3e3377df4bd156566b172303cc69bec3f3b2f Mon Sep 17 00:00:00 2001 From: Aaron Feickert <66188213+AaronFeickert@users.noreply.github.com> Date: Mon, 8 Jan 2024 16:27:10 -0600 Subject: [PATCH] Refactor benchmarks --- benches/triptych.rs | 155 +++++++++++++++++--------------------------- 1 file changed, 60 insertions(+), 95 deletions(-) diff --git a/benches/triptych.rs b/benches/triptych.rs index c0bdb6a..d36dc46 100644 --- a/benches/triptych.rs +++ b/benches/triptych.rs @@ -12,7 +12,7 @@ use alloc::sync::Arc; use criterion::Criterion; use curve25519_dalek::{RistrettoPoint, Scalar}; use rand_chacha::ChaCha12Rng; -use rand_core::SeedableRng; +use rand_core::{CryptoRngCore, SeedableRng}; use triptych::{ parameters::Parameters, proof::Proof, @@ -23,7 +23,45 @@ use triptych::{ // Parameters const N_VALUES: [u32; 1] = [2]; const M_VALUES: [u32; 4] = [2, 4, 8, 10]; -const BATCH_SIZES: [usize; 2] = [2, 4]; +const BATCH_SIZES: [usize; 1] = [2]; + +// Generate a batch of witnesses and corresponding statements +#[allow(non_snake_case)] +fn generate_batch_data( + params: &Arc, + b: usize, + rng: &mut R, +) -> (Vec, Vec) { + // Generate witnesses; for this test, we use adjacent indexes for simplicity + // This means the batch size must not exceed the input set size! + assert!(b <= params.get_N() as usize); + let mut witnesses = Vec::with_capacity(b); + witnesses.push(Witness::random(params, rng)); + for _ in 1..b { + let r = Scalar::random(rng); + let l = (witnesses.last().unwrap().get_l() + 1) % params.get_N(); + witnesses.push(Witness::new(params, l, &r).unwrap()); + } + + // Generate input set from all witnesses + let mut M = (0..params.get_N()) + .map(|_| RistrettoPoint::random(rng)) + .collect::>(); + for witness in &witnesses { + M[witness.get_l() as usize] = witness.compute_verification_key(); + } + let input_set = Arc::new(InputSet::new(&M)); + + // Generate statements + let mut statements = Vec::with_capacity(b); + for witness in &witnesses { + let J = witness.compute_linking_tag(); + let message = "Proof message".as_bytes(); + statements.push(Statement::new(params, &input_set, &J, Some(message)).unwrap()); + } + + (witnesses, statements) +} #[allow(non_snake_case)] #[allow(non_upper_case_globals)] @@ -38,30 +76,13 @@ fn generate_proof(c: &mut Criterion) { let label = format!("Generate proof: n = {}, m = {} (N = {})", n, m, params.get_N()); group.bench_function(&label, |b| { - // Generate witness - let witness = Witness::random(¶ms, &mut rng); - - // Generate input set - let M = (0..params.get_N()) - .map(|i| { - if i == witness.get_l() { - witness.compute_verification_key() - } else { - RistrettoPoint::random(&mut rng) - } - }) - .collect::>(); - let input_set = Arc::new(InputSet::new(&M)); - - // Generate statement - let J = witness.compute_linking_tag(); - let message = "Proof message".as_bytes(); - let statement = Statement::new(¶ms, &input_set, &J, Some(message)).unwrap(); + // Generate data + let (witnesses, statements) = generate_batch_data(¶ms, 1, &mut rng); // Start the benchmark b.iter(|| { // Generate the proof - let _proof = Proof::prove(&witness, &statement, &mut rng).unwrap(); + let _proof = Proof::prove(&witnesses[0], &statements[0], &mut rng).unwrap(); }) }); } @@ -87,30 +108,13 @@ fn generate_proof_vartime(c: &mut Criterion) { params.get_N() ); group.bench_function(&label, |b| { - // Generate witness - let witness = Witness::random(¶ms, &mut rng); - - // Generate input set - let M = (0..params.get_N()) - .map(|i| { - if i == witness.get_l() { - witness.compute_verification_key() - } else { - RistrettoPoint::random(&mut rng) - } - }) - .collect::>(); - let input_set = Arc::new(InputSet::new(&M)); - - // Generate statement - let J = witness.compute_linking_tag(); - let message = "Proof message".as_bytes(); - let statement = Statement::new(¶ms, &input_set, &J, Some(message)).unwrap(); + // Generate data + let (witnesses, statements) = generate_batch_data(¶ms, 1, &mut rng); // Start the benchmark b.iter(|| { // Generate the proof - let _proof = Proof::prove_vartime(&witness, &statement, &mut rng).unwrap(); + let _proof = Proof::prove_vartime(&witnesses[0], &statements[0], &mut rng).unwrap(); }) }); } @@ -131,31 +135,16 @@ fn verify_proof(c: &mut Criterion) { let label = format!("Verify proof: n = {}, m = {} (N = {})", n, m, params.get_N()); group.bench_function(&label, |b| { - // Generate witness - let witness = Witness::random(¶ms, &mut rng); - - // Generate input set - let M = (0..params.get_N()) - .map(|i| { - if i == witness.get_l() { - witness.compute_verification_key() - } else { - RistrettoPoint::random(&mut rng) - } - }) - .collect::>(); - let input_set = Arc::new(InputSet::new(&M)); + // Generate data + let (witnesses, statements) = generate_batch_data(¶ms, 1, &mut rng); - // Generate statement - let J = witness.compute_linking_tag(); - let message = "Proof message".as_bytes(); - let statement = Statement::new(¶ms, &input_set, &J, Some(message)).unwrap(); - - let proof = Proof::prove(&witness, &statement, &mut rng).unwrap(); + // Generate the proof + let proof = Proof::prove(&witnesses[0], &statements[0], &mut rng).unwrap(); // Start the benchmark b.iter(|| { - assert!(proof.verify(&statement)); + // Verify the proof + assert!(proof.verify(&statements[0])); }) }); } @@ -171,10 +160,10 @@ fn verify_batch_proof(c: &mut Criterion) { for n in N_VALUES { for m in M_VALUES { - for batch in BATCH_SIZES { - // Generate parameters - let params = Arc::new(Parameters::new(n, m).unwrap()); + // Generate parameters + let params = Arc::new(Parameters::new(n, m).unwrap()); + for batch in BATCH_SIZES { let label = format!( "Verify batch proof: n = {}, m = {} (N = {}), {}-batch", n, @@ -183,35 +172,10 @@ fn verify_batch_proof(c: &mut Criterion) { batch ); group.bench_function(&label, |b| { - // Generate witnesses; for this test, we use adjacent indexes for simplicity - // This means the batch size must not exceed the input set size! - assert!(batch <= params.get_N() as usize); - let mut witnesses = Vec::with_capacity(batch); - witnesses.push(Witness::random(¶ms, &mut rng)); - for _ in 1..batch { - let r = Scalar::random(&mut rng); - let l = (witnesses.last().unwrap().get_l() + 1) % params.get_N(); - witnesses.push(Witness::new(¶ms, l, &r).unwrap()); - } - - // Generate input set from all witnesses - let mut M = (0..params.get_N()) - .map(|_| RistrettoPoint::random(&mut rng)) - .collect::>(); - for witness in &witnesses { - M[witness.get_l() as usize] = witness.compute_verification_key(); - } - let input_set = Arc::new(InputSet::new(&M)); - - // Generate statements - let mut statements = Vec::with_capacity(batch); - for witness in &witnesses { - let J = witness.compute_linking_tag(); - let message = "Proof message".as_bytes(); - statements.push(Statement::new(¶ms, &input_set, &J, Some(message)).unwrap()); - } - - // Generate proofs + // Generate data + let (witnesses, statements) = generate_batch_data(¶ms, 1, &mut rng); + + // Generate the proofs let proofs = witnesses .iter() .zip(statements.iter()) @@ -220,6 +184,7 @@ fn verify_batch_proof(c: &mut Criterion) { // Start the benchmark b.iter(|| { + // Verify the proofs in a batch assert!(Proof::verify_batch(&statements, &proofs)); }) });