Skip to content

Commit

Permalink
Variable-time prover (#19)
Browse files Browse the repository at this point in the history
Prover operations make some attempt at avoiding timing side-channel attacks through the use of constant-time multiscalar multiplication operations that are used in several places. While this is useful in general, some callers might not need this, and would prefer a speedier prover.

This PR adds a `prove_vartime` function that uses variable-time multiscalar multiplication, which cuts the proving time by about half. A simple refactoring means it and the existing `prove` function are now trivial wrappers to avoid code duplication.

Tests and benchmarks are updated to account for this.
  • Loading branch information
AaronFeickert authored Jan 5, 2024
1 parent 0d68dff commit 1599b97
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 12 deletions.
51 changes: 50 additions & 1 deletion benches/triptych.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,55 @@ fn generate_proof(c: &mut Criterion) {
group.finish();
}

#[allow(non_snake_case)]
#[allow(non_upper_case_globals)]
fn generate_proof_vartime(c: &mut Criterion) {
let mut group = c.benchmark_group("generate_proof_vartime");
let mut rng = ChaCha12Rng::seed_from_u64(8675309);

for n in N_VALUES {
for m in M_VALUES {
// Generate parameters
let params = Arc::new(Parameters::new(n, m).unwrap());

let label = format!(
"Generate proof (variable time): n = {}, m = {} (N = {})",
n,
m,
params.get_N()
);
group.bench_function(&label, |b| {
// Generate witness
let witness = Witness::random(&params, &mut rng);

// Generate input set
let M = (0..params.get_N())
.map(|i| {
if i == witness.get_l() {
witness.compute_verification_key()
} else {
RistrettoPoint::random(&mut rng)
}
})
.collect::<Vec<RistrettoPoint>>();
let input_set = Arc::new(InputSet::new(&M));

// Generate statement
let J = witness.compute_linking_tag();
let statement = Statement::new(&params, &input_set, &J).unwrap();

// Start the benchmark
b.iter(|| {
// Generate the proof
let _proof =
Proof::prove_vartime(&witness, &statement, Some("Proof message".as_bytes()), &mut rng).unwrap();
})
});
}
}
group.finish();
}

#[allow(non_snake_case)]
#[allow(non_upper_case_globals)]
fn verify_proof(c: &mut Criterion) {
Expand Down Expand Up @@ -116,7 +165,7 @@ fn verify_proof(c: &mut Criterion) {
criterion_group! {
name = generate;
config = Criterion::default();
targets = generate_proof
targets = generate_proof, generate_proof_vartime
}

criterion_group! {
Expand Down
15 changes: 13 additions & 2 deletions src/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ use alloc::vec::Vec;
use core::iter::once;

use blake3::Hasher;
use curve25519_dalek::{constants::RISTRETTO_BASEPOINT_POINT, traits::MultiscalarMul, RistrettoPoint, Scalar};
use curve25519_dalek::{
constants::RISTRETTO_BASEPOINT_POINT,
traits::{MultiscalarMul, VartimeMultiscalarMul},
RistrettoPoint,
Scalar,
};
use snafu::prelude::*;

/// Public parameters used for generating and verifying Triptych proofs.
Expand Down Expand Up @@ -126,10 +131,12 @@ impl Parameters {
/// Commit to a matrix.
///
/// This requires that `matrix` be an `m x n` scalar matrix.
/// You can decide if you want to use variable-time operations via the `vartime` flag.
pub(crate) fn commit_matrix(
&self,
matrix: &[Vec<Scalar>],
mask: &Scalar,
vartime: bool,
) -> Result<RistrettoPoint, ParameterError> {
// Check that the matrix dimensions are valid
if matrix.len() != (self.m as usize) || matrix.iter().any(|m| m.len() != (self.n as usize)) {
Expand All @@ -140,7 +147,11 @@ impl Parameters {
let scalars = matrix.iter().flatten().chain(once(mask)).collect::<Vec<&Scalar>>();
let points = self.get_CommitmentG().iter().chain(once(self.get_CommitmentH()));

Ok(RistrettoPoint::multiscalar_mul(scalars, points))
if vartime {
Ok(RistrettoPoint::vartime_multiscalar_mul(scalars, points))
} else {
Ok(RistrettoPoint::multiscalar_mul(scalars, points))
}
}

/// Get the group generator `G` from these parameters.
Expand Down
73 changes: 64 additions & 9 deletions src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,28 @@ fn xi_powers(transcript: &mut Transcript, m: u32) -> Result<Vec<Scalar>, ProofEr
}

impl Proof {
/// Generate a Triptych proof, throwing constant-time operations out the window.
///
/// The proof is generated by supplying a witness `witness` and corresponding statement `statement`.
/// If the witness and statement do not share the same parameters, or if the statement is invalid for the witness,
/// returns an error.
///
/// You must also supply a cryptographically-secure random number generator `rng`.
///
/// You may optionally provide a byte slice `message` that is bound to the proof's Fiat-Shamir transcript.
/// The verifier must provide the same message in order for the proof to verify.
///
/// This function specifically avoids constant-time operations for efficiency.
/// If you want any attempt at avoiding timing side-channel attacks, use `prove` instead.
pub fn prove_vartime<R: CryptoRngCore>(
witness: &Witness,
statement: &Statement,
message: Option<&[u8]>,
rng: &mut R,
) -> Result<Self, ProofError> {
Self::prove_internal(witness, statement, message, rng, true)
}

/// Generate a Triptych proof.
///
/// The proof is generated by supplying a witness `witness` and corresponding statement `statement`.
Expand All @@ -93,12 +115,26 @@ impl Proof {
///
/// You may optionally provide a byte slice `message` that is bound to the proof's Fiat-Shamir transcript.
/// The verifier must provide the same message in order for the proof to verify.
#[allow(clippy::too_many_lines, non_snake_case)]
///
/// This function makes some attempt at avoiding timing side-channel attacks.
/// If you know you don't need this, you can use `prove_vartime` for speedier operations.
pub fn prove<R: CryptoRngCore>(
witness: &Witness,
statement: &Statement,
message: Option<&[u8]>,
rng: &mut R,
) -> Result<Self, ProofError> {
Self::prove_internal(witness, statement, message, rng, false)
}

/// The actual prover functionality.
#[allow(clippy::too_many_lines, non_snake_case)]
fn prove_internal<R: CryptoRngCore>(
witness: &Witness,
statement: &Statement,
message: Option<&[u8]>,
rng: &mut R,
vartime: bool,
) -> Result<Self, ProofError> {
// Check that the witness and statement have identical parameters
if witness.get_params() != statement.get_params() {
Expand Down Expand Up @@ -150,7 +186,7 @@ impl Proof {
a[j][0] = -a[j][1..].iter().sum::<Scalar>();
}
let A = params
.commit_matrix(&a, &r_A)
.commit_matrix(&a, &r_A, vartime)
.map_err(|_| ProofError::InvalidParameter)?;

// Compute the `B` matrix commitment
Expand All @@ -165,7 +201,7 @@ impl Proof {
})
.collect::<Vec<Vec<Scalar>>>();
let B = params
.commit_matrix(&sigma, &r_B)
.commit_matrix(&sigma, &r_B, vartime)
.map_err(|_| ProofError::InvalidParameter)?;

// Compute the `C` matrix commitment
Expand All @@ -179,7 +215,7 @@ impl Proof {
})
.collect::<Vec<Vec<Scalar>>>();
let C = params
.commit_matrix(&a_sigma, &r_C)
.commit_matrix(&a_sigma, &r_C, vartime)
.map_err(|_| ProofError::InvalidParameter)?;

// Compute the `D` matrix commitment
Expand All @@ -192,7 +228,7 @@ impl Proof {
})
.collect::<Vec<Vec<Scalar>>>();
let D = params
.commit_matrix(&a_square, &r_D)
.commit_matrix(&a_square, &r_D, vartime)
.map_err(|_| ProofError::InvalidParameter)?;

// Random masks
Expand Down Expand Up @@ -251,7 +287,11 @@ impl Proof {
let X_points = M.iter().chain(once(params.get_G()));
let X_scalars = p.iter().map(|p| &p[j]).chain(once(rho));

RistrettoPoint::multiscalar_mul(X_scalars, X_points)
if vartime {
RistrettoPoint::vartime_multiscalar_mul(X_scalars, X_points)
} else {
RistrettoPoint::multiscalar_mul(X_scalars, X_points)
}
})
.collect::<Vec<RistrettoPoint>>();

Expand Down Expand Up @@ -544,6 +584,21 @@ mod test {
assert!(proof.verify(&statement, Some(message)));
}

#[test]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_prove_verify_vartime() {
// Generate data
const n: u32 = 2;
const m: u32 = 4;
let mut rng = ChaCha12Rng::seed_from_u64(8675309);
let (witness, statement) = generate_data(n, m, &mut rng);

// Generate and verify a proof
let message = "Proof messsage".as_bytes();
let proof = Proof::prove_vartime(&witness, &statement, Some(message), &mut rng).unwrap();
assert!(proof.verify(&statement, Some(message)));
}

#[test]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_evil_message() {
Expand All @@ -555,7 +610,7 @@ mod test {

// Generate a proof
let message = "Proof messsage".as_bytes();
let proof = Proof::prove(&witness, &statement, Some(message), &mut rng).unwrap();
let proof = Proof::prove_vartime(&witness, &statement, Some(message), &mut rng).unwrap();

// Attempt to verify the proof against a different message, which should fail
let evil_message = "Evil proof message".as_bytes();
Expand All @@ -573,7 +628,7 @@ mod test {

// Generate a proof
let message = "Proof messsage".as_bytes();
let proof = Proof::prove(&witness, &statement, Some(message), &mut rng).unwrap();
let proof = Proof::prove_vartime(&witness, &statement, Some(message), &mut rng).unwrap();

// Generate a statement with a modified input set
let mut M = statement.get_input_set().get_keys().to_vec();
Expand All @@ -597,7 +652,7 @@ mod test {

// Generate a proof
let message = "Proof messsage".as_bytes();
let proof = Proof::prove(&witness, &statement, Some(message), &mut rng).unwrap();
let proof = Proof::prove_vartime(&witness, &statement, Some(message), &mut rng).unwrap();

// Generate a statement with a modified linking tag
let evil_statement = Statement::new(
Expand Down

0 comments on commit 1599b97

Please sign in to comment.