Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: batch verification #24

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Batch verification
AaronFeickert committed Jan 8, 2024
commit 56267ef3e0fab3478825c43d8533ce0e46771270
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -10,6 +10,8 @@ Successful verification of a signature means that the signer knew the signing ke
It also produces a linking tag; if any two verified signatures have the same linking tag, they were produced using the same signing key.
However, it is not possible to determine the signing key associated to a linking tag, nor the corresponding verification key.

Triptych proofs scale nicely, with their size increasingly only logarithmically with the size of the verification key set. Proofs sharing the same verification key set can also be verified efficiently in batches to save time.

More formally, let `G` and `U` be fixed independent generators of the Ristretto group.
Let `N = n**m`, where `n, m > 1` are fixed parameters.
The Triptych proving system protocol is a sigma protocol for the following relation, where `M` is an `N`-vector of group elements:
75 changes: 71 additions & 4 deletions benches/triptych.rs
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@ extern crate alloc;
use alloc::sync::Arc;

use criterion::Criterion;
use curve25519_dalek::RistrettoPoint;
use curve25519_dalek::{RistrettoPoint, Scalar};
use rand_chacha::ChaCha12Rng;
use rand_core::SeedableRng;
use triptych::{
@@ -21,8 +21,9 @@ use triptych::{
};

// Parameters
static N_VALUES: [u32; 1] = [2];
static M_VALUES: [u32; 4] = [2, 4, 8, 10];
const N_VALUES: [u32; 1] = [2];
const M_VALUES: [u32; 4] = [2, 4, 8, 10];
const BATCH_SIZES: [usize; 2] = [2, 4];

#[allow(non_snake_case)]
#[allow(non_upper_case_globals)]
@@ -162,6 +163,72 @@ fn verify_proof(c: &mut Criterion) {
group.finish();
}

#[allow(non_snake_case)]
#[allow(non_upper_case_globals)]
fn verify_batch_proof(c: &mut Criterion) {
let mut group = c.benchmark_group("verify_batch_proof");
let mut rng = ChaCha12Rng::seed_from_u64(8675309);

for n in N_VALUES {
for m in M_VALUES {
for batch in BATCH_SIZES {
// Generate parameters
let params = Arc::new(Parameters::new(n, m).unwrap());

let label = format!(
"Verify batch proof: n = {}, m = {} (N = {}), {}-batch",
n,
m,
params.get_N(),
batch
);
group.bench_function(&label, |b| {
// Generate witnesses; for this test, we use adjacent indexes for simplicity
// This means the batch size must not exceed the input set size!
assert!(batch <= params.get_N() as usize);
let mut witnesses = Vec::with_capacity(batch);
witnesses.push(Witness::random(&params, &mut rng));
for _ in 1..batch {
let r = Scalar::random(&mut rng);
let l = (witnesses.last().unwrap().get_l() + 1) % params.get_N();
witnesses.push(Witness::new(&params, l, &r).unwrap());
}

// Generate input set from all witnesses
let mut M = (0..params.get_N())
.map(|_| RistrettoPoint::random(&mut rng))
.collect::<Vec<RistrettoPoint>>();
for witness in &witnesses {
M[witness.get_l() as usize] = witness.compute_verification_key();
}
let input_set = Arc::new(InputSet::new(&M));

// Generate statements
let mut statements = Vec::with_capacity(batch);
for witness in &witnesses {
let J = witness.compute_linking_tag();
let message = "Proof message".as_bytes();
statements.push(Statement::new(&params, &input_set, &J, Some(message)).unwrap());
}

// Generate proofs
let proofs = witnesses
.iter()
.zip(statements.iter())
.map(|(w, s)| Proof::prove_vartime(w, s, &mut rng).unwrap())
.collect::<Vec<Proof>>();

// Start the benchmark
b.iter(|| {
assert!(Proof::verify_batch(&statements, &proofs));
})
});
}
}
}
group.finish();
}

criterion_group! {
name = generate;
config = Criterion::default();
@@ -171,7 +238,7 @@ criterion_group! {
criterion_group! {
name = verify;
config = Criterion::default();
targets = verify_proof
targets = verify_proof, verify_batch_proof
}

criterion_main!(generate, verify);
6 changes: 5 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -12,6 +12,9 @@
//! the same linking tag, they were produced using the same signing key. However, it is not possible to determine the
//! signing key associated to a linking tag, nor the corresponding verification key.
//!
//! Triptych proofs scale nicely, with their size increasingly only logarithmically with the size of the verification
//! key set. Proofs sharing the same verification key set can also be verified efficiently in batches to save time.
//!
//! More formally, let `G` and `U` be fixed independent generators of the Ristretto group.
//! Let `N = n**m`, where `n, m > 1` are fixed parameters.
//! The Triptych proving system protocol is a sigma protocol for the following relation, where `M` is an `N`-vector of
@@ -50,7 +53,8 @@
//!
//! # Example
//!
//! Here's a complete example of how to generate and verify a Triptych proof.
//! Here's a complete example of how to generate and verify a Triptych proof; see the documentation for additional
//! functionality.
//!
//! ```
//! # extern crate alloc;
428 changes: 296 additions & 132 deletions src/proof.rs
Original file line number Diff line number Diff line change
@@ -340,179 +340,287 @@ impl Proof {

/// Verify a Triptych proof.
///
/// Verification requires that the statement `statement` match that used when the
/// proof was generated.
/// Verification requires that the statement `statement` match that used when the proof was generated.
///
/// Returns a boolean that is `true` if and only if the proof is valid.
#[allow(clippy::too_many_lines, non_snake_case)]
/// Returns a boolean that is `true` if and only if the above requirement is met and the proof is valid.
pub fn verify(&self, statement: &Statement) -> bool {
// Extract statement values for convenience
let M = statement.get_input_set().get_keys();
let params = statement.get_params();
let J = statement.get_J();
// Verify as a trivial batch
Self::verify_batch(&[statement.clone()], &[self.clone()])
}

// Check that the proof semantics are valid for the statement
if self.X.len() != params.get_m() as usize {
/// Verify a batch of Triptych proofs that share a common input set and parameters.
///
/// Verification requires that the statements `statements` match those used when the proofs were generated, and that
/// they share a common input set and parameters.
///
/// Returns a boolean that is `true` if and only if the above requirements are met and each proof is valid.
/// If the batch is empty, returns `true`.
#[allow(clippy::too_many_lines, non_snake_case)]
pub fn verify_batch(statements: &[Statement], proofs: &[Proof]) -> bool {
// Check that we have the same number of statements and proofs
if statements.len() != proofs.len() {
return false;
}
if self.Y.len() != params.get_m() as usize {

// An empty batch is considered trivially valid
let first_statement = match statements.first() {
Some(statement) => statement,
None => {
return false;
},
};

// Each statement must use the same input set (checked using the hash for efficiency)
if !statements
.iter()
.map(|s| s.get_input_set().get_hash())
.all(|h| h == first_statement.get_input_set().get_hash())
{
return false;
}
if self.f.len() != params.get_m() as usize {

// Each statement must use the same parameters (checked using the hash for efficiency)
if !statements
.iter()
.map(|s| s.get_params().get_hash())
.all(|h| h == first_statement.get_params().get_hash())
{
return false;
}
for f_row in &self.f {
if f_row.len() != (params.get_n() - 1) as usize {
return false;
}
}

// Generate the verifier challenge
let mut transcript = Transcript::new("Triptych proof".as_bytes());
transcript.append_u64("version".as_bytes(), VERSION);
if let Some(message) = statement.get_message() {
transcript.append_message("message".as_bytes(), message);
}
transcript.append_message("params".as_bytes(), params.get_hash());
transcript.append_message("M".as_bytes(), statement.get_input_set().get_hash());
transcript.append_message("J".as_bytes(), J.compress().as_bytes());
// Extract common values for convenience
let M = first_statement.get_input_set().get_keys();
let params = first_statement.get_params();

transcript.append_message("A".as_bytes(), self.A.compress().as_bytes());
transcript.append_message("B".as_bytes(), self.B.compress().as_bytes());
transcript.append_message("C".as_bytes(), self.C.compress().as_bytes());
transcript.append_message("D".as_bytes(), self.D.compress().as_bytes());
for item in &self.X {
transcript.append_message("X".as_bytes(), item.compress().as_bytes());
}
for item in &self.Y {
transcript.append_message("Y".as_bytes(), item.compress().as_bytes());
// Check that all proof semantics are valid for the statement
for proof in proofs {
if proof.X.len() != params.get_m() as usize {
return false;
}
if proof.Y.len() != params.get_m() as usize {
return false;
}
if proof.f.len() != params.get_m() as usize {
return false;
}
for f_row in &proof.f {
if f_row.len() != (params.get_n() - 1) as usize {
return false;
}
}
}

// Get challenge powers
let xi_powers = match xi_powers(&mut transcript, params.get_m()) {
Ok(xi_powers) => xi_powers,
// Determine the size of the final check vector, which must not overflow `usize`
let batch_size = match u32::try_from(proofs.len()) {
Ok(batch) => batch,
_ => {
return false;
},
};

// Finish the transcript for pseudorandom number generation
for f_row in &self.f {
for f in f_row {
transcript.append_message("f".as_bytes(), f.as_bytes());
}
}
transcript.append_message("z_A".as_bytes(), self.z_A.as_bytes());
transcript.append_message("z_C".as_bytes(), self.z_C.as_bytes());
transcript.append_message("z".as_bytes(), self.z.as_bytes());
let mut transcript_rng = transcript.build_rng().finalize(&mut DangerousRng);

// Reconstruct the remaining `f` terms
let f = (0..params.get_m())
.map(|j| {
let mut f_j = Vec::with_capacity(params.get_n() as usize);
f_j.push(xi_powers[1] - self.f[j as usize].iter().sum::<Scalar>());
f_j.extend(self.f[j as usize].iter());
f_j
})
.collect::<Vec<Vec<Scalar>>>();

// Check that `f` does not contain zero, which breaks batch inversion
for f_row in &f {
if f_row.contains(&Scalar::ZERO) {
let final_size = match usize::try_from(
1 // G
+ params.get_n() * params.get_m() // CommitmentG
+ 1 // CommitmentH
+ params.get_N() // M
+ 1 // U
+ batch_size * (
4 // A, B, C, D
+ 1 // J
+ 2 * params.get_m() // X, Y
),
) {
Ok(size) => size,
_ => {
return false;
}
}

// Generate weights for verification equations
// We implicitly set `w3 = 1` to avoid unnecessary constant-time multiplication
let w1 = Scalar::random(&mut transcript_rng);
let w2 = Scalar::random(&mut transcript_rng);
let w4 = Scalar::random(&mut transcript_rng);
},
};

// Set up the point iterator for the final check
let points = once(params.get_G())
// Set up the point vector for the final check
let points = proofs
.iter()
.zip(statements.iter())
.flat_map(|(p, s)| {
once(&p.A)
.chain(once(&p.B))
.chain(once(&p.C))
.chain(once(&p.D))
.chain(once(s.get_J()))
.chain(p.X.iter())
.chain(p.Y.iter())
})
.chain(once(params.get_G()))
.chain(params.get_CommitmentG().iter())
.chain(once(params.get_CommitmentH()))
.chain(once(&self.A))
.chain(once(&self.B))
.chain(once(&self.C))
.chain(once(&self.D))
.chain(once(J))
.chain(self.X.iter())
.chain(self.Y.iter())
.chain(M.iter())
.chain(once(params.get_U()));
.chain(once(params.get_U()))
.collect::<Vec<&RistrettoPoint>>();

// Start the scalar vector, putting the common elements last
let mut scalars = Vec::with_capacity(final_size);

// Set up the scalar vector for the final check, matching the point iterator
let mut scalars =
Vec::with_capacity((params.get_N() + 2 * params.get_m() + params.get_n() * params.get_m() + 8) as usize);
// Set up common scalars
let mut G_scalar = Scalar::ZERO;
let mut CommitmentG_scalars = vec![Scalar::ZERO; params.get_CommitmentG().len()];
let mut CommitmentH_scalar = Scalar::ZERO;
let mut M_scalars = vec![Scalar::ZERO; M.len()];
let mut U_scalar = Scalar::ZERO;

// G
scalars.push(-self.z);
// Set up a transcript generator for use in weighting
let mut transcript_weights = Transcript::new("Triptych verifier weights".as_bytes());

// Generate all verifier challenges
let mut xi_powers_all = Vec::with_capacity(proofs.len());
for (statement, proof) in statements.iter().zip(proofs.iter()) {
// Generate the verifier challenge
let mut transcript = Transcript::new("Triptych proof".as_bytes());
transcript.append_u64("version".as_bytes(), VERSION);
if let Some(message) = statement.get_message() {
transcript.append_message("message".as_bytes(), message);
}
transcript.append_message("params".as_bytes(), params.get_hash());
transcript.append_message("M".as_bytes(), statement.get_input_set().get_hash());
transcript.append_message("J".as_bytes(), statement.get_J().compress().as_bytes());

transcript.append_message("A".as_bytes(), proof.A.compress().as_bytes());
transcript.append_message("B".as_bytes(), proof.B.compress().as_bytes());
transcript.append_message("C".as_bytes(), proof.C.compress().as_bytes());
transcript.append_message("D".as_bytes(), proof.D.compress().as_bytes());
for item in &proof.X {
transcript.append_message("X".as_bytes(), item.compress().as_bytes());
}
for item in &proof.Y {
transcript.append_message("Y".as_bytes(), item.compress().as_bytes());
}

// CommitmentG
for f_row in &f {
for f_item in f_row {
scalars.push(w1 * f_item + w2 * f_item * (xi_powers[1] - f_item));
// Get challenge powers
let xi_powers = match xi_powers(&mut transcript, params.get_m()) {
Ok(xi_powers) => xi_powers,
_ => {
return false;
},
};

xi_powers_all.push(xi_powers);

// Finish the transcript for pseudorandom number generation
for f_row in &proof.f {
for f in f_row {
transcript.append_message("f".as_bytes(), f.as_bytes());
}
}
transcript.append_message("z_A".as_bytes(), proof.z_A.as_bytes());
transcript.append_message("z_C".as_bytes(), proof.z_C.as_bytes());
transcript.append_message("z".as_bytes(), proof.z.as_bytes());
let mut transcript_rng = transcript.build_rng().finalize(&mut DangerousRng);

transcript_weights.append_u64("proof".as_bytes(), transcript_rng.as_rngcore().next_u64());
}

// CommitmentH
scalars.push(w1 * self.z_A + w2 * self.z_C);
// Finalize the weighting transcript into a pseudorandom number generator
let mut transcript_weights_rng = transcript_weights.build_rng().finalize(&mut DangerousRng);

// Process each proof
for (proof, xi_powers) in proofs.iter().zip(xi_powers_all.iter()) {
// Reconstruct the remaining `f` terms
let f = (0..params.get_m())
.map(|j| {
let mut f_j = Vec::with_capacity(params.get_n() as usize);
f_j.push(xi_powers[1] - proof.f[j as usize].iter().sum::<Scalar>());
f_j.extend(proof.f[j as usize].iter());
f_j
})
.collect::<Vec<Vec<Scalar>>>();

// Check that `f` does not contain zero, which breaks batch inversion
for f_row in &f {
if f_row.contains(&Scalar::ZERO) {
return false;
}
}

// A
scalars.push(-w1);
// Generate weights for this proof's verification equations
let w1 = Scalar::random(&mut transcript_weights_rng);
let w2 = Scalar::random(&mut transcript_weights_rng);
let w3 = Scalar::random(&mut transcript_weights_rng);
let w4 = Scalar::random(&mut transcript_weights_rng);

// B
scalars.push(-w1 * xi_powers[1]);
// Get the challenge for convenience
let xi = xi_powers[1];

// C
scalars.push(-w2 * xi_powers[1]);
// G
G_scalar -= w3 * proof.z;

// D
scalars.push(-w2);
// CommitmentG
for (CommitmentG_scalar, f_item) in CommitmentG_scalars
.iter_mut()
.zip(f.iter().flatten().map(|f| w1 * f + w2 * f * (xi - f)))
{
*CommitmentG_scalar += f_item;
}

// J
scalars.push(-w4 * self.z);
// CommitmentH
CommitmentH_scalar += w1 * proof.z_A + w2 * proof.z_C;

// X
for xi_power in &xi_powers[0..(params.get_m() as usize)] {
scalars.push(-xi_power);
}
// A
scalars.push(-w1);

// Y
for xi_power in &xi_powers[0..(params.get_m() as usize)] {
scalars.push(-w4 * xi_power);
}
// B
scalars.push(-w1 * xi_powers[1]);

// Set up the initial `f` product and Gray iterator
let mut f_product = f.iter().map(|f_row| f_row[0]).product::<Scalar>();
let gray_iterator = if let Some(gray_iterator) = GrayIterator::new(params.get_n(), params.get_m()) {
gray_iterator
} else {
return false;
};
// C
scalars.push(-w2 * xi_powers[1]);

// D
scalars.push(-w2);

// J
scalars.push(-w4 * proof.z);

// Invert each element of `f` for efficiency
let mut f_inverse_flat = f.iter().flatten().copied().collect::<Vec<Scalar>>();
Scalar::batch_invert(&mut f_inverse_flat);
let f_inverse = f_inverse_flat
.chunks_exact(params.get_n() as usize)
.collect::<Vec<&[Scalar]>>();
// X
for xi_power in &xi_powers[0..(params.get_m() as usize)] {
scalars.push(-w3 * xi_power);
}

// Y
for xi_power in &xi_powers[0..(params.get_m() as usize)] {
scalars.push(-w4 * xi_power);
}

// M
for (gray_index, gray_old, gray_new) in gray_iterator {
// Update the `f` product
f_product *= f_inverse[gray_index][gray_old as usize] * f[gray_index][gray_new as usize];
// Set up the initial `f` product and Gray iterator
let mut f_product = f.iter().map(|f_row| f_row[0]).product::<Scalar>();
let gray_iterator = if let Some(gray_iterator) = GrayIterator::new(params.get_n(), params.get_m()) {
gray_iterator
} else {
return false;
};

// Invert each element of `f` for efficiency
let mut f_inverse_flat = f.iter().flatten().copied().collect::<Vec<Scalar>>();
Scalar::batch_invert(&mut f_inverse_flat);
let f_inverse = f_inverse_flat
.chunks_exact(params.get_n() as usize)
.collect::<Vec<&[Scalar]>>();

// M
let mut U_scalar_proof = Scalar::ZERO;
for (M_scalar, (gray_index, gray_old, gray_new)) in M_scalars.iter_mut().zip(gray_iterator) {
// Update the `f` product
f_product *= f_inverse[gray_index][gray_old as usize] * f[gray_index][gray_new as usize];

*M_scalar += w3 * f_product;
U_scalar_proof += f_product;
}

scalars.push(f_product);
U_scalar += f_product;
// U
U_scalar += w4 * U_scalar_proof;
}

// U
scalars.push(w4 * U_scalar);
// Add all common elements to the scalar vector
scalars.push(G_scalar);
scalars.extend(CommitmentG_scalars);
scalars.push(CommitmentH_scalar);
scalars.extend(M_scalars);
scalars.push(U_scalar);

// Perform the final check; this can be done in variable time since it holds no secrets
RistrettoPoint::vartime_multiscalar_mul(scalars.iter(), points) == RistrettoPoint::identity()
@@ -523,7 +631,7 @@ impl Proof {
mod test {
use alloc::{sync::Arc, vec::Vec};

use curve25519_dalek::RistrettoPoint;
use curve25519_dalek::{RistrettoPoint, Scalar};
use rand_chacha::ChaCha12Rng;
use rand_core::{CryptoRngCore, SeedableRng};

@@ -563,6 +671,43 @@ mod test {
(witness, statement)
}

// Generate a batch of witnesses and corresponding statements
#[allow(non_snake_case)]
fn generate_batch_data<R: CryptoRngCore>(n: u32, m: u32, b: usize, rng: &mut R) -> (Vec<Witness>, Vec<Statement>) {
// Generate parameters
let params = Arc::new(Parameters::new(n, m).unwrap());

// Generate witnesses; for this test, we use adjacent indexes for simplicity
// This means the batch size must not exceed the input set size!
assert!(b <= params.get_N() as usize);
let mut witnesses = Vec::with_capacity(b);
witnesses.push(Witness::random(&params, rng));
for _ in 1..b {
let r = Scalar::random(rng);
let l = (witnesses.last().unwrap().get_l() + 1) % params.get_N();
witnesses.push(Witness::new(&params, l, &r).unwrap());
}

// Generate input set from all witnesses
let mut M = (0..params.get_N())
.map(|_| RistrettoPoint::random(rng))
.collect::<Vec<RistrettoPoint>>();
for witness in &witnesses {
M[witness.get_l() as usize] = witness.compute_verification_key();
}
let input_set = Arc::new(InputSet::new(&M));

// Generate statements
let mut statements = Vec::with_capacity(b);
for witness in &witnesses {
let J = witness.compute_linking_tag();
let message = "Proof message".as_bytes();
statements.push(Statement::new(&params, &input_set, &J, Some(message)).unwrap());
}

(witnesses, statements)
}

#[test]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_prove_verify() {
@@ -591,6 +736,25 @@ mod test {
assert!(proof.verify(&statement));
}

#[test]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_prove_verify_batch() {
// Generate data
const n: u32 = 2;
const m: u32 = 4;
const b: usize = 3; // batch size
let mut rng = ChaCha12Rng::seed_from_u64(8675309);
let (witnesses, statements) = generate_batch_data(n, m, b, &mut rng);

// Generate the proofs and verify as a batch
let proofs = witnesses
.iter()
.zip(statements.iter())
.map(|(w, s)| Proof::prove_vartime(w, s, &mut rng).unwrap())
.collect::<Vec<Proof>>();
assert!(Proof::verify_batch(&statements, &proofs));
}

#[test]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_evil_message() {