Skip to content

Commit

Permalink
refactor: Replace zip with zip_eq for iterator safety (#149)
Browse files Browse the repository at this point in the history
* refactor: Replace zip with zip_eq for iterator safety

- Replaced `.zip` method with `.zip_eq` from Itertools across multiple files to ensure equal length iteration between zipped arrays,
- Updated the clippy configuration to disallow zip methods which do not check the size of their iterator arguments.
- Removed redundant assertions and parallel processing that became unnecessary after the inclusion of `zip_eq`.

- Updated import statements in numerous files to include `Itertools` from itertools for the efficient usage of iterator tools.
- Added the `itertools` dependency version `0.12.0` to the `Cargo.toml` file.

* refactor: Allow zip for unequal lengths in msm

- Modified the `for` loop in the source provider `msm.rs`, replacing `zip_eq` with `zip` to populate variables.
- Addressed Clippy warning for using disallowed methods by adding explicit permissions.
  • Loading branch information
huitseeker authored Nov 29, 2023
1 parent 8057e85 commit 61ff80a
Show file tree
Hide file tree
Showing 20 changed files with 66 additions and 47 deletions.
7 changes: 6 additions & 1 deletion .clippy.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
type-complexity-threshold = 9999
too-many-arguments-threshold = 20
too-many-arguments-threshold = 20
disallowed-methods = [
# we are strict about size checks in iterators
{ path = "core::iter::traits::iterator::Iterator::zip", reason = "use itertools::zip_eq instead" },
{ path = "rayon::iter::IndexedParallelIterator::zip", reason = "use rayon::iter::IndexedParallelIterator::zip_eq instead" },
]
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ tracing-texray = "0.2.0"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
cfg-if = "1.0.0"
once_cell = "1.18.0"
itertools = "0.12.0"

[target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies]
pasta-msm = { git="https://github.com/lurk-lab/pasta-msm", branch="dev", version = "0.1.4" }
Expand Down
5 changes: 3 additions & 2 deletions src/gadgets/nonnative/bignat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use super::{
};
use bellpepper_core::{ConstraintSystem, LinearCombination, SynthesisError};
use ff::PrimeField;
use itertools::Itertools as _;
use num_bigint::BigInt;
use num_traits::cast::ToPrimitive;
use std::borrow::Borrow;
Expand Down Expand Up @@ -267,7 +268,7 @@ impl<Scalar: PrimeField> BigNat<Scalar> {
// swap the option and iterator
let limb_values_split =
(0..self.limbs.len()).map(|i| self.limb_values.as_ref().map(|vs| vs[i]));
for (i, (limb, limb_value)) in self.limbs.iter().zip(limb_values_split).enumerate() {
for (i, (limb, limb_value)) in self.limbs.iter().zip_eq(limb_values_split).enumerate() {
Num::new(limb_value, limb.clone())
.fits_in_bits(cs.namespace(|| format!("{i}")), self.params.limb_width)?;
}
Expand All @@ -284,7 +285,7 @@ impl<Scalar: PrimeField> BigNat<Scalar> {
let bitvectors: Vec<Bitvector<Scalar>> = self
.limbs
.iter()
.zip(limb_values_split)
.zip_eq(limb_values_split)
.enumerate()
.map(|(i, (limb, limb_value))| {
Num::new(limb_value, limb.clone()).decompose(
Expand Down
3 changes: 2 additions & 1 deletion src/gadgets/r1cs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::{
use bellpepper::gadgets::{boolean::Boolean, num::AllocatedNum, Assignment};
use bellpepper_core::{ConstraintSystem, SynthesisError};
use ff::Field;
use itertools::Itertools as _;

/// An Allocated R1CS Instance
#[derive(Clone)]
Expand Down Expand Up @@ -390,7 +391,7 @@ pub fn conditionally_select_vec_allocated_relaxed_r1cs_instance<
) -> Result<Vec<AllocatedRelaxedR1CSInstance<E>>, SynthesisError> {
a.iter()
.enumerate()
.zip(b.iter())
.zip_eq(b.iter())
.map(|((i, a), b)| {
a.conditionally_select(
cs.namespace(|| format!("cond ? a[{}]: b[{}]", i, i)),
Expand Down
3 changes: 2 additions & 1 deletion src/gadgets/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use bellpepper_core::{
ConstraintSystem, LinearCombination, SynthesisError,
};
use ff::{Field, PrimeField, PrimeFieldBits};
use itertools::Itertools as _;
use num_bigint::BigInt;

/// Gets as input the little indian representation of a number and spits out the number
Expand Down Expand Up @@ -212,7 +213,7 @@ pub fn conditionally_select_vec<F: PrimeField, CS: ConstraintSystem<F>>(
condition: &Boolean,
) -> Result<Vec<AllocatedNum<F>>, SynthesisError> {
a.iter()
.zip(b.iter())
.zip_eq(b.iter())
.enumerate()
.map(|(i, (a, b))| {
conditionally_select(cs.namespace(|| format!("select_{i}")), a, b, condition)
Expand Down
4 changes: 2 additions & 2 deletions src/provider/ipa_pc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,13 @@ where
// fold the left half and the right half
let a_vec_folded = a_vec[0..n / 2]
.par_iter()
.zip(a_vec[n / 2..n].par_iter())
.zip_eq(a_vec[n / 2..n].par_iter())
.map(|(a_L, a_R)| *a_L * r + r_inverse * *a_R)
.collect::<Vec<E::Scalar>>();

let b_vec_folded = b_vec[0..n / 2]
.par_iter()
.zip(b_vec[n / 2..n].par_iter())
.zip_eq(b_vec[n / 2..n].par_iter())
.map(|(b_L, b_R)| *b_L * r_inverse + r * *b_R)
.collect::<Vec<E::Scalar>>();

Expand Down
3 changes: 2 additions & 1 deletion src/provider/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ mod tests {
use digest::{ExtendableOutput, Update};
use group::{ff::Field, Curve, Group};
use halo2curves::{CurveAffine, CurveExt};
use itertools::Itertools as _;
use pasta_curves::{pallas, vesta};
use rand_core::OsRng;
use sha3::Shake256;
Expand Down Expand Up @@ -164,7 +165,7 @@ mod tests {
.collect::<Vec<_>>();
let naive = coeffs
.iter()
.zip(bases.iter())
.zip_eq(bases.iter())
.fold(A::CurveExt::identity(), |acc, (coeff, base)| {
acc + *base * coeff
});
Expand Down
9 changes: 6 additions & 3 deletions src/provider/msm.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! This module provides a multi-scalar multiplication routine
/// Adapted from zcash/halo2
use ff::PrimeField;
use itertools::Itertools as _;
use pasta_curves::{self, arithmetic::CurveAffine, group::Group as AnotherGroup};
use rayon::{current_num_threads, prelude::*};

Expand All @@ -22,6 +23,7 @@ fn cpu_msm_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve
}

let mut v = [0; 8];
#[allow(clippy::disallowed_methods)]
for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
*v = *o;
}
Expand Down Expand Up @@ -67,7 +69,7 @@ fn cpu_msm_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve

let mut buckets = vec![Bucket::None; (1 << c) - 1];

for (coeff, base) in coeffs.iter().zip(bases.iter()) {
for (coeff, base) in coeffs.iter().zip_eq(bases.iter()) {
let coeff = get_at::<C::Scalar>(segment, c, &coeff.to_repr());
if coeff != 0 {
buckets[coeff - 1].add_assign(base);
Expand Down Expand Up @@ -101,7 +103,7 @@ pub(crate) fn cpu_best_msm<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) ->
let chunk = coeffs.len() / num_threads;
coeffs
.par_chunks(chunk)
.zip(bases.par_chunks(chunk))
.zip_eq(bases.par_chunks(chunk))
.map(|(coeffs, bases)| cpu_msm_serial(coeffs, bases))
.reduce(C::Curve::identity, |sum, evl| sum + evl)
} else {
Expand All @@ -119,6 +121,7 @@ mod tests {
};
use group::{ff::Field, Group};
use halo2curves::CurveAffine;
use itertools::Itertools as _;
use pasta_curves::{pallas, vesta};
use rand_core::OsRng;

Expand All @@ -130,7 +133,7 @@ mod tests {
.collect::<Vec<_>>();
let naive = coeffs
.iter()
.zip(bases.iter())
.zip_eq(bases.iter())
.fold(A::CurveExt::identity(), |acc, (coeff, base)| {
acc + *base * coeff
});
Expand Down
12 changes: 6 additions & 6 deletions src/r1cs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,9 @@ impl<E: Engine> R1CSShape<E> {
let T = tracing::trace_span!("T").in_scope(|| {
AZ_1_circ_BZ_2
.par_iter()
.zip(&AZ_2_circ_BZ_1)
.zip(&u_1_cdot_CZ_2)
.zip(&u_2_cdot_CZ_1)
.zip_eq(&AZ_2_circ_BZ_1)
.zip_eq(&u_1_cdot_CZ_2)
.zip_eq(&u_2_cdot_CZ_1)
.map(|(((a, b), c), d)| *a + *b - *c - *d)
.collect::<Vec<E::Scalar>>()
});
Expand Down Expand Up @@ -479,12 +479,12 @@ impl<E: Engine> RelaxedR1CSWitness<E> {

let W = W1
.par_iter()
.zip(W2)
.zip_eq(W2)
.map(|(a, b)| *a + *r * *b)
.collect::<Vec<E::Scalar>>();
let E = E1
.par_iter()
.zip(T)
.zip_eq(T)
.map(|(a, b)| *a + *r * *b)
.collect::<Vec<E::Scalar>>();
Ok(RelaxedR1CSWitness { W, E })
Expand Down Expand Up @@ -557,7 +557,7 @@ impl<E: Engine> RelaxedR1CSInstance<E> {
// weighted sum of X, comm_W, comm_E, and u
let X = X1
.par_iter()
.zip(X2)
.zip_eq(X2)
.map(|(a, b)| *a + *r * *b)
.collect::<Vec<E::Scalar>>();
let comm_W = *comm_W_1 + *comm_W_2 * *r;
Expand Down
3 changes: 2 additions & 1 deletion src/r1cs/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::cmp::Ordering;
use abomonation::Abomonation;
use abomonation_derive::Abomonation;
use ff::PrimeField;
use itertools::Itertools as _;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -90,7 +91,7 @@ impl<F: PrimeField> SparseMatrix<F> {
pub fn get_row_unchecked(&self, ptrs: &[usize; 2]) -> impl Iterator<Item = (&F, &usize)> {
self.data[ptrs[0]..ptrs[1]]
.iter()
.zip(&self.indices[ptrs[0]..ptrs[1]])
.zip_eq(&self.indices[ptrs[0]..ptrs[1]])
}

/// Multiply by a dense vector; uses rayon to parallelize.
Expand Down
9 changes: 5 additions & 4 deletions src/spartan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mod sumcheck;

use crate::{traits::Engine, Commitment};
use ff::Field;
use itertools::Itertools as _;
use polys::multilinear::SparsePolynomial;
use rayon::{iter::IntoParallelRefIterator, prelude::*};

Expand Down Expand Up @@ -66,7 +67,7 @@ impl<E: Engine> PolyEvalWitness<E> {

let p = p_vec
.par_iter()
.zip(powers_of_s.par_iter())
.zip_eq(powers_of_s.par_iter())
.map(|(v, &weight)| {
// compute the weighted sum for each vector
v.iter().map(|&x| x * weight).collect::<Vec<E::Scalar>>()
Expand All @@ -75,7 +76,7 @@ impl<E: Engine> PolyEvalWitness<E> {
|| vec![E::Scalar::ZERO; p_vec[0].len()],
|acc, v| {
// perform vector addition to combine the weighted vectors
acc.into_iter().zip(v).map(|(x, y)| x + y).collect()
acc.into_iter().zip_eq(v).map(|(x, y)| x + y).collect()
},
);

Expand Down Expand Up @@ -115,12 +116,12 @@ impl<E: Engine> PolyEvalInstance<E> {
let powers_of_s = powers::<E>(s, c_vec.len());
let e = e_vec
.par_iter()
.zip(powers_of_s.par_iter())
.zip_eq(powers_of_s.par_iter())
.map(|(e, p)| *e * p)
.sum();
let c = c_vec
.par_iter()
.zip(powers_of_s.par_iter())
.zip_eq(powers_of_s.par_iter())
.map(|(c, p)| *c * *p)
.reduce(Commitment::<E>::default, |acc, item| acc + item);

Expand Down
2 changes: 1 addition & 1 deletion src/spartan/polys/eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl<Scalar: PrimeField> EqPolynomial<Scalar> {

evals_left
.par_iter_mut()
.zip(evals_right.par_iter_mut())
.zip_eq(evals_right.par_iter_mut())
.for_each(|(x, y)| {
*y = *x * r;
*x -= &*y;
Expand Down
7 changes: 4 additions & 3 deletions src/spartan/polys/multilinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
use std::ops::{Add, Index};

use ff::PrimeField;
use itertools::Itertools as _;
use rayon::prelude::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
IntoParallelRefMutIterator, ParallelIterator,
Expand Down Expand Up @@ -67,7 +68,7 @@ impl<Scalar: PrimeField> MultilinearPolynomial<Scalar> {

left
.par_iter_mut()
.zip(right.par_iter())
.zip_eq(right.par_iter())
.for_each(|(a, b)| {
*a += *r * (*b - *a);
});
Expand Down Expand Up @@ -97,7 +98,7 @@ impl<Scalar: PrimeField> MultilinearPolynomial<Scalar> {
EqPolynomial::new(r.to_vec())
.evals()
.into_par_iter()
.zip(Z.into_par_iter())
.zip_eq(Z.into_par_iter())
.map(|(a, b)| a * b)
.sum()
}
Expand Down Expand Up @@ -170,7 +171,7 @@ impl<Scalar: PrimeField> Add for MultilinearPolynomial<Scalar> {
let sum: Vec<Scalar> = self
.Z
.iter()
.zip(other.Z.iter())
.zip_eq(other.Z.iter())
.map(|(a, b)| *a + *b)
.collect();

Expand Down
9 changes: 5 additions & 4 deletions src/spartan/ppsnark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use abomonation::Abomonation;
use abomonation_derive::Abomonation;
use core::cmp::max;
use ff::{Field, PrimeField};
use itertools::Itertools as _;
use once_cell::sync::OnceCell;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -352,7 +353,7 @@ impl<E: Engine> MemorySumcheckInstance<E> {
Ok(
inv
.par_iter()
.zip(TS.par_iter())
.zip_eq(TS.par_iter())
.map(|(e1, e2)| *e1 * *e2)
.collect::<Vec<_>>(),
)
Expand Down Expand Up @@ -871,7 +872,7 @@ where
// compute the joint claim
let claim = claims
.iter()
.zip(coeffs.iter())
.zip_eq(coeffs.iter())
.map(|(c_1, c_2)| *c_1 * c_2)
.sum();

Expand Down Expand Up @@ -1115,8 +1116,8 @@ where
.S_repr
.val_A
.par_iter()
.zip(pk.S_repr.val_B.par_iter())
.zip(pk.S_repr.val_C.par_iter())
.zip_eq(pk.S_repr.val_B.par_iter())
.zip_eq(pk.S_repr.val_C.par_iter())
.map(|((v_a, v_b), v_c)| *v_a + c * *v_b + c * c * *v_c)
.collect::<Vec<E::Scalar>>();
let inner_sc_inst = InnerSumcheckInstance {
Expand Down
Loading

0 comments on commit 61ff80a

Please sign in to comment.