From aaab8031a3ceafe3f91dbe2828744cea5f5b4b3b Mon Sep 17 00:00:00 2001 From: Alex Ozdemir Date: Thu, 27 Jun 2024 11:26:14 -0700 Subject: [PATCH] Eliminate tuples in preprocessing This started as an optimization patch, but my first optimization revealed a bug. In chasing the bug, I found more optimization. Changes: 1. Eliminate tuples in preprocessing. (opt) 2. Handle CStore in tuple elimination pass. (bugfix) 3. Use tuples instead of arrays in a few more extension ops: (opt) * GCD for vanishing polynomials and their derivatives * sorting in transcript checking 4. A few logging revisions --- .../pf/mem/2024_05_31_benny_bug_tr.zok.vin | 2 +- examples/circ.rs | 12 +++--- scripts/zokrates_test.zsh | 5 ++- src/ir/opt/mem/ram/checker.rs | 13 ++----- src/ir/opt/mem/ram/checker/permutation.rs | 24 +++++------- src/ir/opt/mem/ram/volatile.rs | 2 +- src/ir/opt/tuple.rs | 29 ++++++++++---- src/ir/term/ext/haboeck.rs | 4 +- src/ir/term/ext/poly.rs | 38 ++++++------------- src/ir/term/ext/sort.rs | 23 +++++++---- src/ir/term/ext/test.rs | 14 +++---- src/ir/term/ext/waksman.rs | 11 +++--- src/ir/term/text/mod.rs | 2 +- src/target/r1cs/mirage.rs | 1 + src/target/r1cs/wit_comp.rs | 2 +- 15 files changed, 90 insertions(+), 92 deletions(-) diff --git a/examples/ZoKrates/pf/mem/2024_05_31_benny_bug_tr.zok.vin b/examples/ZoKrates/pf/mem/2024_05_31_benny_bug_tr.zok.vin index 6669752de..3afec9501 100644 --- a/examples/ZoKrates/pf/mem/2024_05_31_benny_bug_tr.zok.vin +++ b/examples/ZoKrates/pf/mem/2024_05_31_benny_bug_tr.zok.vin @@ -1,7 +1,7 @@ (set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513 (let ( (x #f6) - (return #f6) + (return #f0) ) false ; ignored )) diff --git a/examples/circ.rs b/examples/circ.rs index 05fb38e36..6f5e7efa5 100644 --- a/examples/circ.rs +++ b/examples/circ.rs @@ -195,6 +195,7 @@ fn main() { Backend::Smt { .. } => Mode::Proof, }; let language = determine_language(&options.frontend.language, &options.path); + println!("Running frontend"); let cs = match language { #[cfg(all(feature = "smt", feature = "zok"))] DeterminedLanguage::Zsharp => { @@ -233,6 +234,7 @@ fn main() { panic!("Missing feature: c"); } }; + println!("Running IR optimizations"); let cs = match mode { Mode::Opt => opt( cs, @@ -295,8 +297,7 @@ fn main() { opt(cs, opts) } }; - println!("Done with IR optimization"); - + println!("Running backend"); match options.backend { #[cfg(feature = "r1cs")] Backend::R1cs { @@ -306,7 +307,6 @@ fn main() { proof_impl, .. } => { - println!("Converting to r1cs"); let cs = cs.get("main"); trace!("IR: {}", circ::ir::term::text::serialize_computation(cs)); let mut r1cs = to_r1cs(cs, cfg()); @@ -314,7 +314,7 @@ fn main() { println!("R1CS stats: {:#?}", r1cs.stats()); } - println!("Pre-opt R1cs size: {}", r1cs.constraints().len()); + println!("Running r1cs optimizations "); r1cs = reduce_linearities(r1cs, cfg()); println!("Final R1cs size: {}", r1cs.constraints().len()); @@ -326,7 +326,7 @@ fn main() { ProofAction::Count => (), #[cfg(feature = "bellman")] ProofAction::Setup => { - println!("Generating Parameters"); + println!("Running Setup"); match proof_impl { ProofImpl::Groth16 => Bellman::::setup_fs( prover_data, @@ -348,7 +348,7 @@ fn main() { ProofAction::Setup => panic!("Missing feature: bellman"), #[cfg(feature = "bellman")] ProofAction::CpSetup => { - println!("Generating Parameters"); + println!("Running CpSetup"); match proof_impl { ProofImpl::Groth16 => panic!("Groth16 is not CP"), ProofImpl::Mirage => Mirage::::cp_setup_fs( diff --git a/scripts/zokrates_test.zsh b/scripts/zokrates_test.zsh index 5af8f6ddc..52aa6fa17 100755 --- a/scripts/zokrates_test.zsh +++ b/scripts/zokrates_test.zsh @@ -72,8 +72,6 @@ function pf_test_isolate { done } -pf_test 2024_05_24_benny_bug -pf_test 2024_05_31_benny_bug r1cs_test_count ./examples/ZoKrates/pf/mm4_cond.zok 120 r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsAdd.zok r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOnCurve.zok @@ -110,4 +108,7 @@ pf_test var_idx_arr_str_arr_str pf_test mm pf_test unused_var +pf_test 2024_05_24_benny_bug +pf_test 2024_05_31_benny_bug + scripts/zx_tests/run_tests.sh diff --git a/src/ir/opt/mem/ram/checker.rs b/src/ir/opt/mem/ram/checker.rs index d6d93c49e..156ba1507 100644 --- a/src/ir/opt/mem/ram/checker.rs +++ b/src/ir/opt/mem/ram/checker.rs @@ -213,9 +213,7 @@ fn range_check( debug_assert!(values.iter().all(|v| check(v) == f_sort)); let mut ms_hash_inputs = values.clone(); values.extend(f_sort.elems_iter().take(n)); - let sorted_term = unmake_array( - term![Op::ExtOp(ExtOp::Sort); make_array(f_sort.clone(), f_sort.clone(), values.clone())], - ); + let sorted_term = tuple_terms(term![Op::ExtOp(ExtOp::Sort); term(Op::Tuple, values.clone())]); let sorted: Vec = sorted_term .into_iter() .enumerate() @@ -302,10 +300,7 @@ fn derivative_gcd( let ns = ns.subspace("uniq"); let fs = Sort::Field(f.clone()); let pairs = term( - Op::Array(Box::new(ArrayOp { - key: fs.clone(), - val: Sort::new_tuple(vec![fs.clone(), Sort::Bool]), - })), + Op::Tuple, values .clone() .into_iter() @@ -314,8 +309,8 @@ fn derivative_gcd( .collect(), ); let two_polys = term![Op::ExtOp(ExtOp::UniqDeriGcd); pairs]; - let s_coeffs = unmake_array(term![Op::Field(0); two_polys.clone()]); - let t_coeffs = unmake_array(term![Op::Field(1); two_polys]); + let s_coeffs = tuple_terms(term![Op::Field(0); two_polys.clone()]); + let t_coeffs = tuple_terms(term![Op::Field(1); two_polys]); let mut decl_poly = |coeffs: Vec, poly_name: &str| -> Vec { coeffs .into_iter() diff --git a/src/ir/opt/mem/ram/checker/permutation.rs b/src/ir/opt/mem/ram/checker/permutation.rs index b88fbdf4d..8a0779b88 100644 --- a/src/ir/opt/mem/ram/checker/permutation.rs +++ b/src/ir/opt/mem/ram/checker/permutation.rs @@ -11,19 +11,15 @@ pub(super) fn waksman( val_sort: &Sort, new_var: &mut impl FnMut(&str, Term) -> Term, ) -> Vec { - let f = &cfg.field; - let f_s = Sort::Field(f.clone()); // (1) sort the transcript let field_tuples: Vec = accesses.iter().map(|a| a.to_field_tuple(cfg)).collect(); - let switch_settings_tuple = term![Op::ExtOp(ExtOp::Waksman); make_array(f_s.clone(), check(&field_tuples[0]), field_tuples.clone())]; - let n = check(&switch_settings_tuple).as_tuple().len(); - let mut switch_settings: VecDeque = (0..n) - .map(|i| { - new_var( - &format!("sw{}", i), - term![Op::Field(i); switch_settings_tuple.clone()], - ) - }) + let switch_settings_tuple = + term![Op::ExtOp(ExtOp::Waksman); term(Op::Tuple, field_tuples.clone())]; + let switch_settings_terms = tuple_terms(switch_settings_tuple); + let mut switch_settings: VecDeque = switch_settings_terms + .into_iter() + .enumerate() + .map(|(i, t)| new_var(&format!("sw{}", i), t)) .collect(); let sorted_field_tuple_values: Vec = @@ -69,12 +65,10 @@ pub(super) fn msh( assertions: &mut Vec, ) -> Vec { let f = &cfg.field; - let f_s = Sort::Field(f.clone()); // (1) sort the transcript let field_tuples: Vec = accesses.iter().map(|a| a.to_field_tuple(cfg)).collect(); - let sorted_field_tuple_values: Vec = unmake_array( - term![Op::ExtOp(ExtOp::Sort); make_array(f_s.clone(), check(&field_tuples[0]), field_tuples.clone())], - ); + let sorted_field_tuple_values: Vec = + tuple_terms(term![Op::ExtOp(ExtOp::Sort); term(Op::Tuple, field_tuples.clone())]); let mut sorted_accesses: Vec = sorted_field_tuple_values .into_iter() .enumerate() diff --git a/src/ir/opt/mem/ram/volatile.rs b/src/ir/opt/mem/ram/volatile.rs index 1df3db202..e243a9966 100644 --- a/src/ir/opt/mem/ram/volatile.rs +++ b/src/ir/opt/mem/ram/volatile.rs @@ -106,7 +106,7 @@ impl ArrayGraph { .collect(); while let Some(top) = stack.pop() { if ram_terms.insert(top.clone()) { - trace!("Maybe RAM: {}", top); + trace!("Maybe RAM: {}", top.op()); for p in ps.get(&top).unwrap() { if right_sort(p, field) { stack.push(p.clone()); diff --git a/src/ir/opt/tuple.rs b/src/ir/opt/tuple.rs index 82f507a9a..32cd23085 100644 --- a/src/ir/opt/tuple.rs +++ b/src/ir/opt/tuple.rs @@ -143,6 +143,7 @@ impl TupleTree { } } } + #[track_caller] fn unwrap_non_tuple(self) -> Term { match self { TupleTree::NonTuple(t) => t, @@ -245,9 +246,13 @@ fn tuple_free(t: Term) -> bool { /// Run the tuple elimination pass. pub fn eliminate_tuples(cs: &mut Computation) { let mut lifted: TermMap = TermMap::default(); - let terms = - PostOrderIter::from_roots_and_skips(cs.outputs().iter().cloned(), Default::default()); - // .chain(cs.precomputes.outputs().values().cloned()), + let terms = PostOrderIter::from_roots_and_skips( + cs.outputs() + .iter() + .cloned() + .chain(cs.precomputes.outputs().values().cloned()), + Default::default(), + ); for t in terms { let mut cs: Vec = t .cs() @@ -270,6 +275,14 @@ pub fn eliminate_tuples(cs: &mut Computation) { let eqs = zip_eq(a.flatten(), b.flatten()).map(|(a, b)| term![Op::Eq; a, b]); TupleTree::NonTuple(term(AND, eqs.collect())) } + Op::CStore => { + let c = cs.pop().unwrap().unwrap_non_tuple(); + let v = cs.pop().unwrap(); + let i = cs.pop().unwrap().unwrap_non_tuple(); + let a = cs.pop().unwrap(); + debug_assert!(cs.is_empty()); + a.bimap(|a, v| term![Op::CStore; a, i.clone(), v, c.clone()], &v) + } Op::Store => { let v = cs.pop().unwrap(); let i = cs.pop().unwrap().unwrap_non_tuple(); @@ -321,11 +334,11 @@ pub fn eliminate_tuples(cs: &mut Computation) { .into_iter() .flat_map(|o| lifted.get(&o).unwrap().clone().flatten()) .collect(); - // let os = cs.precomputes.outputs().clone(); - // for (name, old_term) in os { - // let new_term = lifted.get(&old_term).unwrap().clone().as_term(); - // cs.precomputes.change_output(&name, new_term); - // } + let os = cs.precomputes.outputs().clone(); + for (name, old_term) in os { + let new_term = lifted.get(&old_term).unwrap().clone().as_term(); + cs.precomputes.change_output(&name, new_term); + } #[cfg(debug_assertions)] for o in &cs.outputs { if let Some(t) = find_tuple_term(o.clone()) { diff --git a/src/ir/term/ext/haboeck.rs b/src/ir/term/ext/haboeck.rs index 135ad8556..3899e6aed 100644 --- a/src/ir/term/ext/haboeck.rs +++ b/src/ir/term/ext/haboeck.rs @@ -10,7 +10,7 @@ use crate::ir::term::ty::*; use crate::ir::term::*; -/// Type-check [super::ExtOp::UniqDeriGcd]. +/// Type-check [super::ExtOp::Haboeck]. pub fn check(arg_sorts: &[&Sort]) -> Result { let &[haystack, needles] = ty::count_or_ref(arg_sorts)?; let (_n, value0) = ty::homogenous_tuple_or(haystack, "haystack must be a tuple")?; @@ -21,7 +21,7 @@ pub fn check(arg_sorts: &[&Sort]) -> Result { Ok(haystack.clone()) } -/// Evaluate [super::ExtOp::UniqDeriGcd]. +/// Evaluate [super::ExtOp::Haboeck]. pub fn eval(args: &[&Value]) -> Value { let haystack: Vec = args[0] .as_tuple() diff --git a/src/ir/term/ext/poly.rs b/src/ir/term/ext/poly.rs index 8312efc69..9cdbd371b 100644 --- a/src/ir/term/ext/poly.rs +++ b/src/ir/term/ext/poly.rs @@ -11,40 +11,24 @@ use crate::ir::term::*; /// Type-check [super::ExtOp::UniqDeriGcd]. pub fn check(arg_sorts: &[&Sort]) -> Result { - if let &[pairs] = arg_sorts { - let (key, value, size) = ty::array_or(pairs, "UniqDeriGcd pairs")?; - let f = pf_or(key, "UniqDeriGcd pairs: indices must be field")?; - let value_tup = ty::tuple_or(value, "UniqDeriGcd entries: value must be a tuple")?; - if let &[root, cond] = &value_tup { - eq_or(f, root, "UniqDeriGcd pairs: first element must be a field")?; - eq_or( - cond, - &Sort::Bool, - "UniqDeriGcd pairs: second element must be a bool", - )?; - let arr = Sort::new_array(f.clone(), f.clone(), size); - Ok(Sort::new_tuple(vec![arr.clone(), arr])) - } else { - // non-pair entries value - Err(TypeErrorReason::Custom( - "UniqDeriGcd: pairs value must be a pair".into(), - )) - } - } else { - // wrong arg count - Err(TypeErrorReason::ExpectedArgs(2, arg_sorts.len())) - } + let [pairs] = ty::count_or_ref(arg_sorts)?; + let (size, value) = ty::homogenous_tuple_or(pairs, "UniqDeriGcd")?; + let [root, cond] = ty::count_or(ty::tuple_or(value, "UniqDeriGcd")?)?; + let f = pf_or(root, "UniqDeriGcd: first is field")?; + eq_or(cond, &Sort::Bool, "UniqDeriGcd pairs: second is bool")?; + let coeffs = Sort::new_tuple(vec![f.clone(); size]); + Ok(Sort::new_tuple(vec![coeffs.clone(), coeffs])) } /// Evaluate [super::ExtOp::UniqDeriGcd]. #[cfg(feature = "poly")] pub fn eval(args: &[&Value]) -> Value { use rug_polynomial::ModPoly; - let sort = args[0].sort().as_array().0.clone(); + let sort = args[0].sort().as_tuple()[0].as_tuple()[0].clone(); let field = sort.as_pf().clone(); let mut roots: Vec = Vec::new(); - let deg = args[0].as_array().size; - for t in args[0].as_array().values() { + let deg = args[0].as_tuple().len(); + for t in args[0].as_tuple() { let tuple = t.as_tuple(); let cond = tuple[1].as_bool(); if cond { @@ -60,7 +44,7 @@ pub fn eval(args: &[&Value]) -> Value { let v: Vec = (0..deg) .map(|i| Value::Field(field.new_v(s.get_coefficient(i)))) .collect(); - Value::Array(Array::from_vec(sort.clone(), sort.clone(), v)) + Value::Tuple(v.into()) }; let s_cs = coeff_arr(s); let t_cs = coeff_arr(t); diff --git a/src/ir/term/ext/sort.rs b/src/ir/term/ext/sort.rs index 36a059d71..6990085f3 100644 --- a/src/ir/term/ext/sort.rs +++ b/src/ir/term/ext/sort.rs @@ -17,12 +17,21 @@ pub fn check(arg_sorts: &[&Sort]) -> Result { /// Evaluate [super::ExtOp::Sort]. pub fn eval(args: &[&Value]) -> Value { let sort = args[0].sort(); - let (key_sort, value_sort, _) = sort.as_array(); - let mut values: Vec = args[0].as_array().values(); + let is_array = sort.is_array(); + let mut values: Vec = if is_array { + args[0].as_array().values() + } else { + args[0].as_tuple().to_vec() + }; values.sort(); - Value::Array(Array::from_vec( - key_sort.clone(), - value_sort.clone(), - values, - )) + if is_array { + let (key_sort, value_sort, _) = sort.as_array(); + Value::Array(Array::from_vec( + key_sort.clone(), + value_sort.clone(), + values, + )) + } else { + Value::Tuple(values.into()) + } } diff --git a/src/ir/term/ext/test.rs b/src/ir/term/ext/test.rs index 345feb9ae..11ec7311d 100644 --- a/src/ir/term/ext/test.rs +++ b/src/ir/term/ext/test.rs @@ -25,7 +25,7 @@ fn uniq_deri_gcd_eval() { let t = text::parse_term( b" (declare ( - (pairs (array (mod 17) (tuple (mod 17) bool) 5)) + (pairs (tuple 5 (tuple (mod 17) bool))) ) (uniq_deri_gcd pairs))", ); @@ -35,7 +35,7 @@ fn uniq_deri_gcd_eval() { (set_default_modulus 17 (let ( - (pairs (#l (mod 17) ( (#t #f0 false) (#t #f1 false) (#t #f2 true) (#t #f3 false) (#t #f4 true) ))) + (pairs (#t (#t #f0 false) (#t #f1 false) (#t #f2 true) (#t #f3 false) (#t #f4 true) )) ) false)) ", ); @@ -46,8 +46,8 @@ fn uniq_deri_gcd_eval() { (let ( (output (#t - (#l (mod 17) ( #f16 #f0 #f0 #f0 #f0 ) ) ; s, from sage - (#l (mod 17) ( #f7 #f9 #f0 #f0 #f0 ) ) ; t, from sage + (#t #f16 #f0 #f0 #f0 #f0 ) ; s, from sage + (#t #f7 #f9 #f0 #f0 #f0 ) ; t, from sage )) ) false)) ", @@ -59,7 +59,7 @@ fn uniq_deri_gcd_eval() { (set_default_modulus 17 (let ( - (pairs (#l (mod 17) ( (#t #f0 true) (#t #f1 true) (#t #f2 true) (#t #f3 false) (#t #f4 true) ))) + (pairs (#t (#t #f0 true) (#t #f1 true) (#t #f2 true) (#t #f3 false) (#t #f4 true))) ) false)) ", ); @@ -70,8 +70,8 @@ fn uniq_deri_gcd_eval() { (let ( (output (#t - (#l (mod 17) ( #f8 #f9 #f16 #f0 #f0 ) ) ; s, from sage - (#l (mod 17) ( #f2 #f16 #f9 #f13 #f0 ) ) ; t, from sage + (#t #f8 #f9 #f16 #f0 #f0 ) ; s, from sage + (#t #f2 #f16 #f9 #f13 #f0 ) ; t, from sage )) ) false)) ", diff --git a/src/ir/term/ext/waksman.rs b/src/ir/term/ext/waksman.rs index c14adf02f..b65534814 100644 --- a/src/ir/term/ext/waksman.rs +++ b/src/ir/term/ext/waksman.rs @@ -9,15 +9,16 @@ use std::iter::FromIterator; /// Type-check [super::ExtOp::Waksman]. pub fn check(arg_sorts: &[&Sort]) -> Result { - array_or(arg_sorts[0], "sort argument") - .map(|(_, _, n_flows)| Sort::Tuple(vec![Sort::Bool; n_switches(n_flows)].into())) + let &[values] = ty::count_or_ref(arg_sorts)?; + let (n_flows, _v_sort) = ty::homogenous_tuple_or(values, "Waksman argument")?; + Ok(Sort::Tuple(vec![Sort::Bool; n_switches(n_flows)].into())) } /// Evaluate [super::ExtOp::Waksman]. pub fn eval(args: &[&Value]) -> Value { - let len = args[0].as_array().size; - let cfg = Config::for_sorting(args[0].as_array().values()); + let values = args[0].as_tuple(); + let cfg = Config::for_sorting(values.to_vec()); let switch_bools = Vec::from_iter(cfg.switches().into_iter().map(Value::Bool)); - assert_eq!(switch_bools.len(), n_switches(len)); + assert_eq!(switch_bools.len(), n_switches(values.len())); Value::Tuple(switch_bools.into()) } diff --git a/src/ir/term/text/mod.rs b/src/ir/term/text/mod.rs index af62e4400..cbd470961 100644 --- a/src/ir/term/text/mod.rs +++ b/src/ir/term/text/mod.rs @@ -1381,7 +1381,7 @@ mod test { let t = parse_term( b" (declare ( - (pairs (array (mod 17) (tuple (mod 17) bool) 5)) + (pairs (tuple 5 (tuple (mod 17) bool))) ) (uniq_deri_gcd pairs))", ); diff --git a/src/target/r1cs/mirage.rs b/src/target/r1cs/mirage.rs index efe4d1925..7dd81835d 100644 --- a/src/target/r1cs/mirage.rs +++ b/src/target/r1cs/mirage.rs @@ -463,6 +463,7 @@ where ) -> Self::Proof { assert_eq!(rand.len(), pk.data.num_commitments()); let rng = &mut rand::thread_rng(); + #[cfg(debug_assertions)] pk.data.check_all(witness); let rands: Vec = rand.iter().map(|r| r.0).collect(); let mut rng = &mut rand::thread_rng(); diff --git a/src/target/r1cs/wit_comp.rs b/src/target/r1cs/wit_comp.rs index b280788ae..95a267ad8 100644 --- a/src/target/r1cs/wit_comp.rs +++ b/src/target/r1cs/wit_comp.rs @@ -215,7 +215,7 @@ impl<'a> StagedWitCompEvaluator<'a> { rows.sort_by_key(|t| t.1); println!("time,op,nanos,counts,nanos_per,arg_sorts"); for (op, nanos, counts, nanos_per, arg_sorts) in &rows { - println!("time,{op},{nanos},{counts},{nanos_per},{arg_sorts}"); + println!("time,{op},{nanos},{counts},{nanos_per},\"{arg_sorts}\""); } } }