From 1599b97932284e8bad212f6b4596607c0802fee7 Mon Sep 17 00:00:00 2001 From: Aaron Feickert <66188213+AaronFeickert@users.noreply.github.com> Date: Fri, 5 Jan 2024 16:18:36 -0600 Subject: [PATCH] Variable-time prover (#19) Prover operations make some attempt at avoiding timing side-channel attacks through the use of constant-time multiscalar multiplication operations that are used in several places. While this is useful in general, some callers might not need this, and would prefer a speedier prover. This PR adds a `prove_vartime` function that uses variable-time multiscalar multiplication, which cuts the proving time by about half. A simple refactoring means it and the existing `prove` function are now trivial wrappers to avoid code duplication. Tests and benchmarks are updated to account for this. --- benches/triptych.rs | 51 ++++++++++++++++++++++++++++++- src/parameters.rs | 15 ++++++++-- src/proof.rs | 73 +++++++++++++++++++++++++++++++++++++++------ 3 files changed, 127 insertions(+), 12 deletions(-) diff --git a/benches/triptych.rs b/benches/triptych.rs index 0967e5d..2e6a218 100644 --- a/benches/triptych.rs +++ b/benches/triptych.rs @@ -68,6 +68,55 @@ fn generate_proof(c: &mut Criterion) { group.finish(); } +#[allow(non_snake_case)] +#[allow(non_upper_case_globals)] +fn generate_proof_vartime(c: &mut Criterion) { + let mut group = c.benchmark_group("generate_proof_vartime"); + let mut rng = ChaCha12Rng::seed_from_u64(8675309); + + for n in N_VALUES { + for m in M_VALUES { + // Generate parameters + let params = Arc::new(Parameters::new(n, m).unwrap()); + + let label = format!( + "Generate proof (variable time): n = {}, m = {} (N = {})", + n, + m, + params.get_N() + ); + group.bench_function(&label, |b| { + // Generate witness + let witness = Witness::random(¶ms, &mut rng); + + // Generate input set + let M = (0..params.get_N()) + .map(|i| { + if i == witness.get_l() { + witness.compute_verification_key() + } else { + RistrettoPoint::random(&mut rng) + } + }) + .collect::>(); + let input_set = Arc::new(InputSet::new(&M)); + + // Generate statement + let J = witness.compute_linking_tag(); + let statement = Statement::new(¶ms, &input_set, &J).unwrap(); + + // Start the benchmark + b.iter(|| { + // Generate the proof + let _proof = + Proof::prove_vartime(&witness, &statement, Some("Proof message".as_bytes()), &mut rng).unwrap(); + }) + }); + } + } + group.finish(); +} + #[allow(non_snake_case)] #[allow(non_upper_case_globals)] fn verify_proof(c: &mut Criterion) { @@ -116,7 +165,7 @@ fn verify_proof(c: &mut Criterion) { criterion_group! { name = generate; config = Criterion::default(); - targets = generate_proof + targets = generate_proof, generate_proof_vartime } criterion_group! { diff --git a/src/parameters.rs b/src/parameters.rs index dea383d..9786f25 100644 --- a/src/parameters.rs +++ b/src/parameters.rs @@ -5,7 +5,12 @@ use alloc::vec::Vec; use core::iter::once; use blake3::Hasher; -use curve25519_dalek::{constants::RISTRETTO_BASEPOINT_POINT, traits::MultiscalarMul, RistrettoPoint, Scalar}; +use curve25519_dalek::{ + constants::RISTRETTO_BASEPOINT_POINT, + traits::{MultiscalarMul, VartimeMultiscalarMul}, + RistrettoPoint, + Scalar, +}; use snafu::prelude::*; /// Public parameters used for generating and verifying Triptych proofs. @@ -126,10 +131,12 @@ impl Parameters { /// Commit to a matrix. /// /// This requires that `matrix` be an `m x n` scalar matrix. + /// You can decide if you want to use variable-time operations via the `vartime` flag. pub(crate) fn commit_matrix( &self, matrix: &[Vec], mask: &Scalar, + vartime: bool, ) -> Result { // Check that the matrix dimensions are valid if matrix.len() != (self.m as usize) || matrix.iter().any(|m| m.len() != (self.n as usize)) { @@ -140,7 +147,11 @@ impl Parameters { let scalars = matrix.iter().flatten().chain(once(mask)).collect::>(); let points = self.get_CommitmentG().iter().chain(once(self.get_CommitmentH())); - Ok(RistrettoPoint::multiscalar_mul(scalars, points)) + if vartime { + Ok(RistrettoPoint::vartime_multiscalar_mul(scalars, points)) + } else { + Ok(RistrettoPoint::multiscalar_mul(scalars, points)) + } } /// Get the group generator `G` from these parameters. diff --git a/src/proof.rs b/src/proof.rs index 52c492b..5bf1630 100644 --- a/src/proof.rs +++ b/src/proof.rs @@ -83,6 +83,28 @@ fn xi_powers(transcript: &mut Transcript, m: u32) -> Result, ProofEr } impl Proof { + /// Generate a Triptych proof, throwing constant-time operations out the window. + /// + /// The proof is generated by supplying a witness `witness` and corresponding statement `statement`. + /// If the witness and statement do not share the same parameters, or if the statement is invalid for the witness, + /// returns an error. + /// + /// You must also supply a cryptographically-secure random number generator `rng`. + /// + /// You may optionally provide a byte slice `message` that is bound to the proof's Fiat-Shamir transcript. + /// The verifier must provide the same message in order for the proof to verify. + /// + /// This function specifically avoids constant-time operations for efficiency. + /// If you want any attempt at avoiding timing side-channel attacks, use `prove` instead. + pub fn prove_vartime( + witness: &Witness, + statement: &Statement, + message: Option<&[u8]>, + rng: &mut R, + ) -> Result { + Self::prove_internal(witness, statement, message, rng, true) + } + /// Generate a Triptych proof. /// /// The proof is generated by supplying a witness `witness` and corresponding statement `statement`. @@ -93,12 +115,26 @@ impl Proof { /// /// You may optionally provide a byte slice `message` that is bound to the proof's Fiat-Shamir transcript. /// The verifier must provide the same message in order for the proof to verify. - #[allow(clippy::too_many_lines, non_snake_case)] + /// + /// This function makes some attempt at avoiding timing side-channel attacks. + /// If you know you don't need this, you can use `prove_vartime` for speedier operations. pub fn prove( witness: &Witness, statement: &Statement, message: Option<&[u8]>, rng: &mut R, + ) -> Result { + Self::prove_internal(witness, statement, message, rng, false) + } + + /// The actual prover functionality. + #[allow(clippy::too_many_lines, non_snake_case)] + fn prove_internal( + witness: &Witness, + statement: &Statement, + message: Option<&[u8]>, + rng: &mut R, + vartime: bool, ) -> Result { // Check that the witness and statement have identical parameters if witness.get_params() != statement.get_params() { @@ -150,7 +186,7 @@ impl Proof { a[j][0] = -a[j][1..].iter().sum::(); } let A = params - .commit_matrix(&a, &r_A) + .commit_matrix(&a, &r_A, vartime) .map_err(|_| ProofError::InvalidParameter)?; // Compute the `B` matrix commitment @@ -165,7 +201,7 @@ impl Proof { }) .collect::>>(); let B = params - .commit_matrix(&sigma, &r_B) + .commit_matrix(&sigma, &r_B, vartime) .map_err(|_| ProofError::InvalidParameter)?; // Compute the `C` matrix commitment @@ -179,7 +215,7 @@ impl Proof { }) .collect::>>(); let C = params - .commit_matrix(&a_sigma, &r_C) + .commit_matrix(&a_sigma, &r_C, vartime) .map_err(|_| ProofError::InvalidParameter)?; // Compute the `D` matrix commitment @@ -192,7 +228,7 @@ impl Proof { }) .collect::>>(); let D = params - .commit_matrix(&a_square, &r_D) + .commit_matrix(&a_square, &r_D, vartime) .map_err(|_| ProofError::InvalidParameter)?; // Random masks @@ -251,7 +287,11 @@ impl Proof { let X_points = M.iter().chain(once(params.get_G())); let X_scalars = p.iter().map(|p| &p[j]).chain(once(rho)); - RistrettoPoint::multiscalar_mul(X_scalars, X_points) + if vartime { + RistrettoPoint::vartime_multiscalar_mul(X_scalars, X_points) + } else { + RistrettoPoint::multiscalar_mul(X_scalars, X_points) + } }) .collect::>(); @@ -544,6 +584,21 @@ mod test { assert!(proof.verify(&statement, Some(message))); } + #[test] + #[allow(non_snake_case, non_upper_case_globals)] + fn test_prove_verify_vartime() { + // Generate data + const n: u32 = 2; + const m: u32 = 4; + let mut rng = ChaCha12Rng::seed_from_u64(8675309); + let (witness, statement) = generate_data(n, m, &mut rng); + + // Generate and verify a proof + let message = "Proof messsage".as_bytes(); + let proof = Proof::prove_vartime(&witness, &statement, Some(message), &mut rng).unwrap(); + assert!(proof.verify(&statement, Some(message))); + } + #[test] #[allow(non_snake_case, non_upper_case_globals)] fn test_evil_message() { @@ -555,7 +610,7 @@ mod test { // Generate a proof let message = "Proof messsage".as_bytes(); - let proof = Proof::prove(&witness, &statement, Some(message), &mut rng).unwrap(); + let proof = Proof::prove_vartime(&witness, &statement, Some(message), &mut rng).unwrap(); // Attempt to verify the proof against a different message, which should fail let evil_message = "Evil proof message".as_bytes(); @@ -573,7 +628,7 @@ mod test { // Generate a proof let message = "Proof messsage".as_bytes(); - let proof = Proof::prove(&witness, &statement, Some(message), &mut rng).unwrap(); + let proof = Proof::prove_vartime(&witness, &statement, Some(message), &mut rng).unwrap(); // Generate a statement with a modified input set let mut M = statement.get_input_set().get_keys().to_vec(); @@ -597,7 +652,7 @@ mod test { // Generate a proof let message = "Proof messsage".as_bytes(); - let proof = Proof::prove(&witness, &statement, Some(message), &mut rng).unwrap(); + let proof = Proof::prove_vartime(&witness, &statement, Some(message), &mut rng).unwrap(); // Generate a statement with a modified linking tag let evil_statement = Statement::new(