Skip to content

Commit

Permalink
Refactor Polyeval{Instance, Witness} (#315)
Browse files Browse the repository at this point in the history
* test: POlyEvalWitness batch

* test: PolyEvalInstance batch

* chore: C-CALLER-CONTROL

* refactor: shrink batching code

* chore: clippy
  • Loading branch information
huitseeker authored Feb 12, 2024
1 parent c1af06f commit 26fd303
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 78 deletions.
10 changes: 6 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ abomonation_derive = { version = "0.1.0", package = "abomonation_derive_ng" }
tracing = "0.1.37"
cfg-if = "1.0.0"
once_cell = "1.18.0"
itertools = "0.12.0"
itertools = "0.12.0" # zip_eq
rand = "0.8.5"
ref-cast = "1.0.20"
derive_more = "0.99.17"
ref-cast = "1.0.20" # allocation-less conversion in multilinear polys
derive_more = "0.99.17" # lightens impl macros for pasta
static_assertions = "1.1.0"

[target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies]
Expand All @@ -55,8 +55,10 @@ grumpkin-msm = { git = "https://github.com/lurk-lab/grumpkin-msm", branch = "dev
# see https://github.com/rust-random/rand/pull/948
getrandom = { version = "0.2.0", default-features = false, features = ["js"] }

[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
proptest = "1.2.0"

[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
pprof = { version = "0.13" }
criterion = { version = "0.5", features = ["html_reports"] }

Expand Down
1 change: 0 additions & 1 deletion src/r1cs/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! This module defines R1CS related types and a folding scheme for Relaxed R1CS
mod sparse;
#[cfg(test)]
pub(crate) mod util;

use crate::{
Expand Down
24 changes: 24 additions & 0 deletions src/r1cs/util.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use ff::PrimeField;
use group::Group;
#[cfg(not(target_arch = "wasm32"))]
use proptest::prelude::*;

Expand All @@ -24,3 +25,26 @@ impl<F: PrimeField> Arbitrary for FWrap<F> {
strategy.boxed()
}
}

/// Wrapper struct around a Group element that implements additional traits
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct GWrap<G>(pub G);

impl<G: Group> Copy for GWrap<G> {}

#[cfg(not(target_arch = "wasm32"))]
/// Trait implementation for generating `GWrap<F>` instances with proptest
impl<G: Group> Arbitrary for GWrap<G> {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;

fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
use rand::rngs::StdRng;
use rand_core::SeedableRng;

let strategy = any::<[u8; 32]>()
.prop_map(|seed| Self(G::random(StdRng::from_seed(seed))))
.no_shrink();
strategy.boxed()
}
}
2 changes: 1 addition & 1 deletion src/spartan/batched.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> BatchedRelaxedR1CSSNARKTrait<E>
};

let (batched_u, batched_w, sc_proof_batch, claims_batch_left) =
batch_eval_prove(u_vec, w_vec, &mut transcript)?;
batch_eval_prove(u_vec, &w_vec, &mut transcript)?;

let eval_arg = EE::prove(
ck,
Expand Down
7 changes: 4 additions & 3 deletions src/spartan/batched_ppsnark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> BatchedRelaxedR1CSSNARKTrait<E>
|comm_Az_Bz_Cz, evals_Az_Bz_Cz_at_tau| {
let u = PolyEvalInstance::<E>::batch(
comm_Az_Bz_Cz.as_slice(),
&[], // ignored by the function
vec![], // ignored by the function
evals_Az_Bz_Cz_at_tau.as_slice(),
&c,
);
Expand Down Expand Up @@ -701,7 +701,8 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> BatchedRelaxedR1CSSNARKTrait<E>
let num_vars_u = w_vec.iter().map(|w| w.p.len().log_2()).collect::<Vec<_>>();
let u_batch =
PolyEvalInstance::<E>::batch_diff_size(&comms_vec, &evals_vec, &num_vars_u, rand_sc, c);
let w_batch = PolyEvalWitness::<E>::batch_diff_size(w_vec, c);
let w_batch =
PolyEvalWitness::<E>::batch_diff_size(&w_vec.iter().by_ref().collect::<Vec<_>>(), c);

let eval_arg = EE::prove(
ck,
Expand Down Expand Up @@ -819,7 +820,7 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> BatchedRelaxedR1CSSNARKTrait<E>
|comm_Az_Bz_Cz, evals_Az_Bz_Cz_at_tau| {
let u = PolyEvalInstance::<E>::batch(
comm_Az_Bz_Cz.as_slice(),
&tau_coords,
tau_coords.clone(),
evals_Az_Bz_Cz_at_tau.as_slice(),
&c,
);
Expand Down
224 changes: 162 additions & 62 deletions src/spartan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ use crate::{
use ff::Field;
use itertools::Itertools as _;
use polys::multilinear::SparsePolynomial;

use rayon::{iter::IntoParallelRefIterator, prelude::*};
use ref_cast::RefCast;

// Creates a vector of the first `n` powers of `s`.
fn powers<E: Engine>(s: &E::Scalar, n: usize) -> Vec<E::Scalar> {
Expand All @@ -35,7 +37,8 @@ fn powers<E: Engine>(s: &E::Scalar, n: usize) -> Vec<E::Scalar> {
}

/// A type that holds a witness to a polynomial evaluation instance
#[derive(Debug)]
#[repr(transparent)]
#[derive(Debug, RefCast)]
struct PolyEvalWitness<E: Engine> {
p: Vec<E::Scalar>, // polynomial
}
Expand All @@ -47,39 +50,43 @@ impl<E: Engine> PolyEvalWitness<E> {
///
/// We allow the input polynomials to have different sizes, and interpret smaller ones as
/// being padded with 0 to the maximum size of all polynomials.
fn batch_diff_size(W: Vec<Self>, s: E::Scalar) -> Self {
fn batch_diff_size(W: &[&Self], s: E::Scalar) -> Self {
let powers = powers::<E>(&s, W.len());

let size_max = W.iter().map(|w| w.p.len()).max().unwrap();
let p_vec = W.par_iter().map(|w| &w.p);
// Scale the input polynomials by the power of s
let p = W
.into_par_iter()
.zip_eq(powers.par_iter())
.map(|(mut w, s)| {
if *s != E::Scalar::ONE {
w.p.par_iter_mut().for_each(|e| *e *= s);
}
w.p
})
.reduce(
|| vec![E::Scalar::ZERO; size_max],
|left, right| {
// Sum into the largest polynomial
let (mut big, small) = if left.len() > right.len() {
(left, right)
let p = zip_with!((p_vec, powers.par_iter()), |v, weight| {
// compute the weighted sum for each vector
v.iter()
.map(|&x| {
if *weight != E::Scalar::ONE {
x * *weight
} else {
(right, left)
};
x
}
})
.collect::<Vec<_>>()
})
.reduce(
|| vec![E::Scalar::ZERO; size_max],
|left, right| {
// Sum into the largest polynomial
let (mut big, small) = if left.len() > right.len() {
(left, right)
} else {
(right, left)
};

#[allow(clippy::disallowed_methods)]
big
.par_iter_mut()
.zip(small.par_iter())
.for_each(|(b, s)| *b += s);
#[allow(clippy::disallowed_methods)]
big
.par_iter_mut()
.zip(small.par_iter())
.for_each(|(b, s)| *b += s);

big
},
);
big
},
);

Self { p }
}
Expand All @@ -95,22 +102,8 @@ impl<E: Engine> PolyEvalWitness<E> {
.iter()
.skip(1)
.for_each(|p| assert_eq!(p.len(), p_vec[0].len()));

let powers_of_s = powers::<E>(s, p_vec.len());

let p = zip_with!(par_iter, (p_vec, powers_of_s), |v, weight| {
// compute the weighted sum for each vector
v.iter().map(|&x| x * *weight).collect::<Vec<E::Scalar>>()
})
.reduce(
|| vec![E::Scalar::ZERO; p_vec[0].len()],
|acc, v| {
// perform vector addition to combine the weighted vectors
acc.into_iter().zip_eq(v).map(|(x, y)| x + y).collect()
},
);

Self { p }
let instances = p_vec.iter().map(|p| Self::ref_cast(p)).collect::<Vec<_>>();
Self::batch_diff_size(&instances, *s)
}
}

Expand Down Expand Up @@ -150,15 +143,14 @@ impl<E: Engine> PolyEvalInstance<E> {

// vᵢ = L₀(x_lo)⋅Pᵢ(x_hi)
lagrange_eval * eval
})
.collect::<Vec<_>>();
});

// C = ∑ᵢ γⁱ⋅Cᵢ
let comm_joint = zip_with!(iter, (c_vec, powers), |c, g_i| *c * *g_i)
.fold(Commitment::<E>::default(), |acc, item| acc + item);

// v = ∑ᵢ γⁱ⋅vᵢ
let eval_joint = zip_with!((evals_scaled.into_iter(), powers.iter()), |e, g_i| e * g_i).sum();
let eval_joint = zip_with!((evals_scaled, powers.iter()), |e, g_i| e * g_i).sum();

Self {
c: comm_joint,
Expand All @@ -167,22 +159,9 @@ impl<E: Engine> PolyEvalInstance<E> {
}
}

fn batch(c_vec: &[Commitment<E>], x: &[E::Scalar], e_vec: &[E::Scalar], s: &E::Scalar) -> Self {
let num_instances = c_vec.len();
assert_eq!(e_vec.len(), num_instances);

let powers_of_s = powers::<E>(s, num_instances);
// Weighted sum of evaluations
let e = zip_with!(par_iter, (e_vec, powers_of_s), |e, p| *e * p).sum();
// Weighted sum of commitments
let c = zip_with!(par_iter, (c_vec, powers_of_s), |c, p| *c * *p)
.reduce(Commitment::<E>::default, |acc, item| acc + item);

Self {
c,
x: x.to_vec(),
e,
}
fn batch(c_vec: &[Commitment<E>], x: Vec<E::Scalar>, e_vec: &[E::Scalar], s: &E::Scalar) -> Self {
let sizes = vec![x.len(); e_vec.len()];
Self::batch_diff_size(c_vec, e_vec, &sizes, x, *s)
}
}

Expand Down Expand Up @@ -225,3 +204,124 @@ fn compute_eval_table_sparse<E: Engine>(

(A_evals, B_evals, C_evals)
}

#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
use crate::provider::PallasEngine;
use crate::r1cs::util::{FWrap, GWrap};
use pasta_curves::pallas::Point as PallasPoint;
use pasta_curves::Fq as Scalar;
use proptest::collection::vec;
use proptest::prelude::*;

impl<E: Engine> PolyEvalWitness<E> {
fn alt_batch(p_vec: &[&Vec<E::Scalar>], s: &E::Scalar) -> Self {
p_vec
.iter()
.skip(1)
.for_each(|p| assert_eq!(p.len(), p_vec[0].len()));

let powers_of_s = powers::<E>(s, p_vec.len());

let p = zip_with!(par_iter, (p_vec, powers_of_s), |v, weight| {
// compute the weighted sum for each vector
v.iter().map(|&x| x * *weight).collect::<Vec<E::Scalar>>()
})
.reduce(
|| vec![E::Scalar::ZERO; p_vec[0].len()],
|acc, v| {
// perform vector addition to combine the weighted vectors
acc.into_iter().zip_eq(v).map(|(x, y)| x + y).collect()
},
);

Self { p }
}
}

impl<E: Engine> PolyEvalInstance<E> {
fn alt_batch(
c_vec: &[Commitment<E>],
x: Vec<E::Scalar>,
e_vec: &[E::Scalar],
s: &E::Scalar,
) -> Self {
let num_instances = c_vec.len();
assert_eq!(e_vec.len(), num_instances);

let powers_of_s = powers::<E>(s, num_instances);
// Weighted sum of evaluations
let e = zip_with!(par_iter, (e_vec, powers_of_s), |e, p| *e * p).sum();
// Weighted sum of commitments
let c = zip_with!(par_iter, (c_vec, powers_of_s), |c, p| *c * *p)
.reduce(Commitment::<E>::default, |acc, item| acc + item);

Self { c, x, e }
}
}

proptest! {
#[test]
fn test_pe_witness_batch_diff_size_batch(
s in any::<FWrap<Scalar>>(),
vecs in (50usize..100).prop_flat_map(|size| vec(
vec(any::<FWrap<Scalar>>().prop_map(|f| f.0), size..=size), // even-sized vec
1..5))
)
{
// when the vectors are the same size, batch_diff_size and batch agree
let res = PolyEvalWitness::<PallasEngine>::alt_batch(&vecs.iter().by_ref().collect::<Vec<_>>(), &s.0);
let witnesses = vecs.iter().map(PolyEvalWitness::ref_cast).collect::<Vec<_>>();
let res2 = PolyEvalWitness::<PallasEngine>::batch_diff_size(&witnesses, s.0);

prop_assert_eq!(res.p, res2.p);
}

#[test]
fn test_pe_witness_batch_diff_size_pad_batch(
s in any::<FWrap<Scalar>>(),
vecs in (50usize..100).prop_flat_map(|size| vec(
vec(any::<FWrap<Scalar>>().prop_map(|f| f.0), size-10..=size), // even-sized vec
1..10))
)
{
let size = vecs.iter().map(|v| v.len()).max().unwrap_or(0);
// when the vectors are not the same size, batch agrees with the padded version of the input
let padded_vecs = vecs.iter().cloned().map(|mut v| {v.resize(size, Scalar::ZERO); v}).collect::<Vec<_>>();
let res = PolyEvalWitness::<PallasEngine>::alt_batch(&padded_vecs.iter().by_ref().collect::<Vec<_>>(), &s.0);
let witnesses = vecs.iter().map(PolyEvalWitness::ref_cast).collect::<Vec<_>>();
let res2 = PolyEvalWitness::<PallasEngine>::batch_diff_size(&witnesses, s.0);

prop_assert_eq!(res.p, res2.p);
}

#[test]
fn test_pe_instance_batch_diff_size_batch(
s in any::<FWrap<Scalar>>(),
vecs_tuple in (50usize..100).prop_flat_map(|size|
(vec(any::<GWrap<PallasPoint>>().prop_map(|f| f.0), size..=size),
vec(any::<FWrap<Scalar>>().prop_map(|f| f.0), size..=size),
vec(any::<FWrap<Scalar>>().prop_map(|f| f.0), size..=size)
), // even-sized vecs
)
)
{
let (c_vec, e_vec, x_vec) = vecs_tuple;
let c_vecs = c_vec.into_iter().map(|c| Commitment::<PallasEngine>{ comm: c }).collect::<Vec<_>>();
// when poly evals are all for the max # of variables, batch_diff_size and batch agree
let res = PolyEvalInstance::<PallasEngine>::alt_batch(
&c_vecs,
x_vec.clone(),
&e_vec,
&s.0);

let sizes = vec![x_vec.len(); x_vec.len()];
let res2 = PolyEvalInstance::<PallasEngine>::batch_diff_size(&c_vecs, &e_vec, &sizes, x_vec.clone(), s.0);

prop_assert_eq!(res.c, res2.c);
prop_assert_eq!(res.x, res2.x);
prop_assert_eq!(res.e, res2.e);
}
}
}
Loading

1 comment on commit 26fd303

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks

Table of Contents

Overview

This benchmark report shows the Arecibo GPU benchmarks.
NVIDIA L4
Intel(R) Xeon(R) CPU @ 2.20GHz
32 vCPUs
125 GB RAM
Workflow run: https://github.com/lurk-lab/arecibo/actions/runs/7875839506

Benchmark Results

RecursiveSNARK-NIVC-2

ref=c1af06f ref=26fd303
Prove-NumCons-6540 52.96 ms (✅ 1.00x) 52.95 ms (✅ 1.00x faster)
Verify-NumCons-6540 32.77 ms (✅ 1.00x) 33.19 ms (✅ 1.01x slower)
Prove-NumCons-1028888 343.18 ms (✅ 1.00x) 344.69 ms (✅ 1.00x slower)
Verify-NumCons-1028888 253.29 ms (✅ 1.00x) 256.84 ms (✅ 1.01x slower)

CompressedSNARK-NIVC-Commitments-2

ref=c1af06f ref=26fd303
Prove-NumCons-6540 13.89 s (✅ 1.00x) 13.84 s (✅ 1.00x faster)
Verify-NumCons-6540 77.77 ms (✅ 1.00x) 81.89 ms (✅ 1.05x slower)
Prove-NumCons-1028888 110.80 s (✅ 1.00x) 110.50 s (✅ 1.00x faster)
Verify-NumCons-1028888 772.30 ms (✅ 1.00x) 777.70 ms (✅ 1.01x slower)

Made with criterion-table

Please sign in to comment.