Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow solvers to output specific objects on solution and failure #47

Merged
merged 3 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 100 additions & 30 deletions crates/pindakaas-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,15 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
};

let assumptions = if opts.assumptions {
let fail_ident = format_ident!("{}Failed", ident);
quote! {
impl crate::solver::SolveAssuming for #ident {
type FailFn = #fail_ident;

fn solve_assuming<
I: IntoIterator<Item = crate::Lit>,
SolCb: FnOnce(&dyn crate::Valuation),
FailCb: FnOnce(&crate::solver::FailFn<'_>),
SolCb: FnOnce(&Self::ValueFn),
FailCb: FnOnce(&Self::FailFn),
>(
&mut self,
assumptions: I,
Expand All @@ -60,18 +63,25 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
}
match self.solve(on_sol) {
crate::solver::SolveResult::Unsat => {
let fail_fn = |lit: crate::Lit| {
let lit: i32 = lit.into();
let failed = unsafe { #krate::ipasir_failed(#ptr, lit) };
failed != 0
};
let fail_fn = #fail_ident { ptr: #ptr };
on_fail(&fail_fn);
crate::solver::SolveResult::Unsat
}
r => r,
}
}
}

pub struct #fail_ident {
ptr: *mut std::ffi::c_void,
}
impl crate::solver::FailedAssumtions for #fail_ident {
fn fail(&self, lit: crate::Lit) -> bool {
let lit: i32 = lit.into();
let failed = unsafe { #krate::ipasir_failed(#ptr, lit) };
failed != 0
}
}
}
} else {
quote!()
Expand Down Expand Up @@ -142,6 +152,7 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
quote!()
};

let sol_ident = format_ident!("{}Sol", ident);
let ipasir_up = if opts.ipasir_up {
let prop_ident = format_ident!("{}Prop", ident);
let prop_member = match opts.prop {
Expand Down Expand Up @@ -215,13 +226,6 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
}
}

fn propagator<P: crate::solver::Propagator + 'static>(&self) -> Option<&P> {
#prop_member.as_ref().map(|p| p.prop.prop.as_any().downcast_ref()).flatten()
}
fn propagator_mut<P: crate::solver::Propagator + 'static>(&mut self) -> Option<&mut P> {
#prop_member.as_mut().map(|p| p.prop.prop.as_mut_any().downcast_mut()).flatten()
}

fn add_observed_var(&mut self, var: crate::Var){
unsafe { #krate::ipasir_add_observed_var( #ptr, var.0.get()) };
}
Expand All @@ -233,6 +237,20 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
}
}

#[cfg(feature = "ipasir-up")]
impl crate::solver::PropagatorAccess for #ident {
fn propagator<P: crate::solver::Propagator + 'static>(&self) -> Option<&P> {
#prop_member.as_ref().map(|p| p.prop.prop.as_any().downcast_ref()).flatten()
}
}

#[cfg(feature = "ipasir-up")]
impl crate::solver::MutPropagatorAccess for #ident {
fn propagator_mut<P: crate::solver::Propagator + 'static>(&mut self) -> Option<&mut P> {
#prop_member.as_mut().map(|p| p.prop.prop.as_mut_any().downcast_mut()).flatten()
}
}

#[cfg(feature = "ipasir-up")]
impl crate::solver::SolvingActions for #ident {
fn new_var(&mut self) -> crate::Var {
Expand All @@ -245,9 +263,55 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
unsafe { #krate::ipasir_is_decision( #ptr, lit.0.get() ) }
}
}

pub struct #sol_ident {
ptr: *mut std::ffi::c_void,
#[cfg(feature = "ipasir-up")]
prop: Option<*mut std::ffi::c_void>,
}
impl #ident {
fn solver_solution_obj(&mut self) -> #sol_ident {
#sol_ident {
ptr: self.ptr,
#[cfg(feature = "ipasir-up")]
prop: if let Some(p) = &mut #prop_member { Some((&mut (p.prop)) as *mut _ as *mut std::ffi::c_void) } else { None },
}
}
}
#[cfg(feature = "ipasir-up")]
impl crate::solver::PropagatorAccess for #sol_ident {
fn propagator<P: crate::solver::Propagator + 'static>(&self) -> Option<&P> {
if let Some(prop) = self.prop {
let prop = unsafe { &*(prop as *const crate::solver::libloading::IpasirPropStore) };
prop.prop.as_any().downcast_ref()
} else {
None
}
}
}
#[cfg(feature = "ipasir-up")]
impl crate::solver::MutPropagatorAccess for #sol_ident {
fn propagator_mut<P: crate::solver::Propagator + 'static>(&mut self) -> Option<&mut P> {
if let Some(prop) = self.prop {
let prop = unsafe { &mut *(prop as *mut crate::solver::libloading::IpasirPropStore) };
prop.prop.as_mut_any().downcast_mut()
} else {
None
}
}
}
}
} else {
quote!()
quote! {
pub struct #sol_ident {
ptr: *mut std::ffi::c_void,
}
impl #ident {
fn solver_solution_obj(&self) -> #sol_ident {
#sol_ident { ptr: self.ptr }
}
}
}
};

let from_cnf = if opts.has_default {
Expand Down Expand Up @@ -306,34 +370,24 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
}

impl crate::solver::Solver for #ident {
type ValueFn = #sol_ident;

fn signature(&self) -> &str {
unsafe { std::ffi::CStr::from_ptr(#krate::ipasir_signature()) }
.to_str()
.unwrap()
}

fn solve<SolCb: FnOnce(&dyn crate::Valuation)>(
fn solve<SolCb: FnOnce(&Self::ValueFn)>(
&mut self,
on_sol: SolCb,
) -> crate::solver::SolveResult {
let res = unsafe { #krate::ipasir_solve( #ptr ) };
match res {
10 => {
// 10 -> Sat
let val_fn = |lit: crate::Lit| {
let var: i32 = lit.var().into();
// WARN: Always ask about variable (positive) literal, otherwise solvers sometimes seem incorrect
let ret = unsafe { #krate::ipasir_val( #ptr , var) };
match ret {
_ if ret == var => Some(!lit.is_negated()),
_ if ret == -var => Some(lit.is_negated()),
_ => {
debug_assert_eq!(ret, 0); // zero according to spec, both value are valid
None
}
}
};
on_sol(&val_fn);
let model = self.solver_solution_obj();
on_sol(&model);
crate::solver::SolveResult::Sat
}
20 => crate::solver::SolveResult::Unsat, // 20 -> Unsat
Expand All @@ -345,6 +399,22 @@ pub fn ipasir_solver_derive(input: TokenStream) -> TokenStream {
}
}

impl crate::Valuation for #sol_ident {
fn value(&self, lit: crate::Lit) -> Option<bool> {
let var: i32 = lit.var().into();
// WARN: Always ask about variable (positive) literal, otherwise solvers sometimes seem incorrect
let ret = unsafe { #krate::ipasir_val(self.ptr, var) };
match ret {
_ if ret == var => Some(!lit.is_negated()),
_ if ret == -var => Some(lit.is_negated()),
_ => {
debug_assert_eq!(ret, 0); // zero according to spec, both value are valid
None
}
}
}
}

#from_cnf
#assumptions
#term_callback
Expand Down
2 changes: 1 addition & 1 deletion crates/pindakaas/src/cardinality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl Cardinality {
}

impl Checker for Cardinality {
fn check<F: Valuation>(&self, value: F) -> Result<(), CheckError> {
fn check<F: Valuation + ?Sized>(&self, value: &F) -> Result<(), CheckError> {
Linear::from(self.clone()).check(value)
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/pindakaas/src/cardinality_one.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl CardinalityOne {
}

impl Checker for CardinalityOne {
fn check<F: Valuation>(&self, value: F) -> Result<(), CheckError> {
fn check<F: Valuation + ?Sized>(&self, value: &F) -> Result<(), CheckError> {
Linear::from(self.clone()).check(value)
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/pindakaas/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ impl<'a> XorConstraint<'a> {
}

impl<'a> Checker for XorConstraint<'a> {
fn check<F: Valuation>(&self, value: F) -> Result<(), CheckError> {
fn check<F: Valuation + ?Sized>(&self, value: &F) -> Result<(), CheckError> {
let count = LinExp::from_terms(self.lits.iter().map(|&l| (l, 1)).collect_vec().as_slice())
.value(value)?;
if count % 2 == 1 {
Expand Down
8 changes: 4 additions & 4 deletions crates/pindakaas/src/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ use self::enc::GROUND_BINARY_AT_LB;
use crate::{linear::Constraint, CheckError, Coeff, LinExp, Result, Unsatisfiable, Valuation};

impl LinExp {
pub(crate) fn value<F: Valuation>(&self, value: F) -> Result<Coeff, CheckError> {
pub(crate) fn value<F: Valuation + ?Sized>(&self, sol: &F) -> Result<Coeff, CheckError> {
let mut total = self.add;
for (constraint, terms) in self.iter() {
// Calculate sum for constraint
let sum = terms
.iter()
.filter(|(lit, _)| value(*lit).expect("missing assignment to literal"))
.filter(|(lit, _)| sol.value(*lit).expect("missing assignment to literal"))
.map(|(_, i)| i)
.sum();
match constraint {
Some(Constraint::AtMostOne) => {
if sum != 0
&& terms
.iter()
.filter(|&(l, _)| value(*l).unwrap_or(true))
.filter(|&(l, _)| sol.value(*l).unwrap_or(true))
.count() > 1
{
return Err(Unsatisfiable.into());
Expand All @@ -38,7 +38,7 @@ impl LinExp {
.iter()
.map(|(l, _)| *l)
.tuple_windows()
.any(|(a, b)| !value(a).unwrap_or(false) & value(b).unwrap_or(true))
.any(|(a, b)| !sol.value(a).unwrap_or(false) & sol.value(b).unwrap_or(true))
{
return Err(Unsatisfiable.into());
}
Expand Down
10 changes: 5 additions & 5 deletions crates/pindakaas/src/int/constrain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ impl<'a> TernLeConstraint<'a> {
}

impl<'a> Checker for TernLeConstraint<'a> {
fn check<F: Valuation>(&self, value: F) -> Result<(), CheckError> {
let x = LinExp::from(self.x).value(&value)?;
let y = LinExp::from(self.y).value(&value)?;
let z = LinExp::from(self.z).value(&value)?;
fn check<F: Valuation + ?Sized>(&self, sol: &F) -> Result<(), CheckError> {
let x = LinExp::from(self.x).value(sol)?;
let y = LinExp::from(self.y).value(sol)?;
let z = LinExp::from(self.z).value(sol)?;
if Self::check(x, y, &self.cmp, z) {
Ok(())
} else {
Expand Down Expand Up @@ -105,7 +105,7 @@ impl<'a, DB: ClauseDatabase> Encoder<DB, TernLeConstraint<'a>> for TernLeEncoder

return match (x, y, z) {
(IntVarEnc::Const(_), IntVarEnc::Const(_), IntVarEnc::Const(_)) => {
if tern.check(|_| None).is_ok() {
if tern.check(&|_| None).is_ok() {
Ok(())
} else {
Err(Unsatisfiable)
Expand Down
4 changes: 2 additions & 2 deletions crates/pindakaas/src/int/enc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,9 @@ impl ImplicationChainEncoder {
}

impl Checker for ImplicationChainConstraint {
fn check<F: Valuation>(&self, value: F) -> Result<(), CheckError> {
fn check<F: Valuation + ?Sized>(&self, sol: &F) -> Result<(), CheckError> {
for (a, b) in self.lits.iter().copied().tuple_windows() {
if value(a).unwrap_or(true) & !value(b).unwrap_or(false) {
if sol.value(a).unwrap_or(true) & !sol.value(b).unwrap_or(false) {
return Err(Unsatisfiable.into());
}
}
Expand Down
23 changes: 15 additions & 8 deletions crates/pindakaas/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,20 @@ impl fmt::Display for Unsatisfiable {
/// an empty value, or the [`Unsatisfiable`] error type.
pub type Result<T = (), E = Unsatisfiable> = std::result::Result<T, E>;

/// A function that gives the valuation/truth-value for a given literal in the
/// current solution/model.
///
/// Note that the function can return None if the model/solution is independent
/// of the given literal.
pub trait Valuation: Fn(Lit) -> Option<bool> {}
impl<F: Fn(Lit) -> Option<bool>> Valuation for F {}
/// A trait implemented by types that can be used to represent a solution/model
pub trait Valuation {
/// Returns the valuation/truth-value for a given literal in the
/// current solution/model.
///
/// Note that the function can return None if the model/solution is independent
/// of the given literal.
fn value(&self, lit: Lit) -> Option<bool>;
}
impl<F: Fn(Lit) -> Option<bool>> Valuation for F {
fn value(&self, lit: Lit) -> Option<bool> {
self(lit)
}
}

/// Encoder is the central trait implemented for all the encoding algorithms
pub trait Encoder<DB: ClauseDatabase, Constraint> {
Expand All @@ -206,7 +213,7 @@ pub trait Checker {
/// the constraint,
/// - it returns [`Unsatisfiable`] when the assignment violates the
/// constraint
fn check<F: Valuation>(&self, value: F) -> Result<(), CheckError>;
fn check<F: Valuation + ?Sized>(&self, value: &F) -> Result<(), CheckError>;
}

/// Incomplete is a error type returned by a [`Checker`] type when the
Expand Down
6 changes: 3 additions & 3 deletions crates/pindakaas/src/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ impl LinearConstraint {
}

impl Checker for LinearConstraint {
fn check<F: Valuation>(&self, value: F) -> Result<(), CheckError> {
fn check<F: Valuation + ?Sized>(&self, value: &F) -> Result<(), CheckError> {
let lhs = self.exp.value(value)?;
if match self.cmp {
Comparator::LessEq => lhs <= self.k,
Expand Down Expand Up @@ -561,10 +561,10 @@ impl Mul<Coeff> for LinExp {
}

impl Checker for Linear {
fn check<F: Valuation>(&self, value: F) -> Result<(), CheckError> {
fn check<F: Valuation + ?Sized>(&self, sol: &F) -> Result<(), CheckError> {
let mut sum = 0;
for (lit, coef) in self.terms.iter().flat_map(|p| p.iter().copied()) {
match value(lit) {
match sol.value(lit) {
Some(true) => sum += *coef,
None if self.cmp == LimitComp::LessEq => sum += *coef,
Some(false) => {}
Expand Down
Loading