Skip to content

Commit

Permalink
encapsulate sparse API
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Feb 28, 2024
1 parent 02d9fe5 commit ff4253b
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 14 deletions.
6 changes: 3 additions & 3 deletions src/bellpepper/r1cs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ fn add_constraint<S: PrimeField>(
) {
let (A, B, C, nn) = X;
let n = **nn;
assert_eq!(n + 1, A.indptr.len(), "A: invalid shape");
assert_eq!(n + 1, B.indptr.len(), "B: invalid shape");
assert_eq!(n + 1, C.indptr.len(), "C: invalid shape");
assert_eq!(n, A.num_rows(), "A: invalid shape");
assert_eq!(n, B.num_rows(), "B: invalid shape");
assert_eq!(n, C.num_rows(), "C: invalid shape");

let add_constraint_component = |index: Index, coeff: &S, M: &mut SparseMatrix<S>| {
// we add constraints to the matrix only if the associated coefficient is non-zero
Expand Down
38 changes: 38 additions & 0 deletions src/r1cs/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use ff::PrimeField;
use itertools::Itertools as _;
use rand_core::{CryptoRng, RngCore};
use rayon::prelude::*;
use ref_cast::RefCast;
use serde::{Deserialize, Serialize};

/// CSR format sparse matrix, We follow the names used by scipy.
Expand All @@ -31,6 +32,11 @@ pub struct SparseMatrix<F: PrimeField> {
pub cols: usize,
}

/// Wrapper type for encode rows of [`SparseMatrix`]
#[derive(Debug, Clone, RefCast)]
#[repr(transparent)]
pub struct RowData([usize; 2]);

/// [`SparseMatrix`]s are often large, and this helps with cloning bottlenecks
impl<F: PrimeField> Clone for SparseMatrix<F> {
fn clone(&self) -> Self {
Expand Down Expand Up @@ -111,6 +117,30 @@ impl<F: PrimeField> SparseMatrix<F> {
Self::new(&matrix, rows, cols)
}

/// Returns an iterator into the rows
pub fn iter_rows(&self) -> impl Iterator<Item = &RowData> {
self
.indptr
.windows(2)
.map(|ptrs| RowData::ref_cast(ptrs.try_into().unwrap()))
}

/// Returns a parallel iterator into the rows
pub fn par_iter_rows(&self) -> impl IndexedParallelIterator<Item = &RowData> {
self
.indptr
.par_windows(2)
.map(|ptrs| RowData::ref_cast(ptrs.try_into().unwrap()))
}

/// Retrieves the data for row slice [i..j] from `row`.
/// [`RowData`] **must** be created from unmodified `self` previously to guarentee safety.
pub fn get_row(&self, row: &RowData) -> impl Iterator<Item = (&F, &usize)> {
self.data[row.0[0]..row.0[1]]
.iter()
.zip_eq(&self.indices[row.0[0]..row.0[1]])
}

/// Retrieves the data for row slice [i..j] from `ptrs`.
/// We assume that `ptrs` is indexed from `indptrs` and do not check if the
/// returned slice is actually a valid row.
Expand Down Expand Up @@ -226,6 +256,14 @@ impl<F: PrimeField> SparseMatrix<F> {
nnz: *self.indptr.last().unwrap(),
}
}

pub fn num_rows(&self) -> usize {
self.indptr.len() - 1
}

pub fn num_cols(&self) -> usize {
self.cols
}
}

/// Iterator for sparse matrix
Expand Down
8 changes: 3 additions & 5 deletions src/spartan/batched.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,13 +513,11 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> BatchedRelaxedR1CSSNARKTrait<E>
r_y: &[E::Scalar]|
-> Vec<E::Scalar> {
let evaluate_with_table =
// TODO(@winston-h-zhang): review
|M: &SparseMatrix<E::Scalar>, T_x: &[E::Scalar], T_y: &[E::Scalar]| -> E::Scalar {
M.indptr
.par_windows(2)
M.par_iter_rows()
.enumerate()
.map(|(row_idx, ptrs)| {
M.get_row_unchecked(ptrs.try_into().unwrap())
.map(|(row_idx, row)| {
M.get_row(row)
.map(|(val, col_idx)| T_x[row_idx] * T_y[*col_idx] * val)
.sum::<E::Scalar>()
})
Expand Down
5 changes: 3 additions & 2 deletions src/spartan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ fn compute_eval_table_sparse<E: Engine>(
assert_eq!(rx.len(), S.num_cons);

let inner = |M: &SparseMatrix<E::Scalar>, M_evals: &mut Vec<E::Scalar>| {
for (row_idx, ptrs) in M.indptr.windows(2).enumerate() {
for (val, col_idx) in M.get_row_unchecked(ptrs.try_into().unwrap()) {
for (row_idx, row) in M.iter_rows().enumerate() {
for (val, col_idx) in M.get_row(row) {
// TODO(@winston-h-zhang): Parallelize? Will need more complicated locking
M_evals[*col_idx] += rx[row_idx] * val;
}
}
Expand Down
7 changes: 3 additions & 4 deletions src/spartan/snark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,10 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> RelaxedR1CSSNARKTrait<E> for Relax
-> Vec<E::Scalar> {
let evaluate_with_table =
|M: &SparseMatrix<E::Scalar>, T_x: &[E::Scalar], T_y: &[E::Scalar]| -> E::Scalar {
M.indptr
.par_windows(2)
M.par_iter_rows()
.enumerate()
.map(|(row_idx, ptrs)| {
M.get_row_unchecked(ptrs.try_into().unwrap())
.map(|(row_idx, row)| {
M.get_row(row)
.map(|(val, col_idx)| T_x[row_idx] * T_y[*col_idx] * val)
.sum::<E::Scalar>()
})
Expand Down

0 comments on commit ff4253b

Please sign in to comment.