From 79e55f46bc732b9af5c33a8836eef2ba0177f04f Mon Sep 17 00:00:00 2001 From: Rudi Grinberg Date: Tue, 14 Jan 2025 22:48:44 +0000 Subject: [PATCH] refactor: remove oop in sat (#11288) Signed-off-by: Rudi Grinberg --- src/0install-solver/sat.ml | 363 ++++++++++++++++++------------------- 1 file changed, 178 insertions(+), 185 deletions(-) diff --git a/src/0install-solver/sat.ml b/src/0install-solver/sat.ml index c70990f888d..c85a52a940c 100644 --- a/src/0install-solver/sat.ml +++ b/src/0install-solver/sat.ml @@ -88,17 +88,10 @@ let log_debug p = module Make (User : USER) = struct type clause = - < (* [lit] is now [True]. Add any new deductions. - @return false if there is a conflict. *) - propagate : lit -> bool - ; (* Why are we causing a conflict? - @return a list of literals which caused the problem by all being True. *) - calc_reason : lit list - ; (* Which literals caused [lit] to have its current value? - @return a list of literals which caused the problem by all being True. *) - calc_reason_for : lit -> lit list - ; (* For debugging *) - pp : User_message.Style.t Pp.t > + | Union of t * lit array + | At_most_one of at_most_one_clause + + and at_most_one_clause = lit option ref * t * lit list (** The reason why a literal is set. *) and reason = @@ -121,6 +114,19 @@ module Make (User : USER) = struct and lit = sign * var + and t = + { id_maker : VarID.mint + ; (* Propagation *) + mutable vars : var list + ; propQ : lit Queue.t (* propagation queue *) + ; (* Assignments *) + mutable trail : lit list (* order of assignments, most recent first *) + ; mutable trail_lim : int list (* decision levels (len(trail) at each decision) *) + ; mutable toplevel_conflict : bool + ; mutable set_to_false : bool + (* we are finishing up by setting everything else to False *) + } + let lit_equal (s1, v1) (s2, v2) = s1 == s2 && v1 == v2 module C = Comparable.Make (VarID) @@ -161,24 +167,6 @@ module Make (User : USER) = struct } ;; - type t = - { id_maker : VarID.mint - ; (* Propagation *) - mutable vars : var list - ; propQ : lit Queue.t (* propagation queue *) - ; (* Assignments *) - mutable trail : lit list (* order of assignments, most recent first *) - ; mutable trail_lim : int list (* decision levels (len(trail) at each decision) *) - ; mutable toplevel_conflict : bool - ; mutable set_to_false : bool - (* we are finishing up by setting everything else to False *) - } - - let pp_reason = function - | Clause clause -> clause#pp - | External msg -> Pp.text msg - ;; - let neg = function | Pos, var -> Neg, var | Neg, var -> Pos, var @@ -244,21 +232,17 @@ module Make (User : USER) = struct User.pp info.obj ++ Pp.textf "=%s" (Var_value.to_string info.value) ;; - let pp_lit_reason lit = - match (var_of_lit lit).reason with - | None -> Pp.text "no reason (BUG)" - | Some (External reason) -> Pp.text reason - | Some (Clause c) -> - let reason = c#calc_reason_for lit in - Pp.concat_map ~sep:(Pp.text " && ") reason ~f:pp_lit_assignment + let pp_clause = function + | Union (_, lits) -> + Pp.text "' + | At_most_one (_, _, lits) -> Pp.text "' ;; - (* Why is [lit] assigned the way it is? For debugging. *) - let explain_reason lit = - let value = lit_value lit in - if value = Undecided - then Pp.text "undecided!" - else Pp.hovbox (pp_lit_reason lit ++ Pp.text " => " ++ pp_lit_assignment lit) + let pp_reason = function + | Clause clause -> pp_clause clause + | External msg -> Pp.text msg ;; let get_decision_level problem = List.length problem.trail_lim @@ -336,61 +320,73 @@ module Make (User : USER) = struct done ;; - (** Process the propQ. - Returns None when done, or the clause that caused a conflict. *) - let propagate problem = - (* if debug then log_debug "propagate: queue length = %d" (Queue.length problem.propQ); *) - try - while not (Queue.is_empty problem.propQ) do - let lit = Queue.pop_exn problem.propQ in - let old_watches = Queue.create () in - let watches = watch_queue lit in - Queue.transfer watches old_watches; - (* if debug then log_debug "%s -> True : watches: %d" (name_lit lit) (Queue.length old_watches); *) - - (* Notifiy all watchers *) - while not (Queue.is_empty old_watches) do - let clause = Queue.pop_exn old_watches in - if not (clause#propagate lit) - then ( - (* Conflict *) - - (* Re-add remaining watches *) - Queue.transfer old_watches watches; - (* No point processing the rest of the queue as - we'll have to backtrack now. *) - Queue.clear problem.propQ; - raise (ConflictingClause clause)) - done - done; - None - with - | ConflictingClause c -> Some c - ;; - let impossible problem = problem.toplevel_conflict <- true - (* Call [clause#propagate lit] when lit becomes True *) + (* Call [Clause.propagate lit] when lit becomes True *) let watch_lit lit clause = (* if debug then log_debug "%s is watching for %s to become True" clause#to_string (name_lit lit); *) Queue.push (watch_queue lit) clause ;; - let union_clause problem lits = - object (self : clause) - (* Try to infer new facts. - We can do this only when all of our literals are False except one, - which is undecided. That is, - False... or X or False... = True => X = True - - To get notified when this happens, we tell the solver to - watch two of our undecided literals. Watching two undecided - literals is sufficient. When one changes we check the state - again. If we still have two or more undecided then we switch - to watching them, otherwise we propagate. - - Returns false on conflict. *) - method propagate lit = + exception Conflict + + module Clause = struct + (* Why are we causing a conflict? + @return a list of literals which caused the problem by all being True. *) + let calc_reason = function + | Union (_, lits) -> List.map ~f:neg (Array.to_list lits) + | At_most_one (_, _, lits) -> + (* If we caused a conflict, it's because two of our literals became true. *) + (* Find two True literals *) + let rec find_two found = function + | [] -> assert false (* Don't know why! *) + | x :: xs when lit_value x <> True -> find_two found xs + | x :: xs -> + (match found with + | None -> find_two (Some x) xs + | Some first -> [ first; x ]) + in + find_two None lits + ;; + + (* Which literals caused [lit] to have its current value? + @return a list of literals which caused the problem by all being True. *) + let calc_reason_for t lit = + match t with + | Union (_, lits) -> + (* Which literals caused [lit] to have its current value? *) + assert (lit_equal lit lits.(0)); + (* The cause is everything except lit. *) + let rec get_cause i = + if i = Array.length lits + then [] + else ( + let l = lits.(i) in + if lit_equal l lit then get_cause (i + 1) else neg l :: get_cause (i + 1)) + in + get_cause 0 + | At_most_one (_, _, lits) -> + (* Find the True literal. Any true literal other than [lit] would do. *) + [ List.find_exn lits ~f:(fun l -> (not (lit_equal l lit)) && lit_value l = True) ] + ;; + + (* [lit] is now [True]. Add any new deductions. @return false if there is a + conflict. *) + let propagate t lit = + match t with + | Union (problem, lits) -> + (* Try to infer new facts. + We can do this only when all of our literals are False except one, + which is undecided. That is, + False... or X or False... = True => X = True + + To get notified when this happens, we tell the solver to + watch two of our undecided literals. Watching two undecided + literals is sufficient. When one changes we check the state + again. If we still have two or more undecided then we switch + to watching them, otherwise we propagate. + + Returns false on conflict. *) (* [neg lit] has just become False *) (*if debug then log_debug("%s: noticed %s has become False" % (self, self.solver.name_lit(neg(lit)))) *) @@ -404,7 +400,7 @@ module Make (User : USER) = struct if lit_value lits.(0) = True then ( (* We're already satisfied. Do nothing. *) - watch_lit lit (self :> clause); + watch_lit lit t; true) else ( assert (lit_value lits.(1) = False); @@ -414,58 +410,24 @@ module Make (User : USER) = struct if i = Array.length lits then ( (* Only lits[0], is now undefined, so set it to True. *) - watch_lit lit (self :> clause); - enqueue problem lits.(0) (Clause (self :> clause))) + watch_lit lit t; + enqueue problem lits.(0) (Clause t)) else ( match lit_value lits.(i) with | Undecided | True -> (* If it's True then we've already done our job, so this means we don't get notified unless we backtrack, which is fine. *) Array.swap lits 1 i; - watch_lit (neg lits.(1)) (self :> clause); + watch_lit (neg lits.(1)) t; true | False -> find_not_false (i + 1)) in find_not_false 2) - - (* We can only cause a conflict if all our lits are False, so they're all the cause. - e.g. if we are "A or B or not(C)" then "not(A) and not(B) and C" causes a conflict. *) - method calc_reason = List.map ~f:neg (Array.to_list lits) - - (** Which literals caused [lit] to have its current value? *) - method calc_reason_for lit = - assert (lit_equal lit lits.(0)); - (* The cause is everything except lit. *) - let rec get_cause i = - if i = Array.length lits - then [] - else ( - let l = lits.(i) in - if lit_equal l lit then get_cause (i + 1) else neg l :: get_cause (i + 1)) - in - get_cause 0 - - method pp = - Pp.text "' - end - ;; - - exception Conflict - - (* If one literal in the list becomes True, all the others must be False. - Preferred literals should be listed first. *) - class at_most_one_clause problem lits = - (* The single literal from our set that is True. - We store this explicitly because the decider needs to know quickly. *) - let current = ref None in - object (self) - method propagate lit = + | At_most_one (current, problem, lits) -> (* Re-add ourselves to the watch list. (we we won't get any more notifications unless we backtrack, in which case we'd need to get back on the list anyway) *) - watch_lit lit (self :> clause); + watch_lit lit t; (* value[lit] has just become true *) assert (lit_value lit = True); (* if debug then log_debug("%s: noticed %s has become True" % (self, self.solver.name_lit(lit))) *) @@ -491,58 +453,71 @@ module Make (User : USER) = struct in let var_info = var_of_lit lit in var_info.undo <- undo :: var_info.undo; - try - (* We set all other literals to False. *) - List.iter lits ~f:(fun l -> - match lit_value l with - | True when not (lit_equal l lit) -> - (* Due to queuing, we might get called with current = None - and two versions already selected. *) - if debug then log_debug (Pp.text "CONFLICT: already selected " ++ name_lit l); - raise Conflict - | Undecided -> - (* Since one of our lits is already true, all unknown ones - can be set to False. *) - if not (enqueue problem (neg l) (Clause (self :> clause))) - then ( - if debug - then - log_debug (Pp.text "CONFLICT: enqueue failed for " ++ name_lit (neg l)); - raise - Conflict (* Can't happen, since we already checked we're Undecided *)) - | _ -> ()); - true - with - | Conflict -> false - - (** If we caused a conflict, it's because two of our literals became true. *) - method calc_reason = - (* Find two True literals *) - let rec find_two found = function - | [] -> assert false (* Don't know why! *) - | x :: xs when lit_value x <> True -> find_two found xs - | x :: xs -> - (match found with - | None -> find_two (Some x) xs - | Some first -> [ first; x ]) - in - find_two None lits + (try + (* We set all other literals to False. *) + List.iter lits ~f:(fun l -> + match lit_value l with + | True when not (lit_equal l lit) -> + (* Due to queuing, we might get called with current = None + and two versions already selected. *) + if debug + then log_debug (Pp.text "CONFLICT: already selected " ++ name_lit l); + raise Conflict + | Undecided -> + (* Since one of our lits is already true, all unknown ones + can be set to False. *) + if not (enqueue problem (neg l) (Clause t)) + then ( + if debug + then + log_debug (Pp.text "CONFLICT: enqueue failed for " ++ name_lit (neg l)); + raise + Conflict (* Can't happen, since we already checked we're Undecided *)) + | _ -> ()); + true + with + | Conflict -> false) + ;; + end - (** Which literals caused [lit] to have its current value? *) - method calc_reason_for lit = - (* Find the True literal. Any true literal other than [lit] would do. *) - [ List.find_exn lits ~f:(fun l -> (not (lit_equal l lit)) && lit_value l = True) ] + (** Process the propQ. + Returns None when done, or the clause that caused a conflict. *) + let propagate problem = + (* if debug then log_debug "propagate: queue length = %d" (Queue.length problem.propQ); *) + try + while not (Queue.is_empty problem.propQ) do + let lit = Queue.pop_exn problem.propQ in + let old_watches = Queue.create () in + let watches = watch_queue lit in + Queue.transfer watches old_watches; + (* if debug then log_debug "%s -> True : watches: %d" (name_lit lit) (Queue.length old_watches); *) + + (* Notifiy all watchers *) + while not (Queue.is_empty old_watches) do + let clause = Queue.pop_exn old_watches in + if not (Clause.propagate clause lit) + then ( + (* Conflict *) - method best_undecided = - (* if debug then log_debug "best_undecided: %s" (string_of_lits lits); *) - List.find lits ~f:(fun l -> lit_value l = Undecided) + (* Re-add remaining watches *) + Queue.transfer old_watches watches; + (* No point processing the rest of the queue as + we'll have to backtrack now. *) + Queue.clear problem.propQ; + raise (ConflictingClause clause)) + done + done; + None + with + | ConflictingClause c -> Some c + ;; - method get_selected = !current - method pp = Pp.text "' - end + let get_best_undecided (_, _, lits) = + (* if debug then log_debug "best_undecided: %s" (string_of_lits lits); *) + List.find lits ~f:(fun l -> lit_value l = Undecided) + ;; - let get_best_undecided clause = clause#best_undecided - let get_selected clause = clause#get_selected + let get_selected (current, _, _) = !current (* Returns the new clause if one was added, [AddedFact true] if none was added because this clause is trivially True, or [AddedFact false] if the clause @@ -556,7 +531,7 @@ module Make (User : USER) = struct AddedFact (enqueue problem lit reason) | lits -> let lits = Array.of_list lits in - let clause = union_clause problem lits in + let clause = Union (problem, lits) in if learnt then ( (* lits[0] is Undecided because we just backtracked. @@ -635,8 +610,8 @@ module Make (User : USER) = struct If any are True then they're enqueued and we'll process them soon. *) let lits = List.filter lits ~f:(fun l -> lit_value l <> False) in - let clause = new at_most_one_clause problem lits in - List.iter lits ~f:(fun l -> watch_lit l (clause :> clause)); + let clause = ref None, problem, lits in + List.iter lits ~f:(fun l -> watch_lit l (At_most_one clause)); clause ;; @@ -774,13 +749,13 @@ module Make (User : USER) = struct | None -> Code_error.raise "No reason!" [] in (* Can't happen *) - let p_reason = cause#calc_reason_for p in + let p_reason = Clause.calc_reason_for cause p in let outcome = name_lit p in if debug then log_debug (Pp.text "why did " - ++ cause#pp + ++ pp_clause cause ++ Pp.text " lead to " ++ outcome ++ Pp.char '?'); @@ -796,8 +771,9 @@ module Make (User : USER) = struct (* Start with all the literals involved in the conflict. *) if debug then - log_debug (Pp.text "why did " ++ original_cause#pp ++ Pp.text " lead to conflict?"); - let p = follow_causes original_cause#calc_reason (Pp.text "conflict") in + log_debug + (Pp.text "why did " ++ pp_clause original_cause ++ Pp.text " lead to conflict?"); + let p = follow_causes (Clause.calc_reason original_cause) (Pp.text "conflict") in assert (!counter = 0); (* p is the literal we decided to stop processing on. It's either a derived variable at the current level, or the decision that @@ -898,4 +874,21 @@ module Make (User : USER) = struct with | SolveDone x -> x) ;; + + let pp_lit_reason lit = + match (var_of_lit lit).reason with + | None -> Pp.text "no reason (BUG)" + | Some (External reason) -> Pp.text reason + | Some (Clause c) -> + let reason = Clause.calc_reason_for c lit in + Pp.concat_map ~sep:(Pp.text " && ") reason ~f:pp_lit_assignment + ;; + + (* Why is [lit] assigned the way it is? For debugging. *) + let explain_reason lit = + let value = lit_value lit in + if value = Undecided + then Pp.text "undecided!" + else Pp.hovbox (pp_lit_reason lit ++ Pp.text " => " ++ pp_lit_assignment lit) + ;; end