From ac57ced23fa580c8fe37885aa0139a312fba8983 Mon Sep 17 00:00:00 2001 From: Hendrik Bierlee Date: Thu, 16 May 2024 14:48:47 +1000 Subject: [PATCH] Fix more tests by handling missing cases --- crates/pindakaas/src/helpers.rs | 23 ++++- crates/pindakaas/src/int/bin.rs | 52 +++++++++- crates/pindakaas/src/int/con.rs | 25 ++++- crates/pindakaas/src/int/decompose.rs | 75 +++++++++++--- crates/pindakaas/src/int/model.rs | 142 +++++++++++++++++--------- crates/pindakaas/src/int/term.rs | 22 +++- crates/pindakaas/src/linear/adder.rs | 21 +--- 7 files changed, 269 insertions(+), 91 deletions(-) diff --git a/crates/pindakaas/src/helpers.rs b/crates/pindakaas/src/helpers.rs index a5f197d2b8..a598f6b3ac 100644 --- a/crates/pindakaas/src/helpers.rs +++ b/crates/pindakaas/src/helpers.rs @@ -9,8 +9,8 @@ use std::{ use itertools::Itertools; use crate::{ - linear::PosCoeff, trace::emit_clause, CheckError, Checker, ClauseDatabase, Coeff, Encoder, - LinExp, Lit, Result, Unsatisfiable, Valuation, Var, + int::LitOrConst, linear::PosCoeff, trace::emit_clause, CheckError, Checker, ClauseDatabase, + Coeff, Encoder, LinExp, Lit, Result, Unsatisfiable, Valuation, Var, }; #[allow(unused_macros)] @@ -154,6 +154,25 @@ pub(crate) fn add_clauses_for( Ok(()) } +pub(crate) fn emit_filtered_clause>( + db: &mut DB, + lits: I, +) -> Result { + if let Ok(clause) = lits + .into_iter() + .filter_map(|lit| match lit { + LitOrConst::Lit(lit) => Some(Ok(lit)), + LitOrConst::Const(true) => Some(Err(())), // clause satisfied + LitOrConst::Const(false) => None, // literal falsified + }) + .collect::, ()>>() + { + emit_clause!(db, clause) + } else { + Ok(()) + } +} + /// Negates CNF (flipping between empty clause and formula) pub(crate) fn negate_cnf(clauses: Vec>) -> Vec> { if clauses.is_empty() { diff --git a/crates/pindakaas/src/int/bin.rs b/crates/pindakaas/src/int/bin.rs index 47f5a02501..397084a9ba 100644 --- a/crates/pindakaas/src/int/bin.rs +++ b/crates/pindakaas/src/int/bin.rs @@ -8,7 +8,9 @@ use itertools::Itertools; use super::{required_lits, Dom, LitOrConst}; use crate::{ - helpers::{add_clauses_for, as_binary, negate_cnf, pow2, unsigned_binary_range}, + helpers::{ + add_clauses_for, as_binary, emit_filtered_clause, negate_cnf, pow2, unsigned_binary_range, + }, int::{helpers::remove_red, model::PRINT_COUPLING}, linear::{lex_geq_const, lex_leq_const, PosCoeff}, trace::{emit_clause, new_var}, @@ -42,6 +44,54 @@ impl BinEnc { } } + /// Encode x:B <=/>= y:B + pub(crate) fn lex( + &self, + db: &mut DB, + cmp: Comparator, + other: Self, + ) -> crate::Result { + let n = std::cmp::max(self.bits(), other.bits()) as usize; + + fn bit(x: &[LitOrConst], i: usize) -> LitOrConst { + *x.get(i).unwrap_or(&LitOrConst::Const(false)) + } + + let (x, y, c) = ( + &self.xs(), + &other.xs(), + &(0..n) + .map(|_i| LitOrConst::Lit(new_var!(db, crate::trace::subscripted_name("c", _i)))) + .chain(std::iter::once(LitOrConst::Const(true))) + .collect_vec(), + ); + + // higher i -> more significant + for i in 0..n { + // c = all more significant bits are equal AND current one is + // if up to i is equal, all preceding must be equal + emit_filtered_clause(db, [!bit(c, i), bit(c, i + 1)])?; + // if up to i is equal, x<->y + emit_filtered_clause(db, [!bit(c, i), !bit(x, i), bit(y, i)])?; + emit_filtered_clause(db, [!bit(c, i), !bit(y, i), bit(x, i)])?; + + // if not up to i is equal, either preceding bit was not equal, or x!=y + emit_filtered_clause(db, [bit(c, i), !bit(c, i + 1), bit(x, i), bit(y, i)])?; + emit_filtered_clause(db, [bit(c, i), !bit(c, i + 1), !bit(x, i), !bit(y, i)])?; + + // if preceding bits are equal, then x<=y (or x>=y) + match cmp { + Comparator::LessEq => { + emit_filtered_clause(db, [!bit(c, i + 1), !bit(x, i), bit(y, i)]) + } + Comparator::GreaterEq => { + emit_filtered_clause(db, [!bit(c, i + 1), bit(x, i), !bit(y, i)]) + } + Comparator::Equal => unreachable!(), + }?; + } + Ok(()) + } /// Returns conjunction for x>=k, given x>=b pub(crate) fn geqs(&self, k: Coeff, a: Coeff) -> Vec> { let (range_lb, range_ub) = self.range(); diff --git a/crates/pindakaas/src/int/con.rs b/crates/pindakaas/src/int/con.rs index 2c6f5d7357..f9f90da0a0 100644 --- a/crates/pindakaas/src/int/con.rs +++ b/crates/pindakaas/src/int/con.rs @@ -31,6 +31,7 @@ pub(crate) enum LinCase { Couple(Term, Term), Fixed(Lin), Unary(Term, Comparator, Coeff), + Binary(Term, Comparator, Term), // just for binary ineqs Scm(Term, IntVarRef), Rca(Term, Term, Term), Order, @@ -57,6 +58,11 @@ impl TryFrom<&Lin> for LinCase { { LinCase::Unary((*t).clone().encode_bin(None, cmp, None)?, cmp, con.k) } + ( + [(x, Some(IntVarEnc::Bin(_))), (y, Some(IntVarEnc::Bin(_)))], + Comparator::LessEq | Comparator::GreaterEq, + 0, + ) => LinCase::Binary((*x).clone(), con.cmp, (*y).clone()), // VIEW COUPLING // TODO this makes single literal comparisons views if possible // ([(t, Some(IntVarEnc::Ord(_))), (y, Some(IntVarEnc::Bin(None)))], _) @@ -83,7 +89,14 @@ impl TryFrom<&Lin> for LinCase { { LinCase::Couple((*t).clone(), (*y).clone()) } - + // ([(x, Some(IntVarEnc::Bin(_)))], Comparator::Equal, k) => { + // LinCase::Rca((*x).clone(), Term::from(0), Term::from(k)) + // } + ( + [(x, Some(IntVarEnc::Bin(_))), (y, Some(IntVarEnc::Bin(_)))], + Comparator::Equal, + k, + ) => LinCase::Rca((*x).clone(), (*y).clone(), Term::from(-k)), ( [(x, Some(IntVarEnc::Bin(_))), (y, Some(IntVarEnc::Bin(_))), (z, Some(IntVarEnc::Bin(_)))], Comparator::Equal, @@ -344,6 +357,16 @@ impl Lin { let x_enc = x.clone().borrow_mut().encode_bin(db)?; x_enc.encode_unary_constraint(db, &cmp, k, &dom, false) } + LinCase::Binary(t_x, cmp, t_y) => { + println!("self = {}", self); + + t_x.x.borrow_mut().encode_bin(db)?; + t_y.x.borrow_mut().encode_bin(db)?; + + let x_enc = t_x.x.borrow_mut().encode_bin(db)?; + let y_enc = (t_y * -1).x.borrow_mut().encode_bin(db)?; + x_enc.lex(db, cmp, y_enc) + } LinCase::Couple(t_x, t_y) => { t_x.x.borrow_mut().encode_ord(db)?; if !t_x.x.borrow().add_consistency { diff --git a/crates/pindakaas/src/int/decompose.rs b/crates/pindakaas/src/int/decompose.rs index 8cc84b0bbb..519e4d8942 100644 --- a/crates/pindakaas/src/int/decompose.rs +++ b/crates/pindakaas/src/int/decompose.rs @@ -16,15 +16,14 @@ pub trait Decompose { pub struct EqualizeTernsDecomposer {} impl Decompose for EqualizeTernsDecomposer { - fn decompose(&self, model: Model) -> Result { - const REMOVE_GAPS: bool = true; - + fn decompose(&self, mut model: Model) -> Result { let cons = model.cons.iter().cloned().collect_vec(); Ok(Model { cons: cons .into_iter() - .map(|con| { - if REMOVE_GAPS && con.exp.terms.len() >= 2 && con.cmp.is_ineq() { + .with_position() + .flat_map(|(pos, con)| { + if con.exp.terms.len() >= 2 && con.cmp.is_ineq() { if con .exp .terms @@ -40,15 +39,51 @@ impl Decompose for EqualizeTernsDecomposer { Comparator::GreaterEq => (std::cmp::max(-last.ub(), lb), ub), Comparator::Equal => unreachable!(), }; + let dom = Dom::from_bounds(lb, ub); + if matches!(pos, Position::First | Position::Middle) { + last.x.borrow_mut().dom = dom; - last.x.borrow_mut().dom = Dom::from_bounds(lb, ub); + vec![Lin { + exp: LinExp { + terms: firsts.iter().chain([last]).cloned().collect(), + }, + cmp: Comparator::Equal, + ..con + }] + } else if con.exp.terms.len() >= 3 { + // x+y<=z == x+y=z' /\ z' <= z + let y = model + .new_aux_var( + dom, + true, + Some(IntVarEnc::Bin(None)), + Some(String::from("last")), + ) + .unwrap(); - Lin { - exp: LinExp { - terms: firsts.iter().chain([last]).cloned().collect(), - }, - cmp: Comparator::Equal, - ..con + vec![ + Lin { + exp: LinExp { + terms: firsts + .iter() + .chain([&Term::new(-1, y.clone())]) + .cloned() + .collect(), + }, + cmp: Comparator::Equal, + k: 0, + lbl: Some(String::from("last")), + }, + Lin { + exp: LinExp { + terms: vec![Term::from(y), last.clone()], + }, + cmp: con.cmp, + ..con + }, + ] + } else { + vec![con] } } else { unreachable!() @@ -67,12 +102,22 @@ impl Decompose for EqualizeTernsDecomposer { ) { con.exp.terms[0].x.borrow_mut().dom = con.exp.terms[1].x.borrow().dom.clone(); - con + vec![con] } else { - con + vec![con] } + // } else if con.exp.terms.len() == 2 && con.cmp == Comparator::Equal && false { + // let z = con.exp.terms[0] + // .clone() + // .add(con.exp.terms[1].clone(), &mut model) + // .unwrap(); + // vec![Lin { + // exp: LinExp { terms: vec![z] }, + // cmp: con.cmp, + // ..con + // }] } else { - con + vec![con] } }) .collect(), diff --git a/crates/pindakaas/src/int/model.rs b/crates/pindakaas/src/int/model.rs index ba3f184c79..4457ecb3e6 100644 --- a/crates/pindakaas/src/int/model.rs +++ b/crates/pindakaas/src/int/model.rs @@ -715,10 +715,10 @@ mod tests { // Decomposer::Rca ], [Consistency::None], - [false], // consistency - [true], // equalize terns - [None, Some(0)], // cutoffs: [None, Some(0), Some(2)] - [false] // equalize_uniform_bin_ineqs + [false], // consistency + // [true], // equalize terns + [Some(0)], // cutoffs: [None, Some(0), Some(2)] + [false] // equalize_uniform_bin_ineqs ) .map( |( @@ -726,7 +726,7 @@ mod tests { decomposer, propagate, add_consistency, - equalize_ternaries, + // equalize_ternaries, cutoff, equalize_uniform_bin_ineqs, )| { @@ -735,7 +735,7 @@ mod tests { decomposer: decomposer.clone(), propagate: propagate.clone(), add_consistency, - equalize_ternaries, + equalize_ternaries: cutoff == Some(0), cutoff, equalize_uniform_bin_ineqs, ..ModelConfig::default() @@ -912,7 +912,11 @@ mod tests { .clone(), )] } else { - var_encs_gen.into_iter().enumerate().collect_vec() + if var_encs_gen.is_empty() { + vec![(0, HashMap::default())] + } else { + var_encs_gen.into_iter().enumerate().collect_vec() + } } } { let spec = if VAR_ENCS.is_empty() { @@ -1147,21 +1151,21 @@ End ); } - // #[test] - // fn test_lp_le_double_w_const() { - // test_lp_for_configs( - // r" - // Subject To - // c0: + 2 x1 + 3 x2 - 1 x3 <= 0 - // bounds - // 0 <= x1 <= 1 - // 0 <= x2 <= 1 - // 4 <= x3 <= 4 - // End - // ", - // None, - // ); - // } + #[test] + fn test_lp_le_double_w_const() { + test_lp_for_configs( + r" + Subject To + c0: + 2 x1 + 3 x2 - 1 x3 <= 0 + bounds + 0 <= x1 <= 1 + 0 <= x2 <= 1 + 4 <= x3 <= 4 + End + ", + None, + ); + } #[test] fn test_int_lin_ge_single() { @@ -1177,6 +1181,42 @@ End ); } + #[test] + fn test_int_lin_binary_constraint_le() { + test_lp_for_configs( + r" +Subject To +c0: + 1 x1 - 1 x2 <= 0 +Bounds +0 <= x1 <= 3 +0 <= x2 <= 3 +Encs + x1 B + x2 B +End +", + None, + ); + } + + #[test] + fn test_int_lin_binary_constraint_ge() { + test_lp_for_configs( + r" +Subject To +c0: + 1 x1 - 1 x2 >= 0 +Bounds +0 <= x1 <= 3 +0 <= x2 <= 3 +Encs + x1 B + x2 B +End +", + None, + ); + } + #[test] fn test_int_lin_le_1() { test_lp_for_configs( @@ -1294,35 +1334,35 @@ End ); } - // #[test] - // fn test_int_lin_eq_tmp() { - // test_lp_for_configs( - // r" - // Subject To - // c0: + 1 x1 - 1 x2 <= 0 - // Bounds - // 0 <= x1 <= 3 - // 0 <= x2 <= 3 - // End - // ", - // None, - // ); - // } + #[test] + fn test_int_lin_eq_tmp() { + test_lp_for_configs( + r" + Subject To + c0: + 1 x1 - 1 x2 <= 0 + Bounds + 0 <= x1 <= 3 + 0 <= x2 <= 3 + End + ", + None, + ); + } - // #[test] - // fn test_int_lin_eq_3() { - // test_lp_for_configs( - // r" - // Subject To - // c0: + 1 x1 + 1 x2 = 2 - // Bounds - // 0 <= x1 <= 1 - // 0 <= x2 <= 1 - // End - // ", - // None, - // ); - // } + #[test] + fn test_int_lin_eq_3() { + test_lp_for_configs( + r" + Subject To + c0: + 1 x1 + 1 x2 = 2 + Bounds + 0 <= x1 <= 1 + 0 <= x2 <= 1 + End + ", + None, + ); + } #[test] fn test_int_lin_ge_1() { @@ -1386,7 +1426,7 @@ End ); } - // #[test] + #[test] fn _test_lp_ge_neg() { test_lp_for_configs( r" diff --git a/crates/pindakaas/src/int/term.rs b/crates/pindakaas/src/int/term.rs index c8b2af126a..b5dd39f335 100644 --- a/crates/pindakaas/src/int/term.rs +++ b/crates/pindakaas/src/int/term.rs @@ -15,7 +15,7 @@ use crate::{ Cse, LitOrConst, }, linear::PosCoeff, - Coeff, Comparator, IntLinExp as LinExp, IntVar, IntVarRef, Lin, Lit, Model, Scm, + Coeff, Comparator, IntLinExp as LinExp, IntVar, IntVarRef, Lin, Lit, Model, Scm, Unsatisfiable, }; /// A linear term (constant times integer variable) @@ -459,6 +459,26 @@ impl Term { pub(crate) fn size(&self) -> Coeff { self.x.borrow().size() } + + pub(crate) fn _add(&self, other: Self, model: &mut Model) -> Result { + let (x, y) = (self, other); + let z = Term::from(model.new_aux_var( + Dom::from_bounds(x.lb() + y.lb(), x.ub() + y.ub()), + false, + Some(IntVarEnc::Bin(None)), + None, + )?); + + model.add_constraint(Lin { + exp: LinExp { + terms: vec![x.clone(), y.clone(), z.clone() * -1], + }, + cmp: Comparator::Equal, + k: 0, + lbl: None, + })?; + Ok(z) + } } #[cfg(test)] diff --git a/crates/pindakaas/src/linear/adder.rs b/crates/pindakaas/src/linear/adder.rs index 410a8f94b2..b2fe79fceb 100644 --- a/crates/pindakaas/src/linear/adder.rs +++ b/crates/pindakaas/src/linear/adder.rs @@ -3,7 +3,7 @@ use rustc_hash::FxHashMap; use super::PosCoeff; use crate::{ - helpers::{as_binary, XorConstraint, XorEncoder}, + helpers::{as_binary, emit_filtered_clause, XorConstraint, XorEncoder}, int::LitOrConst, linear::LimitComp, trace::{emit_clause, new_var}, @@ -350,25 +350,6 @@ fn bit(x: &[LitOrConst], i: usize) -> LitOrConst { *x.get(i).unwrap_or(&LitOrConst::Const(false)) } -fn emit_filtered_clause>( - db: &mut DB, - lits: I, -) -> Result { - if let Ok(clause) = lits - .into_iter() - .filter_map(|lit| match lit { - LitOrConst::Lit(lit) => Some(Ok(lit)), - LitOrConst::Const(true) => Some(Err(())), // clause satisfied - LitOrConst::Const(false) => None, // literal falsified - }) - .collect::, ()>>() - { - emit_clause!(db, clause) - } else { - Ok(()) - } -} - /// Encode the adder sum circuit /// /// This function accepts either 2 literals as `input` (half adder) or 3