Skip to content

Commit

Permalink
Merge pull request #2247 from o1-labs/volhovm/2231-ivc-array-on-stack…
Browse files Browse the repository at this point in the history
…-handling

IVC: Add array allocators/helpers: reduce stack usage
  • Loading branch information
dannywillems authored May 28, 2024
2 parents 2c692af + cde3a90 commit aafb54c
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 73 deletions.
87 changes: 31 additions & 56 deletions ivc/src/ivc/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,9 @@ where
Ff: PrimeField,
Env: MultiRowReadCap<F, IVCColumn> + LookupCap<F, IVCColumn, IVCLookupTable<Ff>>,
{
let mut comms_limbs: [[Vec<Vec<F>>; 3]; 3] =
std::array::from_fn(|_| std::array::from_fn(|_| vec![]));
let mut comms_limbs_s: [Vec<Vec<F>>; 3] = std::array::from_fn(|_| vec![]);
let mut comms_limbs_l: [Vec<Vec<F>>; 3] = std::array::from_fn(|_| vec![]);
let mut comms_limbs_xl: [Vec<Vec<F>>; 3] = std::array::from_fn(|_| vec![]);

for _block_row_i in 0..(3 * N_COL_TOTAL) {
let row_num = env.curr_row();
Expand All @@ -274,41 +275,19 @@ where
let (limbs_small, limbs_large, limbs_xlarge) =
write_inputs_row(env, target_comms, row_num_local);

comms_limbs[0][comtype].push(limbs_small);
comms_limbs[1][comtype].push(limbs_large);
comms_limbs[2][comtype].push(limbs_xlarge);
comms_limbs_s[comtype].push(limbs_small);
comms_limbs_l[comtype].push(limbs_large);
comms_limbs_xl[comtype].push(limbs_xlarge);

constrain_inputs(env);

env.next_row();
}

// Transforms nested Vec<Vec<_>> into fixed-size arrays. Returns
// Left-Right-Output for a given limb size.
fn repack_output<F: PrimeField, const TWO_LIMB_SIZE: usize, const N_COL_TOTAL: usize>(
input: [Vec<Vec<F>>; 3],
) -> Box<[[[F; TWO_LIMB_SIZE]; N_COL_TOTAL]; 3]> {
Box::new(
input
.into_iter()
.map(|vector: Vec<Vec<_>>| {
vector
.into_iter()
.map(|subvec: Vec<_>| subvec.try_into().unwrap())
.collect::<Vec<_>>()
.try_into()
.unwrap()
})
.collect::<Vec<_>>()
.try_into()
.unwrap(),
)
}

(
repack_output(comms_limbs[0].clone()),
repack_output(comms_limbs[1].clone()),
repack_output(comms_limbs[2].clone()),
o1_utils::array::vec_to_boxed_array3(comms_limbs_s.to_vec()),
o1_utils::array::vec_to_boxed_array3(comms_limbs_l.to_vec()),
o1_utils::array::vec_to_boxed_array3(comms_limbs_xl.to_vec()),
)
}

Expand Down Expand Up @@ -641,32 +620,28 @@ pub fn process_ecadds<F, Ff, Env, const N_COL_TOTAL: usize>(
let r_hat_large: Box<[[F; 2 * N_LIMBS_LARGE]; N_COL_TOTAL]> = Box::new(comms_large[1]);

// Compute error and t terms limbs.
let error_terms_large: [[F; 2 * N_LIMBS_LARGE]; 3] = error_terms
.iter()
.map(|(x, y)| {
limb_decompose_ff::<F, Ff, LIMB_BITSIZE_LARGE, N_LIMBS_LARGE>(x)
.into_iter()
.chain(limb_decompose_ff::<F, Ff, LIMB_BITSIZE_LARGE, N_LIMBS_LARGE>(y))
.collect::<Vec<_>>()
.try_into()
.unwrap()
})
.collect::<Vec<_>>()
.try_into()
.unwrap();
let t_terms_large: [[F; 2 * N_LIMBS_LARGE]; 2] = t_terms
.iter()
.map(|(x, y)| {
limb_decompose_ff::<F, Ff, LIMB_BITSIZE_LARGE, N_LIMBS_LARGE>(x)
.into_iter()
.chain(limb_decompose_ff::<F, Ff, LIMB_BITSIZE_LARGE, N_LIMBS_LARGE>(y))
.collect::<Vec<_>>()
.try_into()
.unwrap()
})
.collect::<Vec<_>>()
.try_into()
.unwrap();
let error_terms_large: Box<[[F; 2 * N_LIMBS_LARGE]; 3]> = o1_utils::array::vec_to_boxed_array2(
error_terms
.iter()
.map(|(x, y)| {
limb_decompose_ff::<F, Ff, LIMB_BITSIZE_LARGE, N_LIMBS_LARGE>(x)
.into_iter()
.chain(limb_decompose_ff::<F, Ff, LIMB_BITSIZE_LARGE, N_LIMBS_LARGE>(y))
.collect()
})
.collect(),
);
let t_terms_large: Box<[[F; 2 * N_LIMBS_LARGE]; 2]> = o1_utils::array::vec_to_boxed_array2(
t_terms
.iter()
.map(|(x, y)| {
limb_decompose_ff::<F, Ff, LIMB_BITSIZE_LARGE, N_LIMBS_LARGE>(x)
.into_iter()
.chain(limb_decompose_ff::<F, Ff, LIMB_BITSIZE_LARGE, N_LIMBS_LARGE>(y))
.collect()
})
.collect(),
);

// E_R' = r·T_0 + r^2·T_1 + r^3·E_R
// FIXME for now stubbed and just equal to E_L
Expand Down
32 changes: 15 additions & 17 deletions ivc/src/ivc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ mod tests {
witness::Witness,
BaseSponge, Ff1, Fp, OpeningProof, ScalarSponge, BN254,
};
use o1_utils::box_array;
use poly_commitment::pairing_proof::PairingSRS;
use rand::{CryptoRng, RngCore};

Expand All @@ -50,8 +51,6 @@ mod tests {
0,
LT,
>;
//type IVCWitnessBuilderEnv = IVCWitnessBuilderEnvRaw<IVCLookupTable<Ff1>>;
//type IVCWitnessBuilderEnvDummy = IVCWitnessBuilderEnvRaw<DummyLookupTable>;

impl PoseidonParams<Fp, IVC_POSEIDON_STATE_SIZE, IVC_POSEIDON_NB_FULL_ROUND>
for PoseidonBN254Parameters
Expand All @@ -78,32 +77,31 @@ mod tests {
) -> IVCWitnessBuilderEnvRaw<LT> {
let mut witness_env = IVCWitnessBuilderEnvRaw::<LT>::create();

// To support less rows than domain_size we need to have selectors.
//let row_num = rng.gen_range(0..domain_size);
let mut comms_left: Box<_> = box_array![(Ff1::zero(),Ff1::zero()); TEST_N_COL_TOTAL];
let mut comms_right: Box<_> = box_array![(Ff1::zero(),Ff1::zero()); TEST_N_COL_TOTAL];
let mut comms_output: Box<_> = box_array![(Ff1::zero(),Ff1::zero()); TEST_N_COL_TOTAL];

let comms_left: Box<[_; TEST_N_COL_TOTAL]> = Box::new(core::array::from_fn(|_i| {
(
for i in 0..TEST_N_COL_TOTAL {
comms_left[i] = (
<Ff1 as UniformRand>::rand(rng),
<Ff1 as UniformRand>::rand(rng),
)
}));
let comms_right: Box<[_; TEST_N_COL_TOTAL]> = Box::new(core::array::from_fn(|_i| {
(
);
comms_right[i] = (
<Ff1 as UniformRand>::rand(rng),
<Ff1 as UniformRand>::rand(rng),
)
}));
let comms_output: Box<[_; TEST_N_COL_TOTAL]> = Box::new(core::array::from_fn(|_i| {
(
);
comms_output[i] = (
<Ff1 as UniformRand>::rand(rng),
<Ff1 as UniformRand>::rand(rng),
)
}));
);
}

println!("Building fixed selectors");
let fixed_selectors: Vec<Vec<Fp>> =
build_selectors::<_, TEST_N_COL_TOTAL, TEST_N_CHALS>(domain_size);
witness_env.set_fixed_selectors(fixed_selectors.to_vec());

println!("Calling the IVC circuit");
// TODO add nonzero E/T values.
ivc_circuit::<_, _, _, _, TEST_N_COL_TOTAL>(
&mut SubEnvLookup::new(&mut witness_env, lt_lens),
Expand All @@ -113,7 +111,7 @@ mod tests {
[(Ff1::zero(), Ff1::zero()); 3],
[(Ff1::zero(), Ff1::zero()); 2],
Fp::zero(),
vec![Fp::zero(); 200],
vec![Fp::zero(); TEST_N_CHALS],
&PoseidonBN254Parameters,
TEST_DOMAIN_SIZE,
);
Expand Down
124 changes: 124 additions & 0 deletions utils/src/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
//! This module provides different helpers in creating constant sized
//! arrays and converting them to different formats.
//!
//! Functions in this module are not necessarily optimal in terms of
//! allocations, as they tend to create intermediate vectors. For
//! better performance, either optimise this code, or use
//! (non-fixed-sized) vectors.
/// Converts a two-dimensional vector to a constant sized two-dimensional array.
pub fn vec_to_boxed_array2<T, const N: usize, const M: usize>(
vec: Vec<Vec<T>>,
) -> Box<[[T; N]; M]> {
let vec_of_slices2: Vec<[T; N]> = vec
.into_iter()
.map(|x: Vec<T>| {
let y: Box<[T]> = x.into_boxed_slice();
let z: Box<[T; N]> = y
.try_into()
.unwrap_or_else(|_| panic!("vec_to_boxed_array2: length mismatch inner array"));
*z
})
.collect();
let array: Box<[[T; N]; M]> = vec_of_slices2
.into_boxed_slice()
.try_into()
.unwrap_or_else(|_| panic!("vec_to_boxed_array2: length mismatch outer array"));

array
}

/// Converts a three-dimensional vector to a constant sized two-dimensional array.
pub fn vec_to_boxed_array3<T, const N: usize, const M: usize, const K: usize>(
vec: Vec<Vec<Vec<T>>>,
) -> Box<[[[T; N]; M]; K]> {
let vec_of_slices2: Vec<[[T; N]; M]> =
vec.into_iter().map(|v| *vec_to_boxed_array2(v)).collect();
vec_of_slices2
.into_boxed_slice()
.try_into()
.unwrap_or_else(|_| panic!("vec_to_boxed_array3: length mismatch outer array"))
}

/// A macro similar to `vec![$elem; $size]` which returns a boxed
/// array, allocated directly on the heap (via a vector, with reallocations).
///
/// ```rustc
/// let _: Box<[u8; 1024]> = box_array![0; 1024];
/// ```
///
/// See
/// <https://stackoverflow.com/questions/25805174/creating-a-fixed-size-array-on-heap-in-rust/68122278#68122278>
#[macro_export]
macro_rules! box_array {
($val:expr ; $len:expr) => {{
// Use a generic function so that the pointer cast remains type-safe
fn vec_to_boxed_array<T>(vec: Vec<T>) -> Box<[T; $len]> {
(vec.into_boxed_slice())
.try_into()
.unwrap_or_else(|_| panic!("box_array: length mismatch"))
}

vec_to_boxed_array(vec![$val; $len])
}};
}

/// A macro similar to `vec![vec![$elem; $size1]; $size2]` which
/// returns a two-dimensional boxed array, allocated directly on the
/// heap (via a vector, with reallocations).
///
/// ```rustc
/// let _: Box<[[u8; 1024]; 512]> = box_array![0; 1024; 512];
/// ```
///
#[macro_export]
macro_rules! box_array2 {
($val:expr; $len1:expr; $len2:expr) => {{
pub fn vec_to_boxed_array2<T>(vec: Vec<Vec<T>>) -> Box<[[T; $len1]; $len2]> {
let vec_of_slices2: Vec<[T; $len1]> = vec
.into_iter()
.map(|x: Vec<T>| {
let y: Box<[T]> = x.into_boxed_slice();
let z: Box<[T; $len1]> = y
.try_into()
.unwrap_or_else(|_| panic!("box_array2: length mismatch inner array"));
*z
})
.collect();
let array: Box<[[T; $len1]; $len2]> = vec_of_slices2
.into_boxed_slice()
.try_into()
.unwrap_or_else(|_| panic!("box_array2: length mismatch outer array"));

array
}

vec_to_boxed_array2(vec![vec![$val; $len1]; $len2])
}};
}

#[cfg(test)]
mod tests {
use super::*;

use ark_ec::AffineCurve;
use ark_ff::Zero;
use mina_curves::pasta::Pallas as CurvePoint;

pub type BaseField = <CurvePoint as AffineCurve>::BaseField;

#[test]
/// Tests whether initialising different arrays creates a stack
/// overflow. The usual default size of the stack is 128kB.
fn test_boxed_stack_overflow() {
// Each point is assumed to be 256 bits, so 512 points is
// 16MB. This often overflows the stack if created as an
// array.
let _boxed: Box<[[BaseField; 256]; 1]> =
vec_to_boxed_array2(vec![vec![BaseField::zero(); 256]; 1]);
let _boxed: Box<[[BaseField; 64]; 4]> =
vec_to_boxed_array2(vec![vec![BaseField::zero(); 64]; 4]);
let _boxed: Box<[BaseField; 256]> = box_array![BaseField::zero(); 256];
let _boxed: Box<[[BaseField; 256]; 1]> = box_array2![BaseField::zero(); 256; 1];
}
}
1 change: 1 addition & 0 deletions utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//! A collection of utility functions and constants that can be reused from multiple projects
pub mod adjacent_pairs;
pub mod array;
pub mod biguint_helpers;
pub mod bitwise_operations;
pub mod chunked_evaluations;
Expand Down

0 comments on commit aafb54c

Please sign in to comment.