diff --git a/Cargo.lock b/Cargo.lock index 6a65675ba..e7049f1bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -30,13 +30,14 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" -version = "0.7.6" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ - "getrandom 0.2.10", + "cfg-if", "once_cell", "version_check", + "zerocopy", ] [[package]] @@ -48,6 +49,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" + [[package]] name = "anstream" version = "0.5.0" @@ -858,19 +865,14 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.12.3" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" dependencies = [ "ahash", + "allocator-api2", ] -[[package]] -name = "hashbrown" -version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" - [[package]] name = "heck" version = "0.4.1" @@ -925,7 +927,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" dependencies = [ "equivalent", - "hashbrown 0.14.0", + "hashbrown", ] [[package]] @@ -1037,11 +1039,11 @@ dependencies = [ [[package]] name = "lru" -version = "0.7.8" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999beba7b6e8345721bd280141ed958096a2e4abdf74f67ff4ce49b4b54e47a" +checksum = "d3262e75e648fce39813cb56ac41f3c3e3f65217ebf3844d818d1f9398cfb0dc" dependencies = [ - "hashbrown 0.12.3", + "hashbrown", ] [[package]] @@ -1846,6 +1848,26 @@ version = "0.8.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bab77e97b50aee93da431f2cee7cd0f43b4d1da3c408042f2d7d164187774f0a" +[[package]] +name = "zerocopy" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", +] + [[package]] name = "zeroize" version = "1.3.0" diff --git a/circ_hc/Cargo.toml b/circ_hc/Cargo.toml index 57c67cd42..18664ecf2 100644 --- a/circ_hc/Cargo.toml +++ b/circ_hc/Cargo.toml @@ -14,7 +14,7 @@ default = ["hashconsing", "lru", "raw", "rc"] fxhash = "0.2.1" hashconsing = { git = "https://github.com/alex-ozdemir/hashconsing.git", branch = "phash", optional = true } log = "0.4" -lru = { version = "0.7.2", optional = true } +lru = { version = "0.12", optional = true } [dev-dependencies] quickcheck = "1" diff --git a/circ_hc/src/collections/lru.rs b/circ_hc/src/collections/lru.rs index ee2f93410..1d4b5fce7 100644 --- a/circ_hc/src/collections/lru.rs +++ b/circ_hc/src/collections/lru.rs @@ -1,5 +1,7 @@ //! A LRU cache from terms to values which does not retain its keys. +use std::num::NonZero; + use crate::Table; /// A LRU cache from terms to values which does not retain its keys. @@ -11,7 +13,7 @@ impl<Op, T: Table<Op>, V> NodeLruCache<Op, T, V> { /// Create an empty cache with room for `n` items. pub fn with_capacity(n: usize) -> Self { Self { - inner: lru::LruCache::new(n), + inner: lru::LruCache::new(NonZero::new(n).unwrap()), } } } diff --git a/examples/circ.rs b/examples/circ.rs index b062528b8..05fb38e36 100644 --- a/examples/circ.rs +++ b/examples/circ.rs @@ -404,7 +404,7 @@ fn main() { .ordered_inputs() .iter() .map(|term| match term.op() { - Op::Var(n, s) => (n.clone(), s.clone()), + Op::Var(v) => (v.name.to_string(), v.sort.clone()), _ => unreachable!(), }) .collect(); diff --git a/examples/opa_bench.rs b/examples/opa_bench.rs index ffe07b5b3..c4fbd56c2 100644 --- a/examples/opa_bench.rs +++ b/examples/opa_bench.rs @@ -20,7 +20,7 @@ fn main() { .format_timestamp(None) .init(); let options = Options::parse(); - let v = leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))); + let v = var("a".to_owned(), Sort::BitVector(32)); let mut t = v.clone(); for _i in 0..options.n_mults { t = term![BV_MUL; t.clone(), t.clone()]; diff --git a/src/circify/mem.rs b/src/circify/mem.rs index 8d14b0315..900b0da22 100644 --- a/src/circify/mem.rs +++ b/src/circify/mem.rs @@ -55,11 +55,11 @@ impl MemManager { /// Allocate a new stack array, equal to `array`. pub fn allocate(&mut self, array: Term) -> AllocId { let s = check(&array); - if let Sort::Array(box_addr_width, box_val_width, size) = s { - if let Sort::BitVector(addr_width) = *box_addr_width { - if let Sort::BitVector(val_width) = *box_val_width { + if let Sort::Array(a) = s { + if let Sort::BitVector(addr_width) = &a.key { + if let Sort::BitVector(val_width) = &a.val { let id = self.take_next_id(); - let alloc = Alloc::new(addr_width, val_width, size, array); + let alloc = Alloc::new(*addr_width, *val_width, a.size, array); self.allocs.insert(id, alloc); id } else { @@ -85,11 +85,11 @@ impl MemManager { /// /// Returns a (concrete) allocation identifier which can be used to access this allocation. pub fn zero_allocate(&mut self, size: usize, addr_width: usize, val_width: usize) -> AllocId { - self.allocate(term![Op::Const(Value::Array(Array::default( + self.allocate(const_(Value::Array(Array::default( Sort::BitVector(addr_width), &Sort::BitVector(val_width), - size - )))]) + size, + )))) } /// Load the value of index `offset` from the allocation `id`. @@ -130,7 +130,7 @@ impl MemManager { } } -#[cfg(all(feature = "smt", feature = "test", feature = "zok"))] +#[cfg(all(feature = "smt", test, feature = "zok"))] mod test { use super::*; use crate::target::smt::check_sat; @@ -138,20 +138,16 @@ mod test { use std::rc::Rc; fn bv_var(s: &str, w: usize) -> Term { - leaf_term(Op::Var(s.to_owned(), Sort::BitVector(w))) + var(s.to_owned(), Sort::BitVector(w)) } + #[test] fn sat_test() { - let cs = Rc::new(RefCell::new(Computation::new(false))); + let cs = Rc::new(RefCell::new(Computation::new())); let mut mem = MemManager::default(); let id0 = mem.zero_allocate(6, 4, 8); let _id1 = mem.zero_allocate(6, 4, 8); - mem.store( - id0, - bv_lit(3, 4), - bv_lit(2, 8), - leaf_term(Op::Const(Value::Bool(true))), - ); + mem.store(id0, bv_lit(3, 4), bv_lit(2, 8), const_(Value::Bool(true))); let a = mem.load(id0, bv_lit(3, 4)); let b = mem.load(id0, bv_lit(1, 4)); let t = term![Op::BvBinPred(BvBinPred::Ugt); a, b]; @@ -163,17 +159,13 @@ mod test { assert!(check_sat(&sys)) } + #[test] fn unsat_test() { - let cs = Rc::new(RefCell::new(Computation::new(false))); + let cs = Rc::new(RefCell::new(Computation::new())); let mut mem = MemManager::default(); let id0 = mem.zero_allocate(6, 4, 8); let _id1 = mem.zero_allocate(6, 4, 8); - mem.store( - id0, - bv_lit(3, 4), - bv_var("a", 8), - leaf_term(Op::Const(Value::Bool(true))), - ); + mem.store(id0, bv_lit(3, 4), bv_var("a", 8), const_(Value::Bool(true))); let a = mem.load(id0, bv_lit(3, 4)); let b = mem.load(id0, bv_lit(3, 4)); let t = term![Op::Not; term![Op::Eq; a, b]]; diff --git a/src/circify/mod.rs b/src/circify/mod.rs index b3e7d9cdb..19d9e09dd 100644 --- a/src/circify/mod.rs +++ b/src/circify/mod.rs @@ -301,7 +301,7 @@ impl<Ty: Display> FnFrame<Ty> { StateEntry::Break(name_, ref mut break_conds) => { if name_ == name { break_conds.push(if break_if.is_empty() { - leaf_term(Op::Const(Value::Bool(true))) + bool_lit(true) } else { term(Op::BoolNaryOp(BoolNaryOp::And), break_if) }); @@ -437,7 +437,7 @@ impl<E: Embeddable> Circify<E> { mem: Rc::new(RefCell::new(mem::MemManager::default())), cs, }, - condition: leaf_term(Op::Const(Value::Bool(true))), + condition: bool_lit(true), typedefs: HashMap::default(), } } @@ -677,7 +677,7 @@ impl<E: Embeddable> Circify<E> { // TODO: more precise conditions, depending on lex scopes. let cs: Vec<_> = self.fn_stack.iter().flat_map(|f| f.conditions()).collect(); if cs.is_empty() { - leaf_term(Op::Const(Value::Bool(true))) + bool_lit(true) } else { term(Op::BoolNaryOp(BoolNaryOp::And), cs) } @@ -935,7 +935,7 @@ mod test { impl Ty { fn default(&self) -> T { match self { - Ty::Bool => T::Base(leaf_term(Op::Const(Value::Bool(false)))), + Ty::Bool => T::Base(bool_lit(false)), Ty::Pair(a, b) => T::Pair(Box::new(a.default()), Box::new(b.default())), } } diff --git a/src/front/c/mod.rs b/src/front/c/mod.rs index 9c54215ee..8f4bb485b 100644 --- a/src/front/c/mod.rs +++ b/src/front/c/mod.rs @@ -54,10 +54,11 @@ impl FrontEnd for C { let main_comp = g.circify().consume().borrow().clone(); cs.comps.insert("main".to_string(), main_comp); while let Some(call_term) = g.function_queue.pop() { - if let Op::Call(name, arg_sorts, rets) = call_term.op() { - g.fn_call(name, arg_sorts, rets); + if let Op::Call(call) = call_term.op() { + let name = call.name.to_string(); + g.fn_call(&name, &call.arg_sorts, &call.ret_sort); let comp = g.circify().consume().borrow().clone(); - cs.comps.insert(name.to_string(), comp); + cs.comps.insert(name, comp); } else { panic!("Non-call term added to function queue."); } @@ -1265,11 +1266,11 @@ impl CGen { assert!(p_sort == arg_sorts[i]); let p_ty = match ¶m.ty { Ty::Ptr(_, t) => { - if let Sort::Array(_, _, len) = p_sort { - let dims = vec![len]; + if let Sort::Array(a) = p_sort { + let dims = vec![a.size]; // Add reference ret_names.push(p_name.clone()); - Ty::Array(len, dims, t.clone()) + Ty::Array(a.size, dims, t.clone()) } else { panic!("Ptr type does not match with Array sort: {}", p_sort) } diff --git a/src/front/c/term.rs b/src/front/c/term.rs index d616816f4..2424593b3 100644 --- a/src/front/c/term.rs +++ b/src/front/c/term.rs @@ -499,8 +499,8 @@ pub fn uge(a: CTerm, b: CTerm) -> Result<CTerm, String> { pub fn const_int(a: CTerm) -> Integer { let s = match &a.term { - CTermData::Int(s, _, i) => match &i.op() { - Op::Const(Value::BitVector(f)) => { + CTermData::Int(s, _, i) => match i.as_value_opt() { + Some(Value::BitVector(f)) => { if *s { f.as_sint() } else { @@ -622,7 +622,7 @@ impl Embeddable for Ct { }; for (i, t) in v.iter().enumerate() { let val = t.term.term(ctx); - let t_term = leaf_term(Op::Const(Value::Bool(true))); + let t_term = bool_lit(true); mem.store(id, bv_lit(i, 32), val, t_term); } arr diff --git a/src/front/c/types.rs b/src/front/c/types.rs index a806a31b9..0694f064e 100644 --- a/src/front/c/types.rs +++ b/src/front/c/types.rs @@ -82,9 +82,7 @@ impl Ty { Self::Void => Sort::Bool, Self::Bool => Sort::Bool, Self::Int(_s, w) => Sort::BitVector(*w), - Self::Array(n, _, b) => { - Sort::Array(Box::new(Sort::BitVector(32)), Box::new(b.sort()), *n) - } + Self::Array(n, _, b) => Sort::new_array(Sort::BitVector(32), b.sort(), *n), Self::Struct(_name, fs) => { Sort::Tuple(fs.fields().map(|(_f_name, f_ty)| f_ty.sort()).collect()) } diff --git a/src/front/datalog/term.rs b/src/front/datalog/term.rs index b7226016c..92509cd10 100644 --- a/src/front/datalog/term.rs +++ b/src/front/datalog/term.rs @@ -33,7 +33,7 @@ impl T { match (ty, ir) { (Ty::Bool, Sort::Bool) | (Ty::Field, Sort::Field(_)) => {} (Ty::Uint(w), Sort::BitVector(w2)) if *w as usize == *w2 => {} - (Ty::Array(l, t), Sort::Array(_, t2, l2)) if l == l2 => Self::check_ty(t2, t), + (Ty::Array(l, t), Sort::Array(a)) if *l == a.size => Self::check_ty(&a.val, t), _ => panic!("IR sort {} doesn't match datalog type {}", ir, ty), } } @@ -63,12 +63,12 @@ pub fn pf_ir_lit<I>(i: I) -> Term where Integer: From<I>, { - leaf_term(Op::Const(Value::Field(cfg().field().new_v(i)))) + const_(Value::Field(cfg().field().new_v(i))) } /// Initialize a boolean literal pub fn bool_lit(b: bool) -> T { - T::new(leaf_term(Op::Const(Value::Bool(b))), Ty::Bool) + T::new(const_(Value::Bool(b)), Ty::Bool) } /// Initialize an unsigned integer literal @@ -82,11 +82,7 @@ impl Ty { Self::Bool => Sort::Bool, Self::Uint(w) => Sort::BitVector(*w as usize), Self::Field => Sort::Field(cfg().field().clone()), - Self::Array(n, b) => Sort::Array( - Box::new(Sort::Field(cfg().field().clone())), - Box::new(b.sort()), - *n, - ), + Self::Array(n, b) => Sort::new_array(Sort::Field(cfg().field().clone()), b.sort(), *n), } } fn default_ir_term(&self) -> Term { @@ -318,7 +314,7 @@ pub fn or(s: &T, t: &T) -> Result<T> { pub fn uint_to_field(s: &T) -> Result<T> { match &s.ty { Ty::Uint(_) => Ok(T::new( - term![Op::UbvToPf(cfg().field().clone()); s.ir.clone()], + term![Op::new_ubv_to_pf(cfg().field().clone()); s.ir.clone()], Ty::Field, )), _ => Err(ErrorKind::InvalidUnOp("to_field".into(), s.clone())), diff --git a/src/front/zsharp/interp.rs b/src/front/zsharp/interp.rs index fd2d0da97..fd5bf70f4 100644 --- a/src/front/zsharp/interp.rs +++ b/src/front/zsharp/interp.rs @@ -21,7 +21,7 @@ pub fn extract( let ir_val = scalar_input_values .remove(name) .ok_or_else(|| format!("Could not find scalar variable {name} in the input map"))?; - Ok(T::new(ty.clone(), leaf_term(Op::Const(ir_val)))) + Ok(T::new(ty.clone(), const_(ir_val))) } Ty::Array(elem_count, elem_ty) => T::new_array( (0..*elem_count) diff --git a/src/front/zsharp/mod.rs b/src/front/zsharp/mod.rs index e2c33fa36..38eb719ee 100644 --- a/src/front/zsharp/mod.rs +++ b/src/front/zsharp/mod.rs @@ -405,7 +405,7 @@ impl<'ast> ZGen<'ast> { let map = term![Op::ExtOp(ExtOp::ArrayToMap); array.term.clone()]; let flip = term![Op::ExtOp(ExtOp::MapFlip); map]; let key = term![Op::ExtOp(ExtOp::MapSelect); flip.clone(), value.term.clone()]; - let key_witness = term![Op::Witness("rlook".into()); key]; + let key_witness = term![Op::new_witness("rlook".into()); key]; if !self.in_witness_gen.get() { let eq_lookup = term![EQ; value.term, term![Op::Select; array.term, key_witness.clone()]]; self.assert(eq_lookup)?; @@ -1384,7 +1384,7 @@ impl<'ast> ZGen<'ast> { )); } let mut e = wit_e; - e.term = term![Op::Witness("wit".into()); e.term]; + e.term = term![Op::new_witness("wit".into()); e.term]; self.declare_init_impl_::<IS_CNST>(d.id.value.clone(), decl_ty, e)?; Ok(()) } diff --git a/src/front/zsharp/term.rs b/src/front/zsharp/term.rs index 59572ab4d..3ac4f3050 100644 --- a/src/front/zsharp/term.rs +++ b/src/front/zsharp/term.rs @@ -68,14 +68,8 @@ impl Ty { Self::Bool => Sort::Bool, Self::Uint(w) => Sort::BitVector(*w), Self::Field => default_field_sort(), - Self::Array(n, b) => { - Sort::Array(Box::new(default_field_sort()), Box::new(b.sort()), *n) - } - Self::MutArray(n) => Sort::Array( - Box::new(default_field_sort()), - Box::new(default_field_sort()), - *n, - ), + Self::Array(n, b) => Sort::new_array(default_field_sort(), b.sort(), *n), + Self::MutArray(n) => Sort::new_array(default_field_sort(), default_field_sort(), *n), Self::Struct(_name, fs) => { Sort::Tuple(fs.fields().map(|(_f_name, f_ty)| f_ty.sort()).collect()) } @@ -127,8 +121,8 @@ impl T { fn terms_tail(term: &Term, output: &mut Vec<Term>) { match check(term) { Sort::Bool | Sort::BitVector(_) | Sort::Field(_) => output.push(term.clone()), - Sort::Array(_k, _v, size) => { - for i in 0..size { + Sort::Array(a) => { + for i in 0..a.size { terms_tail(&term![Op::Select; term.clone(), pf_lit_ir(i)], output) } } @@ -234,7 +228,7 @@ impl T { Op::Const(v) => Ok(v), _ => Err(Error::new(ErrorKind::Other, "not a const val")), }?; - match val { + match &**val { Value::Bool(b) => write!(f, "{b}"), Value::Field(fe) => write!(f, "{}f", fe.i()), Value::BitVector(bv) => match bv.width() { @@ -256,7 +250,7 @@ impl T { write!(f, "{n} {{ ")?; fl.fields().zip(vs.iter()).try_for_each(|((n, ty), v)| { write!(f, "{n}: ")?; - T::new(ty.clone(), leaf_term(Op::Const(v.clone()))).pretty(f)?; + T::new(ty.clone(), const_(v.clone())).pretty(f)?; write!(f, ", ") })?; write!(f, "}}") @@ -277,7 +271,7 @@ impl T { .try_for_each(|idx| { T::new( *inner_ty.clone(), - leaf_term(Op::Const(arr.select(idx.as_value_opt().unwrap()))), + const_(arr.select(idx.as_value_opt().unwrap())), ) .pretty(f)?; write!(f, ", ") @@ -387,11 +381,15 @@ pub fn div(a: T, b: T) -> Result<T, String> { wrap_bin_op("/", Some(div_uint), Some(div_field), None, a, b) } +fn to_dflt_f(t: Term) -> Term { + term![Op::new_ubv_to_pf(default_field()); t] +} + fn rem_field(a: Term, b: Term) -> Term { let len = cfg().field().modulus().significant_bits() as usize; let a_bv = term![Op::PfToBv(len); a]; let b_bv = term![Op::PfToBv(len); b]; - term![Op::UbvToPf(default_field()); term![Op::BvBinOp(BvBinOp::Urem); a_bv, b_bv]] + to_dflt_f(term![Op::BvBinOp(BvBinOp::Urem); a_bv, b_bv]) } fn rem_uint(a: Term, b: Term) -> Term { @@ -600,7 +598,7 @@ pub fn const_bool(a: T) -> Option<bool> { pub fn const_val(a: T) -> Result<T, String> { match const_value(&a.term) { - Some(v) => Ok(T::new(a.ty, leaf_term(Op::Const(v)))), + Some(v) => Ok(T::new(a.ty, const_(v))), _ => Err(format!("{} is not a constant value", &a)), } } @@ -608,7 +606,7 @@ pub fn const_val(a: T) -> Result<T, String> { fn const_value(t: &Term) -> Option<Value> { let folded = constant_fold(t, &[]); match &folded.op() { - Op::Const(v) => Some(v.clone()), + Op::Const(v) => Some((**v).clone()), _ => None, } } @@ -652,7 +650,7 @@ pub fn pf_lit_ir<I>(i: I) -> Term where Integer: From<I>, { - leaf_term(Op::Const(pf_val(i))) + const_(pf_val(i)) } fn pf_val<I>(i: I) -> Value @@ -670,7 +668,7 @@ where } pub fn z_bool_lit(v: bool) -> T { - T::new(Ty::Bool, leaf_term(Op::Const(Value::Bool(v)))) + T::new(Ty::Bool, bool_lit(v)) } pub fn uint_lit<I>(v: I, bits: usize) -> T @@ -737,7 +735,7 @@ pub fn field_store(struct_: T, field: &str, val: T) -> Result<T, String> { fn coerce_to_field(i: T) -> Result<Term, String> { match &i.ty { - Ty::Uint(_) => Ok(term![Op::UbvToPf(default_field()); i.term]), + Ty::Uint(_) => Ok(to_dflt_f(i.term)), Ty::Field => Ok(i.term), _ => Err(format!("Cannot coerce {} to a field element", &i)), } @@ -772,7 +770,7 @@ pub fn array_store(array: T, idx: T, val: T) -> Result<T, String> { if matches!(&array.ty, Ty::Array(_, _)) && matches!(&idx.ty, Ty::Uint(_) | Ty::Field) { // XXX(q) typecheck here? let iterm = if matches!(idx.ty, Ty::Uint(_)) { - term![Op::UbvToPf(default_field()); idx.term] + to_dflt_f(idx.term) } else { idx.term }; @@ -787,13 +785,19 @@ pub fn array_store(array: T, idx: T, val: T) -> Result<T, String> { fn ir_array<I: IntoIterator<Item = Term>>(value_sort: Sort, elems: I) -> Term { let key_sort = Sort::Field(cfg().field().clone()); - term(Op::Array(key_sort, value_sort), elems.into_iter().collect()) + term( + Op::Array(Box::new(ArrayOp { + key: key_sort, + val: value_sort, + })), + elems.into_iter().collect(), + ) } pub fn fill_array(value: T, size: usize) -> Result<T, String> { Ok(T::new( Ty::Array(size, Box::new(value.ty)), - term![Op::Fill(default_field_sort(), size); value.term], + term![Op::new_fill(default_field_sort(), size); value.term], )) } pub fn array<I: IntoIterator<Item = T>>(elems: I) -> Result<T, String> { @@ -816,10 +820,7 @@ pub fn array<I: IntoIterator<Item = T>>(elems: I) -> Result<T, String> { pub fn uint_to_field(u: T) -> Result<T, String> { match &u.ty { - Ty::Uint(_) => Ok(T::new( - Ty::Field, - term![Op::UbvToPf(default_field()); u.term], - )), + Ty::Uint(_) => Ok(T::new(Ty::Field, to_dflt_f(u.term))), u => Err(format!("Cannot do uint-to-field on {u}")), } } @@ -915,7 +916,7 @@ pub fn sample_challenge(a: T, number: usize) -> Result<T, String> { Ok(T::new( Ty::Field, term( - Op::PfChallenge(format!("zx_chall_{number}"), default_field()), + Op::new_chall(format!("zx_chall_{number}"), default_field()), a.unwrap_array_ir()?, ), )) diff --git a/src/front/zsharp/zvisit/zgenericinf.rs b/src/front/zsharp/zvisit/zgenericinf.rs index b6e7ee41c..2c694e5d6 100644 --- a/src/front/zsharp/zvisit/zgenericinf.rs +++ b/src/front/zsharp/zvisit/zgenericinf.rs @@ -2,7 +2,7 @@ use super::super::term::{cond, const_val, Ty, T}; use super::super::{span_to_string, ZGen}; -use crate::ir::term::{bv_lit, leaf_term, term, BoolNaryOp, Op, Sort, Term, Value}; +use crate::ir::term::{bv_lit, const_, term, var, BoolNaryOp, Op, Sort, Term, Value}; #[cfg(feature = "smt")] use crate::target::smt::find_unique_model; @@ -184,7 +184,7 @@ impl<'ast, 'gen, const IS_CNST: bool> ZGenericInf<'ast, 'gen, IS_CNST> { g_name.truncate(self.gens[idx].value.len()); g_name.shrink_to_fit(); assert!(res - .insert(g_name, T::new(Ty::Uint(32), term![Op::Const(g_val)])) + .insert(g_name, T::new(Ty::Uint(32), const_(g_val))) .is_none()); } }); @@ -469,5 +469,5 @@ fn make_varname_str(id: &str, sfx: &str) -> String { fn make_varname(id: &str, sfx: &str) -> Term { let tmp = make_varname_str(id, sfx); - term![Op::Var(tmp, Sort::BitVector(32))] + var(tmp, Sort::BitVector(32)) } diff --git a/src/ir/opt/binarize.rs b/src/ir/opt/binarize.rs index 0ff34fbcb..65d8fb9f9 100644 --- a/src/ir/opt/binarize.rs +++ b/src/ir/opt/binarize.rs @@ -61,7 +61,7 @@ mod test { use quickcheck_macros::quickcheck; fn bool(b: bool) -> Term { - leaf_term(Op::Const(Value::Bool(b))) + bool_lit(b) } fn is_binary(t: Term) -> bool { @@ -84,7 +84,7 @@ mod test { #[test] fn simple_bool() { - for o in vec![AND, OR, XOR] { + for o in [AND, OR, XOR] { let t = term![o.clone(); bool(true), term![o.clone(); bool(false), bool(true)]]; let tt = term![o.clone(); bool(true), bool(false), bool(true)]; assert_eq!(t, binarize_nary_ops(tt)); @@ -93,7 +93,7 @@ mod test { #[test] fn simple_bv() { - for o in vec![BV_AND, BV_OR, BV_XOR, BV_ADD, BV_MUL] { + for o in [BV_AND, BV_OR, BV_XOR, BV_ADD, BV_MUL] { let t = term![o.clone(); bv_lit(3,5), term![o.clone(); bv_lit(3,5), bv_lit(3,5)]]; let tt = term![o.clone(); bv_lit(3, 5), bv_lit(3, 5), bv_lit(3, 5)]; assert_eq!(t, binarize_nary_ops(tt)); diff --git a/src/ir/opt/cfold.rs b/src/ir/opt/cfold.rs index 5cf5a94e0..5de56134f 100644 --- a/src/ir/opt/cfold.rs +++ b/src/ir/opt/cfold.rs @@ -9,6 +9,7 @@ use itertools::Itertools; use rug::Integer; use std::cell::RefCell; use std::cmp::Ordering; +use std::num::NonZero; thread_local! { static FOLDS: RefCell<TermCache<TTerm>> = RefCell::new(TermCache::with_capacity(TERM_CACHE_LIMIT)); @@ -31,12 +32,12 @@ pub(in super::super) fn collect() { /// Create a constant boolean fn cbool(b: bool) -> Option<Term> { - Some(leaf_term(Op::Const(Value::Bool(b)))) + Some(bool_lit(b)) } /// Create a constant bit-vector fn cbv(b: BitVector) -> Option<Term> { - Some(leaf_term(Op::Const(Value::BitVector(b)))) + Some(const_(Value::BitVector(b))) } /// Fold away operators over constants. @@ -46,7 +47,7 @@ pub fn fold(node: &Term, ignore: &[Op]) -> Term { // make the cache unbounded during the fold_cache call let old_capacity = cache.cap(); - cache.resize(usize::MAX); + cache.resize(NonZero::new(usize::MAX).unwrap()); let ret = fold_cache(node, &mut cache, ignore); // shrink cache to its max size @@ -94,11 +95,7 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<TTerm>, ignore: &[Op]) -> T let args: Vec<Term> = t.cs().iter().map(&mut c_get).collect(); let const_args: Vec<&Value> = args.iter().filter_map(|a| a.as_value_opt()).collect(); let new_t_opt = if const_args.len() == args.len() && !t.is_var() { - Some(leaf_term(Op::Const(eval_op( - t.op(), - &const_args, - &Default::default(), - )))) + Some(const_(eval_op(t.op(), &const_args, &Default::default()))) } else { let mut get = |i: usize| c_get(&t.cs()[i]); match &t.op() { @@ -226,8 +223,8 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<TTerm>, ignore: &[Op]) -> T assert!(b.uint() < &Integer::from(b.width())); let n = b.uint().to_usize().unwrap(); Some(term![BV_CONCAT; - term![Op::BvExtract(b.width()-n-1, 0); c0], - leaf_term(Op::Const(Value::BitVector(BitVector::zeros(n)))) + term![Op::new_bv_extract(b.width()-n-1, 0); c0], + const_(Value::BitVector(BitVector::zeros(n))) ]) } (Ashr, Some(a), Some(b)) => cbv(a.clone().ashr(b)), @@ -235,14 +232,14 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<TTerm>, ignore: &[Op]) -> T assert!(b.uint() < &Integer::from(b.width())); let n = b.uint().to_usize().unwrap(); Some(term![Op::BvSext(n); - term![Op::BvExtract(b.width()-1, n); c0]]) + term![Op::new_bv_extract(b.width()-1, n); c0]]) } (Lshr, Some(a), Some(b)) => cbv(a.clone().lshr(b)), (Lshr, _, Some(b)) => { assert!(b.uint() < &Integer::from(b.width())); let n = b.uint().to_usize().unwrap(); Some(term![Op::BvUext(n); - term![Op::BvExtract(b.width()-1, n); c0]]) + term![Op::new_bv_extract(b.width()-1, n); c0]]) } _ => None, } @@ -250,7 +247,7 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<TTerm>, ignore: &[Op]) -> T Op::BvNaryOp(o) => Some(o.flatten(t.cs().iter().map(c_get))), Op::BvBinPred(p) => { if let (Some(a), Some(b)) = (get(0).as_bv_opt(), get(1).as_bv_opt()) { - Some(leaf_term(Op::Const(Value::Bool(match p { + Some(bool_lit(match p { BvBinPred::Uge => a.uint() >= b.uint(), BvBinPred::Ugt => a.uint() > b.uint(), BvBinPred::Ule => a.uint() <= b.uint(), @@ -259,16 +256,16 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<TTerm>, ignore: &[Op]) -> T BvBinPred::Sgt => a.as_sint() > b.as_sint(), BvBinPred::Sle => a.as_sint() <= b.as_sint(), BvBinPred::Slt => a.as_sint() < b.as_sint(), - })))) + })) } else { None } } Op::BvUnOp(o) => get(0).as_bv_opt().map(|bv| { - leaf_term(Op::Const(Value::BitVector(match o { + const_(Value::BitVector(match o { BvUnOp::Not => !bv.clone(), BvUnOp::Neg => -bv.clone(), - }))) + })) }), Op::Ite => { let c = get(0); @@ -300,27 +297,25 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<TTerm>, ignore: &[Op]) -> T } Op::PfNaryOp(o) => Some(o.flatten(t.cs().iter().map(c_get))), Op::PfUnOp(o) => get(0).as_pf_opt().map(|pf| { - leaf_term(Op::Const(Value::Field(match o { + pf_lit(match o { PfUnOp::Recip => pf.clone().recip(), PfUnOp::Neg => -pf.clone(), - }))) + }) }), Op::IntNaryOp(o) => Some(o.flatten(t.cs().iter().map(c_get))), Op::IntBinPred(p) => { if let (Some(a), Some(b)) = (get(0).as_bv_opt(), get(1).as_bv_opt()) { - Some(leaf_term(Op::Const(Value::Bool(match p { + Some(bool_lit(match p { IntBinPred::Ge => a >= b, IntBinPred::Gt => a > b, IntBinPred::Le => a <= b, IntBinPred::Lt => a < b, - })))) + })) } else { None } } - Op::UbvToPf(fty) => get(0) - .as_bv_opt() - .map(|bv| leaf_term(Op::Const(Value::Field(fty.new_v(bv.uint()))))), + Op::UbvToPf(fty) => get(0).as_bv_opt().map(|bv| pf_lit(fty.new_v(bv.uint()))), Op::Store => { match ( get(0).as_array_opt(), @@ -329,33 +324,33 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<TTerm>, ignore: &[Op]) -> T ) { (Some(arr), Some(idx), Some(val)) => { let new_arr = arr.clone().store(idx.clone(), val.clone()); - Some(leaf_term(Op::Const(Value::Array(new_arr)))) + Some(const_(Value::Array(new_arr))) } _ => None, } } - Op::Array(k, v) => t + Op::Array(a) => t .cs() .iter() .map(|c| c_get(c).as_value_opt().cloned()) .collect::<Option<_>>() .map(|cs| { - leaf_term(Op::Const(Value::Array(Array::from_vec( - k.clone(), - v.clone(), + const_(Value::Array(Array::from_vec( + a.key.clone(), + a.val.clone(), cs, - )))) + ))) }), - Op::Fill(k, s) => c_get(&t.cs()[0]).as_value_opt().map(|v| { - leaf_term(Op::Const(Value::Array(Array::new( - k.clone(), + Op::Fill(f) => c_get(&t.cs()[0]).as_value_opt().map(|v| { + const_(Value::Array(Array::new( + f.key_sort.clone(), Box::new(v.clone()), Default::default(), - *s, - )))) + f.size, + ))) }), Op::Select => match (get(0).as_array_opt(), get(1).as_value_opt()) { - (Some(arr), Some(idx)) => Some(leaf_term(Op::Const(arr.select(idx)))), + (Some(arr), Some(idx)) => Some(const_(arr.select(idx))), _ => None, }, Op::Tuple => t @@ -363,16 +358,14 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<TTerm>, ignore: &[Op]) -> T .iter() .map(|c| c_get(c).as_value_opt().cloned()) .collect::<Option<_>>() - .map(|v| leaf_term(Op::Const(Value::Tuple(v)))), - Op::Field(n) => get(0) - .as_tuple_opt() - .map(|t| leaf_term(Op::Const(t[*n].clone()))), + .map(|v| const_(Value::Tuple(v))), + Op::Field(n) => get(0).as_tuple_opt().map(|t| const_(t[*n].clone())), Op::Update(n) => match (get(0).as_tuple_opt(), get(1).as_value_opt()) { (Some(t), Some(v)) => { let mut new_vec = Vec::from(t).into_boxed_slice(); assert_eq!(new_vec[*n].sort(), v.sort()); new_vec[*n] = v.clone(); - Some(leaf_term(Op::Const(Value::Tuple(new_vec)))) + Some(const_(Value::Tuple(new_vec))) } _ => None, }, @@ -382,18 +375,15 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<TTerm>, ignore: &[Op]) -> T .map(|c| c_get(c).as_bv_opt().cloned()) .collect::<Option<Vec<_>>>() .and_then(|v| v.into_iter().reduce(BitVector::concat)) - .map(|bv| leaf_term(Op::Const(Value::BitVector(bv)))), - Op::BoolToBv => get(0).as_bool_opt().map(|b| { - leaf_term(Op::Const(Value::BitVector(BitVector::new( - Integer::from(b), - 1, - )))) - }), + .map(|bv| const_(Value::BitVector(bv))), + Op::BoolToBv => get(0) + .as_bool_opt() + .map(|b| const_(Value::BitVector(BitVector::new(Integer::from(b), 1)))), Op::BvUext(w) => get(0).as_bv_opt().map(|b| { - leaf_term(Op::Const(Value::BitVector(BitVector::new( + const_(Value::BitVector(BitVector::new( b.uint().clone(), b.width() + w, - )))) + ))) }), _ => None, } @@ -443,8 +433,8 @@ trait NaryFlat<T: Clone>: Sized { impl NaryFlat<bool> for BoolNaryOp { fn as_const(t: Term) -> Result<bool, Term> { - match t.op() { - Op::Const(Value::Bool(b)) => Ok(*b), + match t.as_value_opt() { + Some(Value::Bool(b)) => Ok(*b), _ => Err(t), } } @@ -452,18 +442,18 @@ impl NaryFlat<bool> for BoolNaryOp { match self { BoolNaryOp::Or => { if consts.iter().any(|b| *b) { - leaf_term(Op::Const(Value::Bool(true))) + bool_lit(true) } else if children.is_empty() { - leaf_term(Op::Const(Value::Bool(false))) + bool_lit(false) } else { safe_nary(OR, children) } } BoolNaryOp::And => { if consts.iter().any(|b| !*b) { - leaf_term(Op::Const(Value::Bool(false))) + bool_lit(false) } else if children.is_empty() { - leaf_term(Op::Const(Value::Bool(true))) + bool_lit(true) } else { safe_nary(AND, children) } @@ -471,7 +461,7 @@ impl NaryFlat<bool> for BoolNaryOp { BoolNaryOp::Xor => { let odd_trues = consts.into_iter().filter(|b| *b).count() % 2 == 1; if children.is_empty() { - leaf_term(Op::Const(Value::Bool(odd_trues))) + bool_lit(odd_trues) } else { let t = safe_nary(XOR, children); if odd_trues { @@ -487,8 +477,8 @@ impl NaryFlat<bool> for BoolNaryOp { impl NaryFlat<BitVector> for BvNaryOp { fn as_const(t: Term) -> Result<BitVector, Term> { - match &t.op() { - Op::Const(Value::BitVector(b)) => Ok(b.clone()), + match t.as_value_opt() { + Some(Value::BitVector(b)) => Ok(b.clone()), _ => Err(t), } } @@ -498,7 +488,7 @@ impl NaryFlat<BitVector> for BvNaryOp { if let Some(c) = consts.pop() { let c = consts.into_iter().fold(c, std::ops::BitOr::bitor); if children.is_empty() { - leaf_term(Op::Const(Value::BitVector(c))) + const_(Value::BitVector(c)) } else if c.uint() == &Integer::from(0) { safe_nary(BV_OR, children) } else { @@ -507,7 +497,7 @@ impl NaryFlat<BitVector> for BvNaryOp { (0..c.width()) .map(|i| { term![Op::BoolToBv; if c.bit(i) { - leaf_term(Op::Const(Value::Bool(true))) + bool_lit(true) } else { safe_nary( OR, @@ -531,7 +521,7 @@ impl NaryFlat<BitVector> for BvNaryOp { if let Some(c) = consts.pop() { let c = consts.into_iter().fold(c, std::ops::BitAnd::bitand); if children.is_empty() { - leaf_term(Op::Const(Value::BitVector(c))) + const_(Value::BitVector(c)) } else { safe_nary( BV_CONCAT, @@ -547,7 +537,7 @@ impl NaryFlat<BitVector> for BvNaryOp { .collect(), ) } else { - leaf_term(Op::Const(Value::Bool(false))) + bool_lit(false) }] }) .rev() @@ -562,7 +552,7 @@ impl NaryFlat<BitVector> for BvNaryOp { if let Some(c) = consts.pop() { let c = consts.into_iter().fold(c, std::ops::BitXor::bitxor); if children.is_empty() { - leaf_term(Op::Const(Value::BitVector(c))) + const_(Value::BitVector(c)) } else { safe_nary( BV_CONCAT, @@ -594,7 +584,7 @@ impl NaryFlat<BitVector> for BvNaryOp { if let Some(c) = consts.pop() { let c = consts.into_iter().fold(c, std::ops::Add::add); if c.uint() != &Integer::from(0) || children.is_empty() { - children.push(leaf_term(Op::Const(Value::BitVector(c)))); + children.push(const_(Value::BitVector(c))); } } safe_nary(BV_ADD, children) @@ -603,10 +593,10 @@ impl NaryFlat<BitVector> for BvNaryOp { if let Some(c) = consts.pop() { let c = consts.into_iter().fold(c, std::ops::Mul::mul); if c.uint() == &Integer::from(0) { - leaf_term(Op::Const(Value::BitVector(c))) + const_(Value::BitVector(c)) } else { if c.uint() != &Integer::from(1) || children.is_empty() { - children.push(leaf_term(Op::Const(Value::BitVector(c)))); + children.push(const_(Value::BitVector(c))); } safe_nary(BV_MUL, children) } @@ -620,8 +610,8 @@ impl NaryFlat<BitVector> for BvNaryOp { impl NaryFlat<FieldV> for PfNaryOp { fn as_const(t: Term) -> Result<FieldV, Term> { - match &t.op() { - Op::Const(Value::Field(b)) => Ok(b.clone()), + match t.as_value_opt() { + Some(Value::Field(b)) => Ok(b.clone()), _ => Err(t), } } @@ -631,7 +621,7 @@ impl NaryFlat<FieldV> for PfNaryOp { if let Some(c) = consts.pop() { let c = consts.into_iter().fold(c, std::ops::Add::add); if !c.is_zero() || children.is_empty() { - children.push(leaf_term(Op::Const(Value::Field(c)))); + children.push(pf_lit(c)); } } safe_nary(PF_ADD, children) @@ -640,10 +630,10 @@ impl NaryFlat<FieldV> for PfNaryOp { if let Some(c) = consts.pop() { let c = consts.into_iter().fold(c, std::ops::Mul::mul); if c.is_zero() || children.is_empty() { - leaf_term(Op::Const(Value::Field(c))) + pf_lit(c) } else { if !c.is_one() { - children.push(leaf_term(Op::Const(Value::Field(c)))); + children.push(pf_lit(c)) } safe_nary(PF_MUL, children) } @@ -657,8 +647,8 @@ impl NaryFlat<FieldV> for PfNaryOp { impl NaryFlat<Integer> for IntNaryOp { fn as_const(t: Term) -> Result<Integer, Term> { - match &t.op() { - Op::Const(Value::Int(b)) => Ok(b.clone()), + match t.as_value_opt() { + Some(Value::Int(b)) => Ok(b.clone()), _ => Err(t), } } @@ -668,7 +658,7 @@ impl NaryFlat<Integer> for IntNaryOp { if let Some(c) = consts.pop() { let c = consts.into_iter().fold(c, std::ops::Add::add); if c != 0u8 || children.is_empty() { - children.push(leaf_term(Op::Const(Value::Int(c)))); + children.push(const_(Value::Int(c))); } } safe_nary(INT_ADD, children) @@ -677,10 +667,10 @@ impl NaryFlat<Integer> for IntNaryOp { if let Some(c) = consts.pop() { let c = consts.into_iter().fold(c, std::ops::Mul::mul); if c == 0u8 || children.is_empty() { - leaf_term(Op::Const(Value::Int(c))) + const_(Value::Int(c)) } else { if c != 1u8 { - children.push(leaf_term(Op::Const(Value::Int(c)))); + children.push(const_(Value::Int(c))); } safe_nary(INT_MUL, children) } @@ -707,11 +697,11 @@ mod test { use quickcheck_macros::quickcheck; fn v_bv(n: &str, w: usize) -> Term { - leaf_term(Op::Var(n.to_owned(), Sort::BitVector(w))) + var(n.to_owned(), Sort::BitVector(w)) } fn bool(b: bool) -> Term { - leaf_term(Op::Const(Value::Bool(b))) + bool_lit(b) } #[quickcheck] diff --git a/src/ir/opt/chall.rs b/src/ir/opt/chall.rs index 13ee29b06..ca0168096 100644 --- a/src/ir/opt/chall.rs +++ b/src/ir/opt/chall.rs @@ -100,8 +100,8 @@ pub fn deskolemize_challenges(comp: &mut Computation) { .max() .unwrap_or(0); let round = match t.op() { - Op::Var(n, _) => { - if let Some(v) = comp.precomputes.outputs().get(n) { + Op::Var(v) => { + if let Some(v) = comp.precomputes.outputs().get(&*v.name) { *min_round .borrow() .get(v) @@ -149,8 +149,8 @@ pub fn deskolemize_challenges(comp: &mut Computation) { for t in terms.into_iter().rev() { let round = match t.op() { Op::PfChallenge(..) => min_round.get(&t).unwrap().checked_sub(1).unwrap(), - Op::Var(name, _) if comp.metadata.is_input_public(name) => 0, - Op::Var(name, _) if comp.metadata.lookup(name).committed => 0, + Op::Var(v) if comp.metadata.is_input_public(&v.name) => 0, + Op::Var(v) if comp.metadata.lookup(&*v.name).committed => 0, _ => parents .get(&t) .unwrap() @@ -180,15 +180,15 @@ pub fn deskolemize_challenges(comp: &mut Computation) { let mut challs = TermMap::default(); for t in comp.terms_postorder() { - if let Op::PfChallenge(name, field) = t.op() { + if let Op::PfChallenge(c) = t.op() { let round = *actual_round.get(&t).unwrap(); - debug!("challenge {name}: round = {round}"); + debug!("challenge {}: round = {round}", c.name); trace!("challenge term {t}"); let md = VariableMetadata { - name: name.clone(), + name: c.name.to_string(), random: true, vis: None, - sort: Sort::Field(field.clone()), + sort: Sort::Field(c.field.clone()), round, ..Default::default() }; diff --git a/src/ir/opt/fits_in_bits_ip.rs b/src/ir/opt/fits_in_bits_ip.rs index 5b079835c..9af04ad8c 100644 --- a/src/ir/opt/fits_in_bits_ip.rs +++ b/src/ir/opt/fits_in_bits_ip.rs @@ -63,13 +63,12 @@ pub fn fits_in_bits_ip(c: &mut Computation) { let bv = term_c![Op::PfToBv(field_bits); t]; let mut pf_summands = Vec::new(); for ii in 0..num_subranges { - let sub_bv = - term_c![Op::BvExtract(k as usize * (ii + 1) - 1, k as usize * ii); &bv]; + let sub_bv = term_c![Op::new_bv_extract(k as usize * (ii + 1) - 1, k as usize * ii); &bv]; let sub_f = c.new_var( &ns.fqn(format!("sub{}", ii)), Sort::Field(field.clone()), Some(super::super::proof::PROVER_ID), - Some(term![Op::UbvToPf(field.clone()); sub_bv]), + Some(term![Op::new_ubv_to_pf(field.clone()); sub_bv]), ); pf_summands.push( term![PF_MUL.clone(); pf_lit(field.new_v(1).pow(k as u64 * ii as u64)), sub_f.clone()], @@ -78,12 +77,12 @@ pub fn fits_in_bits_ip(c: &mut Computation) { } if end_length > 0 { let end_start = num_subranges * k as usize; - let sub_bv = term_c![Op::BvExtract(num_bits - 1, end_start); &bv]; + let sub_bv = term_c![Op::new_bv_extract(num_bits - 1, end_start); &bv]; let sub_f = c.new_var( &ns.fqn("end"), Sort::Field(field.clone()), Some(super::super::proof::PROVER_ID), - Some(term![Op::UbvToPf(field.clone()); sub_bv]), + Some(term![Op::new_ubv_to_pf(field.clone()); sub_bv]), ); pf_summands.push(term![PF_MUL.clone(); pf_lit(field.new_v(1 << end_start)), sub_f.clone()]); new_assertions.push(term![Op::PfFitsInBits(end_length); sub_f]); diff --git a/src/ir/opt/flat.rs b/src/ir/opt/flat.rs index 501a57b27..6a04c4dde 100644 --- a/src/ir/opt/flat.rs +++ b/src/ir/opt/flat.rs @@ -130,7 +130,7 @@ mod test { use quickcheck_macros::quickcheck; fn bool(b: bool) -> Term { - leaf_term(Op::Const(Value::Bool(b))) + bool_lit(b) } fn is_flat(t: Term) -> bool { @@ -155,7 +155,7 @@ mod test { #[test] fn simple_bool() { - for o in vec![AND, OR, XOR] { + for o in [AND, OR, XOR] { let t = term![o.clone(); term![o.clone(); bool(true), bool(false)], bool(true)]; let tt = term![o.clone(); bool(true), bool(false), bool(true)]; assert_eq!(flatten_nary_ops(t), tt); @@ -164,7 +164,7 @@ mod test { #[test] fn simple_bv() { - for o in vec![BV_AND, BV_OR, BV_XOR, BV_ADD, BV_MUL] { + for o in [BV_AND, BV_OR, BV_XOR, BV_ADD, BV_MUL] { let t = term![o.clone(); term![o.clone(); bv_lit(3,5), bv_lit(3,5)], bv_lit(3,5)]; let tt = term![o.clone(); bv_lit(3, 5), bv_lit(3, 5), bv_lit(3, 5)]; assert_eq!(flatten_nary_ops(t), tt); diff --git a/src/ir/opt/inline.rs b/src/ir/opt/inline.rs index 2db124105..550fff83c 100644 --- a/src/ir/opt/inline.rs +++ b/src/ir/opt/inline.rs @@ -100,17 +100,17 @@ impl<'a> Inliner<'a> { /// Will not return `v` which are protected. fn as_fresh_def(&self, t: &Term) -> Option<(Term, Term)> { if &EQ == t.op() { - if let Op::Var(name, _) = &t.cs()[0].op() { + if let Op::Var(v) = &t.cs()[0].op() { if !self.stale_vars.contains(&t.cs()[0]) - && !self.protected.contains(name) + && !self.protected.contains(&*v.name) && does_not_contain(t.cs()[1].clone(), &t.cs()[0]) { return Some((t.cs()[0].clone(), t.cs()[1].clone())); } } - if let Op::Var(name, _) = &t.cs()[1].op() { + if let Op::Var(v) = &t.cs()[1].op() { if !self.stale_vars.contains(&t.cs()[1]) - && !self.protected.contains(name) + && !self.protected.contains(&*v.name) && does_not_contain(t.cs()[0].clone(), &t.cs()[1]) { return Some((t.cs()[1].clone(), t.cs()[0].clone())); @@ -175,13 +175,13 @@ pub fn inline(assertions: &mut Vec<Term>, public_inputs: &FxHashSet<String>) { *assertions = new_assertions; } -#[cfg(all(feature = "smt", feature = "test"))] +#[cfg(all(feature = "smt", test))] mod test { use super::*; use crate::target::smt::{check_sat, find_model}; fn b_var(b: &str) -> Term { - leaf_term(Op::Var(b.to_string(), Sort::Bool)) + var(b.to_string(), Sort::Bool) } fn sub_test(xs: Vec<Term>, n: usize) { diff --git a/src/ir/opt/link.rs b/src/ir/opt/link.rs index d28eb92a0..96feccbe0 100644 --- a/src/ir/opt/link.rs +++ b/src/ir/opt/link.rs @@ -33,7 +33,7 @@ pub fn link_one(callee: &Computation, values: Vec<Term>) -> Term { assert_eq!(names.len(), values.len()); for (name, value) in names.into_iter().zip(values) { let sort = callee.metadata.input_sort(&name).clone(); - substitution_map.insert(leaf_term(Op::Var(name, sort)), value); + substitution_map.insert(var(name, sort), value); } term( Op::Tuple, @@ -51,8 +51,8 @@ impl<'f> Linker<'f> { if !self.cache.contains_key(name) { let mut c = self.cs.get(name).clone(); for t in c.terms_postorder() { - if let Op::Call(callee_name, ..) = &t.op() { - self.link_all(callee_name); + if let Op::Call(c) = &t.op() { + self.link_all(&c.name); } } @@ -73,8 +73,8 @@ impl<'f> RewritePass for Linker<'f> { orig: &Term, rewritten_children: F, ) -> Option<Term> { - if let Op::Call(fn_name, _, _) = &orig.op() { - let callee = self.cache.get(fn_name).expect("missing inlined callee"); + if let Op::Call(c) = &orig.op() { + let callee = self.cache.get(&c.name).expect("missing inlined callee"); let term = link_one(callee, rewritten_children()); Some(term) } else { diff --git a/src/ir/opt/mem/lin.rs b/src/ir/opt/mem/lin.rs index f795feca6..51c332a5f 100644 --- a/src/ir/opt/mem/lin.rs +++ b/src/ir/opt/mem/lin.rs @@ -35,36 +35,38 @@ impl RewritePass for Linearizer { rewritten_children: F, ) -> Option<Term> { match &orig.op() { - Op::Const(v) => Some(leaf_term(Op::Const(arr_val_to_tup(v)))), - Op::Var(name, Sort::Array(..)) => { + Op::Const(v) => Some(const_(arr_val_to_tup(v))), + Op::Var(v) if v.sort.is_array() => { let precomp = extras::array_to_tuple(orig); - let new_name = format!("{name}.tup"); + let new_name = format!("{}.tup", v.name); let new_sort = check(&precomp); computation.extend_precomputation(new_name.clone(), precomp); - Some(leaf_term(Op::Var(new_name, new_sort))) + Some(var(new_name, new_sort)) } Op::Array(..) => Some(term(Op::Tuple, rewritten_children())), - Op::Fill(_, size) => Some(term( + Op::Fill(f) => Some(term( Op::Tuple, - vec![rewritten_children().pop().unwrap(); *size], + vec![rewritten_children().pop().unwrap(); f.size], )), Op::Select => { let cs = rewritten_children(); let idx = &cs[1]; let tup = &cs[0]; - if let Sort::Array(key_sort, val_sort, sz) = check(&orig.cs()[0]) { - assert!(sz > 0); + if let Sort::Array(a) = check(&orig.cs()[0]) { + assert!(a.size > 0); if idx.is_const() { Some( extras::as_uint_constant(idx) .and_then(|cidx| cidx.to_usize()) - .and_then(|u| (u < sz).then_some(term![Op::Field(u); tup.clone()])) - .unwrap_or_else(|| val_sort.default_term()), + .and_then(|u| { + (u < a.size).then_some(term![Op::Field(u); tup.clone()]) + }) + .unwrap_or_else(|| a.val.default_term()), ) } else { - let mut fields = (0..sz).map(|idx| term![Op::Field(idx); tup.clone()]); + let mut fields = (0..a.size).map(|idx| term![Op::Field(idx); tup.clone()]); let first = fields.next().unwrap(); - Some(key_sort.elems_iter().take(sz).skip(1).zip(fields).fold(first, |acc, (idx_c, field)| { + Some(a.key.elems_iter().take(a.size).skip(1).zip(fields).fold(first, |acc, (idx_c, field)| { term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], field, acc] })) } @@ -77,23 +79,23 @@ impl RewritePass for Linearizer { let tup = &cs[0]; let idx = &cs[1]; let val = &cs[2]; - if let Sort::Array(key_sort, _, sz) = check(&orig.cs()[0]) { - assert!(sz > 0); + if let Sort::Array(a) = check(&orig.cs()[0]) { + assert!(a.size > 0); if idx.is_const() { Some( extras::as_uint_constant(idx) .and_then(|cidx| cidx.to_usize()) .and_then(|u| { - (u < sz) + (u < a.size) .then_some(term![Op::Update(u); tup.clone(), val.clone()]) }) .unwrap_or_else(|| tup.clone()), ) } else { let mut updates = - (0..sz).map(|idx| term![Op::Update(idx); tup.clone(), val.clone()]); + (0..a.size).map(|idx| term![Op::Update(idx); tup.clone(), val.clone()]); let first = updates.next().unwrap(); - Some(key_sort.elems_iter().take(sz).skip(1).zip(updates).fold(first, |acc, (idx_c, update)| { + Some(a.key.elems_iter().take(a.size).skip(1).zip(updates).fold(first, |acc, (idx_c, update)| { term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], update, acc] })) } @@ -107,23 +109,23 @@ impl RewritePass for Linearizer { let idx = &cs[1]; let val = &cs[2]; let cond = &cs[3]; - if let Sort::Array(key_sort, _, sz) = check(&orig.cs()[0]) { - assert!(sz > 0); + if let Sort::Array(a) = check(&orig.cs()[0]) { + assert!(a.size > 0); if idx.is_const() { Some( extras::as_uint_constant(idx) .and_then(|cidx| cidx.to_usize()) .and_then(|u| { - (u < sz) + (u < a.size) .then_some(term![Op::Ite; cond.clone(), term![Op::Update(u); tup.clone(), val.clone()], tup.clone()]) }) .unwrap_or_else(|| tup.clone()), ) } else { let mut updates = - (0..sz).map(|idx| term![Op::Update(idx); tup.clone(), val.clone()]); + (0..a.size).map(|idx| term![Op::Update(idx); tup.clone(), val.clone()]); let first = updates.next().unwrap(); - Some(key_sort.elems_iter().take(sz).skip(1).zip(updates).fold(first, |acc, (idx_c, update)| { + Some(a.key.elems_iter().take(a.size).skip(1).zip(updates).fold(first, |acc, (idx_c, update)| { term![Op::Ite; term![AND; term![Op::Eq; idx.clone(), idx_c], cond.clone()], update, acc] })) } diff --git a/src/ir/opt/mem/obliv.rs b/src/ir/opt/mem/obliv.rs index bd5a961e0..649f27bae 100644 --- a/src/ir/opt/mem/obliv.rs +++ b/src/ir/opt/mem/obliv.rs @@ -44,25 +44,32 @@ impl OblivRewriter { } fn visit(&mut self, t: &Term) { let (tup_opt, term_opt) = match t.op() { - Op::Var(_, sort) if sort.is_scalar() => (Some(t.clone()), None), - Op::Const(v @ Value::Array(a)) => { - if a.size <= OBLIV_SIZE_THRESH { - (Some(leaf_term(Op::Const(arr_val_to_tup(v)))), None) + Op::Var(v) if v.sort.is_scalar() => (Some(t.clone()), None), + Op::Const(v) => { + if let Value::Array(a) = &**v { + if a.size <= OBLIV_SIZE_THRESH { + (Some(const_(arr_val_to_tup(v))), None) + } else { + (None, None) + } } else { (None, None) } } - Op::Array(_k, _v) => ( + Op::Array(_) => ( Some(term( Op::Tuple, t.cs().iter().map(|c| self.get_t(c)).cloned().collect(), )), None, ), - Op::Fill(_k, size) => { - if *size < OBLIV_SIZE_THRESH { + Op::Fill(f) => { + if f.size < OBLIV_SIZE_THRESH { ( - Some(term(Op::Tuple, vec![self.get_t(&t.cs()[0]).clone(); *size])), + Some(term( + Op::Tuple, + vec![self.get_t(&t.cs()[0]).clone(); f.size], + )), None, ) } else { @@ -240,7 +247,7 @@ mod test { use super::*; fn v_bv(n: &str, w: usize) -> Term { - leaf_term(Op::Var(n.to_owned(), Sort::BitVector(w))) + var(n.to_owned(), Sort::BitVector(w)) } fn array_free(t: &Term) -> bool { @@ -260,15 +267,15 @@ mod test { #[test] fn obliv() { - let z = term![Op::Const(Value::Array(Array::new( + let z = const_(Value::Array(Array::new( Sort::BitVector(4), Box::new(Sort::BitVector(4).default_value()), Default::default(), - 6 - )))]; + 6, + ))); let t = term![Op::Select; term![Op::Ite; - leaf_term(Op::Const(Value::Bool(true))), + bool_lit(true), term![Op::Store; z.clone(), bv_lit(3, 4), bv_lit(1, 4)], term![Op::Store; z, bv_lit(2, 4), bv_lit(1, 4)] ], @@ -282,15 +289,15 @@ mod test { #[test] fn not_obliv() { - let z = term![Op::Const(Value::Array(Array::new( + let z = const_(Value::Array(Array::new( Sort::BitVector(4), Box::new(Sort::BitVector(4).default_value()), Default::default(), - 6 - )))]; + 6, + ))); let t = term![Op::Select; term![Op::Ite; - leaf_term(Op::Const(Value::Bool(true))), + bool_lit(true), term![Op::Store; z.clone(), v_bv("a", 4), bv_lit(1, 4)], term![Op::Store; z, bv_lit(2, 4), bv_lit(1, 4)] ], @@ -304,18 +311,18 @@ mod test { #[test] fn mix_diff_constant() { - let z0 = term![Op::Const(Value::Array(Array::new( + let z0 = const_(Value::Array(Array::new( Sort::BitVector(4), Box::new(Sort::BitVector(4).default_value()), Default::default(), - 6 - )))]; - let z1 = term![Op::Const(Value::Array(Array::new( + 6, + ))); + let z1 = const_(Value::Array(Array::new( Sort::BitVector(4), Box::new(Sort::BitVector(4).default_value()), Default::default(), - 5 - )))]; + 5, + ))); let t0 = term![Op::Select; term![Op::Store; z0, v_bv("a", 4), bv_lit(1, 4)], bv_lit(3, 4) @@ -334,12 +341,12 @@ mod test { #[test] fn mix_same_constant() { - let z = term![Op::Const(Value::Array(Array::new( + let z = const_(Value::Array(Array::new( Sort::BitVector(4), Box::new(Sort::BitVector(4).default_value()), Default::default(), - 6 - )))]; + 6, + ))); let t0 = term![Op::Select; term![Op::Store; z.clone(), v_bv("a", 4), bv_lit(1, 4)], bv_lit(3, 4) diff --git a/src/ir/opt/mem/ram.rs b/src/ir/opt/mem/ram.rs index 420c5770a..3e6485a6c 100644 --- a/src/ir/opt/mem/ram.rs +++ b/src/ir/opt/mem/ram.rs @@ -121,7 +121,7 @@ impl AccessCfg { fn val_sort_len(s: &Sort) -> usize { match s { Sort::Tuple(t) => t.iter().map(Self::val_sort_len).sum(), - Sort::Array(_, v, size) => *size * Self::val_sort_len(v), + Sort::Array(a) => a.size * Self::val_sort_len(&a.val), _ => 1, } } @@ -155,7 +155,7 @@ fn scalar_to_field(scalar: &Term, c: &AccessCfg) -> Term { } } Sort::Bool => c.bool2pf(scalar.clone()), - Sort::BitVector(_) => term![Op::UbvToPf(c.field.clone()); scalar.clone()], + Sort::BitVector(_) => term![Op::new_ubv_to_pf(c.field.clone()); scalar.clone()], s => panic!("non-scalar sort {}", s), } } @@ -274,9 +274,9 @@ impl Access { Self::sort_subnames(s, &format!("{}_{}", prefix, i), out); } } - Sort::Array(_, v, size) => { - for i in 0..*size { - Self::sort_subnames(v, &format!("{}_{}", prefix, i), out); + Sort::Array(a) => { + for i in 0..a.size { + Self::sort_subnames(&a.val, &format!("{}_{}", prefix, i), out); } } _ => unreachable!(), @@ -292,8 +292,8 @@ impl Access { Self::val_to_field_elements(&term![Op::Field(i); val.clone()], c, out); } } - Sort::Array(_, _, size) => { - for i in 0..size { + Sort::Array(a) => { + for i in 0..a.size { Self::val_to_field_elements( &term![Op::Select; val.clone(), c.pf_lit(i)], c, @@ -316,10 +316,13 @@ impl Access { .map(|s| Self::val_from_field_elements_trusted(s, next)) .collect(), ), - Sort::Array(k, v, size) => term( - Op::Array(*k.clone(), *v.clone()), - (0..*size) - .map(|_| Self::val_from_field_elements_trusted(v, next)) + Sort::Array(a) => term( + Op::Array(Box::new(ArrayOp { + key: a.key.clone(), + val: a.val.clone(), + })), + (0..a.size) + .map(|_| Self::val_from_field_elements_trusted(&a.val, next)) .collect(), ), _ => unreachable!(), @@ -509,7 +512,7 @@ fn hashable(s: &Sort, f: &FieldT) -> bool { Sort::Tuple(ss) => ss.iter().all(|s| hashable(s, f)), Sort::BitVector(_) => true, Sort::Bool => true, - Sort::Array(_k, v, size) => *size < 20 && hashable(v, f), + Sort::Array(a) => a.size < 20 && hashable(&a.val, f), _ => false, } } diff --git a/src/ir/opt/mem/ram/checker.rs b/src/ir/opt/mem/ram/checker.rs index 43ce8b054..c74f5b162 100644 --- a/src/ir/opt/mem/ram/checker.rs +++ b/src/ir/opt/mem/ram/checker.rs @@ -302,7 +302,10 @@ fn derivative_gcd( let ns = ns.subspace("uniq"); let fs = Sort::Field(f.clone()); let pairs = term( - Op::Array(fs.clone(), Sort::Tuple(Box::new([fs.clone(), Sort::Bool]))), + Op::Array(Box::new(ArrayOp { + key: fs.clone(), + val: Sort::Tuple(Box::new([fs.clone(), Sort::Bool])), + })), values .clone() .into_iter() @@ -337,7 +340,7 @@ fn derivative_gcd( terms_that_define_all_polys.extend(t_coeffs_skolem.iter().cloned()); let n = values.len(); let x = term( - Op::PfChallenge(ns.fqn("x"), f.clone()), + Op::new_chall(ns.fqn("x"), f.clone()), terms_that_define_all_polys, ); let r = values; diff --git a/src/ir/opt/mem/ram/checker/rom.rs b/src/ir/opt/mem/ram/checker/rom.rs index c6da30ac6..b42d3011e 100644 --- a/src/ir/opt/mem/ram/checker/rom.rs +++ b/src/ir/opt/mem/ram/checker/rom.rs @@ -44,7 +44,7 @@ pub fn lookup(c: &mut Computation, ns: Namespace, haystack: Vec<Term>, needles: }) .collect(); let key = term( - Op::PfChallenge(ns.fqn("key"), f.clone()), + Op::new_chall(ns.fqn("key"), f.clone()), haystack .iter() .chain(&needles) diff --git a/src/ir/opt/mem/ram/hash.rs b/src/ir/opt/mem/ram/hash.rs index 0b33c7118..40faa8b2f 100644 --- a/src/ir/opt/mem/ram/hash.rs +++ b/src/ir/opt/mem/ram/hash.rs @@ -17,7 +17,7 @@ impl MsHasher { pub fn new(key_name: String, f: &FieldT, inputs: Vec<Term>) -> Self { Self { f: f.clone(), - key: term(Op::PfChallenge(key_name, f.clone()), inputs), + key: term(Op::new_chall(key_name, f.clone()), inputs), } } /// Hash some `data`, as a multi-set. @@ -43,7 +43,7 @@ impl UniversalHasher { /// * `f` is the field used. /// * `len` is the data length. pub fn new(key_name: String, f: &FieldT, inputs: Vec<Term>, len: usize) -> Self { - let key = term(Op::PfChallenge(key_name, f.clone()), inputs); + let key = term(Op::new_chall(key_name, f.clone()), inputs); let key_powers: Vec<Term> = std::iter::successors(Some(key.clone()), |p| { Some(term![PF_MUL; p.clone(), key.clone()]) }) diff --git a/src/ir/opt/mem/ram/persistent.rs b/src/ir/opt/mem/ram/persistent.rs index adc28a45e..421de6041 100644 --- a/src/ir/opt/mem/ram/persistent.rs +++ b/src/ir/opt/mem/ram/persistent.rs @@ -26,7 +26,7 @@ pub fn persistent_to_ram(c: &mut Computation, cfg: &AccessCfg) -> Vec<Ram> { let key_sort = sort.as_array().0.clone(); let value_sort = sort.as_array().1.clone(); let size = sort.as_array().2; - let init_term = leaf_term(Op::Var(name.clone(), sort)); + let init_term = var(name.clone(), sort); // create a new var for each initial value let names: Vec<String> = (0..size).map(|i| format!("{name}.init.{i}")).collect(); @@ -108,7 +108,7 @@ pub fn check_ram(c: &mut Computation, mut ram: Ram, cfg: &AccessCfg) { let mut uhf_inputs = inital_terms.clone(); uhf_inputs.extend(final_terms.iter().cloned()); let uhf_key = term( - Op::PfChallenge(format!("__uhf_key.{j}"), field.clone()), + Op::new_chall(format!("__uhf_key.{j}"), field.clone()), uhf_inputs, ); let uhf = |idx: Term, val: Term| term![PF_ADD; val, term![PF_MUL; uhf_key.clone(), idx]]; diff --git a/src/ir/opt/mem/ram/set.rs b/src/ir/opt/mem/ram/set.rs index 69911d2ae..b744fc1a2 100644 --- a/src/ir/opt/mem/ram/set.rs +++ b/src/ir/opt/mem/ram/set.rs @@ -43,8 +43,7 @@ pub fn apply(c: &mut Computation) { .map .keys() .cloned() - .map(Op::Const) - .map(leaf_term) + .map(const_) .collect(); to_assert.push(super::checker::rom::lookup( c, diff --git a/src/ir/opt/mem/ram/volatile.rs b/src/ir/opt/mem/ram/volatile.rs index 8f496ad4e..1df3db202 100644 --- a/src/ir/opt/mem/ram/volatile.rs +++ b/src/ir/opt/mem/ram/volatile.rs @@ -41,7 +41,7 @@ fn hashable(s: &Sort, f: &FieldT) -> bool { Sort::Tuple(ss) => ss.iter().all(|s| hashable(s, f)), Sort::BitVector(_) => true, Sort::Bool => true, - Sort::Array(_k, v, size) => *size < 20 && hashable(v, f), + Sort::Array(a) => a.size < 20 && hashable(&a.val, f), _ => false, } } @@ -49,9 +49,9 @@ fn hashable(s: &Sort, f: &FieldT) -> bool { /// Does this array have a sort compatible with our RAM machinery? fn right_sort(t: &Term, f: &FieldT) -> bool { let s = check(t); - if let Sort::Array(k, v, _) = &s { - if let Sort::Field(k) = &**k { - k == f && hashable(v, f) + if let Sort::Array(a) = &s { + if let Sort::Field(k) = &a.key { + k == f && hashable(&a.val, f) } else { false } @@ -197,14 +197,10 @@ impl Extactor { let value = &t.cs()[0]; ram.boundary_conditions = BoundaryConditions::Default(value.clone()); } - Op::Const(Value::Array(a)) => { + Op::Const(v) if v.is_array() => { // for a constant: add (constant) writes - for (k, v) in &a.map { - ram.new_write( - leaf_term(Op::Const(k.clone())), - leaf_term(Op::Const(v.clone())), - self.cfg.true_.clone(), - ); + for (k, v) in &v.as_array().map { + ram.new_write(const_(k.clone()), const_(v.clone()), self.cfg.true_.clone()); } } Op::Array(..) => { @@ -538,7 +534,7 @@ mod test { let field = FieldT::from(rug::Integer::from(11)); let rams = extract(&mut cs2, AccessCfg::default_from_field(field.clone())); extras::assert_all_vars_declared(&cs2); - let a = leaf_term(Op::Var("a".to_string(), Sort::Bool)); + let a = var("a".to_string(), Sort::Bool); assert_ne!(cs, cs2); assert_eq!(1, rams.len()); assert_eq!(3, rams[0].accesses.len()); @@ -557,7 +553,7 @@ mod test { #[test] fn mix_store_chain() { - let a = leaf_term(Op::Var("a".to_string(), Sort::Bool)); + let a = var("a".to_string(), Sort::Bool); let cs = text::parse_computation( b" (computation diff --git a/src/ir/opt/mod.rs b/src/ir/opt/mod.rs index 031829f30..0f0a98686 100644 --- a/src/ir/opt/mod.rs +++ b/src/ir/opt/mod.rs @@ -13,6 +13,8 @@ pub mod sha; pub mod tuple; mod visit; +use std::num::NonZero; + use super::term::*; use log::{debug, info, trace}; @@ -85,7 +87,7 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computations, optimizations: I) } Opt::ConstantFold(ignore) => { let mut cache = TermCache::with_capacity(TERM_CACHE_LIMIT); - cache.resize(usize::MAX); + cache.resize(NonZero::new(usize::MAX).unwrap()); for a in &mut c.outputs { *a = cfold::fold_cache(a, &mut cache, &ignore.clone()); } diff --git a/src/ir/opt/scalarize_vars.rs b/src/ir/opt/scalarize_vars.rs index 2d070e932..7083eec44 100644 --- a/src/ir/opt/scalarize_vars.rs +++ b/src/ir/opt/scalarize_vars.rs @@ -32,18 +32,18 @@ fn create_vars( }) .collect(), ), - Sort::Array(key_s, val_s, size) => { + Sort::Array(a) => { let array_elements = extras::array_elements(&prefix_term); make_array( - (**key_s).clone(), - (**val_s).clone(), - (0..*size) + a.key.clone(), + a.val.clone(), + (0..a.size) .zip(array_elements) .map(|(i, element)| { create_vars( &format!("{prefix}.{i}"), element, - val_s, + &a.val, new_var_requests, false, ) @@ -57,7 +57,7 @@ fn create_vars( trace!("New scalar var: {}", prefix); new_var_requests.push((prefix.into(), prefix_term)); } - leaf_term(Op::Var(prefix.into(), sort.clone())) + var(prefix.into(), sort.clone()) } } } @@ -78,18 +78,18 @@ fn create_wits(prefix: &str, prefix_term: Term, sort: &Sort) -> Term { }) .collect(), ), - Sort::Array(key_s, val_s, size) => { + Sort::Array(a) => { let array_elements = extras::array_elements(&prefix_term); make_array( - (**key_s).clone(), - (**val_s).clone(), - (0..*size) + a.key.clone(), + a.val.clone(), + (0..a.size) .zip(array_elements) - .map(|(i, element)| create_wits(&format!("{prefix}.{i}"), element, val_s)) + .map(|(i, element)| create_wits(&format!("{prefix}.{i}"), element, &a.val)) .collect(), ) } - _ => term![Op::Witness(prefix.to_owned()); prefix_term], + _ => term![Op::new_witness(prefix.into()); prefix_term], } } @@ -100,11 +100,11 @@ impl RewritePass for Pass { orig: &Term, rewritten_children: F, ) -> Option<Term> { - if let Op::Var(name, sort) = &orig.op() { - trace!("Considering var: {}", name); - if !computation.metadata.lookup(name).committed { + if let Op::Var(v) = &orig.op() { + trace!("Considering var: {}", v.name); + if !computation.metadata.lookup(&*v.name).committed { let mut new_var_reqs = Vec::new(); - let new = create_vars(name, orig.clone(), sort, &mut new_var_reqs, true); + let new = create_vars(&v.name, orig.clone(), &v.sort, &mut new_var_reqs, true); for (name, term) in new_var_reqs { computation.extend_precomputation(name, term); } @@ -139,9 +139,9 @@ pub fn scalarize_inputs(cs: &mut Computation) { /// Check that every variables is a scalar (or committed) pub fn assert_all_vars_are_scalars(cs: &Computation) { for t in cs.terms_postorder() { - if let Op::Var(name, sort) = &t.op() { - if !cs.metadata.lookup(name).committed { - match sort { + if let Op::Var(v) = &t.op() { + if !cs.metadata.lookup(&*v.name).committed { + match &v.sort { Sort::Array(..) | Sort::Tuple(..) => { panic!("Variable {} is non-scalar", t); } diff --git a/src/ir/opt/tuple.rs b/src/ir/opt/tuple.rs index 9849fc198..82f507a9a 100644 --- a/src/ir/opt/tuple.rs +++ b/src/ir/opt/tuple.rs @@ -60,7 +60,7 @@ //! fast vector type, instead of standard terms. This allows for log-time updates. use crate::ir::term::{ - bv_lit, check, leaf_term, term, Array, Computation, Node, Op, PostOrderIter, Sort, Term, + bv_lit, check, const_, term, Array, ArrayOp, Computation, Node, Op, PostOrderIter, Sort, Term, TermMap, Value, AND, }; use std::collections::BTreeMap; @@ -122,7 +122,7 @@ impl TupleTree { TupleTree::NonTuple(cs) => { if let Sort::Tuple(_) = check(cs) { TupleTree::NonTuple(term![Op::Field(i); cs.clone()]) - } else if let Sort::Array(_, _, _) = check(cs) { + } else if let Sort::Array(_) = check(cs) { TupleTree::NonTuple(term![Op::Select; cs.clone(), bv_lit(i, 32)]) } else { panic!("Get ({}) on non-tuple {:?}", i, self) @@ -196,7 +196,7 @@ fn termify_val_tuples(v: Value) -> TupleTree { if let Value::Tuple(vs) = v { TupleTree::Tuple(Vec::from(vs).into_iter().map(termify_val_tuples).collect()) } else { - TupleTree::NonTuple(leaf_term(Op::Const(v))) + TupleTree::NonTuple(const_(v)) } } @@ -277,14 +277,20 @@ pub fn eliminate_tuples(cs: &mut Computation) { debug_assert!(cs.is_empty()); a.bimap(|a, v| term![Op::Store; a, i.clone(), v], &v) } - Op::Array(k, _v) => TupleTree::transpose_map(cs, |children| { + Op::Array(a) => TupleTree::transpose_map(cs, |children| { assert!(!children.is_empty()); let v_s = check(&children[0]); - term(Op::Array(k.clone(), v_s), children) + term( + Op::Array(Box::new(ArrayOp { + key: a.key.clone(), + val: v_s, + })), + children, + ) }), - Op::Fill(key_sort, size) => { + Op::Fill(_) => { let values = cs.pop().unwrap(); - values.map(|v| term![Op::Fill(key_sort.clone(), *size); v]) + values.map(|v| term![t.op().clone(); v]) } Op::Select => { let i = cs.pop().unwrap().unwrap_non_tuple(); diff --git a/src/ir/proof.rs b/src/ir/proof.rs index a83c1c7bb..ba3ed4487 100644 --- a/src/ir/proof.rs +++ b/src/ir/proof.rs @@ -48,8 +48,8 @@ impl Constraints for Computation { let public_inputs_set: FxHashSet<String> = public_inputs .iter() .filter_map(|t| { - if let Op::Var(n, _) = &t.op() { - Some(n.clone()) + if let Op::Var(v) = &t.op() { + Some(v.name.to_string()) } else { None } @@ -57,16 +57,16 @@ impl Constraints for Computation { .collect(); for v in public_inputs { - if let Op::Var(n, s) = &v.op() { - metadata.new_input(n.to_owned(), None, s.clone()); + if let Op::Var(var) = &v.op() { + metadata.new_input(var.name.to_string(), None, var.sort.clone()); } else { panic!() } } for v in all_vars { - if let Op::Var(n, s) = &v.op() { - if !public_inputs_set.contains(n) { - metadata.new_input(n.to_owned(), Some(PROVER_ID), s.clone()); + if let Op::Var(var) = &v.op() { + if !public_inputs_set.contains(&*var.name) { + metadata.new_input(var.name.to_string(), Some(PROVER_ID), var.sort.clone()); } } else { panic!() diff --git a/src/ir/term/bv.rs b/src/ir/term/bv.rs index d3572f1e1..91c358038 100644 --- a/src/ir/term/bv.rs +++ b/src/ir/term/bv.rs @@ -180,10 +180,10 @@ impl BitVector { /// Gets the bits from `high` to `low`, inclusive. Zero-indexed. /// /// The number of bits yielded is `high-low+1`. - pub fn extract(self, high: usize, low: usize) -> Self { + pub fn extract(self, high: u32, low: u32) -> Self { let r = BitVector { - uint: (self.uint >> low as u32).keep_bits((high - low + 1) as u32), - width: high - low + 1, + uint: (self.uint >> low).keep_bits(high - low + 1), + width: (high - low + 1) as usize, }; r.check("extract"); r diff --git a/src/ir/term/dist.rs b/src/ir/term/dist.rs index b4d839371..c271a8ddc 100644 --- a/src/ir/term/dist.rs +++ b/src/ir/term/dist.rs @@ -37,8 +37,8 @@ impl rand::distributions::Distribution<Vec<usize>> for Sum { impl rand::distributions::Distribution<Term> for PureBoolDist { fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Term { let ops = &[ - Op::Const(Value::Bool(rng.gen())), - Op::Var( + Op::new_const(Value::Bool(rng.gen())), + Op::new_var( std::str::from_utf8(&[b'a' + rng.gen_range(0..26)]) .unwrap() .to_owned(), @@ -91,14 +91,14 @@ impl FixedSizeDist { format!("{}_{}", prefix, (b'a' + rng.gen_range(0..26)) as char) } fn sample_value<R: Rng + ?Sized>(&self, sort: &Sort, rng: &mut R) -> Op { - Op::Const(UniformValue(sort).sample(rng)) + Op::new_const(UniformValue(sort).sample(rng)) } fn sample_op<R: Rng + ?Sized>(&self, sort: &Sort, rng: &mut R) -> Op { let mut ops = match sort { Sort::Bool => { let mut ops = vec![ self.sample_value(sort, rng), - Op::Var(self.sample_ident("b_", rng), sort.clone()), + Op::new_var(self.sample_ident("b_", rng), sort.clone()), Op::Not, // 2 Op::Implies, Op::Eq, @@ -120,7 +120,7 @@ impl FixedSizeDist { } Sort::BitVector(w) => vec![ self.sample_value(sort, rng), - Op::Var(self.sample_ident(&format!("bv{w}"), rng), sort.clone()), + Op::new_var(self.sample_ident(&format!("bv{w}"), rng), sort.clone()), Op::BvUnOp(BvUnOp::Neg), Op::BvUnOp(BvUnOp::Not), Op::BvUext(rng.gen_range(0..*w)), @@ -137,7 +137,7 @@ impl FixedSizeDist { Sort::Field(_) => { vec![ self.sample_value(sort, rng), - Op::Var(self.sample_ident("pf", rng), sort.clone()), + Op::new_var(self.sample_ident("pf", rng), sort.clone()), Op::PfUnOp(PfUnOp::Neg), // Can error // Op::PfUnOp(PfUnOp::Recip), @@ -150,7 +150,7 @@ impl FixedSizeDist { Op::Tuple, self.sample_value(sort, rng), // No variables! - Op::Var( + Op::new_var( self.sample_ident( &format!("tp_{sort}") .replace('(', "[") @@ -330,8 +330,8 @@ pub mod test { let t = PureBoolDist(g.size()).sample(&mut rng); let values: FxHashMap<String, Value> = PostOrderIter::new(t.clone()) .filter_map(|c| { - if let Op::Var(n, _) = &c.op() { - Some((n.clone(), Value::Bool(bool::arbitrary(g)))) + if let Op::Var(v) = &c.op() { + Some((v.name.to_string(), Value::Bool(bool::arbitrary(g)))) } else { None } @@ -395,7 +395,9 @@ pub mod test { let t = d.sample(&mut rng); let values: HashMap<String, Value> = PostOrderIter::new(t.clone()) .filter_map(|c| match &c.op() { - Op::Var(n, Sort::Bool) => Some((n.clone(), Value::Bool(bool::arbitrary(g)))), + Op::Var(var) if matches!(&var.sort, Sort::Bool) => { + Some((var.name.to_string(), Value::Bool(bool::arbitrary(g)))) + } _ => None, }) .collect(); @@ -436,7 +438,9 @@ pub mod test { let t = d.sample(&mut rng); let values: HashMap<String, Value> = PostOrderIter::new(t.clone()) .filter_map(|c| match &c.op() { - Op::Var(n, s) => Some((n.clone(), UniformValue(s).sample(&mut rng))), + Op::Var(v) => { + Some((v.name.to_string(), UniformValue(&v.sort).sample(&mut rng))) + } _ => None, }) .collect(); diff --git a/src/ir/term/eval.rs b/src/ir/term/eval.rs index 4eee91eeb..0d8eb22a6 100644 --- a/src/ir/term/eval.rs +++ b/src/ir/term/eval.rs @@ -1,7 +1,7 @@ //! IR Evaluation use super::{ - check, extras, leaf_term, term, Array, BitVector, BoolNaryOp, BvBinOp, BvBinPred, BvNaryOp, + check, const_, extras, term, Array, BitVector, BoolNaryOp, BvBinOp, BvBinPred, BvNaryOp, BvUnOp, FieldToBv, FxHashMap, IntBinPred, IntNaryOp, Integer, Node, Op, PfNaryOp, PfUnOp, Sort, Term, TermMap, Value, }; @@ -67,9 +67,9 @@ fn eval_value(vs: &mut TermMap<Value>, h: &FxHashMap<String, Value>, t: Term) -> #[allow(clippy::uninlined_format_args)] pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) -> Value { match op { - Op::Var(n, _) => var_vals - .get(n) - .unwrap_or_else(|| panic!("Missing var: {} in {:?}", n, var_vals)) + Op::Var(var) => var_vals + .get(&*var.name) + .unwrap_or_else(|| panic!("Missing var: {} in {:?}", var.name, var_vals)) .clone(), Op::Eq => Value::Bool(args[0] == args[1]), Op::Not => Value::Bool(!args[0].as_bool()), @@ -94,7 +94,7 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) -> it.fold(f, BitVector::concat) }), Op::BvExtract(h, l) => Value::BitVector(args[0].as_bv().clone().extract(*h, *l)), - Op::Const(v) => v.clone(), + Op::Const(v) => (**v).clone(), Op::BvBinOp(o) => Value::BitVector({ let a = args[0].as_bv().clone(); let b = args[1].as_bv().clone(); @@ -216,7 +216,7 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) -> ) }), Op::UbvToPf(fty) => Value::Field(fty.new_v(args[0].as_bv().uint())), - Op::PfChallenge(name, field) => Value::Field(pf_challenge(name, field)), + Op::PfChallenge(c) => Value::Field(eval_pf_challenge(&c.name, &c.field)), Op::Witness(_) => args[0].clone(), Op::PfFitsInBits(n_bits) => { Value::Bool(args[0].as_pf().i().signed_bits() <= *n_bits as u32) @@ -254,18 +254,18 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) -> Value::Array(a) } } - Op::Fill(key_sort, size) => { + Op::Fill(f) => { let v = args[0].clone(); Value::Array(Array::new( - key_sort.clone(), + f.key_sort.clone(), Box::new(v), Default::default(), - *size, + f.size, )) } - Op::Array(key, value) => Value::Array(Array::from_vec( - key.clone(), - value.clone(), + Op::Array(a) => Value::Array(Array::from_vec( + a.key.clone(), + a.val.clone(), args.iter().cloned().cloned().collect(), )), Op::Select => { @@ -280,7 +280,7 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) -> for arg in args { let arr = arg.as_array().clone(); let iter = match arg.sort() { - Sort::Array(k, _, s) => (*k).clone().elems_iter_values().take(s).enumerate(), + Sort::Array(a) => a.key.clone().elems_iter_values().take(a.size).enumerate(), _ => panic!("Input type should be Array"), }; for (j, jval) in iter { @@ -289,14 +289,12 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) -> } let term = term( op.clone(), - args.iter() - .map(|a| leaf_term(Op::Const((*a).clone()))) - .collect(), + args.iter().map(|a| const_((*a).clone())).collect(), ); let (mut res, iter) = match check(&term) { - Sort::Array(k, v, n) => ( - Array::default((*k).clone(), &v, n), - (*k).clone().elems_iter_values().take(n).enumerate(), + Sort::Array(a) => ( + Array::default(a.key.clone(), &a.val, a.size), + a.key.clone().elems_iter_values().take(a.size).enumerate(), ), _ => panic!("Output type of map should be array"), }; @@ -311,10 +309,10 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) -> Op::Rot(i) => { let a = args[0].as_array().clone(); let (mut res, iter, len) = match args[0].sort() { - Sort::Array(k, v, n) => ( - Array::default((*k).clone(), &v, n), - (*k).clone().elems_iter_values().take(n).enumerate(), - n, + Sort::Array(a) => ( + Array::default(a.key.clone(), &a.val, a.size), + a.key.clone().elems_iter_values().take(a.size).enumerate(), + a.size, ), _ => panic!("Input type should be Array"), }; @@ -341,7 +339,7 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) -> } /// Compute a (deterministic) prime-field challenge. -pub fn pf_challenge(name: &str, field: &FieldT) -> FieldV { +pub fn eval_pf_challenge(name: &str, field: &FieldT) -> FieldV { use rand::SeedableRng; use rand_chacha::ChaChaRng; use std::hash::{Hash, Hasher}; diff --git a/src/ir/term/ext/map.rs b/src/ir/term/ext/map.rs index 49f67e529..53fe3ca38 100644 --- a/src/ir/term/ext/map.rs +++ b/src/ir/term/ext/map.rs @@ -7,7 +7,7 @@ use crate::ir::term::*; pub fn check_array_to_map(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> { let [array] = ty::count_or_ref(arg_sorts)?; let (k, v, _size) = ty::array_or(array, "ArrayToMap expects array")?; - Ok(Sort::Map(Box::new(k.clone()), Box::new(v.clone()))) + Ok(Sort::new_map(k.clone(), v.clone())) } /// Evaluate [super::ExtOp::ArrayToMap]. @@ -22,7 +22,7 @@ pub fn eval_array_to_map(args: &[&Value]) -> Value { pub fn check_map_flip(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> { let [map] = ty::count_or_ref(arg_sorts)?; let (k, v) = ty::map_or(map, "MapFlip expects map")?; - Ok(Sort::Map(Box::new(k.clone()), Box::new(v.clone()))) + Ok(Sort::new_map(v.clone(), k.clone())) } /// Evaluate [super::ExtOp::MapFlip]. diff --git a/src/ir/term/ext/poly.rs b/src/ir/term/ext/poly.rs index b544bf556..a48daee92 100644 --- a/src/ir/term/ext/poly.rs +++ b/src/ir/term/ext/poly.rs @@ -22,8 +22,7 @@ pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> { &Sort::Bool, "UniqDeriGcd pairs: second element must be a bool", )?; - let box_f = Box::new(f.clone()); - let arr = Sort::Array(box_f.clone(), box_f, size); + let arr = Sort::new_array(f.clone(), f.clone(), size); Ok(Sort::Tuple(Box::new([arr.clone(), arr]))) } else { // non-pair entries value diff --git a/src/ir/term/extras.rs b/src/ir/term/extras.rs index bcc1f55c3..dd638d0cb 100644 --- a/src/ir/term/extras.rs +++ b/src/ir/term/extras.rs @@ -9,7 +9,7 @@ pub fn to_width(t: &Term, w: usize) -> Term { match old_w.cmp(&w) { Ordering::Less => term(Op::BvUext(w - old_w), vec![t.clone()]), Ordering::Equal => t.clone(), - Ordering::Greater => term(Op::BvExtract(w - 1, 0), vec![t.clone()]), + Ordering::Greater => term(Op::BvExtract(w as u32 - 1, 0), vec![t.clone()]), } } @@ -76,7 +76,7 @@ pub fn contains(haystack: Term, needle: &Term) -> bool { pub fn free_in(v: &str, t: Term) -> bool { for n in PostOrderIter::new(t) { match &n.op() { - Op::Var(name, _) if v == name => { + Op::Var(var) if v == &*var.name => { return true; } _ => {} @@ -89,7 +89,7 @@ pub fn free_in(v: &str, t: Term) -> bool { pub fn free_variables(t: Term) -> FxHashSet<String> { PostOrderIter::new(t) .filter_map(|n| match &n.op() { - Op::Var(name, _) => Some(name.into()), + Op::Var(var) => Some(var.name.to_string()), _ => None, }) .collect() @@ -99,7 +99,7 @@ pub fn free_variables(t: Term) -> FxHashSet<String> { pub fn free_variables_with_sorts(t: Term) -> FxHashSet<(String, Sort)> { PostOrderIter::new(t) .filter_map(|n| match &n.op() { - Op::Var(name, sort) => Some((name.into(), sort.clone())), + Op::Var(var) => Some((var.name.to_string(), var.sort.clone())), _ => None, }) .collect() @@ -107,12 +107,12 @@ pub fn free_variables_with_sorts(t: Term) -> FxHashSet<(String, Sort)> { /// If this term is a constant field or bit-vector, get the unsigned int value. pub fn as_uint_constant(t: &Term) -> Option<Integer> { - match &t.op() { - Op::Const(Value::BitVector(bv)) => Some(bv.uint().clone()), - Op::Const(Value::Field(f)) => Some(f.i()), - Op::Const(Value::Bool(b)) => Some((*b).into()), + t.as_value_opt().and_then(|v| match v { + Value::BitVector(bv) => Some(bv.uint().clone()), + Value::Field(f) => Some(f.i()), + Value::Bool(b) => Some((*b).into()), _ => None, - } + }) } /// Assert that all variables in the term graph are declared in the metadata. @@ -143,10 +143,10 @@ pub fn parents_map(c: &Computation) -> TermMap<Vec<Term>> { /// The elements in this array (select terms) as a vector. pub fn array_elements(t: &Term) -> Vec<Term> { - if let Sort::Array(key_sort, _, size) = check(t) { - key_sort + if let Sort::Array(a) = check(t) { + a.key .elems_iter() - .take(size) + .take(a.size) .map(|key| term(Op::Select, vec![t.clone(), key])) .collect() } else { diff --git a/src/ir/term/fmt.rs b/src/ir/term/fmt.rs index abe6234a5..385fb1dcc 100644 --- a/src/ir/term/fmt.rs +++ b/src/ir/term/fmt.rs @@ -280,20 +280,20 @@ impl DisplayIr for Sort { Sort::F32 => write!(f, "f32"), Sort::F64 => write!(f, "f64"), Sort::Field(fty) => write!(f, "(mod {})", fty.modulus()), - Sort::Array(k, v, n) => { + Sort::Array(a) => { // we could make our own write macro. write!(f, "(array ")?; - k.ir_fmt(f)?; + a.key.ir_fmt(f)?; write!(f, " ")?; - v.ir_fmt(f)?; - write!(f, " {n})") + a.val.ir_fmt(f)?; + write!(f, " {})", a.size) } - Sort::Map(k, v) => { + Sort::Map(m) => { // we could make our own write macro. write!(f, "(map ")?; - k.ir_fmt(f)?; + m.key.ir_fmt(f)?; write!(f, " ")?; - v.ir_fmt(f)?; + m.val.ir_fmt(f)?; write!(f, ")") } Sort::Tuple(fields) => { @@ -319,7 +319,7 @@ impl DisplayIr for Op { match self { Op::Ite => write!(f, "ite"), Op::Eq => write!(f, "="), - Op::Var(n, _) => write!(f, "{n}"), + Op::Var(v) => write!(f, "{}", v.name), Op::Const(c) => c.ir_fmt(f), Op::BvBinOp(a) => write!(f, "{a}"), Op::BvBinPred(a) => write!(f, "{a}"), @@ -350,22 +350,22 @@ impl DisplayIr for Op { Op::IntNaryOp(a) => write!(f, "{a}"), Op::IntBinPred(a) => write!(f, "{a}"), Op::UbvToPf(a) => write!(f, "(bv2pf {})", a.modulus()), - Op::PfChallenge(n, m) => write!(f, "(challenge {} {})", n, m.modulus()), + Op::PfChallenge(c) => write!(f, "(challenge {} {})", c.name, c.field.modulus()), Op::Witness(n) => write!(f, "(witness {})", n), Op::PfFitsInBits(n) => write!(f, "(pf_fits_in_bits {})", n), Op::Select => write!(f, "select"), Op::Store => write!(f, "store"), Op::CStore => write!(f, "cstore"), - Op::Fill(key_sort, size) => { + Op::Fill(fill) => { write!(f, "(fill ")?; - key_sort.ir_fmt(f)?; - write!(f, " {})", *size) + fill.key_sort.ir_fmt(f)?; + write!(f, " {})", fill.size) } - Op::Array(k, v) => { + Op::Array(a) => { write!(f, "(array ")?; - k.ir_fmt(f)?; + a.key.ir_fmt(f)?; write!(f, " ")?; - v.ir_fmt(f)?; + a.val.ir_fmt(f)?; write!(f, ")") } Op::Tuple => write!(f, "tuple"), @@ -376,9 +376,9 @@ impl DisplayIr for Op { op.ir_fmt(f)?; write!(f, "))") } - Op::Call(name, a, r) => { - let arg_sorts = a.iter().map(|x| x.to_string()).join(" "); - write!(f, "(call {name} ({arg_sorts}) {r})") + Op::Call(c) => { + let arg_sorts = c.arg_sorts.iter().map(|x| x.to_string()).join(" "); + write!(f, "(call {} ({}) {})", c.name, arg_sorts, c.ret_sort) } Op::Rot(i) => write!(f, "(rot {i})"), Op::PfToBoolTrusted => write!(f, "pf2bool_trusted"), @@ -578,13 +578,7 @@ impl DisplayIr for ComputationMetadata { fn fmt_term_with_bindings(t: &Term, f: &mut IrFormatter) -> FmtResult { let close_dft_f = if f.cfg.use_default_field && f.default_field.is_none() { let fields: HashSet<FieldT> = PostOrderIter::new(t.clone()) - .filter_map(|c| { - if let Op::Const(Value::Field(f)) = &c.op() { - Some(f.ty()) - } else { - None - } - }) + .filter_map(|c| c.as_pf_opt().map(|f| f.ty())) .collect(); if fields.len() == 1 && !f.cfg.hide_field { f.default_field = fields.into_iter().next(); @@ -612,9 +606,9 @@ fn fmt_term_with_bindings(t: &Term, f: &mut IrFormatter) -> FmtResult { n_bindings += 1; } } - if let Op::Var(name, sort) = &t.op() { - write!(f, " ({name} ")?; - sort.ir_fmt(f)?; + if let Op::Var(v) = &t.op() { + write!(f, " ({} ", v.name)?; + v.sort.ir_fmt(f)?; writeln!(f, ")")?; } } diff --git a/src/ir/term/mod.rs b/src/ir/term/mod.rs index c1883b415..edf958d4e 100644 --- a/src/ir/term/mod.rs +++ b/src/ir/term/mod.rs @@ -46,7 +46,7 @@ pub mod text; pub mod ty; pub use bv::BitVector; -pub use eval::{eval, eval_cached, eval_op, pf_challenge}; +pub use eval::{eval, eval_cached, eval_op, eval_pf_challenge}; pub use ext::ExtOp; pub use ty::{check, check_rec, TypeError, TypeErrorReason}; @@ -54,9 +54,9 @@ pub use ty::{check, check_rec, TypeError, TypeErrorReason}; /// An operator pub enum Op { /// a variable - Var(String, Sort), + Var(Box<Var>), /// a constant - Const(Value), + Const(Box<Value>), /// if-then-else: ternary Ite, @@ -76,7 +76,7 @@ pub enum Op { /// Get bits (high) through (low) from the underlying bit-vector. /// /// Zero-indexed and inclusive. - BvExtract(usize, usize), + BvExtract(u32, u32), /// bit-vector concatenation. n-ary. Low-index arguements map to high-order bits BvConcat, /// add this many zero bits @@ -126,14 +126,14 @@ pub enum Op { /// Unsigned bit-vector to prime-field /// /// Takes the modulus. - UbvToPf(FieldT), + UbvToPf(Box<FieldT>), /// A random value, sampled uniformly and independently of its arguments. /// /// Takes a name (if deterministically sampled, challenges of different names are sampled /// differentely) and a field to sample from. /// /// In IR evaluation, we sample deterministically based on a hash of the name. - PfChallenge(String, FieldT), + PfChallenge(Box<ChallengeOp>), /// Requires the input pf element to fit in this many (unsigned) bits. PfFitsInBits(usize), /// Prime-field division @@ -141,7 +141,8 @@ pub enum Op { /// Receive a value from the prover (in a proof) /// The string is a name for it; does not need to be unique. - Witness(String), + /// The double box is to get a thin pointer. + Witness(Box<Box<str>>), /// Integer n-ary operator IntNaryOp(IntNaryOp), @@ -162,9 +163,9 @@ pub enum Op { /// Otherwise, oupputs `array`. CStore, /// Makes an array of the indicated key sort with the indicated size, filled with the argument. - Fill(Sort, usize), + Fill(Box<FillOp>), /// Create an array from (contiguous) values. - Array(Sort, Sort), + Array(Box<ArrayOp>), /// Assemble n things into a tuple Tuple, @@ -177,7 +178,7 @@ pub enum Op { Map(Box<Op>), /// Call a function (name, argument sorts, return sort) - Call(String, Vec<Sort>, Sort), + Call(Box<CallOp>), /// Cyclic right rotation of an array /// i.e. (Rot(1) [1,2,3,4]) --> ([4,1,2,3]) @@ -190,6 +191,53 @@ pub enum Op { ExtOp(ext::ExtOp), } +/// Variable +#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct Var { + /// Variable name + pub name: Box<str>, + /// Variable sort + pub sort: Sort, +} + +/// A function call operator +#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct ChallengeOp { + /// The key sort + pub name: Box<str>, + /// The size + pub field: FieldT, +} + +/// A function call operator +#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct FillOp { + /// The key sort + pub key_sort: Sort, + /// The size + pub size: usize, +} + +/// A function call operator +#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct CallOp { + /// The function name + pub name: String, + /// Argument sorts + pub arg_sorts: Vec<Sort>, + /// Return sorts + pub ret_sort: Sort, +} + +/// An array creation operator +#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct ArrayOp { + /// The key sort + pub key: Sort, + /// The value sort + pub val: Sort, +} + /// Boolean AND pub const AND: Op = Op::BoolNaryOp(BoolNaryOp::And); /// Boolean OR @@ -279,7 +327,7 @@ impl Op { match self { Op::Ite => Some(3), Op::Eq => Some(2), - Op::Var(_, _) => Some(0), + Op::Var(_) => Some(0), Op::Const(_) => Some(0), Op::BvBinOp(_) => Some(2), Op::BvBinPred(_) => Some(2), @@ -307,7 +355,7 @@ impl Op { Op::PfUnOp(_) => Some(1), Op::PfDiv => Some(2), Op::PfNaryOp(_) => None, - Op::PfChallenge(_, _) => None, + Op::PfChallenge(_) => None, Op::Witness(_) => Some(1), Op::PfFitsInBits(..) => Some(1), Op::IntNaryOp(_) => None, @@ -322,12 +370,53 @@ impl Op { Op::Field(_) => Some(1), Op::Update(_) => Some(2), Op::Map(op) => op.arity(), - Op::Call(_, args, _) => Some(args.len()), + Op::Call(c) => Some(c.arg_sorts.len()), Op::Rot(_) => Some(1), Op::ExtOp(o) => o.arity(), Op::PfToBoolTrusted => Some(1), } } + + /// Create a new [Op::Fill]. + pub fn new_fill(key_sort: Sort, size: usize) -> Self { + Op::Fill(Box::new(FillOp { key_sort, size })) + } + + /// Create a new [Op::PfChallenge]. + pub fn new_chall(name: String, field: FieldT) -> Self { + Op::PfChallenge(Box::new(ChallengeOp { + name: name.into_boxed_str(), + field, + })) + } + + /// Create a new [Op::Var]. + pub fn new_var(name: String, sort: Sort) -> Self { + Op::Var(Box::new(Var { + name: name.into_boxed_str(), + sort, + })) + } + + /// Create a new [Op::Const]. + pub fn new_const(value: Value) -> Self { + Op::Const(Box::new(value)) + } + + /// Create a new [Op::Const]. + pub fn new_witness(name: String) -> Self { + Op::Witness(Box::new(name.into_boxed_str())) + } + + /// Create a new [Op::Const]. + pub fn new_ubv_to_pf(field: FieldT) -> Self { + Op::UbvToPf(Box::new(field)) + } + + /// Create a new [Op::BvExtract]. + pub fn new_bv_extract(hi: usize, lo: usize) -> Self { + Op::BvExtract(hi as u32, lo as u32) + } } #[derive(Clone, PartialEq, Eq, Hash, Debug, Copy, Serialize, Deserialize)] @@ -709,13 +798,33 @@ pub enum Sort { /// Array from one sort to another, of fixed size. /// /// size presumes an order, and a zero, for the key sort. - Array(Box<Sort>, Box<Sort>, usize), + Array(Box<ArraySort>), /// Map from one sort to another. - Map(Box<Sort>, Box<Sort>), + Map(Box<MapSort>), /// A tuple Tuple(Box<[Sort]>), } +#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +/// Array sort +pub struct ArraySort { + /// key sort + pub key: Sort, + /// value sort + pub val: Sort, + /// size + pub size: usize, +} + +#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +/// Map sort +pub struct MapSort { + /// key sort + pub key: Sort, + /// value sort + pub val: Sort, +} + impl Default for Sort { fn default() -> Self { Self::Bool @@ -733,6 +842,11 @@ impl Sort { } } + /// Is this a bit-vector? + pub fn is_bv(&self) -> bool { + matches!(self, Sort::BitVector(..)) + } + #[track_caller] /// Unwrap the modulus of this prime field, panicking otherwise. pub fn as_pf(&self) -> &FieldT { @@ -743,6 +857,11 @@ impl Sort { } } + /// Is this a prime field? + pub fn is_pf(&self) -> bool { + matches!(self, Sort::Field(..)) + } + #[track_caller] /// Unwrap the constituent sorts of this tuple, panicking otherwise. pub fn as_tuple(&self) -> &[Sort] { @@ -756,13 +875,18 @@ impl Sort { #[track_caller] /// Unwrap the constituent sorts of this array, panicking otherwise. pub fn as_array(&self) -> (&Sort, &Sort, usize) { - if let Sort::Array(k, v, s) = self { - (k, v, *s) + if let Sort::Array(a) = self { + (&a.key, &a.val, a.size) } else { panic!("{} is not an array", self) } } + /// Create a new array sort + pub fn new_array(key: Sort, val: Sort, size: usize) -> Self { + Self::Array(Box::new(ArraySort { key, val, size })) + } + /// Is this an array? pub fn is_array(&self) -> bool { matches!(self, Sort::Array(..)) @@ -771,8 +895,8 @@ impl Sort { #[track_caller] /// Unwrap the constituent sorts of this array, panicking otherwise. pub fn as_map(&self) -> (&Sort, &Sort) { - if let Sort::Map(k, v) = self { - (k, v) + if let Sort::Map(m) = self { + (&m.key, &m.val) } else { panic!("{} is not a map", self) } @@ -783,6 +907,11 @@ impl Sort { matches!(self, Sort::Map(..)) } + /// Create a new map sort + pub fn new_map(key: Sort, val: Sort) -> Self { + Self::Map(Box::new(MapSort { key, val })) + } + /// The nth element of this sort. /// Only defined for booleans, bit-vectors, and field elements. #[track_caller] @@ -808,7 +937,7 @@ impl Sort { /// Only defined for booleans, bit-vectors, and field elements. #[track_caller] pub fn elems_iter(&self) -> Box<dyn Iterator<Item = Term>> { - Box::new(self.elems_iter_values().map(|v| leaf_term(Op::Const(v)))) + Box::new(self.elems_iter_values().map(const_)) } /// An iterator over the elements of this sort (as IR values). @@ -859,7 +988,7 @@ impl Sort { /// * floats: zero /// * tuples/arrays: recursively default pub fn default_term(&self) -> Term { - leaf_term(Op::Const(self.default_value())) + const_(self.default_value()) } /// Compute the default value for this sort. @@ -878,10 +1007,10 @@ impl Sort { Sort::F32 => Value::F32(0.0f32), Sort::F64 => Value::F64(0.0), Sort::Tuple(t) => Value::Tuple(t.iter().map(Sort::default_value).collect()), - Sort::Array(k, v, n) => Value::Array(Array::default((**k).clone(), v, *n)), - Sort::Map(k, v) => Value::Map(map::Map::new( - (**k).clone(), - (**v).clone(), + Sort::Array(a) => Value::Array(Array::default(a.key.clone(), &a.val, a.size)), + Sort::Map(m) => Value::Map(map::Map::new( + m.key.clone(), + m.val.clone(), std::iter::empty(), )), } @@ -989,7 +1118,7 @@ fn collect_types() { impl Term { /// Get the underlying boolean constant, if possible. pub fn as_bool_opt(&self) -> Option<bool> { - if let Op::Const(Value::Bool(b)) = &self.op() { + if let Some(Value::Bool(b)) = self.as_value_opt() { Some(*b) } else { None @@ -997,7 +1126,7 @@ impl Term { } /// Get the underlying bit-vector constant, if possible. pub fn as_bv_opt(&self) -> Option<&BitVector> { - if let Op::Const(Value::BitVector(b)) = &self.op() { + if let Some(Value::BitVector(b)) = self.as_value_opt() { Some(b) } else { None @@ -1005,7 +1134,7 @@ impl Term { } /// Get the underlying prime field constant, if possible. pub fn as_pf_opt(&self) -> Option<&FieldV> { - if let Op::Const(Value::Field(b)) = &self.op() { + if let Some(Value::Field(b)) = self.as_value_opt() { Some(b) } else { None @@ -1014,7 +1143,7 @@ impl Term { /// Get the underlying tuple constant, if possible. pub fn as_tuple_opt(&self) -> Option<&[Value]> { - if let Op::Const(Value::Tuple(t)) = &self.op() { + if let Some(Value::Tuple(t)) = self.as_value_opt() { Some(t) } else { None @@ -1023,7 +1152,7 @@ impl Term { /// Get the underlying array constant, if possible. pub fn as_array_opt(&self) -> Option<&Array> { - if let Op::Const(Value::Array(a)) = &self.op() { + if let Some(Value::Array(a)) = self.as_value_opt() { Some(a) } else { None @@ -1032,7 +1161,7 @@ impl Term { /// Get the underlying map constant, if possible. pub fn as_map_opt(&self) -> Option<&map::Map> { - if let Op::Const(Value::Map(a)) = &self.op() { + if let Some(Value::Map(a)) = self.as_value_opt() { Some(a) } else { None @@ -1061,8 +1190,8 @@ impl Term { /// Get the variable name; panic if not a variable. #[track_caller] pub fn as_var_name(&self) -> &str { - if let Op::Var(n, _) = &self.op() { - n + if let Op::Var(v) = &self.op() { + &v.name } else { panic!("not a variable") } @@ -1084,12 +1213,12 @@ impl Value { default, size, .. - }) => Sort::Array(Box::new(key_sort.clone()), Box::new(default.sort()), *size), + }) => Sort::new_array(key_sort.clone(), default.sort(), *size), Value::Map(map::Map { key_sort, value_sort, .. - }) => Sort::Map(Box::new(key_sort.clone()), Box::new(value_sort.clone())), + }) => Sort::new_map(key_sort.clone(), value_sort.clone()), Value::Tuple(v) => Sort::Tuple(v.iter().map(Value::sort).collect()), } } @@ -1149,6 +1278,12 @@ impl Value { } } + #[track_caller] + /// Unwrap the constituent value of this array, panicking otherwise. + pub fn is_array(&self) -> bool { + matches!(self, Value::Array(_)) + } + #[track_caller] /// Unwrap the constituent value of this map, panicking otherwise. pub fn as_map(&self) -> &map::Map { @@ -1167,6 +1302,7 @@ impl Value { None } } + /// Get the underlying bit-vector constant, if possible. pub fn as_bv_opt(&self) -> Option<&BitVector> { if let Value::BitVector(b) = self { @@ -1195,7 +1331,13 @@ impl Value { /// * a key sort, as all arrays do. This sort must be iterable (i.e., bool, int, bit-vector, or field). /// * a value sort, for the array's default pub fn make_array(key_sort: Sort, value_sort: Sort, i: Vec<Term>) -> Term { - term(Op::Array(key_sort, value_sort), i) + term( + Op::Array(Box::new(ArrayOp { + key: key_sort, + val: value_sort, + })), + i, + ) } /// Make a sequence of terms from an array. @@ -1231,6 +1373,16 @@ pub fn leaf_term(op: Op) -> Term { term(op, Vec::new()) } +/// Make a variable term. +pub fn var(name: String, sort: Sort) -> Term { + leaf_term(Op::new_var(name, sort)) +} + +/// Make a constant term. +pub fn const_(value: Value) -> Term { + leaf_term(Op::new_const(value)) +} + /// Make a term with arguments. #[track_caller] pub fn term(op: Op, cs: Vec<Term>) -> Term { @@ -1243,7 +1395,7 @@ pub fn term(op: Op, cs: Vec<Term>) -> Term { /// Make a prime-field constant term. pub fn pf_lit(elem: FieldV) -> Term { - leaf_term(Op::Const(Value::Field(elem))) + const_(Value::Field(elem)) } /// Make a bit-vector constant term. @@ -1251,15 +1403,12 @@ pub fn bv_lit<T>(uint: T, width: usize) -> Term where Integer: From<T>, { - leaf_term(Op::Const(Value::BitVector(BitVector::new( - uint.into(), - width, - )))) + const_(Value::BitVector(BitVector::new(uint.into(), width))) } /// Make a bit-vector constant term. pub fn bool_lit(b: bool) -> Term { - leaf_term(Op::Const(Value::Bool(b))) + const_(Value::Bool(b)) } #[macro_export] @@ -1394,7 +1543,7 @@ pub struct VariableMetadata { impl VariableMetadata { /// term (cached) pub fn term(&self) -> Term { - leaf_term(Op::Var(self.name.clone(), self.sort.clone())) + var(self.name.clone(), self.sort.clone()) } } @@ -1645,9 +1794,16 @@ impl ComputationMetadata { .iter() .map(|name| args.get(name).expect("Argument not found: {}").clone()) .collect::<Vec<Term>>(); - let ordered_sorts = ordered_args.iter().map(check).collect::<Vec<Sort>>(); - - term(Op::Call(name, ordered_sorts, ret_sort), ordered_args) + let arg_sorts = ordered_args.iter().map(check).collect::<Vec<Sort>>(); + + term( + Op::Call(Box::new(CallOp { + name, + arg_sorts, + ret_sort, + })), + ordered_args, + ) } } @@ -1726,7 +1882,7 @@ impl Computation { assert_eq!(&s, &check(&p), "precompute {} doesn't match sort {}", p, s); self.precomputes.add_output(name.to_owned(), p); } - leaf_term(Op::Var(name.to_owned(), s)) + var(name.to_owned(), s) } /// Create a new variable with the given metadata. @@ -1751,7 +1907,7 @@ impl Computation { assert_eq!(&sort, &check(&p)); self.precomputes.add_output(name.clone(), p); } - leaf_term(Op::Var(name, sort)) + var(name, sort) } /// Add a new input `new_input_var` to this computation, @@ -1789,7 +1945,7 @@ impl Computation { party: PartyId, ) -> Term { let f = Sort::Field(field); - let s = Sort::Array(Box::new(f.clone()), Box::new(f), size); + let s = Sort::new_array(f.clone(), f, size); let md = VariableMetadata { name: var.to_owned(), vis: Some(party), @@ -1880,7 +2036,7 @@ impl Computation { for v in self.metadata.vars.values() { if v.random { let field = v.sort.as_pf(); - let value = Value::Field(eval::pf_challenge(&v.name, field)); + let value = Value::Field(eval::eval_pf_challenge(&v.name, field)); values.insert(v.name.clone(), value); } } diff --git a/src/ir/term/precomp.rs b/src/ir/term/precomp.rs index 5e3be4b82..bbf644b45 100644 --- a/src/ir/term/precomp.rs +++ b/src/ir/term/precomp.rs @@ -63,8 +63,8 @@ impl PreComp { let o_tuple = term(Op::Tuple, os.values().cloned().collect()); let to_remove = &mut TermSet::default(); for t in PostOrderIter::new(o_tuple) { - if let Op::Var(ref name, _) = &t.op() { - if !known.contains(name) { + if let Op::Var(var) = &t.op() { + if !known.contains(&*var.name) { to_remove.insert(t); } } else if t.cs().iter().any(|c| to_remove.contains(c)) { @@ -120,8 +120,8 @@ impl PreComp { fn recompute_inputs(&mut self) { let mut inputs = FxHashSet::default(); for t in PostOrderIter::new(self.tuple()) { - if let Op::Var(name, sort) = &t.op() { - inputs.insert((name.clone(), sort.clone())); + if let Op::Var(var) = &t.op() { + inputs.insert((var.name.to_string(), var.sort.clone())); } } self.inputs = inputs; @@ -145,7 +145,7 @@ impl PreComp { let mut cache: TermMap<Term> = Default::default(); for (name, sort) in &self.sequence { let term = extras::substitute_cache(self.outputs.get(name).unwrap(), &mut cache); - let var_term = leaf_term(Op::Var(name.clone(), sort.clone())); + let var_term = var(name.clone(), sort.clone()); out.insert(name.into(), term.clone()); cache.insert(var_term, term); } @@ -158,7 +158,7 @@ impl PreComp { let mut stack: Vec<Term> = self .outputs .iter() - .map(|(name, t)| leaf_term(Op::Var(name.clone(), check(t)))) + .map(|(name, t)| var(name.clone(), check(t))) .collect(); let mut post_visited: TermSet = Default::default(); let mut pre_visited: TermSet = Default::default(); @@ -169,8 +169,8 @@ impl PreComp { if pre_visited.insert(t.clone()) { // children not yet pushed stack.push(t.clone()); - if let Op::Var(name, _) = t.op() { - if let Some(c) = self.outputs.get(name) { + if let Op::Var(var) = t.op() { + if let Some(c) = self.outputs.get(&*var.name) { if !post_visited.contains(c) { assert!(!pre_visited.contains(c), "loop on {} {}", c.id(), c); stack.push(c.clone()); @@ -186,8 +186,8 @@ impl PreComp { } } else { post_visited.insert(t.clone()); - if let Op::Var(name, _) = t.op() { - order.insert(name.clone(), order.len()); + if let Op::Var(var) = t.op() { + order.insert(var.name.to_string(), order.len()); } } } @@ -210,12 +210,11 @@ impl PreComp { let defined: TermSet = self .sequence .iter() - .map(|(n, s)| leaf_term(Op::Var(n.clone(), s.clone()))) + .map(|(n, s)| var(n.clone(), s.clone())) .collect(); let seen = RefCell::new(TermSet::default()); for (name, sort) in &self.inputs { - seen.borrow_mut() - .insert(leaf_term(Op::Var(name.clone(), sort.clone()))); + seen.borrow_mut().insert(var(name.clone(), sort.clone())); } for (name, sort) in &self.sequence { let t = self.outputs.get(name).unwrap(); @@ -226,8 +225,7 @@ impl PreComp { } seen.borrow_mut().insert(desc); } - seen.borrow_mut() - .insert(leaf_term(Op::Var(name.clone(), sort.clone()))); + seen.borrow_mut().insert(var(name.clone(), sort.clone())); } } @@ -237,7 +235,7 @@ impl PreComp { let mut stack: Vec<Term> = self .outputs .iter() - .map(|(name, t)| leaf_term(Op::Var(name.clone(), check(t)))) + .map(|(name, t)| var(name.clone(), check(t))) .collect(); let mut post_visited: TermSet = Default::default(); let mut pre_visited: TermSet = Default::default(); @@ -248,8 +246,8 @@ impl PreComp { if pre_visited.insert(t.clone()) { // children not yet pushed stack.push(t.clone()); - if let Op::Var(name, _) = t.op() { - if let Some(c) = self.outputs.get(name) { + if let Op::Var(var) = t.op() { + if let Some(c) = self.outputs.get(&*var.name) { if !post_visited.contains(c) { assert!(!pre_visited.contains(c), "loop on {} {}", c.id(), c); stack.push(c.clone()); diff --git a/src/ir/term/test.rs b/src/ir/term/test.rs index 28406096f..ecfb8c246 100644 --- a/src/ir/term/test.rs +++ b/src/ir/term/test.rs @@ -6,9 +6,9 @@ use fxhash::FxHashMap; #[test] fn eq() { - let v = leaf_term(Op::Var("a".to_owned(), Sort::Bool)); - let u = leaf_term(Op::Var("a".to_owned(), Sort::Bool)); - let w = leaf_term(Op::Var("b".to_owned(), Sort::Bool)); + let v = var("a".to_owned(), Sort::Bool); + let u = var("a".to_owned(), Sort::Bool); + let w = var("b".to_owned(), Sort::Bool); assert_eq!(v, u); assert!(v != w); assert!(u != w); @@ -17,17 +17,17 @@ fn eq() { #[test] fn bv2pf() { assert_eq!( - leaf_term(Op::Const(eval( + const_(eval( &text::parse_term(b"(bvshl #b0001 #b0010)"), &FxHashMap::default() - ))), + )), text::parse_term(b" #b0100 ") ); assert_eq!( - leaf_term(Op::Const(eval( + const_(eval( &text::parse_term(b" (set_default_modulus 17 ((pf2bv 4) #f1)) "), &FxHashMap::default() - ))), + )), text::parse_term(b" #b0001 ") ); } @@ -180,22 +180,22 @@ mod type_ { use super::*; fn t() -> Term { - let v = leaf_term(Op::Var("b".to_owned(), Sort::BitVector(4))); + let v = var("b".to_owned(), Sort::BitVector(4)); term![ Op::BvBit(4); term![ Op::BvConcat; v, - term![Op::BoolToBv; leaf_term(Op::Var("c".to_owned(), Sort::Bool))] + term![Op::BoolToBv; var("c".to_owned(), Sort::Bool)] ] ] } #[test] fn vars() { - let v = leaf_term(Op::Var("a".to_owned(), Sort::Bool)); + let v = var("a".to_owned(), Sort::Bool); assert_eq!(check(&v), Sort::Bool); - let v = leaf_term(Op::Var("b".to_owned(), Sort::BitVector(4))); + let v = var("b".to_owned(), Sort::BitVector(4)); assert_eq!(check(&v), Sort::BitVector(4)); let v = t(); assert_eq!(check(&v), Sort::Bool); @@ -206,9 +206,9 @@ mod type_ { let tt = t(); assert_eq!( vec![ - Op::Var("c".to_owned(), Sort::Bool), + Op::new_var("c".to_owned(), Sort::Bool), Op::BoolToBv, - Op::Var("b".to_owned(), Sort::BitVector(4)), + Op::new_var("b".to_owned(), Sort::BitVector(4)), Op::BvConcat, Op::BvBit(4), ], @@ -220,7 +220,7 @@ mod type_ { } fn bool(b: bool) -> Term { - leaf_term(Op::Const(Value::Bool(b))) + bool_lit(b) } pub fn bool_and_tests() -> Vec<Term> { diff --git a/src/ir/term/text/mod.rs b/src/ir/term/text/mod.rs index 44db7ac02..af62e4400 100644 --- a/src/ir/term/text/mod.rs +++ b/src/ir/term/text/mod.rs @@ -307,7 +307,7 @@ impl<'src> IrInterp<'src> { } } List(tts) => match &tts[..] { - [Leaf(Ident, b"extract"), a, b] => Ok(Op::BvExtract(self.usize(a), self.usize(b))), + [Leaf(Ident, b"extract"), a, b] => Ok(Op::BvExtract(self.u32(a), self.u32(b))), [Leaf(Ident, b"uext"), a] => Ok(Op::BvUext(self.usize(a))), [Leaf(Ident, b"sext"), a] => Ok(Op::BvSext(self.usize(a))), [Leaf(Ident, b"pf2bv"), a] => Ok(Op::PfToBv(self.usize(a))), @@ -316,22 +316,29 @@ impl<'src> IrInterp<'src> { [Leaf(Ident, b"ubv2fp"), a] => Ok(Op::UbvToFp(self.usize(a))), [Leaf(Ident, b"sbv2fp"), a] => Ok(Op::SbvToFp(self.usize(a))), [Leaf(Ident, b"fp2fp"), a] => Ok(Op::FpToFp(self.usize(a))), - [Leaf(Ident, b"challenge"), name, field] => Ok(Op::PfChallenge( + [Leaf(Ident, b"challenge"), name, field] => Ok(Op::new_chall( self.ident_string(name), FieldT::from(self.int(field)), )), - [Leaf(Ident, b"array"), k, v] => Ok(Op::Array(self.sort(k), self.sort(v))), - [Leaf(Ident, b"bv2pf"), a] => Ok(Op::UbvToPf(FieldT::from(self.int(a)))), + [Leaf(Ident, b"array"), k, v] => Ok(Op::Array(Box::new(ArrayOp { + key: self.sort(k), + val: self.sort(v), + }))), + [Leaf(Ident, b"bv2pf"), a] => Ok(Op::new_ubv_to_pf(FieldT::from(self.int(a)))), [Leaf(Ident, b"field"), a] => Ok(Op::Field(self.usize(a))), [Leaf(Ident, b"update"), a] => Ok(Op::Update(self.usize(a))), [Leaf(Ident, b"call"), Leaf(Ident, name), arg_sorts, ret_sort] => { let name = from_utf8(name).unwrap().to_owned(); let arg_sorts = self.sorts(arg_sorts); let ret_sort = self.sort(ret_sort); - Ok(Op::Call(name, arg_sorts, ret_sort)) + Ok(Op::Call(Box::new(CallOp { + name, + arg_sorts, + ret_sort, + }))) } [Leaf(Ident, b"fill"), key_sort, size] => { - Ok(Op::Fill(self.sort(key_sort), self.usize(size))) + Ok(Op::new_fill(self.sort(key_sort), self.usize(size))) } _ => todo!("Unparsed op: {}", tt), }, @@ -341,7 +348,7 @@ impl<'src> IrInterp<'src> { fn value(&mut self, tt: &TokTree<'src>) -> Value { let t = self.term(tt); match &t.op() { - Op::Const(v) => v.clone(), + Op::Const(v) => (**v).clone(), _ => panic!("Expected value, found term {}", t), } } @@ -357,14 +364,10 @@ impl<'src> IrInterp<'src> { match &ls[..] { [Leaf(Ident, b"mod"), m] => Sort::Field(FieldT::from(self.int(m))), [Leaf(Ident, b"bv"), w] => Sort::BitVector(self.usize(w)), - [Leaf(Ident, b"array"), k, v, s] => Sort::Array( - Box::new(self.sort(k)), - Box::new(self.sort(v)), - self.usize(s), - ), - [Leaf(Ident, b"map"), k, v] => { - Sort::Map(Box::new(self.sort(k)), Box::new(self.sort(v))) + [Leaf(Ident, b"array"), k, v, s] => { + Sort::new_array(self.sort(k), self.sort(v), self.usize(s)) } + [Leaf(Ident, b"map"), k, v] => Sort::new_map(self.sort(k), self.sort(v)), [Leaf(Ident, b"tuple"), ..] => { if ls.len() > 1 { if let Some(size) = self.maybe_usize(&ls[1]) { @@ -419,6 +422,15 @@ impl<'src> IrInterp<'src> { _ => None, } } + fn u32(&self, tt: &TokTree) -> u32 { + self.maybe_u32(tt).unwrap() + } + fn maybe_u32(&self, tt: &TokTree) -> Option<u32> { + match tt { + Leaf(Token::Int, s) => u32::from_str(from_utf8(s).ok()?).ok(), + _ => None, + } + } /// Parse lets, returning bindings, in-order. fn let_list(&mut self, tt: &TokTree<'src>) -> Vec<Vec<u8>> { if let List(tts) = tt { @@ -470,7 +482,7 @@ impl<'src> IrInterp<'src> { List(ls) => match &ls[..] { [Leaf(Token::Ident, name), s] => { let sort = self.sort(s); - let t = leaf_term(Op::Var(from_utf8(name).unwrap().to_owned(), sort)); + let t = var(from_utf8(name).unwrap().to_owned(), sort); self.bind(name, t); name.to_vec() } @@ -486,13 +498,9 @@ impl<'src> IrInterp<'src> { fn term(&mut self, tt: &TokTree<'src>) -> Term { use Token::*; match tt { - Leaf(Bin, s) => leaf_term(Op::Const(Value::BitVector( - BitVector::from_bin_lit(s).unwrap(), - ))), - Leaf(Hex, s) => leaf_term(Op::Const(Value::BitVector( - BitVector::from_hex_lit(s).unwrap(), - ))), - Leaf(Int, s) => leaf_term(Op::Const(Value::Int(Integer::parse(s).unwrap().into()))), + Leaf(Bin, s) => const_(Value::BitVector(BitVector::from_bin_lit(s).unwrap())), + Leaf(Hex, s) => const_(Value::BitVector(BitVector::from_hex_lit(s).unwrap())), + Leaf(Int, s) => const_(Value::Int(Integer::parse(s).unwrap().into())), Leaf(Field, s) => { let (v, m) = if let Some(i) = s.iter().position(|b| *b == b'm') { ( @@ -509,7 +517,7 @@ impl<'src> IrInterp<'src> { .clone(); (Integer::parse(&s[2..]).unwrap().into(), m) }; - leaf_term(Op::Const(Value::Field(FieldV::new::<Integer>(v, m)))) + pf_lit(FieldV::new::<Integer>(v, m)) } Leaf(Ident, b"false") => bool_lit(false), Leaf(Ident, b"true") => bool_lit(true), @@ -545,39 +553,37 @@ impl<'src> IrInterp<'src> { let default = self.value(&tts[2]); let size = self.usize(&tts[3]); let vals = self.value_alist(&tts[4]); - leaf_term(Op::Const(Value::Array(Array::new( + const_(Value::Array(Array::new( key_sort, Box::new(default), vals.into_iter().collect(), size, - )))) + ))) } Err(CtrlOp::MapValue) => { assert_eq!(tts.len(), 4); let key_sort = self.sort(&tts[1]); let value_sort = self.sort(&tts[2]); let vals = self.value_alist(&tts[3]); - leaf_term(Op::Const(Value::Map(map::Map::new( - key_sort, value_sort, vals, - )))) + const_(Value::Map(map::Map::new(key_sort, value_sort, vals))) } Err(CtrlOp::ListValue) => { assert_eq!(tts.len(), 3); let key_sort = self.sort(&tts[1]); let vals = self.value_list(&tts[2]); - leaf_term(Op::Const(Value::Array(Array::from_vec( + const_(Value::Array(Array::from_vec( key_sort, vals.first().unwrap().sort(), vals, - )))) + ))) } - Err(CtrlOp::TupleValue) => leaf_term(Op::Const(Value::Tuple( + Err(CtrlOp::TupleValue) => const_(Value::Tuple( tts[1..] .iter() .map(|tti| self.value(tti)) .collect::<Vec<_>>() .into(), - ))), + )), Err(CtrlOp::SetDefaultModulus) => { assert_eq!( tts.len(), @@ -900,7 +906,7 @@ pub fn parse_value_map(src: &[u8]) -> HashMap<String, Value> { .map(|(name, term)| { let name = std::str::from_utf8(name).unwrap().to_string(); let val = match term[0].op() { - Op::Const(v) => v.clone(), + Op::Const(v) => (**v).clone(), _ => panic!("Non-value binding {} associated with {}", term[0], name), }; (name, val) diff --git a/src/ir/term/ty.rs b/src/ir/term/ty.rs index 1ede58cb7..ce4bd8a9e 100644 --- a/src/ir/term/ty.rs +++ b/src/ir/term/ty.rs @@ -30,7 +30,7 @@ fn check_dependencies(t: &Term) -> Vec<Term> { match &t.op() { Op::Ite => vec![t.cs()[1].clone()], Op::Eq => Vec::new(), - Op::Var(_, _) => Vec::new(), + Op::Var(_) => Vec::new(), Op::Const(_) => Vec::new(), Op::BvBinOp(_) => vec![t.cs()[0].clone()], Op::BvBinPred(_) => Vec::new(), @@ -61,7 +61,7 @@ fn check_dependencies(t: &Term) -> Vec<Term> { Op::IntNaryOp(_) => Vec::new(), Op::IntBinPred(_) => Vec::new(), Op::UbvToPf(_) => Vec::new(), - Op::PfChallenge(_, _) => Vec::new(), + Op::PfChallenge(_) => Vec::new(), Op::Witness(_) => vec![t.cs()[0].clone()], Op::PfFitsInBits(_) => Vec::new(), Op::Select => vec![t.cs()[0].clone()], @@ -73,7 +73,7 @@ fn check_dependencies(t: &Term) -> Vec<Term> { Op::Field(_) => vec![t.cs()[0].clone()], Op::Update(_i) => vec![t.cs()[0].clone()], Op::Map(_) => t.cs().to_vec(), - Op::Call(_, _, _) => Vec::new(), + Op::Call(_) => Vec::new(), Op::Rot(_) => vec![t.cs()[0].clone()], Op::PfToBoolTrusted => Vec::new(), Op::ExtOp(o) => o.check_dependencies(t), @@ -87,14 +87,14 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> { match &t.op() { Op::Ite => Ok(get_ty(&t.cs()[1]).clone()), Op::Eq => Ok(Sort::Bool), - Op::Var(_, s) => Ok(s.clone()), + Op::Var(v) => Ok(v.sort.clone()), Op::Const(c) => Ok(c.sort()), Op::BvBinOp(_) => Ok(get_ty(&t.cs()[0]).clone()), Op::BvBinPred(_) => Ok(Sort::Bool), Op::BvNaryOp(_) => Ok(get_ty(&t.cs()[0]).clone()), Op::BvUnOp(_) => Ok(get_ty(&t.cs()[0]).clone()), Op::BoolToBv => Ok(Sort::BitVector(1)), - Op::BvExtract(a, b) => Ok(Sort::BitVector(a - b + 1)), + Op::BvExtract(a, b) => Ok(Sort::BitVector(*a as usize - *b as usize + 1)), Op::BvConcat => t .cs() .iter() @@ -138,22 +138,18 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> { Op::PfNaryOp(_) => Ok(get_ty(&t.cs()[0]).clone()), Op::IntNaryOp(_) => Ok(Sort::Int), Op::IntBinPred(_) => Ok(Sort::Bool), - Op::UbvToPf(m) => Ok(Sort::Field(m.clone())), - Op::PfChallenge(_, m) => Ok(Sort::Field(m.clone())), + Op::UbvToPf(m) => Ok(Sort::Field((**m).clone())), + Op::PfChallenge(c) => Ok(Sort::Field(c.field.clone())), Op::Witness(_) => Ok(get_ty(&t.cs()[0]).clone()), Op::PfFitsInBits(_) => Ok(Sort::Bool), Op::Select => array_or(get_ty(&t.cs()[0]), "select").map(|(_, v, _)| v.clone()), Op::Store => Ok(get_ty(&t.cs()[0]).clone()), - Op::Array(k, v) => Ok(Sort::Array( - Box::new(k.clone()), - Box::new(v.clone()), - t.cs().len(), - )), + Op::Array(a) => Ok(Sort::new_array(a.key.clone(), a.val.clone(), t.cs().len())), Op::CStore => Ok(get_ty(&t.cs()[0]).clone()), - Op::Fill(key_sort, size) => Ok(Sort::Array( - Box::new(key_sort.clone()), - Box::new(get_ty(&t.cs()[0]).clone()), - *size, + Op::Fill(fill) => Ok(Sort::new_array( + fill.key_sort.clone(), + get_ty(&t.cs()[0]).clone(), + fill.size, )), Op::Tuple => Ok(Sort::Tuple(t.cs().iter().map(get_ty).cloned().collect())), Op::Field(i) => { @@ -200,11 +196,11 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> { Some(e) => Err(e), None => { let value_sort = rec_check_raw_helper(op, &arg_sorts_to_inner_op)?; - Ok(Sort::Array(Box::new(key_sort), Box::new(value_sort), size)) + Ok(Sort::new_array(key_sort, value_sort, size)) } } } - Op::Call(_, _, ret) => Ok(ret.clone()), + Op::Call(c) => Ok(c.ret_sort.clone()), Op::Rot(_) => Ok(get_ty(&t.cs()[0]).clone()), Op::PfToBoolTrusted => Ok(Sort::Bool), Op::ExtOp(o) => { @@ -263,7 +259,7 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea match (oper, a) { (Op::Eq, &[a, b]) => eq_or(a, b, "=").map(|_| Sort::Bool), (Op::Ite, &[&Sort::Bool, b, c]) => eq_or(b, c, "ITE").map(|_| b.clone()), - (Op::Var(_, s), &[]) => Ok(s.clone()), + (Op::Var(v), &[]) => Ok(v.sort.clone()), (Op::Const(c), &[]) => Ok(c.sort()), (Op::BvBinOp(_), &[a, b]) => { let ctx = "bv binary op"; @@ -286,8 +282,8 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea (Op::BvUnOp(_), &[a]) => bv_or(a, "bv unary op").cloned(), (Op::BoolToBv, &[Sort::Bool]) => Ok(Sort::BitVector(1)), (Op::BvExtract(high, low), &[Sort::BitVector(w)]) => { - if low <= high && high < w { - Ok(Sort::BitVector(high - low + 1)) + if low <= high && *high < *w as u32 { + Ok(Sort::BitVector((high - low + 1) as usize)) } else { Err(TypeErrorReason::OutOfBounds(format!( "Cannot slice from {high} to {low} in a bit-vector of width {w}" @@ -360,8 +356,8 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea .and_then(|t| pf_or(t, ctx)) .cloned() } - (Op::UbvToPf(m), &[a]) => bv_or(a, "ubv-to-pf").map(|_| Sort::Field(m.clone())), - (Op::PfChallenge(_, m), _) => Ok(Sort::Field(m.clone())), + (Op::UbvToPf(m), &[a]) => bv_or(a, "ubv-to-pf").map(|_| Sort::Field((**m).clone())), + (Op::PfChallenge(f), _) => Ok(Sort::Field(f.field.clone())), (Op::Witness(_), &[a]) => Ok(a.clone()), (Op::PfFitsInBits(_), &[a]) => pf_or(a, "pf fits in bits").map(|_| Sort::Bool), (Op::PfUnOp(_), &[a]) => pf_or(a, "pf unary op").cloned(), @@ -375,24 +371,22 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea (Op::IntBinPred(_), &[a, b]) => int_or(a, "int bin pred") .and_then(|_| int_or(b, "int bin pred")) .map(|_| Sort::Bool), - (Op::Select, &[Sort::Array(k, v, _), a]) => eq_or(k, a, "select").map(|_| (**v).clone()), - (Op::Store, &[Sort::Array(k, v, n), a, b]) => eq_or(k, a, "store") - .and_then(|_| eq_or(v, b, "store")) - .map(|_| Sort::Array(k.clone(), v.clone(), *n)), - (Op::CStore, &[Sort::Array(k, v, n), a, b, c]) => eq_or(k, a, "cstore") - .and_then(|_| eq_or(v, b, "cstore")) + (Op::Select, &[Sort::Array(arr), a]) => { + eq_or(&arr.key, a, "select").map(|_| arr.val.clone()) + } + (Op::Store, &[s @ Sort::Array(arr), a, b]) => eq_or(&arr.key, a, "store") + .and_then(|_| eq_or(&arr.val, b, "store")) + .map(|_| s.clone()), + (Op::CStore, &[s @ Sort::Array(arr), a, b, c]) => eq_or(&arr.key, a, "cstore") + .and_then(|_| eq_or(&arr.val, b, "cstore")) .and_then(|_| bool_or(c, "cstore")) - .map(|_| Sort::Array(k.clone(), v.clone(), *n)), - (Op::Fill(key_sort, size), &[v]) => Ok(Sort::Array( - Box::new(key_sort.clone()), - Box::new(v.clone()), - *size, - )), - (Op::Array(k, v), a) => { + .map(|_| s.clone()), + (Op::Fill(f), &[v]) => Ok(Sort::new_array(f.key_sort.clone(), v.clone(), f.size)), + (Op::Array(arr), a) => { let ctx = "array op"; a.iter() - .try_fold((), |(), ai| eq_or(v, ai, ctx).map(|_| ())) - .map(|_| Sort::Array(Box::new(k.clone()), Box::new(v.clone()), a.len())) + .try_fold((), |(), ai| eq_or(&arr.val, ai, ctx).map(|_| ())) + .map(|_| Sort::new_array(arr.key.clone(), arr.val.clone(), a.len())) } (Op::Tuple, a) => Ok(Sort::Tuple(a.iter().map(|a| (*a).clone()).collect())), (Op::Field(i), &[a]) => tuple_or(a, "tuple field access").and_then(|t| { @@ -420,25 +414,29 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea // recursively call helper to get value type of mapped array // then return Ok(...) - let (key_sort, size) = match a[0].clone() { - Sort::Array(k, _, s) => (*k, s), - s => return Err(TypeErrorReason::ExpectedArray(s, "map")), + let (key_sort, size) = match a[0] { + Sort::Array(arr) => (&arr.key, arr.size), + s => return Err(TypeErrorReason::ExpectedArray(s.clone(), "map")), }; let mut val_sorts = Vec::new(); for a_i in a { match (*a_i).clone() { - Sort::Array(k, v, s) => { - if *k != key_sort { - return Err(TypeErrorReason::NotEqual(*k, key_sort, "map: key sorts")); + Sort::Array(arr) => { + if &arr.key != key_sort { + return Err(TypeErrorReason::NotEqual( + arr.key.clone(), + key_sort.clone(), + "map: key sorts", + )); } - if s != size { + if arr.size != size { return Err(TypeErrorReason::Custom( "map: array lengths unequal".to_string(), )); } - val_sorts.push((*v).clone()); + val_sorts.push(arr.val.clone()); } s => return Err(TypeErrorReason::ExpectedArray(s, "map")), }; @@ -449,21 +447,24 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea new_a.push(ptr); } rec_check_raw_helper(&op.clone(), &new_a[..]) - .map(|val_sort| Sort::Array(Box::new(key_sort), Box::new(val_sort), size)) + .map(|val_sort| Sort::new_array(key_sort.clone(), val_sort.clone(), size)) } - (Op::Call(_, ex_args, ret), act_args) => { - if ex_args.len() != act_args.len() { - Err(TypeErrorReason::ExpectedArgs(ex_args.len(), act_args.len())) + (Op::Call(c), act_args) => { + if c.arg_sorts.len() != act_args.len() { + Err(TypeErrorReason::ExpectedArgs( + c.arg_sorts.len(), + act_args.len(), + )) } else { - for (e, a) in ex_args.iter().zip(act_args) { + for (e, a) in c.arg_sorts.iter().zip(act_args) { eq_or(e, a, "in function call")?; } - Ok(ret.clone()) + Ok(c.ret_sort.clone()) } } - (Op::Rot(_), &[Sort::Array(k, v, n)]) => bv_or(k, "rot key") - .and_then(|_| bv_or(v, "rot val")) - .map(|_| Sort::Array(k.clone(), v.clone(), *n)), + (Op::Rot(_), &[s @ Sort::Array(a)]) => bv_or(&a.key, "rot key") + .and_then(|_| bv_or(&a.val, "rot val")) + .map(|_| s.clone()), (Op::PfToBoolTrusted, &[k]) => pf_or(k, "pf to bool argument").map(|_| Sort::Bool), (Op::ExtOp(o), _) => o.check(a), (_, _) => Err(TypeErrorReason::Custom("other".to_string())), @@ -580,8 +581,8 @@ pub(super) fn array_or<'a>( a: &'a Sort, ctx: &'static str, ) -> Result<(&'a Sort, &'a Sort, usize), TypeErrorReason> { - if let Sort::Array(k, v, size) = a { - Ok((k, v, *size)) + if let Sort::Array(arr) = a { + Ok((&arr.key, &arr.val, arr.size)) } else { Err(TypeErrorReason::ExpectedArray(a.clone(), ctx)) } @@ -591,8 +592,8 @@ pub(super) fn map_or<'a>( a: &'a Sort, ctx: &'static str, ) -> Result<(&'a Sort, &'a Sort), TypeErrorReason> { - if let Sort::Map(k, v) = a { - Ok((k, v)) + if let Sort::Map(m) = a { + Ok((&m.key, &m.val)) } else { Err(TypeErrorReason::ExpectedMap(a.clone(), ctx)) } @@ -602,8 +603,8 @@ fn arrmap_or<'a>( a: &'a Sort, ctx: &'static str, ) -> Result<(&'a Sort, &'a Sort, &'a usize), TypeErrorReason> { - if let Sort::Array(k, v, s) = a { - Ok((k, v, s)) + if let Sort::Array(arr) = a { + Ok((&arr.key, &arr.val, &arr.size)) } else { Err(TypeErrorReason::ExpectedArray(a.clone(), ctx)) } diff --git a/src/target/aby/assignment/ilp.rs b/src/target/aby/assignment/ilp.rs index ec79167a6..d71ec7657 100644 --- a/src/target/aby/assignment/ilp.rs +++ b/src/target/aby/assignment/ilp.rs @@ -223,13 +223,23 @@ mod tests { var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR") ); let costs = CostModel::from_opa_cost_file(&p); - let cs = Computation { - outputs: vec![term![BV_MUL; - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))), - leaf_term(Op::Var("b".to_owned(), Sort::BitVector(32))) - ]], - ..Default::default() - }; + let cs = text::parse_computation( + b" + (computation + (metadata (parties) + (inputs + (a (bv 32)) + (b (bv 32)) + ) + (commitments) + ) + (precompute () () (#t )) + (declare ((a (bv 32)) (b (bv 32))) + (bvmul a b) + ) + ) + ", + ); let _assignment = build_ilp(&cs, &costs); } @@ -240,34 +250,31 @@ mod tests { var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR") ); let costs = CostModel::from_opa_cost_file(&p); - let cs = Computation { - outputs: vec![term![Op::Eq; - term![BV_MUL; - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))), - term![BV_MUL; - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))), - term![BV_MUL; - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))), - term![BV_MUL; - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))), - term![BV_MUL; - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))), - term![BV_MUL; - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))), - term![BV_MUL; - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))), - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))) - ] - ] - ] - ] - ] - ] - ], - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))) - ]], - ..Default::default() - }; + let cs = text::parse_computation( + b" + (computation + (metadata (parties) + (inputs + (a (bv 32)) + ) + (commitments) + ) + (precompute () () (#t )) + (declare ((a (bv 32))) + (= + (bvmul a + (bvmul a + (bvmul a + (bvmul a + (bvmul a + (bvmul a + (bvmul a a))))))) + a + ) + ) + ) + ", + ); let assignment = build_ilp(&cs, &costs); // Big enough to do the math with arith assert_eq!( @@ -285,25 +292,28 @@ mod tests { var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR") ); let costs = CostModel::from_opa_cost_file(&p); - let cs = Computation { - outputs: vec![term![Op::Eq; - term![BV_MUL; - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))), - term![BV_MUL; - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))), - term![BV_MUL; - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))), - term![BV_MUL; - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))), - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))) - ] - ] - ] - ], - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))) - ]], - ..Default::default() - }; + let cs = text::parse_computation( + b" + (computation + (metadata (parties) + (inputs + (a (bv 32)) + ) + (commitments) + ) + (precompute () () (#t )) + (declare ((a (bv 32))) + (= + (bvmul a + (bvmul a + (bvmul a + (bvmul a a)))) + a + ) + ) + ) + ", + ); let assignment = build_ilp(&cs, &costs); // All yao assert_eq!( diff --git a/src/target/aby/trans.rs b/src/target/aby/trans.rs index 660db14c8..fcc585007 100644 --- a/src/target/aby/trans.rs +++ b/src/target/aby/trans.rs @@ -178,14 +178,14 @@ impl<'a> ToABY<'a> { .unwrap(); let op = "CONS"; - match &t.op() { - Op::Const(Value::BitVector(b)) => { + match t.as_value_opt() { + Some(Value::BitVector(b)) => { let value = b.as_sint(); let bitlen = 32; let line = format!("2 1 {value} {bitlen} {output_share} {op}\n"); self.const_output.push(line); } - Op::Const(Value::Bool(b)) => { + Some(Value::Bool(b)) => { let value = *b as i32; let bitlen = 1; let line = format!("2 1 {value} {bitlen} {output_share} {op}\n"); @@ -239,7 +239,7 @@ impl<'a> ToABY<'a> { len += match s { Sort::Bool => 1, Sort::BitVector(_) => 1, - Sort::Array(_, _, n) => *n, + Sort::Array(a) => a.size, Sort::Tuple(sorts) => { let mut inner_len = 0; for inner_s in sorts.iter() { @@ -280,19 +280,19 @@ impl<'a> ToABY<'a> { fn embed_bool(&mut self, t: Term) { let to_share_type = self.get_term_share_type(&t); match &t.op() { - Op::Var(name, Sort::Bool) => { + Op::Var(v) if matches!(&v.sort, Sort::Bool) => { let md = self.get_md(); - if !self.inputs.contains(&t) && md.is_input(name) { - let vis = self.unwrap_vis(name); + if !self.inputs.contains(&t) && md.is_input(&v.name) { + let vis = self.unwrap_vis(&v.name); let s = self.get_share(&t, to_share_type); let op = "IN"; if vis == PUBLIC { let bitlen = 1; - let line = format!("3 1 {name} {vis} {bitlen} {s} {op}\n"); + let line = format!("3 1 {} {vis} {bitlen} {s} {op}\n", v.name); self.bytecode_input.push(line); } else { - let line = format!("2 1 {name} {vis} {s} {op}\n"); + let line = format!("2 1 {} {vis} {s} {op}\n", v.name); self.bytecode_input.push(line); } self.inputs.push(t.clone()); @@ -407,25 +407,26 @@ impl<'a> ToABY<'a> { fn embed_bv(&mut self, t: Term) { let to_share_type = self.get_term_share_type(&t); match &t.op() { - Op::Var(name, Sort::BitVector(_)) => { + Op::Var(v) if matches!(&v.sort, Sort::BitVector(_)) => { let md = self.get_md(); - if !self.inputs.contains(&t) && md.is_input(name) { - let vis = self.unwrap_vis(name); + if !self.inputs.contains(&t) && md.is_input(&v.name) { + let vis = self.unwrap_vis(&v.name); let s = self.get_share(&t, to_share_type); let op = "IN"; if vis == PUBLIC { let bitlen = 32; - let line = format!("3 1 {name} {vis} {bitlen} {s} {op}\n"); + let line = format!("3 1 {} {vis} {bitlen} {s} {op}\n", v.name); self.bytecode_input.push(line); } else { - let line = format!("2 1 {name} {vis} {s} {op}\n"); + let line = format!("2 1 {} {vis} {s} {op}\n", v.name); self.bytecode_input.push(line); } self.inputs.push(t.clone()); } } - Op::Const(Value::BitVector(_)) => { + Op::Const(_) => { + assert!(t.as_bv_opt().is_some()); // create all three shares self.insert_const(&t); } @@ -537,7 +538,7 @@ impl<'a> ToABY<'a> { let select_share = self.get_share(&t, to_share_type); let array_share = self.get_share(&t.cs()[0], to_share_type); - let line = if let Op::Const(Value::BitVector(bv)) = &t.cs()[1].op() { + let line = if let Some(Value::BitVector(bv)) = t.cs()[1].as_value_opt() { let op = "SELECT_CONS"; let idx = bv.uint().to_usize().unwrap(); let len = self.get_sort_len(&check(&t.cs()[0])); @@ -557,51 +558,22 @@ impl<'a> ToABY<'a> { fn embed_vector(&mut self, t: Term) { let to_share_type = self.get_term_share_type(&t); - match &t.op() { - Op::Const(Value::Array(arr)) => { - let array_share = self.get_share(&t, to_share_type); - let mut shares: Vec<i32> = Vec::new(); - for i in 0..arr.size { - // TODO: sort of index might not be a 32-bit bitvector - let idx = Value::BitVector(BitVector::new(Integer::from(i), 32)); - let v = match arr.map.get(&idx) { - Some(c) => c, - None => &*arr.default, - }; - - // TODO: sort of value might not be a 32-bit bitvector - let v_term = leaf_term(Op::Const(v.clone())); - if self.const_cache.contains_key(&v_term) { - // existing const - let s = self.get_share(&v_term, to_share_type); - shares.push(s); - } else { - // new const - self.insert_const(&v_term); - let s = self.get_share(&v_term, to_share_type); - shares.push(s); - } - } - assert!(shares.len() == arr.size); - - let op = "CONS_ARRAY"; - let line = format!( - "{} 1 {} {} {}\n", - arr.size, - self.shares_to_string(shares), - array_share, - op - ); - self.const_output.push(line); - self.term_to_shares.insert(t.clone(), array_share); - } - Op::Const(Value::Tuple(tup)) => { - let tuple_share = self.get_share(&t, to_share_type); - let mut shares: Vec<i32> = Vec::new(); - for val in tup.iter() { - match val { - Value::BitVector(b) => { - let v_term: Term = bv_lit(b.as_sint(), 32); + match t.op() { + Op::Const(v) => { + match &**v { + Value::Array(arr) => { + let array_share = self.get_share(&t, to_share_type); + let mut shares: Vec<i32> = Vec::new(); + for i in 0..arr.size { + // TODO: sort of index might not be a 32-bit bitvector + let idx = Value::BitVector(BitVector::new(Integer::from(i), 32)); + let v = match arr.map.get(&idx) { + Some(c) => c, + None => &*arr.default, + }; + + // TODO: sort of value might not be a 32-bit bitvector + let v_term = const_(v.clone()); if self.const_cache.contains_key(&v_term) { // existing const let s = self.get_share(&v_term, to_share_type); @@ -613,21 +585,55 @@ impl<'a> ToABY<'a> { shares.push(s); } } - _ => todo!(), + assert!(shares.len() == arr.size); + + let op = "CONS_ARRAY"; + let line = format!( + "{} 1 {} {} {}\n", + arr.size, + self.shares_to_string(shares), + array_share, + op + ); + self.const_output.push(line); + self.term_to_shares.insert(t.clone(), array_share); } + Value::Tuple(tup) => { + let tuple_share = self.get_share(&t, to_share_type); + let mut shares: Vec<i32> = Vec::new(); + for val in tup.iter() { + match val { + Value::BitVector(b) => { + let v_term: Term = bv_lit(b.as_sint(), 32); + if self.const_cache.contains_key(&v_term) { + // existing const + let s = self.get_share(&v_term, to_share_type); + shares.push(s); + } else { + // new const + self.insert_const(&v_term); + let s = self.get_share(&v_term, to_share_type); + shares.push(s); + } + } + _ => todo!(), + } + } + assert!(shares.len() == tup.len()); + + let op = "CONS_TUPLE"; + let line = format!( + "{} 1 {} {} {}\n", + tup.len(), + self.shares_to_string(shares.clone()), + tuple_share, + op + ); + self.const_output.push(line); + self.term_to_shares.insert(t.clone(), tuple_share); + } + _ => unimplemented!("{}", t.op()), } - assert!(shares.len() == tup.len()); - - let op = "CONS_TUPLE"; - let line = format!( - "{} 1 {} {} {}\n", - tup.len(), - self.shares_to_string(shares.clone()), - tuple_share, - op - ); - self.const_output.push(line); - self.term_to_shares.insert(t.clone(), tuple_share); } Op::Ite => { let op = "MUX"; @@ -648,7 +654,7 @@ impl<'a> ToABY<'a> { let value_share = self.get_share(&t.cs()[2], to_share_type); let store_share = self.get_share(&t, to_share_type); - let line = if let Op::Const(Value::BitVector(bv)) = &t.cs()[1].op() { + let line = if let Some(Value::BitVector(bv)) = t.cs()[1].as_value_opt() { let op = "STORE_CONS"; let idx = bv.uint().to_usize().unwrap(); let len = self.get_sort_len(&check(&t.cs()[0])); @@ -725,9 +731,9 @@ impl<'a> ToABY<'a> { self.bytecode_output.push(line); self.term_to_shares.insert(t.clone(), tuple_share); } - Op::Call(name, ..) => { + Op::Call(call) => { let call_share = self.get_share(&t, to_share_type); - let op = format!("CALL({name})"); + let op = format!("CALL({})", call.name); let mut arg_shares: Vec<i32> = Vec::new(); for c in t.cs().iter() { diff --git a/src/target/ilp/trans.rs b/src/target/ilp/trans.rs index fbae69b9e..64092e7e3 100644 --- a/src/target/ilp/trans.rs +++ b/src/target/ilp/trans.rs @@ -193,8 +193,8 @@ impl ToMilp { debug_assert!(check(&c) == Sort::Bool); if !self.cache.contains_key(&c) { let lc = match &c.op() { - Op::Var(name, Sort::Bool) => self.bit(name.to_string()), - Op::Const(Value::Bool(b)) => Expression::from(*b as i32), + Op::Var(v) => self.bit(v.name.to_string()), + Op::Const(b) => Expression::from(b.as_bool() as i32), Op::Eq => self.embed_eq(&c.cs()[0], &c.cs()[1]), Op::Ite => { let a = self.get_bool(&c.cs()[0]).clone(); @@ -256,12 +256,13 @@ impl ToMilp { fn embed_bv(&mut self, bv: Term) { if let Sort::BitVector(n) = check(&bv) { if !self.cache.contains_key(&bv) { - match &bv.op() { - Op::Var(name, Sort::BitVector(n_bits)) => { - let var = self.bv_lit(name.clone(), *n_bits); + match bv.op() { + Op::Var(v) => { + let var = self.bv_lit(v.name.to_string(), v.sort.as_bv()); self.set_bv_uint(bv.clone(), var, n); } - Op::Const(Value::BitVector(b)) => { + Op::Const(c) => { + let b = c.as_bv(); let bit_lcs = (0..b.width()) .map(|i| Expression::from(b.uint().get_bit(i as u32) as i32)) .collect(); @@ -436,8 +437,8 @@ impl ToMilp { let bits = self .get_bv_bits(&bv.cs()[0]) .into_iter() - .skip(*low) - .take(*high - *low + 1) + .skip(*low as usize) + .take((*high - *low + 1) as usize) .collect(); self.set_bv_bits(bv, bits); } @@ -701,12 +702,12 @@ mod test { fn bool_test() { let cs = Computation { outputs: vec![ - leaf_term(Op::Var("a".to_owned(), Sort::Bool)), - term![Op::Not; leaf_term(Op::Var("b".to_owned(), Sort::Bool))], + var("a".to_owned(), Sort::Bool), + term![Op::Not; var("b".to_owned(), Sort::Bool)], // max this term![AND; - leaf_term(Op::Var("a".to_owned(), Sort::Bool)), - leaf_term(Op::Var("b".to_owned(), Sort::Bool))], + var("a".to_owned(), Sort::Bool), + var("b".to_owned(), Sort::Bool)], ], ..Default::default() }; @@ -725,7 +726,7 @@ mod test { term![Op::Not; t] }; let cs = Computation::from_constraint_system_parts( - vec![t, leaf_term(Op::Const(Value::Bool(true)))], + vec![t, const_(Value::Bool(true))], Vec::new(), ); let mut ilp = to_ilp(cs); @@ -769,7 +770,7 @@ mod test { init(); let mut cs = Computation::new(); cs.assert(term.clone()); - cs.assert(leaf_term(Op::Const(Value::Bool(true)))); + cs.assert(const_(Value::Bool(true))); let ilp = to_ilp(cs); let r = ilp.solve(default_solver); if r.is_err() { @@ -848,7 +849,7 @@ mod test { #[test] fn trivial_bv_opt() { let cs = Computation { - outputs: vec![leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4)))], + outputs: vec![var("a".to_owned(), Sort::BitVector(4))], ..Default::default() }; let ilp = to_ilp(cs); @@ -861,7 +862,7 @@ mod test { fn mul1_bv_opt() { let cs = Computation { outputs: vec![term![BV_MUL; - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4))), + var("a".to_owned(), Sort::BitVector(4)), bv_lit(1,4) ]], ..Default::default() @@ -875,7 +876,7 @@ mod test { fn mul2_bv_opt() { let cs = Computation { outputs: vec![term![BV_MUL; - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4))), + var("a".to_owned(), Sort::BitVector(4)), bv_lit(2,4) ]], ..Default::default() @@ -889,11 +890,11 @@ mod test { let cs = Computation { outputs: vec![term![BV_ADD; term![BV_MUL; - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4))), + var("a".to_owned(), Sort::BitVector(4)), bv_lit(2,4) ], - leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4))) + var("a".to_owned(), Sort::BitVector(4)) ]], ..Default::default() }; @@ -904,8 +905,8 @@ mod test { } #[test] fn ite_bv_opt() { - let a = leaf_term(Op::Var("a".to_owned(), Sort::BitVector(4))); - let c = leaf_term(Op::Var("c".to_owned(), Sort::Bool)); + let a = var("a".to_owned(), Sort::BitVector(4)); + let c = var("c".to_owned(), Sort::Bool); let cs = Computation { outputs: vec![term![BV_ADD; term![ITE; c, bv_lit(2,4), bv_lit(1,4)], diff --git a/src/target/r1cs/mod.rs b/src/target/r1cs/mod.rs index 3603377c4..8e4c7a70e 100644 --- a/src/target/r1cs/mod.rs +++ b/src/target/r1cs/mod.rs @@ -603,7 +603,7 @@ impl ProverData { } let var = self.r1cs.vars[next_var_i]; let name = self.r1cs.names.get(&var).unwrap().clone(); - let val = pf_challenge(&name, &self.r1cs.field); + let val = eval_pf_challenge(&name, &self.r1cs.field); var_values.insert(var, val.clone()); inputs.insert(name, Value::Field(val)); } @@ -1005,8 +1005,8 @@ impl R1cs { Default::default(), ) .filter_map(|t| { - if let Op::Var(n, s) = t.op() { - Some((n.clone(), s.clone())) + if let Op::Var(v) = t.op() { + Some((v.name.to_string(), v.sort.clone())) } else { None } @@ -1064,8 +1064,8 @@ impl R1cs { let vars: HashMap<String, Sort> = { PostOrderIter::new(precompute.tuple()) .filter_map(|t| { - if let Op::Var(n, s) = t.op() { - Some((n.clone(), s.clone())) + if let Op::Var(v) = t.op() { + Some((v.name.to_string(), v.sort.clone())) } else { None } @@ -1116,7 +1116,7 @@ impl R1cs { /// Get an IR term that represents this system. pub fn lc_ir_term(&self, lc: &Lc) -> Term { term(PF_ADD, - std::iter::once(pf_lit(lc.constant.clone())).chain(lc.monomials.iter().map(|(i, coeff)| term![PF_MUL; pf_lit(coeff.clone()), leaf_term(Op::Var(self.idx_to_sig.get_fwd(i).unwrap().into(), Sort::Field(self.modulus.clone())))])).collect()) + std::iter::once(pf_lit(lc.constant.clone())).chain(lc.monomials.iter().map(|(i, coeff)| term![PF_MUL; pf_lit(coeff.clone()), var(self.idx_to_sig.get_fwd(i).unwrap().into(), Sort::Field(self.modulus.clone()))])).collect()) } /// Get an IR term that represents this system. diff --git a/src/target/r1cs/opt.rs b/src/target/r1cs/opt.rs index f6dda8d56..aa24c5cbc 100644 --- a/src/target/r1cs/opt.rs +++ b/src/target/r1cs/opt.rs @@ -249,7 +249,7 @@ mod test { for v in &vars { let var = r1cs.add_var( v.clone(), - leaf_term(Op::Var(v.clone(), Sort::Field(field.clone()))), + var(v.clone(), Sort::Field(field.clone())), VarType::FinalWit, ); let val = field.random_v(&mut rng); diff --git a/src/target/r1cs/trans.rs b/src/target/r1cs/trans.rs index 3c4079cf5..43aec59df 100644 --- a/src/target/r1cs/trans.rs +++ b/src/target/r1cs/trans.rs @@ -137,7 +137,7 @@ impl<'cfg> ToR1cs<'cfg> { self.r1cs.add_committed_witness(elements.clone()); for (name, value) in elements { let lc = self.r1cs.signal_lc(&name); - let var = leaf_term(Op::Var(name, check(&value))); + let var = var(name, check(&value)); self.embed.borrow_mut().insert(var.clone()); self.cache .insert(var, EmbeddedTerm::Field(TermLc(value, lc))); @@ -357,30 +357,30 @@ impl<'cfg> ToR1cs<'cfg> { self.profile_start_term(var.clone()); let public = matches!(ty, VarType::Inst); match var.op() { - Op::Var(name, Sort::Bool) => { + Op::Var(v) if matches!(&v.sort, Sort::Bool) => { let comp = term![Op::Ite; var.clone(), self.one.0.clone(), self.zero.0.clone()]; - let lc = self.fresh_var(name, comp, ty); + let lc = self.fresh_var(&v.name, comp, ty); if !public { self.enforce_bit(lc.clone()); } self.cache.insert(var.clone(), EmbeddedTerm::Bool(lc)); self.embed.borrow_mut().insert(var.clone()); } - Op::Var(name, Sort::BitVector(n_bits)) => { + Op::Var(v) if v.sort.is_bv() => { let public = matches!(ty, VarType::Inst); let lc = self.fresh_var( - name, - term![Op::UbvToPf(self.field.clone()); var.clone()], + &v.name, + term![Op::new_ubv_to_pf(self.field.clone()); var.clone()], ty, ); - self.set_bv_uint(var.clone(), lc, *n_bits); + self.set_bv_uint(var.clone(), lc, v.sort.as_bv()); if !public { self.get_bv_bits(var); } } - Op::Var(name, Sort::Field(f)) => { - assert_eq!(f, &self.field); - let lc = self.fresh_var(name, var.clone(), ty); + Op::Var(v) if v.sort.is_pf() => { + assert_eq!(v.sort.as_pf(), &self.field); + let lc = self.fresh_var(&v.name, var.clone(), ty); self.cache.insert(var.clone(), EmbeddedTerm::Field(lc)); self.embed.borrow_mut().insert(var.clone()); } @@ -494,7 +494,7 @@ impl<'cfg> ToR1cs<'cfg> { if !self.cache.contains_key(&c) { let lc = match &c.op() { Op::Var(..) => panic!("call embed_var instead"), - Op::Const(Value::Bool(b)) => self.zero.clone() + *b as isize, + Op::Const(v) => self.zero.clone() + v.as_bool() as isize, Op::Eq => self.embed_eq(&c.cs()[0], &c.cs()[1]), Op::Ite => { let a = self.get_bool(&c.cs()[0]).clone(); @@ -708,7 +708,8 @@ impl<'cfg> ToR1cs<'cfg> { if !self.cache.contains_key(&bv) { match &bv.op() { Op::Var(..) => panic!("call embed_var instead"), - Op::Const(Value::BitVector(b)) => { + Op::Const(v) => { + let b = v.as_bv(); let bit_lcs = (0..b.width()) .map(|i| self.zero.clone() + b.uint().get_bit(i as u32) as isize) .collect(); @@ -839,8 +840,8 @@ impl<'cfg> ToR1cs<'cfg> { BvBinOp::Udiv | BvBinOp::Urem => { let a_bv_term = term![Op::PfToBv(n); a.0.clone()]; let b_bv_term = term![Op::PfToBv(n); b.0.clone()]; - let q_term = term![Op::UbvToPf(self.field.clone()); term![BV_UDIV; a_bv_term.clone(), b_bv_term.clone()]]; - let r_term = term![Op::UbvToPf(self.field.clone()); term![BV_UREM; a_bv_term, b_bv_term]]; + let q_term = term![Op::new_ubv_to_pf(self.field.clone()); term![BV_UDIV; a_bv_term.clone(), b_bv_term.clone()]]; + let r_term = term![Op::new_ubv_to_pf(self.field.clone()); term![BV_UREM; a_bv_term, b_bv_term]]; let q = self.fresh_wit("div_q", q_term); let r = self.fresh_wit("div_r", r_term); let qb = self.bitify("div_q", &q, n, false); @@ -899,8 +900,8 @@ impl<'cfg> ToR1cs<'cfg> { let bits = self .get_bv_bits(&bv.cs()[0]) .into_iter() - .skip(*low) - .take(*high - *low + 1) + .skip(*low as usize) + .take((*high - *low + 1) as usize) .collect(); self.set_bv_bits(bv, bits); } @@ -1014,9 +1015,9 @@ impl<'cfg> ToR1cs<'cfg> { debug!("embed_pf {}", c); let lc = match &c.op() { Op::Var(..) => panic!("call embed_var instead"), - Op::Const(Value::Field(r)) => TermLc( + Op::Const(v) => TermLc( c.clone(), - self.r1cs.constant(r.as_ty_ref(&self.r1cs.modulus)), + self.r1cs.constant(v.as_pf().as_ty_ref(&self.r1cs.modulus)), ), Op::Ite => { let cond = self.get_bool(&c.cs()[0]).clone(); @@ -1201,12 +1202,12 @@ pub mod test { .collect(); let cs = Computation::from_constraint_system_parts( vec![ - leaf_term(Op::Var("a".to_owned(), Sort::Bool)), - term![Op::Not; leaf_term(Op::Var("b".to_owned(), Sort::Bool))], + var("a".to_owned(), Sort::Bool), + term![Op::Not; var("b".to_owned(), Sort::Bool)], ], vec![ - leaf_term(Op::Var("a".to_owned(), Sort::Bool)), - leaf_term(Op::Var("b".to_owned(), Sort::Bool)), + var("a".to_owned(), Sort::Bool), + var("b".to_owned(), Sort::Bool), ], ); let r1cs = to_r1cs_mod17(cs); @@ -1228,7 +1229,7 @@ pub mod test { #[quickcheck] fn random_bool(ArbitraryTermEnv(t, values): ArbitraryTermEnv) { let v = eval(&t, &values); - let t = term![Op::Eq; t, leaf_term(Op::Const(v))]; + let t = term![Op::Eq; t, const_(v)]; let mut cs = Computation::from_constraint_system_parts(vec![t], Vec::new()); crate::ir::opt::scalarize_vars::scalarize_inputs(&mut cs); crate::ir::opt::tuple::eliminate_tuples(&mut cs); @@ -1240,7 +1241,7 @@ pub mod test { #[quickcheck] fn random_pure_bool_opt(ArbitraryBoolEnv(t, values): ArbitraryBoolEnv) { let v = eval(&t, &values); - let t = term![Op::Eq; t, leaf_term(Op::Const(v))]; + let t = term![Op::Eq; t, const_(v)]; let cs = Computation::from_constraint_system_parts(vec![t], Vec::new()); let cfg = CircCfg::default(); let r1cs = to_r1cs(&cs, &cfg); @@ -1252,7 +1253,7 @@ pub mod test { #[quickcheck] fn random_bool_opt(ArbitraryTermEnv(t, values): ArbitraryTermEnv) { let v = eval(&t, &values); - let t = term![Op::Eq; t, leaf_term(Op::Const(v))]; + let t = term![Op::Eq; t, const_(v)]; let mut cs = Computation::from_constraint_system_parts(vec![t], Vec::new()); crate::ir::opt::scalarize_vars::scalarize_inputs(&mut cs); crate::ir::opt::tuple::eliminate_tuples(&mut cs); @@ -1274,8 +1275,8 @@ pub mod test { let cs = Computation::from_constraint_system_parts( vec![term![Op::Not; term![Op::Eq; bv_lit(0b10110, 8), - term![Op::BvUnOp(BvUnOp::Neg); leaf_term(Op::Var("b".to_owned(), Sort::BitVector(8)))]]]], - vec![leaf_term(Op::Var("b".to_owned(), Sort::BitVector(8)))], + term![Op::BvUnOp(BvUnOp::Neg); var("b".to_owned(), Sort::BitVector(8))]]]], + vec![var("b".to_owned(), Sort::BitVector(8))], ); let r1cs = to_r1cs_dflt(cs); r1cs.check_all(&values); @@ -1284,12 +1285,12 @@ pub mod test { #[test] fn not_opt_test() { init(); - let t = term![Op::Not; leaf_term(Op::Var("b".to_owned(), Sort::Bool))]; + let t = term![Op::Not; var("b".to_owned(), Sort::Bool)]; let values: FxHashMap<String, Value> = vec![("b".to_owned(), Value::Bool(true))] .into_iter() .collect(); let v = eval(&t, &values); - let t = term![Op::Eq; t, leaf_term(Op::Const(v))]; + let t = term![Op::Eq; t, const_(v)]; let cs = Computation::from_constraint_system_parts(vec![t], vec![]); let cfg = CircCfg::default(); let r1cs = to_r1cs(&cs, &cfg); @@ -1299,7 +1300,7 @@ pub mod test { } fn pf_dflt(i: isize) -> Term { - leaf_term(Op::Const(Value::Field(CircCfg::default().field().new_v(i)))) + pf_lit(CircCfg::default().field().new_v(i)) } fn const_test(term: Term) { @@ -1419,12 +1420,12 @@ pub mod test { .collect(); let mut cs = Computation::from_constraint_system_parts( vec![ - term![Op::Field(0); term![Op::Tuple; leaf_term(Op::Var("a".to_owned(), Sort::Bool)), leaf_term(Op::Const(Value::Bool(false)))]], - term![Op::Not; leaf_term(Op::Var("b".to_owned(), Sort::Bool))], + term![Op::Field(0); term![Op::Tuple; var("a".to_owned(), Sort::Bool), bool_lit(false)]], + term![Op::Not; var("b".to_owned(), Sort::Bool)], ], vec![ - leaf_term(Op::Var("a".to_owned(), Sort::Bool)), - leaf_term(Op::Var("b".to_owned(), Sort::Bool)), + var("a".to_owned(), Sort::Bool), + var("b".to_owned(), Sort::Bool), ], ); crate::ir::opt::tuple::eliminate_tuples(&mut cs); diff --git a/src/target/r1cs/wit_comp.rs b/src/target/r1cs/wit_comp.rs index f52024a7c..b280788ae 100644 --- a/src/target/r1cs/wit_comp.rs +++ b/src/target/r1cs/wit_comp.rs @@ -60,8 +60,8 @@ impl StagedWitComp { fn add_step(&mut self, term: Term) { debug_assert!(!self.term_to_step.contains_key(&term)); let step_idx = self.steps.len(); - if let Op::Var(name, _) = term.op() { - debug_assert!(self.vars.contains(name)); + if let Op::Var(var) = term.op() { + debug_assert!(self.vars.contains(&*var.name)); } for child in term.cs() { let child_step = self.term_to_step.get(child).unwrap(); @@ -302,8 +302,8 @@ mod test { let field = FieldT::from(Integer::from(7)); comp.add_stage(mk_inputs(vec![("a".into(), Sort::Bool), ("b".into(), Sort::Field(field.clone()))]), vec![ - leaf_term(Op::Var("b".into(), Sort::Field(field.clone()))), - term![Op::Ite; leaf_term(Op::Var("a".into(), Sort::Bool)), pf_lit(field.new_v(1)), pf_lit(field.new_v(0))], + var("b".into(), Sort::Field(field.clone())), + term![Op::Ite; var("a".into(), Sort::Bool), pf_lit(field.new_v(1)), pf_lit(field.new_v(0))], ]); let mut evaluator = StagedWitCompEvaluator::new(&comp); @@ -331,16 +331,16 @@ mod test { let field = FieldT::from(Integer::from(7)); comp.add_stage(mk_inputs(vec![("a".into(), Sort::Bool), ("b".into(), Sort::Field(field.clone()))]), vec![ - leaf_term(Op::Var("b".into(), Sort::Field(field.clone()))), - term![Op::Ite; leaf_term(Op::Var("a".into(), Sort::Bool)), pf_lit(field.new_v(1)), pf_lit(field.new_v(0))], + var("b".into(), Sort::Field(field.clone())), + term![Op::Ite; var("a".into(), Sort::Bool), pf_lit(field.new_v(1)), pf_lit(field.new_v(0))], ]); comp.add_stage(mk_inputs(vec![("c".into(), Sort::Field(field.clone()))]), vec![ term![PF_ADD; - leaf_term(Op::Var("b".into(), Sort::Field(field.clone()))), - leaf_term(Op::Var("c".into(), Sort::Field(field.clone())))], - term![Op::Ite; leaf_term(Op::Var("a".into(), Sort::Bool)), pf_lit(field.new_v(1)), pf_lit(field.new_v(0))], - term![Op::Ite; leaf_term(Op::Var("a".into(), Sort::Bool)), pf_lit(field.new_v(0)), pf_lit(field.new_v(1))], + var("b".into(), Sort::Field(field.clone())), + var("c".into(), Sort::Field(field.clone()))], + term![Op::Ite; var("a".into(), Sort::Bool), pf_lit(field.new_v(1)), pf_lit(field.new_v(0))], + term![Op::Ite; var("a".into(), Sort::Bool), pf_lit(field.new_v(0)), pf_lit(field.new_v(1))], ]); let mut evaluator = StagedWitCompEvaluator::new(&comp); diff --git a/src/target/smt/mod.rs b/src/target/smt/mod.rs index 86566aa2a..579d5cbce 100644 --- a/src/target/smt/mod.rs +++ b/src/target/smt/mod.rs @@ -81,8 +81,8 @@ impl Expr2Smt<()> for Value { for _ in 0..map.len() { write!(w, "(store ")?; } - let val_s = check(&leaf_term(Op::Const((**default).clone()))); - let s = Sort::Array(Box::new(key_sort.clone()), Box::new(val_s), *size); + let val_s = check(&const_((**default).clone())); + let s = Sort::new_array(key_sort.clone(), val_s, *size); write!( w, "((as const {}) {})", @@ -109,8 +109,8 @@ impl Expr2Smt<()> for Value { impl Expr2Smt<()> for Term { fn expr_to_smt2<W: Write>(&self, w: &mut W, (): ()) -> SmtRes<()> { let s_expr_children = match &self.op() { - Op::Var(n, _) => { - write!(w, "{n}")?; + Op::Var(v) => { + write!(w, "{}", v.name)?; false } Op::Eq => { @@ -138,7 +138,7 @@ impl Expr2Smt<()> for Term { true } Op::Const(c) => { - write!(w, "{}", SmtDisp(c))?; + write!(w, "{}", SmtDisp(&**c))?; false } Op::Store => { @@ -197,8 +197,8 @@ impl Sort2Smt for Sort { fn sort_to_smt2<W: Write>(&self, w: &mut W) -> SmtRes<()> { match self { Sort::BitVector(b) => write!(w, "(_ BitVec {b})")?, - Sort::Array(k, v, _size) => { - write!(w, "(Array {} {})", SmtSortDisp(&**k), SmtSortDisp(&**v))?; + Sort::Array(a) => { + write!(w, "(Array {} {})", SmtSortDisp(&a.key), SmtSortDisp(&a.val))?; } Sort::F64 => write!(w, "Float64")?, Sort::F32 => write!(w, "Float32")?, @@ -228,9 +228,9 @@ impl Expr2Smt<()> for BitVector { } } -struct SmtSymDisp<'a, T>(pub &'a T); +struct SmtSymDisp<'a, T: ?Sized>(pub &'a T); -impl<'a, T: Display + 'a> Sym2Smt<()> for SmtSymDisp<'a, T> { +impl<'a, T: Display + 'a + ?Sized> Sym2Smt<()> for SmtSymDisp<'a, T> { fn sym_to_smt2<W: Write>(&self, w: &mut W, (): ()) -> SmtRes<()> { write!(w, "{}", self.0)?; Ok(()) @@ -345,11 +345,11 @@ fn make_solver<P>(parser: P, models: bool, inc: bool) -> rsmt2::Solver<P> { /// Write SMT2 the encodes this terms satisfiability to a file pub fn write_smt2<W: Write>(mut w: W, t: &Term) { for c in PostOrderIter::new(t.clone()) { - if let Op::Var(n, s) = &c.op() { + if let Op::Var(v) = &c.op() { write!(w, "(declare-const ").unwrap(); - SmtSymDisp(n).sym_to_smt2(&mut w, ()).unwrap(); + SmtSymDisp(&*v.name).sym_to_smt2(&mut w, ()).unwrap(); write!(w, " ").unwrap(); - s.sort_to_smt2(&mut w).unwrap(); + v.sort.sort_to_smt2(&mut w).unwrap(); writeln!(w, ")").unwrap(); } } @@ -364,8 +364,10 @@ pub fn write_smt2<W: Write>(mut w: W, t: &Term) { pub fn check_sat(t: &Term) -> bool { let mut solver = make_solver((), false, false); for c in PostOrderIter::new(t.clone()) { - if let Op::Var(n, s) = &c.op() { - solver.declare_const(&SmtSymDisp(n), s).unwrap(); + if let Op::Var(v) = &c.op() { + solver + .declare_const(&SmtSymDisp(&*v.name), &v.sort) + .unwrap(); } } assert!(check(t) == Sort::Bool); @@ -377,8 +379,10 @@ fn get_model_solver(t: &Term, inc: bool) -> rsmt2::Solver<Parser> { let mut solver = make_solver(Parser, true, inc); //solver.path_tee("solver_com").unwrap(); for c in PostOrderIter::new(t.clone()) { - if let Op::Var(n, s) = &c.op() { - solver.declare_const(&SmtSymDisp(n), s).unwrap(); + if let Op::Var(v) = &c.op() { + solver + .declare_const(&SmtSymDisp(&*v.name), &v.sort) + .unwrap(); } } assert!(check(t) == Sort::Bool); @@ -424,7 +428,7 @@ pub fn find_unique_model(t: &Term, uniqs: Vec<String>) -> Option<HashMap<String, .flat_map(|n| { model .get(&n) - .map(|v| term![EQ; term![Op::Var(n, v.sort())], term![Op::Const(v.clone())]]) + .map(|v| term![EQ; term![Op::new_var(n, v.sort())], const_(v.clone())]) }) .reduce(|l, r| term![AND; l, r]) .map(|t| term![NOT; t]) @@ -451,13 +455,13 @@ mod test { #[test] fn var_is_sat() { - let t = leaf_term(Op::Var("a".into(), Sort::Bool)); + let t = var("a".into(), Sort::Bool); assert!(check_sat(&t)); } #[test] fn var_is_sat_model() { - let t = leaf_term(Op::Var("a".into(), Sort::Bool)); + let t = var("a".into(), Sort::Bool); assert!( find_model(&t) == Some( @@ -470,14 +474,14 @@ mod test { #[test] fn var_and_not_is_unsat() { - let v = leaf_term(Op::Var("a".into(), Sort::Bool)); + let v = var("a".into(), Sort::Bool); let t = term![Op::BoolNaryOp(BoolNaryOp::And); v.clone(), term![Op::Not; v]]; assert!(!check_sat(&t)); } #[test] fn bv_is_sat() { - let t = term![Op::Eq; bv_lit(0,4), leaf_term(Op::Var("a".into(), Sort::BitVector(4)))]; + let t = term![Op::Eq; bv_lit(0,4), var("a".into(), Sort::BitVector(4))]; assert!(check_sat(&t)); } @@ -532,15 +536,15 @@ mod test { #[test] fn tuple_is_sat() { - let t = term![Op::Eq; term![Op::Field(0); term![Op::Tuple; bv_lit(0,4), bv_lit(5,6)]], leaf_term(Op::Var("a".into(), Sort::BitVector(4)))]; + let t = term![Op::Eq; term![Op::Field(0); term![Op::Tuple; bv_lit(0,4), bv_lit(5,6)]], var("a".into(), Sort::BitVector(4))]; assert!(check_sat(&t)); - let t = term![Op::Eq; term![Op::Tuple; bv_lit(0,4), bv_lit(5,6)], leaf_term(Op::Var("a".into(), Sort::Tuple(vec![Sort::BitVector(4), Sort::BitVector(6)].into_boxed_slice())))]; + let t = term![Op::Eq; term![Op::Tuple; bv_lit(0,4), bv_lit(5,6)], var("a".into(), Sort::Tuple(vec![Sort::BitVector(4), Sort::BitVector(6)].into_boxed_slice()))]; assert!(check_sat(&t)); } #[test] fn bv_is_sat_model() { - let t = term![Op::Eq; bv_lit(0,4), leaf_term(Op::Var("a".into(), Sort::BitVector(4)))]; + let t = term![Op::Eq; bv_lit(0,4), var("a".into(), Sort::BitVector(4))]; assert!( find_model(&t) == Some( @@ -557,9 +561,9 @@ mod test { #[test] fn vars_are_sat_model() { let t = term![Op::BoolNaryOp(BoolNaryOp::And); - leaf_term(Op::Var("a".into(), Sort::Bool)), - leaf_term(Op::Var("b".into(), Sort::Bool)), - leaf_term(Op::Var("c".into(), Sort::Bool)) + var("a".into(), Sort::Bool), + var("b".into(), Sort::Bool), + var("c".into(), Sort::Bool) ]; assert!( find_model(&t) @@ -584,29 +588,31 @@ mod test { /// Check that `t` evaluates consistently within the SMT solver under `vs`. pub fn smt_eval_test(t: Term, vs: &HashMap<String, Value>) -> bool { let mut solver = make_solver((), false, false); - for (var, val) in vs { + for (v, val) in vs { let s = val.sort(); - solver.declare_const(&SmtSymDisp(&var), &s).unwrap(); - solver.assert(&term![Op::Eq; leaf_term(Op::Var(var.to_owned(), s)), leaf_term(Op::Const(val.clone()))]).unwrap(); + solver.declare_const(&SmtSymDisp(&v), &s).unwrap(); + solver + .assert(&term![Op::Eq; var(v.to_string(), s), const_(val.clone())]) + .unwrap(); } let val = eval(&t, vs); - solver - .assert(&term![Op::Eq; t, leaf_term(Op::Const(val))]) - .unwrap(); + solver.assert(&term![Op::Eq; t, const_(val)]).unwrap(); solver.check_sat().unwrap() } /// Check that `t` evaluates consistently within the SMT solver under `vs`. pub fn smt_eval_alternate_solution(t: Term, vs: &HashMap<String, Value>) -> bool { let mut solver = make_solver((), false, false); - for (var, val) in vs { + for (v, val) in vs { let s = val.sort(); - solver.declare_const(&SmtSymDisp(&var), &s).unwrap(); - solver.assert(&term![Op::Eq; leaf_term(Op::Var(var.to_owned(), s)), leaf_term(Op::Const(val.clone()))]).unwrap(); + solver.declare_const(&SmtSymDisp(&v), &s).unwrap(); + solver + .assert(&term![Op::Eq; var(v.to_string(), s), const_(val.clone())]) + .unwrap(); } let val = eval(&t, vs); solver - .assert(&term![Op::Not; term![Op::Eq; t, leaf_term(Op::Const(val))]]) + .assert(&term![Op::Not; term![Op::Eq; t, const_(val)]]) .unwrap(); solver.check_sat().unwrap() }