Skip to content

Commit

Permalink
Eliminate tuples in preprocessing (#202)
Browse files Browse the repository at this point in the history
This started as an optimization patch, but my first optimization
revealed a bug. In chasing the bug, I found more optimization.

Changes:
1. Eliminate tuples in preprocessing. (opt)
2. Handle CStore in tuple elimination pass. (bugfix)
3. Use tuples instead of arrays in a few more extension ops: (opt)
  * GCD for vanishing polynomials and their derivatives
  * sorting in transcript checking
4. A few logging revisions
  • Loading branch information
alex-ozdemir authored Jun 27, 2024
1 parent 913da60 commit 3479265
Show file tree
Hide file tree
Showing 15 changed files with 90 additions and 92 deletions.
2 changes: 1 addition & 1 deletion examples/ZoKrates/pf/mem/2024_05_31_benny_bug_tr.zok.vin
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(x #f6)
(return #f6)
(return #f0)
) false ; ignored
))

12 changes: 6 additions & 6 deletions examples/circ.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ fn main() {
Backend::Smt { .. } => Mode::Proof,
};
let language = determine_language(&options.frontend.language, &options.path);
println!("Running frontend");
let cs = match language {
#[cfg(all(feature = "smt", feature = "zok"))]
DeterminedLanguage::Zsharp => {
Expand Down Expand Up @@ -233,6 +234,7 @@ fn main() {
panic!("Missing feature: c");
}
};
println!("Running IR optimizations");
let cs = match mode {
Mode::Opt => opt(
cs,
Expand Down Expand Up @@ -295,8 +297,7 @@ fn main() {
opt(cs, opts)
}
};
println!("Done with IR optimization");

println!("Running backend");
match options.backend {
#[cfg(feature = "r1cs")]
Backend::R1cs {
Expand All @@ -306,15 +307,14 @@ fn main() {
proof_impl,
..
} => {
println!("Converting to r1cs");
let cs = cs.get("main");
trace!("IR: {}", circ::ir::term::text::serialize_computation(cs));
let mut r1cs = to_r1cs(cs, cfg());
if cfg().r1cs.profile {
println!("R1CS stats: {:#?}", r1cs.stats());
}

println!("Pre-opt R1cs size: {}", r1cs.constraints().len());
println!("Running r1cs optimizations ");
r1cs = reduce_linearities(r1cs, cfg());

println!("Final R1cs size: {}", r1cs.constraints().len());
Expand All @@ -326,7 +326,7 @@ fn main() {
ProofAction::Count => (),
#[cfg(feature = "bellman")]
ProofAction::Setup => {
println!("Generating Parameters");
println!("Running Setup");
match proof_impl {
ProofImpl::Groth16 => Bellman::<Bls12>::setup_fs(
prover_data,
Expand All @@ -348,7 +348,7 @@ fn main() {
ProofAction::Setup => panic!("Missing feature: bellman"),
#[cfg(feature = "bellman")]
ProofAction::CpSetup => {
println!("Generating Parameters");
println!("Running CpSetup");
match proof_impl {
ProofImpl::Groth16 => panic!("Groth16 is not CP"),
ProofImpl::Mirage => Mirage::<Bls12>::cp_setup_fs(
Expand Down
5 changes: 3 additions & 2 deletions scripts/zokrates_test.zsh
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ function pf_test_isolate {
done
}

pf_test 2024_05_24_benny_bug
pf_test 2024_05_31_benny_bug
r1cs_test_count ./examples/ZoKrates/pf/mm4_cond.zok 120
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsAdd.zok
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOnCurve.zok
Expand Down Expand Up @@ -110,4 +108,7 @@ pf_test var_idx_arr_str_arr_str
pf_test mm
pf_test unused_var

pf_test 2024_05_24_benny_bug
pf_test 2024_05_31_benny_bug

scripts/zx_tests/run_tests.sh
13 changes: 4 additions & 9 deletions src/ir/opt/mem/ram/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,7 @@ fn range_check(
debug_assert!(values.iter().all(|v| check(v) == f_sort));
let mut ms_hash_inputs = values.clone();
values.extend(f_sort.elems_iter().take(n));
let sorted_term = unmake_array(
term![Op::ExtOp(ExtOp::Sort); make_array(f_sort.clone(), f_sort.clone(), values.clone())],
);
let sorted_term = tuple_terms(term![Op::ExtOp(ExtOp::Sort); term(Op::Tuple, values.clone())]);
let sorted: Vec<Term> = sorted_term
.into_iter()
.enumerate()
Expand Down Expand Up @@ -302,10 +300,7 @@ fn derivative_gcd(
let ns = ns.subspace("uniq");
let fs = Sort::Field(f.clone());
let pairs = term(
Op::Array(Box::new(ArrayOp {
key: fs.clone(),
val: Sort::new_tuple(vec![fs.clone(), Sort::Bool]),
})),
Op::Tuple,
values
.clone()
.into_iter()
Expand All @@ -314,8 +309,8 @@ fn derivative_gcd(
.collect(),
);
let two_polys = term![Op::ExtOp(ExtOp::UniqDeriGcd); pairs];
let s_coeffs = unmake_array(term![Op::Field(0); two_polys.clone()]);
let t_coeffs = unmake_array(term![Op::Field(1); two_polys]);
let s_coeffs = tuple_terms(term![Op::Field(0); two_polys.clone()]);
let t_coeffs = tuple_terms(term![Op::Field(1); two_polys]);
let mut decl_poly = |coeffs: Vec<Term>, poly_name: &str| -> Vec<Term> {
coeffs
.into_iter()
Expand Down
24 changes: 9 additions & 15 deletions src/ir/opt/mem/ram/checker/permutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,15 @@ pub(super) fn waksman(
val_sort: &Sort,
new_var: &mut impl FnMut(&str, Term) -> Term,
) -> Vec<Access> {
let f = &cfg.field;
let f_s = Sort::Field(f.clone());
// (1) sort the transcript
let field_tuples: Vec<Term> = accesses.iter().map(|a| a.to_field_tuple(cfg)).collect();
let switch_settings_tuple = term![Op::ExtOp(ExtOp::Waksman); make_array(f_s.clone(), check(&field_tuples[0]), field_tuples.clone())];
let n = check(&switch_settings_tuple).as_tuple().len();
let mut switch_settings: VecDeque<Term> = (0..n)
.map(|i| {
new_var(
&format!("sw{}", i),
term![Op::Field(i); switch_settings_tuple.clone()],
)
})
let switch_settings_tuple =
term![Op::ExtOp(ExtOp::Waksman); term(Op::Tuple, field_tuples.clone())];
let switch_settings_terms = tuple_terms(switch_settings_tuple);
let mut switch_settings: VecDeque<Term> = switch_settings_terms
.into_iter()
.enumerate()
.map(|(i, t)| new_var(&format!("sw{}", i), t))
.collect();

let sorted_field_tuple_values: Vec<Term> =
Expand Down Expand Up @@ -69,12 +65,10 @@ pub(super) fn msh(
assertions: &mut Vec<Term>,
) -> Vec<Access> {
let f = &cfg.field;
let f_s = Sort::Field(f.clone());
// (1) sort the transcript
let field_tuples: Vec<Term> = accesses.iter().map(|a| a.to_field_tuple(cfg)).collect();
let sorted_field_tuple_values: Vec<Term> = unmake_array(
term![Op::ExtOp(ExtOp::Sort); make_array(f_s.clone(), check(&field_tuples[0]), field_tuples.clone())],
);
let sorted_field_tuple_values: Vec<Term> =
tuple_terms(term![Op::ExtOp(ExtOp::Sort); term(Op::Tuple, field_tuples.clone())]);
let mut sorted_accesses: Vec<Access> = sorted_field_tuple_values
.into_iter()
.enumerate()
Expand Down
2 changes: 1 addition & 1 deletion src/ir/opt/mem/ram/volatile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl ArrayGraph {
.collect();
while let Some(top) = stack.pop() {
if ram_terms.insert(top.clone()) {
trace!("Maybe RAM: {}", top);
trace!("Maybe RAM: {}", top.op());
for p in ps.get(&top).unwrap() {
if right_sort(p, field) {
stack.push(p.clone());
Expand Down
29 changes: 21 additions & 8 deletions src/ir/opt/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ impl TupleTree {
}
}
}
#[track_caller]
fn unwrap_non_tuple(self) -> Term {
match self {
TupleTree::NonTuple(t) => t,
Expand Down Expand Up @@ -245,9 +246,13 @@ fn tuple_free(t: Term) -> bool {
/// Run the tuple elimination pass.
pub fn eliminate_tuples(cs: &mut Computation) {
let mut lifted: TermMap<TupleTree> = TermMap::default();
let terms =
PostOrderIter::from_roots_and_skips(cs.outputs().iter().cloned(), Default::default());
// .chain(cs.precomputes.outputs().values().cloned()),
let terms = PostOrderIter::from_roots_and_skips(
cs.outputs()
.iter()
.cloned()
.chain(cs.precomputes.outputs().values().cloned()),
Default::default(),
);
for t in terms {
let mut cs: Vec<TupleTree> = t
.cs()
Expand All @@ -270,6 +275,14 @@ pub fn eliminate_tuples(cs: &mut Computation) {
let eqs = zip_eq(a.flatten(), b.flatten()).map(|(a, b)| term![Op::Eq; a, b]);
TupleTree::NonTuple(term(AND, eqs.collect()))
}
Op::CStore => {
let c = cs.pop().unwrap().unwrap_non_tuple();
let v = cs.pop().unwrap();
let i = cs.pop().unwrap().unwrap_non_tuple();
let a = cs.pop().unwrap();
debug_assert!(cs.is_empty());
a.bimap(|a, v| term![Op::CStore; a, i.clone(), v, c.clone()], &v)
}
Op::Store => {
let v = cs.pop().unwrap();
let i = cs.pop().unwrap().unwrap_non_tuple();
Expand Down Expand Up @@ -321,11 +334,11 @@ pub fn eliminate_tuples(cs: &mut Computation) {
.into_iter()
.flat_map(|o| lifted.get(&o).unwrap().clone().flatten())
.collect();
// let os = cs.precomputes.outputs().clone();
// for (name, old_term) in os {
// let new_term = lifted.get(&old_term).unwrap().clone().as_term();
// cs.precomputes.change_output(&name, new_term);
// }
let os = cs.precomputes.outputs().clone();
for (name, old_term) in os {
let new_term = lifted.get(&old_term).unwrap().clone().as_term();
cs.precomputes.change_output(&name, new_term);
}
#[cfg(debug_assertions)]
for o in &cs.outputs {
if let Some(t) = find_tuple_term(o.clone()) {
Expand Down
4 changes: 2 additions & 2 deletions src/ir/term/ext/haboeck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
use crate::ir::term::ty::*;
use crate::ir::term::*;

/// Type-check [super::ExtOp::UniqDeriGcd].
/// Type-check [super::ExtOp::Haboeck].
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
let &[haystack, needles] = ty::count_or_ref(arg_sorts)?;
let (_n, value0) = ty::homogenous_tuple_or(haystack, "haystack must be a tuple")?;
Expand All @@ -21,7 +21,7 @@ pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
Ok(haystack.clone())
}

/// Evaluate [super::ExtOp::UniqDeriGcd].
/// Evaluate [super::ExtOp::Haboeck].
pub fn eval(args: &[&Value]) -> Value {
let haystack: Vec<FieldV> = args[0]
.as_tuple()
Expand Down
38 changes: 11 additions & 27 deletions src/ir/term/ext/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,24 @@ use crate::ir::term::*;

/// Type-check [super::ExtOp::UniqDeriGcd].
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
if let &[pairs] = arg_sorts {
let (key, value, size) = ty::array_or(pairs, "UniqDeriGcd pairs")?;
let f = pf_or(key, "UniqDeriGcd pairs: indices must be field")?;
let value_tup = ty::tuple_or(value, "UniqDeriGcd entries: value must be a tuple")?;
if let &[root, cond] = &value_tup {
eq_or(f, root, "UniqDeriGcd pairs: first element must be a field")?;
eq_or(
cond,
&Sort::Bool,
"UniqDeriGcd pairs: second element must be a bool",
)?;
let arr = Sort::new_array(f.clone(), f.clone(), size);
Ok(Sort::new_tuple(vec![arr.clone(), arr]))
} else {
// non-pair entries value
Err(TypeErrorReason::Custom(
"UniqDeriGcd: pairs value must be a pair".into(),
))
}
} else {
// wrong arg count
Err(TypeErrorReason::ExpectedArgs(2, arg_sorts.len()))
}
let [pairs] = ty::count_or_ref(arg_sorts)?;
let (size, value) = ty::homogenous_tuple_or(pairs, "UniqDeriGcd")?;
let [root, cond] = ty::count_or(ty::tuple_or(value, "UniqDeriGcd")?)?;
let f = pf_or(root, "UniqDeriGcd: first is field")?;
eq_or(cond, &Sort::Bool, "UniqDeriGcd pairs: second is bool")?;
let coeffs = Sort::new_tuple(vec![f.clone(); size]);
Ok(Sort::new_tuple(vec![coeffs.clone(), coeffs]))
}

/// Evaluate [super::ExtOp::UniqDeriGcd].
#[cfg(feature = "poly")]
pub fn eval(args: &[&Value]) -> Value {
use rug_polynomial::ModPoly;
let sort = args[0].sort().as_array().0.clone();
let sort = args[0].sort().as_tuple()[0].as_tuple()[0].clone();
let field = sort.as_pf().clone();
let mut roots: Vec<Integer> = Vec::new();
let deg = args[0].as_array().size;
for t in args[0].as_array().values() {
let deg = args[0].as_tuple().len();
for t in args[0].as_tuple() {
let tuple = t.as_tuple();
let cond = tuple[1].as_bool();
if cond {
Expand All @@ -60,7 +44,7 @@ pub fn eval(args: &[&Value]) -> Value {
let v: Vec<Value> = (0..deg)
.map(|i| Value::Field(field.new_v(s.get_coefficient(i))))
.collect();
Value::Array(Array::from_vec(sort.clone(), sort.clone(), v))
Value::Tuple(v.into())
};
let s_cs = coeff_arr(s);
let t_cs = coeff_arr(t);
Expand Down
23 changes: 16 additions & 7 deletions src/ir/term/ext/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,21 @@ pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
/// Evaluate [super::ExtOp::Sort].
pub fn eval(args: &[&Value]) -> Value {
let sort = args[0].sort();
let (key_sort, value_sort, _) = sort.as_array();
let mut values: Vec<Value> = args[0].as_array().values();
let is_array = sort.is_array();
let mut values: Vec<Value> = if is_array {
args[0].as_array().values()
} else {
args[0].as_tuple().to_vec()
};
values.sort();
Value::Array(Array::from_vec(
key_sort.clone(),
value_sort.clone(),
values,
))
if is_array {
let (key_sort, value_sort, _) = sort.as_array();
Value::Array(Array::from_vec(
key_sort.clone(),
value_sort.clone(),
values,
))
} else {
Value::Tuple(values.into())
}
}
14 changes: 7 additions & 7 deletions src/ir/term/ext/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fn uniq_deri_gcd_eval() {
let t = text::parse_term(
b"
(declare (
(pairs (array (mod 17) (tuple (mod 17) bool) 5))
(pairs (tuple 5 (tuple (mod 17) bool)))
)
(uniq_deri_gcd pairs))",
);
Expand All @@ -35,7 +35,7 @@ fn uniq_deri_gcd_eval() {
(set_default_modulus 17
(let
(
(pairs (#l (mod 17) ( (#t #f0 false) (#t #f1 false) (#t #f2 true) (#t #f3 false) (#t #f4 true) )))
(pairs (#t (#t #f0 false) (#t #f1 false) (#t #f2 true) (#t #f3 false) (#t #f4 true) ))
) false))
",
);
Expand All @@ -46,8 +46,8 @@ fn uniq_deri_gcd_eval() {
(let
(
(output (#t
(#l (mod 17) ( #f16 #f0 #f0 #f0 #f0 ) ) ; s, from sage
(#l (mod 17) ( #f7 #f9 #f0 #f0 #f0 ) ) ; t, from sage
(#t #f16 #f0 #f0 #f0 #f0 ) ; s, from sage
(#t #f7 #f9 #f0 #f0 #f0 ) ; t, from sage
))
) false))
",
Expand All @@ -59,7 +59,7 @@ fn uniq_deri_gcd_eval() {
(set_default_modulus 17
(let
(
(pairs (#l (mod 17) ( (#t #f0 true) (#t #f1 true) (#t #f2 true) (#t #f3 false) (#t #f4 true) )))
(pairs (#t (#t #f0 true) (#t #f1 true) (#t #f2 true) (#t #f3 false) (#t #f4 true)))
) false))
",
);
Expand All @@ -70,8 +70,8 @@ fn uniq_deri_gcd_eval() {
(let
(
(output (#t
(#l (mod 17) ( #f8 #f9 #f16 #f0 #f0 ) ) ; s, from sage
(#l (mod 17) ( #f2 #f16 #f9 #f13 #f0 ) ) ; t, from sage
(#t #f8 #f9 #f16 #f0 #f0 ) ; s, from sage
(#t #f2 #f16 #f9 #f13 #f0 ) ; t, from sage
))
) false))
",
Expand Down
Loading

0 comments on commit 3479265

Please sign in to comment.