diff --git a/src/provider/mlkzg.rs b/src/provider/mlkzg.rs index 893926ab4..4e6b50ff3 100644 --- a/src/provider/mlkzg.rs +++ b/src/provider/mlkzg.rs @@ -7,6 +7,7 @@ use crate::{ non_hiding_kzg::{KZGProverKey, KZGVerifierKey, UniversalKZGParam}, pedersen::Commitment, traits::DlogGroup, + util::iterators::DoubleEndedIteratorExt as _, }, spartan::polys::univariate::UniPoly, traits::{ @@ -164,30 +165,14 @@ where .to_affine() }; - let kzg_open_batch = |f: &[Vec], + let kzg_open_batch = |f: Vec>, u: &[E::Fr], transcript: &mut ::TE| -> (Vec, Vec>) { - let scalar_vector_muladd = |a: &mut Vec, v: &Vec, s: E::Fr| { - assert!(a.len() >= v.len()); - #[allow(clippy::disallowed_methods)] - a.par_iter_mut() - .zip(v.par_iter()) - .for_each(|(c, v)| *c += s * v); - }; - - let kzg_compute_batch_polynomial = |f: &[Vec], q: E::Fr| -> Vec { - let k = f.len(); // Number of polynomials we're batching - - let q_powers = Self::batch_challenge_powers(q, k); - - // Compute B(x) = f[0] + q*f[1] + q^2 * f[2] + ... q^(k-1) * f[k-1] - let mut B = f[0].clone(); - for i in 1..k { - scalar_vector_muladd(&mut B, &f[i], q_powers[i]); // B += q_powers[i] * f[i] - } - - B + let kzg_compute_batch_polynomial = |f: Vec>, q: E::Fr| -> Vec { + // Compute B(x) = f_0(x) + q * f_1(x) + ... + q^(k-1) * f_{k-1}(x) + let B: UniPoly = f.into_iter().map(UniPoly::new).rlc(&q); + B.coeffs }; ///////// END kzg_open_batch closure helpers @@ -199,7 +184,7 @@ where let mut v = vec![vec!(E::Fr::ZERO; k); t]; v.par_iter_mut().enumerate().for_each(|(i, v_i)| { // for each point u - v_i.par_iter_mut().zip_eq(f).for_each(|(v_ij, f)| { + v_i.par_iter_mut().zip_eq(&f).for_each(|(v_ij, f)| { // for each poly f (except the last one - since it is constant) *v_ij = UniPoly::ref_cast(f).evaluate(&u[i]); }); @@ -259,7 +244,7 @@ where let u = vec![r, -r, r * r]; // Phase 3 -- create response - let (w, evals) = kzg_open_batch(&polys, &u, transcript); + let (w, evals) = kzg_open_batch(polys, &u, transcript); Ok(EvaluationArgument { comms, w, evals }) } diff --git a/src/provider/mod.rs b/src/provider/mod.rs index 7f12139c6..aabdd8678 100644 --- a/src/provider/mod.rs +++ b/src/provider/mod.rs @@ -15,7 +15,7 @@ pub(crate) mod traits; // a non-hiding variant of {kzg, zeromorph} mod kzg_commitment; mod non_hiding_kzg; -mod util; +pub(crate) mod util; // crate-private modules mod keccak; diff --git a/src/provider/util/mod.rs b/src/provider/util/mod.rs index 94a76af98..1b630c5b3 100644 --- a/src/provider/util/mod.rs +++ b/src/provider/util/mod.rs @@ -11,6 +11,35 @@ pub mod msm { } } +pub mod iterators { + use std::borrow::Borrow; + use std::iter::DoubleEndedIterator; + use std::ops::{AddAssign, MulAssign}; + + pub trait DoubleEndedIteratorExt: DoubleEndedIterator { + /// This function employs Horner's scheme and core traits to create a combination of an iterator input with the powers + /// of a provided coefficient. + fn rlc(&mut self, coefficient: &F) -> T + where + T: Clone + for<'a> MulAssign<&'a F> + for<'r> AddAssign<&'r T>, + Self::Item: Borrow, + { + let mut iter = self.rev(); + let Some(fst) = iter.next() else { + panic!("input iterator should not be empty") + }; + + iter.fold(fst.borrow().clone(), |mut acc, item| { + acc *= coefficient; + acc += item.borrow(); + acc + }) + } + } + + impl DoubleEndedIteratorExt for I {} +} + #[cfg(test)] pub mod test_utils { //! Contains utilities for testing and benchmarking. diff --git a/src/spartan/polys/univariate.rs b/src/spartan/polys/univariate.rs index da846439b..472d7f837 100644 --- a/src/spartan/polys/univariate.rs +++ b/src/spartan/polys/univariate.rs @@ -11,7 +11,10 @@ use rayon::prelude::{IntoParallelIterator, IntoParallelRefMutIterator, ParallelI use ref_cast::RefCast; use serde::{Deserialize, Serialize}; -use crate::traits::{Group, TranscriptReprTrait}; +use crate::{ + provider::util::iterators::DoubleEndedIteratorExt as _, + traits::{Group, TranscriptReprTrait}, +}; // ax^2 + bx + c stored as vec![c, b, a] // ax^3 + bx^2 + cx + d stored as vec![d, c, b, a] @@ -131,13 +134,7 @@ impl UniPoly { } pub fn evaluate(&self, r: &Scalar) -> Scalar { - let mut eval = self.coeffs[0]; - let mut power = *r; - for coeff in self.coeffs.iter().skip(1) { - eval += power * coeff; - power *= r; - } - eval + self.coeffs.iter().rlc(r) } pub fn compress(&self) -> CompressedUniPoly {