Skip to content

Commit

Permalink
Refactor pre-processing SNARK (#220)
Browse files Browse the repository at this point in the history
* refactor: Refactor and add test module for `padded` function

- Simplified the `padded` function in `ppsnark.rs`
- test case for the `padded` function

* refactor: clean up some unnnecessary wraps

* refactor: enforce stricter linting rules

- Updated clippy configuration by removing and adding specific lints for stricter code checks.

* refactor:  refactor sparse matrix iteration

- Refactor `R1CSShapeSparkRepr` in `ppsnark.rs`
- Restructure the main loop for simultaneous population of `row`, `col`, and `val_X` vectors.

* clippy::use_self

* feat: Move all Sumcheck engine related items to the sumcheck module

- Introduced a new module `src/spartan/sumcheck/engine.rs`,
- Changed the visibility status of the `sumcheck` module, making it visible within the `spartan` module.
- Refactored the file structure as `src/spartan/sumcheck.rs` was renamed to `src/spartan/sumcheck/mod.rs` and a new submodule `engine` was added within the `spartan` crate.
- Refactored the import structures in `ppsnark` and `sumcheck::engine`, moving several items to `sumcheck::engine`.

* fixup! clippy::use_self
  • Loading branch information
huitseeker authored Jan 1, 2024
1 parent 7754a70 commit c11e890
Show file tree
Hide file tree
Showing 36 changed files with 943 additions and 914 deletions.
16 changes: 15 additions & 1 deletion .cargo/config
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,38 @@
xclippy = [
"clippy", "--all-targets", "--",
"-Wclippy::all",
"-Wclippy::match_same_arms",
"-Wclippy::cast_lossless",
"-Wclippy::checked_conversions",
"-Wclippy::dbg_macro",
"-Wclippy::disallowed_methods",
"-Wclippy::derive_partial_eq_without_eq",
"-Wclippy::filter_map_next",
"-Wclippy::flat_map_option",
"-Wclippy::from_iter_instead_of_collect",
"-Wclippy::inefficient_to_string",
"-Wclippy::large_stack_arrays",
"-Wclippy::large_types_passed_by_value",
"-Wclippy::macro_use_imports",
"-Wclippy::manual_assert",
"-Wclippy::manual_ok_or",
"-Wclippy::map_flatten",
"-Wclippy::map_unwrap_or",
"-Wclippy::match_same_arms",
"-Wclippy::match_wild_err_arm",
"-Wclippy::needless_borrow",
"-Wclippy::needless_continue",
"-Wclippy::needless_for_each",
"-Wclippy::needless_pass_by_value",
"-Wclippy::option_option",
"-Wclippy::same_functions_in_if_condition",
"-Wclippy::single_match_else",
"-Wclippy::trait_duplication_in_bounds",
"-Wclippy::unnecessary_mut_passed",
"-Wclippy::unnecessary_wraps",
"-Wclippy::use_self",
"-Wnonstandard_style",
"-Wrust_2018_idioms",
"-Wtrivial_numeric_casts",
"-Wunused_lifetimes",
"-Wunused_qualifications",
]
4 changes: 2 additions & 2 deletions src/bellpepper/shape_cs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ where
impl<E: Engine> ShapeCS<E> {
/// Create a new, default `ShapeCS`,
pub fn new() -> Self {
ShapeCS::default()
Self::default()
}

/// Returns the number of constraints defined for this `ShapeCS`.
Expand All @@ -43,7 +43,7 @@ impl<E: Engine> ShapeCS<E> {

impl<E: Engine> Default for ShapeCS<E> {
fn default() -> Self {
ShapeCS {
Self {
constraints: vec![],
inputs: 1,
aux: 0,
Expand Down
8 changes: 4 additions & 4 deletions src/bellpepper/test_shape_cs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ enum NamedObject {

impl Eq for OrderedVariable {}
impl PartialEq for OrderedVariable {
fn eq(&self, other: &OrderedVariable) -> bool {
fn eq(&self, other: &Self) -> bool {
match (self.0.get_unchecked(), other.0.get_unchecked()) {
(Index::Input(ref a), Index::Input(ref b)) | (Index::Aux(ref a), Index::Aux(ref b)) => a == b,
_ => false,
Expand Down Expand Up @@ -95,7 +95,7 @@ where
#[allow(unused)]
/// Create a new, default `TestShapeCS`,
pub fn new() -> Self {
TestShapeCS::default()
Self::default()
}

/// Returns the number of constraints defined for this `TestShapeCS`.
Expand Down Expand Up @@ -216,8 +216,8 @@ where
impl<E: Engine> Default for TestShapeCS<E> {
fn default() -> Self {
let mut map = HashMap::new();
map.insert("ONE".into(), NamedObject::Var(TestShapeCS::<E>::one()));
TestShapeCS {
map.insert("ONE".into(), NamedObject::Var(Self::one()));
Self {
named_objects: map,
current_namespace: vec![],
constraints: vec![],
Expand Down
7 changes: 2 additions & 5 deletions src/digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,7 @@ impl<'a, F: PrimeField, T: Digestible> DigestComputer<'a, F, T> {
/// Compute the digest of a `Digestible` instance.
pub fn digest(&self) -> Result<F, io::Error> {
let mut hasher = Self::hasher();
self
.inner
.write_bytes(&mut hasher)
.expect("Serialization error");
self.inner.write_bytes(&mut hasher)?;
let bytes: [u8; 32] = hasher.finalize().into();
Ok(Self::map_to_field(&bytes))
}
Expand All @@ -99,7 +96,7 @@ mod tests {

impl<E: Engine> S<E> {
fn new(i: usize) -> Self {
S {
Self {
i,
digest: OnceCell::new(),
}
Expand Down
22 changes: 11 additions & 11 deletions src/gadgets/ecc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ where
|lc| lc,
);

Ok(AllocatedPoint { x, y, is_infinity })
Ok(Self { x, y, is_infinity })
}

/// checks if `self` is on the curve or if it is infinity
Expand Down Expand Up @@ -108,7 +108,7 @@ where
let zero = alloc_zero(cs.namespace(|| "zero"));
let one = alloc_one(cs.namespace(|| "one"));

Ok(AllocatedPoint {
Ok(Self {
x: zero.clone(),
y: zero,
is_infinity: one,
Expand Down Expand Up @@ -148,7 +148,7 @@ where
pub fn add<CS: ConstraintSystem<E::Base>>(
&self,
mut cs: CS,
other: &AllocatedPoint<E>,
other: &Self,
) -> Result<Self, SynthesisError> {
// Compute boolean equal indicating if self = other

Expand Down Expand Up @@ -178,13 +178,13 @@ where
// return add(self, other)
// }
// }
let result_for_equal_x = AllocatedPoint::select_point_or_infinity(
let result_for_equal_x = Self::select_point_or_infinity(
cs.namespace(|| "equal_y ? result_from_double : infinity"),
&result_from_double,
&Boolean::from(equal_y),
)?;

AllocatedPoint::conditionally_select(
Self::conditionally_select(
cs.namespace(|| "equal ? result_from_double : result_from_add"),
&result_for_equal_x,
&result_from_add,
Expand All @@ -197,7 +197,7 @@ where
pub fn add_internal<CS: ConstraintSystem<E::Base>>(
&self,
mut cs: CS,
other: &AllocatedPoint<E>,
other: &Self,
equal_x: &AllocatedBit,
) -> Result<Self, SynthesisError> {
//************************************************************************/
Expand Down Expand Up @@ -501,7 +501,7 @@ where
acc.add(cs.namespace(|| "res minus self"), &neg)
}?;

AllocatedPoint::conditionally_select(
Self::conditionally_select(
cs.namespace(|| "remove slack if necessary"),
&acc,
&acc_minus_initial,
Expand All @@ -527,7 +527,7 @@ where
)?;

// we now perform the remaining scalar mul using complete addition law
let mut acc = AllocatedPoint {
let mut acc = Self {
x,
y,
is_infinity: res.is_infinity,
Expand All @@ -536,7 +536,7 @@ where

for (i, bit) in complete_bits.iter().enumerate() {
let temp = acc.add(cs.namespace(|| format!("add_complete {i}")), &p_complete)?;
acc = AllocatedPoint::conditionally_select(
acc = Self::conditionally_select(
cs.namespace(|| format!("acc_complete_iteration_{i}")),
&temp,
&acc,
Expand Down Expand Up @@ -826,7 +826,7 @@ mod tests {
}

/// Add any two points
pub fn add(&self, other: &Point<E>) -> Self {
pub fn add(&self, other: &Self) -> Self {
if self.x == other.x {
// If self == other then call double
if self.y == other.y {
Expand All @@ -845,7 +845,7 @@ mod tests {
}

/// Add two different points
pub fn add_internal(&self, other: &Point<E>) -> Self {
pub fn add_internal(&self, other: &Self) -> Self {
if self.is_infinity {
return other.clone();
}
Expand Down
36 changes: 18 additions & 18 deletions src/gadgets/nonnative/bignat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl BigNatParams {
pub fn new(limb_width: usize, n_limbs: usize) -> Self {
let mut max_word = BigInt::from(1) << limb_width as u32;
max_word -= 1;
BigNatParams {
Self {
max_word,
n_limbs,
limb_width,
Expand Down Expand Up @@ -100,8 +100,8 @@ impl<Scalar: PrimeField> PartialEq for BigNat<Scalar> {
impl<Scalar: PrimeField> Eq for BigNat<Scalar> {}

impl<Scalar: PrimeField> From<BigNat<Scalar>> for Polynomial<Scalar> {
fn from(other: BigNat<Scalar>) -> Polynomial<Scalar> {
Polynomial {
fn from(other: BigNat<Scalar>) -> Self {
Self {
coefficients: other.limbs,
values: other.limb_values,
}
Expand Down Expand Up @@ -450,7 +450,7 @@ impl<Scalar: PrimeField> BigNat<Scalar> {
self_grouped.equal_when_carried(cs.namespace(|| "grouped"), &other_grouped)
}

pub fn add(&self, other: &Self) -> Result<BigNat<Scalar>, SynthesisError> {
pub fn add(&self, other: &Self) -> Result<Self, SynthesisError> {
self.enforce_limb_width_agreement(other, "add")?;
let n_limbs = max(self.params.n_limbs, other.params.n_limbs);
let max_word = &self.params.max_word + &other.params.max_word;
Expand Down Expand Up @@ -500,12 +500,12 @@ impl<Scalar: PrimeField> BigNat<Scalar> {
mut cs: CS,
other: &Self,
modulus: &Self,
) -> Result<(BigNat<Scalar>, BigNat<Scalar>), SynthesisError> {
) -> Result<(Self, Self), SynthesisError> {
self.enforce_limb_width_agreement(other, "mult_mod")?;
let limb_width = self.params.limb_width;
let quotient_bits = (self.n_bits() + other.n_bits()).saturating_sub(modulus.params.min_bits);
let quotient_limbs = quotient_bits.saturating_sub(1) / limb_width + 1;
let quotient = BigNat::alloc_from_nat(
let quotient = Self::alloc_from_nat(
cs.namespace(|| "quotient"),
|| {
Ok({
Expand All @@ -519,7 +519,7 @@ impl<Scalar: PrimeField> BigNat<Scalar> {
quotient_limbs,
)?;
quotient.assert_well_formed(cs.namespace(|| "quotient rangecheck"))?;
let remainder = BigNat::alloc_from_nat(
let remainder = Self::alloc_from_nat(
cs.namespace(|| "remainder"),
|| {
Ok({
Expand Down Expand Up @@ -559,8 +559,8 @@ impl<Scalar: PrimeField> BigNat<Scalar> {
x
};

let left_int = BigNat::from_poly(left, limb_width, left_max_word);
let right_int = BigNat::from_poly(right, limb_width, right_max_word);
let left_int = Self::from_poly(left, limb_width, left_max_word);
let right_int = Self::from_poly(right, limb_width, right_max_word);
left_int.equal_when_carried_regroup(cs.namespace(|| "carry"), &right_int)?;
Ok((quotient, remainder))
}
Expand All @@ -570,19 +570,19 @@ impl<Scalar: PrimeField> BigNat<Scalar> {
&self,
mut cs: CS,
modulus: &Self,
) -> Result<BigNat<Scalar>, SynthesisError> {
) -> Result<Self, SynthesisError> {
self.enforce_limb_width_agreement(modulus, "red_mod")?;
let limb_width = self.params.limb_width;
let quotient_bits = self.n_bits().saturating_sub(modulus.params.min_bits);
let quotient_limbs = quotient_bits.saturating_sub(1) / limb_width + 1;
let quotient = BigNat::alloc_from_nat(
let quotient = Self::alloc_from_nat(
cs.namespace(|| "quotient"),
|| Ok(self.value.grab()? / modulus.value.grab()?),
self.params.limb_width,
quotient_limbs,
)?;
quotient.assert_well_formed(cs.namespace(|| "quotient rangecheck"))?;
let remainder = BigNat::alloc_from_nat(
let remainder = Self::alloc_from_nat(
cs.namespace(|| "remainder"),
|| Ok(self.value.grab()? % modulus.value.grab()?),
self.params.limb_width,
Expand All @@ -605,13 +605,13 @@ impl<Scalar: PrimeField> BigNat<Scalar> {
x
};

let right_int = BigNat::from_poly(right, limb_width, right_max_word);
let right_int = Self::from_poly(right, limb_width, right_max_word);
self.equal_when_carried_regroup(cs.namespace(|| "carry"), &right_int)?;
Ok(remainder)
}

/// Combines limbs into groups.
pub fn group_limbs(&self, limbs_per_group: usize) -> BigNat<Scalar> {
pub fn group_limbs(&self, limbs_per_group: usize) -> Self {
let n_groups = (self.limbs.len() - 1) / limbs_per_group + 1;
let limb_values = self.limb_values.as_ref().map(|vs| {
let mut values: Vec<Scalar> = vec![Scalar::ZERO; n_groups];
Expand Down Expand Up @@ -653,7 +653,7 @@ impl<Scalar: PrimeField> BigNat<Scalar> {
acc.set_bit((i * self.params.limb_width) as u64, true);
acc
}) * &self.params.max_word;
BigNat {
Self {
params: BigNatParams {
min_bits: self.params.min_bits,
limb_width: self.params.limb_width * limbs_per_group,
Expand Down Expand Up @@ -682,7 +682,7 @@ impl<Scalar: PrimeField> Polynomial<Scalar> {
&self,
mut cs: CS,
other: &Self,
) -> Result<Polynomial<Scalar>, SynthesisError> {
) -> Result<Self, SynthesisError> {
let n_product_coeffs = self.coefficients.len() + other.coefficients.len() - 1;
let values = self.values.as_ref().and_then(|self_vs| {
other.values.as_ref().map(|other_vs| {
Expand All @@ -704,7 +704,7 @@ impl<Scalar: PrimeField> Polynomial<Scalar> {
Ok(LinearCombination::zero() + cs.alloc(|| format!("prod {i}"), || Ok(values.grab()?[i]))?)
})
.collect::<Result<Vec<LinearCombination<Scalar>>, SynthesisError>>()?;
let product = Polynomial {
let product = Self {
coefficients,
values,
};
Expand Down Expand Up @@ -773,7 +773,7 @@ impl<Scalar: PrimeField> Polynomial<Scalar> {
lc
})
.collect();
Polynomial {
Self {
coefficients,
values,
}
Expand Down
2 changes: 1 addition & 1 deletion src/gadgets/nonnative/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl<Scalar: PrimeField> Num<Scalar> {
},
)?;

Ok(Num {
Ok(Self {
value: new_value,
num: LinearCombination::zero() + var,
})
Expand Down
Loading

0 comments on commit c11e890

Please sign in to comment.