Skip to content

Commit

Permalink
refactor(sat): stop saving the sat problem for every clause (#11427)
Browse files Browse the repository at this point in the history
Signed-off-by: Rudi Grinberg <[email protected]>
  • Loading branch information
rgrinberg authored Feb 1, 2025
1 parent be00b7d commit a8886c2
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 27 deletions.
15 changes: 6 additions & 9 deletions src/dune_pkg/opam_solver.ml
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ module Solver = struct
let impl_clause =
match impls with
| [] -> None
| _ :: _ -> Some (S.at_most_one sat (List.map impls ~f:(fun s -> s.var)))
| _ :: _ -> Some (S.at_most_one (List.map impls ~f:(fun s -> s.var)))
in
{ role; clause = impl_clause; vars = impls }
in
Expand All @@ -670,12 +670,9 @@ module Solver = struct
module Conflict_classes = struct
module Map = Input.Conflict_class.Map

type t =
{ sat : S.t
; mutable groups : S.lit list ref Map.t
}
type t = { mutable groups : S.lit list ref Map.t }

let create sat = { sat; groups = Map.empty }
let create () = { groups = Map.empty }

let var t name =
match Map.find t.groups name with
Expand All @@ -700,7 +697,7 @@ module Solver = struct
Map.iter t.groups ~f:(fun impls ->
match !impls with
| _ :: _ :: _ ->
let (_ : S.at_most_one_clause) = S.at_most_one t.sat !impls in
let (_ : S.at_most_one_clause) = S.at_most_one !impls in
()
| _ -> ())
;;
Expand All @@ -711,7 +708,7 @@ module Solver = struct
let build_problem context root_req sat ~dummy_impl =
(* For each (iface, source) we have a list of implementations. *)
let impl_cache = ref Input.Role.Map.empty in
let conflict_classes = Conflict_classes.create sat in
let conflict_classes = Conflict_classes.create () in
let+ () =
let rec lookup_impl expand_deps role =
match Input.Role.Map.find !impl_cache role with
Expand Down Expand Up @@ -768,7 +765,7 @@ module Solver = struct
the [essential] case, because we must select a good version and we can't
select two. *)
(try
let (_ : S.at_most_one_clause) = S.at_most_one sat (user_var :: fail) in
let (_ : S.at_most_one_clause) = S.at_most_one (user_var :: fail) in
()
with
| Invalid_argument reason ->
Expand Down
34 changes: 17 additions & 17 deletions src/sat/sat.ml
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ let log_debug p =

module Make (User : USER) = struct
type clause =
| Union of t * lit array
| Union of lit array
| At_most_one of at_most_one_clause

and at_most_one_clause = lit option ref * t * lit array
and at_most_one_clause = lit option ref * lit array

(** The reason why a literal is set. *)
and reason =
Expand Down Expand Up @@ -276,11 +276,11 @@ module Make (User : USER) = struct
;;

let pp_clause = function
| Union (_, lits) ->
| Union lits ->
Pp.text "<some: "
++ Pp.concat_map ~sep:(Pp.text ", ") ~f:name_lit (Array.to_list lits)
++ Pp.char '>'
| At_most_one (_, _, lits) ->
| At_most_one (_, lits) ->
Pp.text "<at most one: " ++ pp_lits (Array.to_list lits) ++ Pp.char '>'
;;

Expand Down Expand Up @@ -382,8 +382,8 @@ module Make (User : USER) = struct
(* Why are we causing a conflict?
@return a list of literals which caused the problem by all being True. *)
let calc_reason = function
| Union (_, lits) -> List.map ~f:neg (Array.to_list lits)
| At_most_one (_, _, lits) ->
| Union lits -> List.map ~f:neg (Array.to_list lits)
| At_most_one (_, lits) ->
(* If we caused a conflict, it's because two of our literals became true. *)
(* Find two True literals *)
let rec find_two found lits i =
Expand All @@ -405,7 +405,7 @@ module Make (User : USER) = struct
@return a list of literals which caused the problem by all being True. *)
let calc_reason_for t lit =
match t with
| Union (_, lits) ->
| Union lits ->
(* Which literals caused [lit] to have its current value? *)
assert (lit_equal lit lits.(0));
(* The cause is everything except lit. *)
Expand All @@ -417,7 +417,7 @@ module Make (User : USER) = struct
if lit_equal l lit then get_cause (i + 1) else neg l :: get_cause (i + 1))
in
get_cause 0
| At_most_one (_, _, lits) ->
| At_most_one (_, lits) ->
(* Find the True literal. Any true literal other than [lit] would do. *)
[ Array.find_opt lits ~f:(fun l -> (not (lit_equal l lit)) && lit_value l = True)
|> Option.value_exn
Expand All @@ -426,9 +426,9 @@ module Make (User : USER) = struct

(* [lit] is now [True]. Add any new deductions. @return false if there is a
conflict. *)
let propagate t lit =
let propagate problem t lit =
match t with
| Union (problem, lits) ->
| Union lits ->
(* Try to infer new facts.
We can do this only when all of our literals are False except one,
which is undecided. That is,
Expand Down Expand Up @@ -477,7 +477,7 @@ module Make (User : USER) = struct
true)
in
find_not_false 2)
| At_most_one (current, problem, lits) ->
| At_most_one (current, lits) ->
(* Re-add ourselves to the watch list.
(we we won't get any more notifications unless we backtrack,
in which case we'd need to get back on the list anyway) *)
Expand Down Expand Up @@ -540,7 +540,7 @@ module Make (User : USER) = struct
(* Notify all watchers *)
while not (Queue.is_empty old_watches) do
let clause = Queue.pop_exn old_watches in
if not (Clause.propagate clause lit)
if not (Clause.propagate problem clause lit)
then (
(* Conflict *)

Expand All @@ -557,12 +557,12 @@ module Make (User : USER) = struct
| ConflictingClause c -> Some c
;;

let get_best_undecided (_, _, lits) =
let get_best_undecided (_, lits) =
(* if debug then log_debug "best_undecided: %s" (string_of_lits lits); *)
Array.find_opt lits ~f:(fun l -> lit_value l = Undecided)
;;

let get_selected (current, _, _) = !current
let get_selected (current, _) = !current

(* Returns the new clause if one was added, [AddedFact true] if none was added
because this clause is trivially True, or [AddedFact false] if the clause
Expand Down Expand Up @@ -591,7 +591,7 @@ module Make (User : USER) = struct
best_i := i)
done;
Array.swap lits 1 !best_i);
let clause = Union (problem, lits) in
let clause = Union lits in
(* Watch the first two literals in the clause (both must be
undefined at this point). *)
let watch i = watch_lit (neg lits.(i)) clause in
Expand Down Expand Up @@ -640,7 +640,7 @@ module Make (User : USER) = struct

let implies problem ?reason first rest = at_least_one problem ?reason (neg first :: rest)

let at_most_one problem lits =
let at_most_one lits : at_most_one_clause =
assert (List.length lits > 0);
(* if debug then log_debug "at_most_one(%s)" (string_of_lits lits); *)

Expand All @@ -663,7 +663,7 @@ module Make (User : USER) = struct
If any are True then they're enqueued and we'll process them
soon. *)
let lits = List.filter lits ~f:(fun l -> lit_value l <> False) |> Array.of_list in
let clause = ref None, problem, lits in
let clause = ref None, lits in
(let clause = At_most_one clause in
Array.iter lits ~f:(fun l -> watch_lit l clause));
clause
Expand Down
2 changes: 1 addition & 1 deletion src/sat/sat.mli
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ module Make (User : USER) : sig

(** Add a clause preventing more than one literal in the list from being [True].
@raise Invalid_argument if the list contains duplicates. *)
val at_most_one : t -> lit list -> at_most_one_clause
val at_most_one : lit list -> at_most_one_clause

(** [run_solver decider] tries to solve the SAT problem. It simplifies it as much as possible first. When it
has two paths which both appear possible, it calls [decider ()] to choose which to explore first. If this
Expand Down

0 comments on commit a8886c2

Please sign in to comment.