Skip to content

Commit

Permalink
Refactor benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
AaronFeickert committed Jan 8, 2024
1 parent 65c5910 commit 82a3e33
Showing 1 changed file with 60 additions and 95 deletions.
155 changes: 60 additions & 95 deletions benches/triptych.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use alloc::sync::Arc;
use criterion::Criterion;
use curve25519_dalek::{RistrettoPoint, Scalar};
use rand_chacha::ChaCha12Rng;
use rand_core::SeedableRng;
use rand_core::{CryptoRngCore, SeedableRng};
use triptych::{
parameters::Parameters,
proof::Proof,
Expand All @@ -23,7 +23,45 @@ use triptych::{
// Parameters
const N_VALUES: [u32; 1] = [2];
const M_VALUES: [u32; 4] = [2, 4, 8, 10];
const BATCH_SIZES: [usize; 2] = [2, 4];
const BATCH_SIZES: [usize; 1] = [2];

// Generate a batch of witnesses and corresponding statements
#[allow(non_snake_case)]
fn generate_batch_data<R: CryptoRngCore>(
params: &Arc<Parameters>,
b: usize,
rng: &mut R,
) -> (Vec<Witness>, Vec<Statement>) {
// Generate witnesses; for this test, we use adjacent indexes for simplicity
// This means the batch size must not exceed the input set size!
assert!(b <= params.get_N() as usize);
let mut witnesses = Vec::with_capacity(b);
witnesses.push(Witness::random(params, rng));
for _ in 1..b {
let r = Scalar::random(rng);
let l = (witnesses.last().unwrap().get_l() + 1) % params.get_N();
witnesses.push(Witness::new(params, l, &r).unwrap());
}

// Generate input set from all witnesses
let mut M = (0..params.get_N())
.map(|_| RistrettoPoint::random(rng))
.collect::<Vec<RistrettoPoint>>();
for witness in &witnesses {
M[witness.get_l() as usize] = witness.compute_verification_key();
}
let input_set = Arc::new(InputSet::new(&M));

// Generate statements
let mut statements = Vec::with_capacity(b);
for witness in &witnesses {
let J = witness.compute_linking_tag();
let message = "Proof message".as_bytes();
statements.push(Statement::new(params, &input_set, &J, Some(message)).unwrap());
}

(witnesses, statements)
}

#[allow(non_snake_case)]
#[allow(non_upper_case_globals)]
Expand All @@ -38,30 +76,13 @@ fn generate_proof(c: &mut Criterion) {

let label = format!("Generate proof: n = {}, m = {} (N = {})", n, m, params.get_N());
group.bench_function(&label, |b| {
// Generate witness
let witness = Witness::random(&params, &mut rng);

// Generate input set
let M = (0..params.get_N())
.map(|i| {
if i == witness.get_l() {
witness.compute_verification_key()
} else {
RistrettoPoint::random(&mut rng)
}
})
.collect::<Vec<RistrettoPoint>>();
let input_set = Arc::new(InputSet::new(&M));

// Generate statement
let J = witness.compute_linking_tag();
let message = "Proof message".as_bytes();
let statement = Statement::new(&params, &input_set, &J, Some(message)).unwrap();
// Generate data
let (witnesses, statements) = generate_batch_data(&params, 1, &mut rng);

// Start the benchmark
b.iter(|| {
// Generate the proof
let _proof = Proof::prove(&witness, &statement, &mut rng).unwrap();
let _proof = Proof::prove(&witnesses[0], &statements[0], &mut rng).unwrap();
})
});
}
Expand All @@ -87,30 +108,13 @@ fn generate_proof_vartime(c: &mut Criterion) {
params.get_N()
);
group.bench_function(&label, |b| {
// Generate witness
let witness = Witness::random(&params, &mut rng);

// Generate input set
let M = (0..params.get_N())
.map(|i| {
if i == witness.get_l() {
witness.compute_verification_key()
} else {
RistrettoPoint::random(&mut rng)
}
})
.collect::<Vec<RistrettoPoint>>();
let input_set = Arc::new(InputSet::new(&M));

// Generate statement
let J = witness.compute_linking_tag();
let message = "Proof message".as_bytes();
let statement = Statement::new(&params, &input_set, &J, Some(message)).unwrap();
// Generate data
let (witnesses, statements) = generate_batch_data(&params, 1, &mut rng);

// Start the benchmark
b.iter(|| {
// Generate the proof
let _proof = Proof::prove_vartime(&witness, &statement, &mut rng).unwrap();
let _proof = Proof::prove_vartime(&witnesses[0], &statements[0], &mut rng).unwrap();
})
});
}
Expand All @@ -131,31 +135,16 @@ fn verify_proof(c: &mut Criterion) {

let label = format!("Verify proof: n = {}, m = {} (N = {})", n, m, params.get_N());
group.bench_function(&label, |b| {
// Generate witness
let witness = Witness::random(&params, &mut rng);

// Generate input set
let M = (0..params.get_N())
.map(|i| {
if i == witness.get_l() {
witness.compute_verification_key()
} else {
RistrettoPoint::random(&mut rng)
}
})
.collect::<Vec<RistrettoPoint>>();
let input_set = Arc::new(InputSet::new(&M));
// Generate data
let (witnesses, statements) = generate_batch_data(&params, 1, &mut rng);

// Generate statement
let J = witness.compute_linking_tag();
let message = "Proof message".as_bytes();
let statement = Statement::new(&params, &input_set, &J, Some(message)).unwrap();

let proof = Proof::prove(&witness, &statement, &mut rng).unwrap();
// Generate the proof
let proof = Proof::prove(&witnesses[0], &statements[0], &mut rng).unwrap();

// Start the benchmark
b.iter(|| {
assert!(proof.verify(&statement));
// Verify the proof
assert!(proof.verify(&statements[0]));
})
});
}
Expand All @@ -171,10 +160,10 @@ fn verify_batch_proof(c: &mut Criterion) {

for n in N_VALUES {
for m in M_VALUES {
for batch in BATCH_SIZES {
// Generate parameters
let params = Arc::new(Parameters::new(n, m).unwrap());
// Generate parameters
let params = Arc::new(Parameters::new(n, m).unwrap());

for batch in BATCH_SIZES {
let label = format!(
"Verify batch proof: n = {}, m = {} (N = {}), {}-batch",
n,
Expand All @@ -183,35 +172,10 @@ fn verify_batch_proof(c: &mut Criterion) {
batch
);
group.bench_function(&label, |b| {
// Generate witnesses; for this test, we use adjacent indexes for simplicity
// This means the batch size must not exceed the input set size!
assert!(batch <= params.get_N() as usize);
let mut witnesses = Vec::with_capacity(batch);
witnesses.push(Witness::random(&params, &mut rng));
for _ in 1..batch {
let r = Scalar::random(&mut rng);
let l = (witnesses.last().unwrap().get_l() + 1) % params.get_N();
witnesses.push(Witness::new(&params, l, &r).unwrap());
}

// Generate input set from all witnesses
let mut M = (0..params.get_N())
.map(|_| RistrettoPoint::random(&mut rng))
.collect::<Vec<RistrettoPoint>>();
for witness in &witnesses {
M[witness.get_l() as usize] = witness.compute_verification_key();
}
let input_set = Arc::new(InputSet::new(&M));

// Generate statements
let mut statements = Vec::with_capacity(batch);
for witness in &witnesses {
let J = witness.compute_linking_tag();
let message = "Proof message".as_bytes();
statements.push(Statement::new(&params, &input_set, &J, Some(message)).unwrap());
}

// Generate proofs
// Generate data
let (witnesses, statements) = generate_batch_data(&params, 1, &mut rng);

// Generate the proofs
let proofs = witnesses
.iter()
.zip(statements.iter())
Expand All @@ -220,6 +184,7 @@ fn verify_batch_proof(c: &mut Criterion) {

// Start the benchmark
b.iter(|| {
// Verify the proofs in a batch
assert!(Proof::verify_batch(&statements, &proofs));
})
});
Expand Down

0 comments on commit 82a3e33

Please sign in to comment.