Skip to content

Commit

Permalink
Eliminate tuples in preprocessing
Browse files Browse the repository at this point in the history
This started as an optimization patch, but my first optimization
revealed a bug. In chasing the bug, I found more optimization.

Changes:
1. Eliminate tuples in preprocessing. (opt)
2. Handle CStore in tuple elimination pass. (bugfix)
3. Use tuples instead of arrays in a few more extension ops: (opt)
  * GCD for vanishing polynomials and their derivatives
  * sorting in transcript checking
4. A few logging revisions
  • Loading branch information
alex-ozdemir committed Jun 27, 2024
1 parent 913da60 commit aaab803
Show file tree
Hide file tree
Showing 15 changed files with 90 additions and 92 deletions.
2 changes: 1 addition & 1 deletion examples/ZoKrates/pf/mem/2024_05_31_benny_bug_tr.zok.vin
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
(set_default_modulus 52435875175126190479447740508185965837690552500527637822603658699938581184513
(let (
(x #f6)
(return #f6)
(return #f0)
) false ; ignored
))

12 changes: 6 additions & 6 deletions examples/circ.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ fn main() {
Backend::Smt { .. } => Mode::Proof,
};
let language = determine_language(&options.frontend.language, &options.path);
println!("Running frontend");
let cs = match language {
#[cfg(all(feature = "smt", feature = "zok"))]
DeterminedLanguage::Zsharp => {
Expand Down Expand Up @@ -233,6 +234,7 @@ fn main() {
panic!("Missing feature: c");
}
};
println!("Running IR optimizations");
let cs = match mode {
Mode::Opt => opt(
cs,
Expand Down Expand Up @@ -295,8 +297,7 @@ fn main() {
opt(cs, opts)
}
};
println!("Done with IR optimization");

println!("Running backend");
match options.backend {
#[cfg(feature = "r1cs")]
Backend::R1cs {
Expand All @@ -306,15 +307,14 @@ fn main() {
proof_impl,
..
} => {
println!("Converting to r1cs");
let cs = cs.get("main");
trace!("IR: {}", circ::ir::term::text::serialize_computation(cs));
let mut r1cs = to_r1cs(cs, cfg());
if cfg().r1cs.profile {
println!("R1CS stats: {:#?}", r1cs.stats());
}

println!("Pre-opt R1cs size: {}", r1cs.constraints().len());
println!("Running r1cs optimizations ");
r1cs = reduce_linearities(r1cs, cfg());

println!("Final R1cs size: {}", r1cs.constraints().len());
Expand All @@ -326,7 +326,7 @@ fn main() {
ProofAction::Count => (),
#[cfg(feature = "bellman")]
ProofAction::Setup => {
println!("Generating Parameters");
println!("Running Setup");
match proof_impl {
ProofImpl::Groth16 => Bellman::<Bls12>::setup_fs(
prover_data,
Expand All @@ -348,7 +348,7 @@ fn main() {
ProofAction::Setup => panic!("Missing feature: bellman"),
#[cfg(feature = "bellman")]
ProofAction::CpSetup => {
println!("Generating Parameters");
println!("Running CpSetup");
match proof_impl {
ProofImpl::Groth16 => panic!("Groth16 is not CP"),
ProofImpl::Mirage => Mirage::<Bls12>::cp_setup_fs(
Expand Down
5 changes: 3 additions & 2 deletions scripts/zokrates_test.zsh
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ function pf_test_isolate {
done
}

pf_test 2024_05_24_benny_bug
pf_test 2024_05_31_benny_bug
r1cs_test_count ./examples/ZoKrates/pf/mm4_cond.zok 120
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsAdd.zok
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOnCurve.zok
Expand Down Expand Up @@ -110,4 +108,7 @@ pf_test var_idx_arr_str_arr_str
pf_test mm
pf_test unused_var

pf_test 2024_05_24_benny_bug
pf_test 2024_05_31_benny_bug

scripts/zx_tests/run_tests.sh
13 changes: 4 additions & 9 deletions src/ir/opt/mem/ram/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,7 @@ fn range_check(
debug_assert!(values.iter().all(|v| check(v) == f_sort));
let mut ms_hash_inputs = values.clone();
values.extend(f_sort.elems_iter().take(n));
let sorted_term = unmake_array(
term![Op::ExtOp(ExtOp::Sort); make_array(f_sort.clone(), f_sort.clone(), values.clone())],
);
let sorted_term = tuple_terms(term![Op::ExtOp(ExtOp::Sort); term(Op::Tuple, values.clone())]);
let sorted: Vec<Term> = sorted_term
.into_iter()
.enumerate()
Expand Down Expand Up @@ -302,10 +300,7 @@ fn derivative_gcd(
let ns = ns.subspace("uniq");
let fs = Sort::Field(f.clone());
let pairs = term(
Op::Array(Box::new(ArrayOp {
key: fs.clone(),
val: Sort::new_tuple(vec![fs.clone(), Sort::Bool]),
})),
Op::Tuple,
values
.clone()
.into_iter()
Expand All @@ -314,8 +309,8 @@ fn derivative_gcd(
.collect(),
);
let two_polys = term![Op::ExtOp(ExtOp::UniqDeriGcd); pairs];
let s_coeffs = unmake_array(term![Op::Field(0); two_polys.clone()]);
let t_coeffs = unmake_array(term![Op::Field(1); two_polys]);
let s_coeffs = tuple_terms(term![Op::Field(0); two_polys.clone()]);
let t_coeffs = tuple_terms(term![Op::Field(1); two_polys]);
let mut decl_poly = |coeffs: Vec<Term>, poly_name: &str| -> Vec<Term> {
coeffs
.into_iter()
Expand Down
24 changes: 9 additions & 15 deletions src/ir/opt/mem/ram/checker/permutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,15 @@ pub(super) fn waksman(
val_sort: &Sort,
new_var: &mut impl FnMut(&str, Term) -> Term,
) -> Vec<Access> {
let f = &cfg.field;
let f_s = Sort::Field(f.clone());
// (1) sort the transcript
let field_tuples: Vec<Term> = accesses.iter().map(|a| a.to_field_tuple(cfg)).collect();
let switch_settings_tuple = term![Op::ExtOp(ExtOp::Waksman); make_array(f_s.clone(), check(&field_tuples[0]), field_tuples.clone())];
let n = check(&switch_settings_tuple).as_tuple().len();
let mut switch_settings: VecDeque<Term> = (0..n)
.map(|i| {
new_var(
&format!("sw{}", i),
term![Op::Field(i); switch_settings_tuple.clone()],
)
})
let switch_settings_tuple =
term![Op::ExtOp(ExtOp::Waksman); term(Op::Tuple, field_tuples.clone())];
let switch_settings_terms = tuple_terms(switch_settings_tuple);
let mut switch_settings: VecDeque<Term> = switch_settings_terms
.into_iter()
.enumerate()
.map(|(i, t)| new_var(&format!("sw{}", i), t))
.collect();

let sorted_field_tuple_values: Vec<Term> =
Expand Down Expand Up @@ -69,12 +65,10 @@ pub(super) fn msh(
assertions: &mut Vec<Term>,
) -> Vec<Access> {
let f = &cfg.field;
let f_s = Sort::Field(f.clone());
// (1) sort the transcript
let field_tuples: Vec<Term> = accesses.iter().map(|a| a.to_field_tuple(cfg)).collect();
let sorted_field_tuple_values: Vec<Term> = unmake_array(
term![Op::ExtOp(ExtOp::Sort); make_array(f_s.clone(), check(&field_tuples[0]), field_tuples.clone())],
);
let sorted_field_tuple_values: Vec<Term> =
tuple_terms(term![Op::ExtOp(ExtOp::Sort); term(Op::Tuple, field_tuples.clone())]);
let mut sorted_accesses: Vec<Access> = sorted_field_tuple_values
.into_iter()
.enumerate()
Expand Down
2 changes: 1 addition & 1 deletion src/ir/opt/mem/ram/volatile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl ArrayGraph {
.collect();
while let Some(top) = stack.pop() {
if ram_terms.insert(top.clone()) {
trace!("Maybe RAM: {}", top);
trace!("Maybe RAM: {}", top.op());
for p in ps.get(&top).unwrap() {
if right_sort(p, field) {
stack.push(p.clone());
Expand Down
29 changes: 21 additions & 8 deletions src/ir/opt/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ impl TupleTree {
}
}
}
#[track_caller]
fn unwrap_non_tuple(self) -> Term {
match self {
TupleTree::NonTuple(t) => t,
Expand Down Expand Up @@ -245,9 +246,13 @@ fn tuple_free(t: Term) -> bool {
/// Run the tuple elimination pass.
pub fn eliminate_tuples(cs: &mut Computation) {
let mut lifted: TermMap<TupleTree> = TermMap::default();
let terms =
PostOrderIter::from_roots_and_skips(cs.outputs().iter().cloned(), Default::default());
// .chain(cs.precomputes.outputs().values().cloned()),
let terms = PostOrderIter::from_roots_and_skips(
cs.outputs()
.iter()
.cloned()
.chain(cs.precomputes.outputs().values().cloned()),
Default::default(),
);
for t in terms {
let mut cs: Vec<TupleTree> = t
.cs()
Expand All @@ -270,6 +275,14 @@ pub fn eliminate_tuples(cs: &mut Computation) {
let eqs = zip_eq(a.flatten(), b.flatten()).map(|(a, b)| term![Op::Eq; a, b]);
TupleTree::NonTuple(term(AND, eqs.collect()))
}
Op::CStore => {
let c = cs.pop().unwrap().unwrap_non_tuple();
let v = cs.pop().unwrap();
let i = cs.pop().unwrap().unwrap_non_tuple();
let a = cs.pop().unwrap();
debug_assert!(cs.is_empty());
a.bimap(|a, v| term![Op::CStore; a, i.clone(), v, c.clone()], &v)
}
Op::Store => {
let v = cs.pop().unwrap();
let i = cs.pop().unwrap().unwrap_non_tuple();
Expand Down Expand Up @@ -321,11 +334,11 @@ pub fn eliminate_tuples(cs: &mut Computation) {
.into_iter()
.flat_map(|o| lifted.get(&o).unwrap().clone().flatten())
.collect();
// let os = cs.precomputes.outputs().clone();
// for (name, old_term) in os {
// let new_term = lifted.get(&old_term).unwrap().clone().as_term();
// cs.precomputes.change_output(&name, new_term);
// }
let os = cs.precomputes.outputs().clone();
for (name, old_term) in os {
let new_term = lifted.get(&old_term).unwrap().clone().as_term();
cs.precomputes.change_output(&name, new_term);
}
#[cfg(debug_assertions)]
for o in &cs.outputs {
if let Some(t) = find_tuple_term(o.clone()) {
Expand Down
4 changes: 2 additions & 2 deletions src/ir/term/ext/haboeck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
use crate::ir::term::ty::*;
use crate::ir::term::*;

/// Type-check [super::ExtOp::UniqDeriGcd].
/// Type-check [super::ExtOp::Haboeck].
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
let &[haystack, needles] = ty::count_or_ref(arg_sorts)?;
let (_n, value0) = ty::homogenous_tuple_or(haystack, "haystack must be a tuple")?;
Expand All @@ -21,7 +21,7 @@ pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
Ok(haystack.clone())
}

/// Evaluate [super::ExtOp::UniqDeriGcd].
/// Evaluate [super::ExtOp::Haboeck].
pub fn eval(args: &[&Value]) -> Value {
let haystack: Vec<FieldV> = args[0]
.as_tuple()
Expand Down
38 changes: 11 additions & 27 deletions src/ir/term/ext/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,24 @@ use crate::ir::term::*;

/// Type-check [super::ExtOp::UniqDeriGcd].
pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
if let &[pairs] = arg_sorts {
let (key, value, size) = ty::array_or(pairs, "UniqDeriGcd pairs")?;
let f = pf_or(key, "UniqDeriGcd pairs: indices must be field")?;
let value_tup = ty::tuple_or(value, "UniqDeriGcd entries: value must be a tuple")?;
if let &[root, cond] = &value_tup {
eq_or(f, root, "UniqDeriGcd pairs: first element must be a field")?;
eq_or(
cond,
&Sort::Bool,
"UniqDeriGcd pairs: second element must be a bool",
)?;
let arr = Sort::new_array(f.clone(), f.clone(), size);
Ok(Sort::new_tuple(vec![arr.clone(), arr]))
} else {
// non-pair entries value
Err(TypeErrorReason::Custom(
"UniqDeriGcd: pairs value must be a pair".into(),
))
}
} else {
// wrong arg count
Err(TypeErrorReason::ExpectedArgs(2, arg_sorts.len()))
}
let [pairs] = ty::count_or_ref(arg_sorts)?;
let (size, value) = ty::homogenous_tuple_or(pairs, "UniqDeriGcd")?;
let [root, cond] = ty::count_or(ty::tuple_or(value, "UniqDeriGcd")?)?;
let f = pf_or(root, "UniqDeriGcd: first is field")?;
eq_or(cond, &Sort::Bool, "UniqDeriGcd pairs: second is bool")?;
let coeffs = Sort::new_tuple(vec![f.clone(); size]);
Ok(Sort::new_tuple(vec![coeffs.clone(), coeffs]))
}

/// Evaluate [super::ExtOp::UniqDeriGcd].
#[cfg(feature = "poly")]
pub fn eval(args: &[&Value]) -> Value {
use rug_polynomial::ModPoly;
let sort = args[0].sort().as_array().0.clone();
let sort = args[0].sort().as_tuple()[0].as_tuple()[0].clone();
let field = sort.as_pf().clone();
let mut roots: Vec<Integer> = Vec::new();
let deg = args[0].as_array().size;
for t in args[0].as_array().values() {
let deg = args[0].as_tuple().len();
for t in args[0].as_tuple() {
let tuple = t.as_tuple();
let cond = tuple[1].as_bool();
if cond {
Expand All @@ -60,7 +44,7 @@ pub fn eval(args: &[&Value]) -> Value {
let v: Vec<Value> = (0..deg)
.map(|i| Value::Field(field.new_v(s.get_coefficient(i))))
.collect();
Value::Array(Array::from_vec(sort.clone(), sort.clone(), v))
Value::Tuple(v.into())
};
let s_cs = coeff_arr(s);
let t_cs = coeff_arr(t);
Expand Down
23 changes: 16 additions & 7 deletions src/ir/term/ext/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,21 @@ pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
/// Evaluate [super::ExtOp::Sort].
pub fn eval(args: &[&Value]) -> Value {
let sort = args[0].sort();
let (key_sort, value_sort, _) = sort.as_array();
let mut values: Vec<Value> = args[0].as_array().values();
let is_array = sort.is_array();
let mut values: Vec<Value> = if is_array {
args[0].as_array().values()
} else {
args[0].as_tuple().to_vec()
};
values.sort();
Value::Array(Array::from_vec(
key_sort.clone(),
value_sort.clone(),
values,
))
if is_array {
let (key_sort, value_sort, _) = sort.as_array();
Value::Array(Array::from_vec(
key_sort.clone(),
value_sort.clone(),
values,
))
} else {
Value::Tuple(values.into())
}
}
14 changes: 7 additions & 7 deletions src/ir/term/ext/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fn uniq_deri_gcd_eval() {
let t = text::parse_term(
b"
(declare (
(pairs (array (mod 17) (tuple (mod 17) bool) 5))
(pairs (tuple 5 (tuple (mod 17) bool)))
)
(uniq_deri_gcd pairs))",
);
Expand All @@ -35,7 +35,7 @@ fn uniq_deri_gcd_eval() {
(set_default_modulus 17
(let
(
(pairs (#l (mod 17) ( (#t #f0 false) (#t #f1 false) (#t #f2 true) (#t #f3 false) (#t #f4 true) )))
(pairs (#t (#t #f0 false) (#t #f1 false) (#t #f2 true) (#t #f3 false) (#t #f4 true) ))
) false))
",
);
Expand All @@ -46,8 +46,8 @@ fn uniq_deri_gcd_eval() {
(let
(
(output (#t
(#l (mod 17) ( #f16 #f0 #f0 #f0 #f0 ) ) ; s, from sage
(#l (mod 17) ( #f7 #f9 #f0 #f0 #f0 ) ) ; t, from sage
(#t #f16 #f0 #f0 #f0 #f0 ) ; s, from sage
(#t #f7 #f9 #f0 #f0 #f0 ) ; t, from sage
))
) false))
",
Expand All @@ -59,7 +59,7 @@ fn uniq_deri_gcd_eval() {
(set_default_modulus 17
(let
(
(pairs (#l (mod 17) ( (#t #f0 true) (#t #f1 true) (#t #f2 true) (#t #f3 false) (#t #f4 true) )))
(pairs (#t (#t #f0 true) (#t #f1 true) (#t #f2 true) (#t #f3 false) (#t #f4 true)))
) false))
",
);
Expand All @@ -70,8 +70,8 @@ fn uniq_deri_gcd_eval() {
(let
(
(output (#t
(#l (mod 17) ( #f8 #f9 #f16 #f0 #f0 ) ) ; s, from sage
(#l (mod 17) ( #f2 #f16 #f9 #f13 #f0 ) ) ; t, from sage
(#t #f8 #f9 #f16 #f0 #f0 ) ; s, from sage
(#t #f2 #f16 #f9 #f13 #f0 ) ; t, from sage
))
) false))
",
Expand Down
Loading

0 comments on commit aaab803

Please sign in to comment.