Skip to content

Commit

Permalink
Fix issue PLTools#173
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitrii Kosarev a.k.a. Kakadu <[email protected]>
  • Loading branch information
Kakadu authored and Dmitrii.Kosarev a.k.a. Kakadu committed Dec 24, 2024
1 parent 79e0390 commit 62c7b28
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 19 deletions.
35 changes: 34 additions & 1 deletion regression_ppx/test014diseq.ml
Original file line number Diff line number Diff line change
@@ -1,17 +1,50 @@
open OCanren
open Tester

let debug_line line =
debug_var !!1 OCanren.reify (function _ ->
Format.printf "%d\n%!" line;
success)
;;

let trace_index msg var =
debug_var var OCanren.reify (function
| [ Var (n, _) ] ->
Printf.printf "%s = _.%d\n" msg n;
success
| _ -> assert false)
;;

let trace fmt =
Format.kasprintf
(fun s ->
debug_var !!1 OCanren.reify (function _ ->
Format.printf "%s\n%!" s;
success))
fmt
;;

let rel list1 =
let open OCanren.Std in
fresh
(list2 hd1 tl1 hd2 tl2)
(trace_index "hd1" hd1)
(trace_index "hd2" hd2)
(trace_index "tl2" tl2)
(list1 =/= list2)
(list1 === hd1 % tl1)
(list2 === hd2 % tl2)
trace_diseq
(trace " hd2 === 1")
(hd2 === !!1)
trace_diseq
(trace " tl2 === []") (* bad behaviour starts now *)
(tl2 === nil ())
(hd1 === !!1)
(tl1 === nil ())
(debug_line __LINE__)
trace_diseq
(tl1 === nil ()) (* crashes here *)
(debug_line __LINE__)
;;

(* let () = [%tester run_r [%show GT.int GT.list] (Std.List.reify reify) 1 (fun q -> rel q)] *)
Expand Down
70 changes: 52 additions & 18 deletions src/core/Disequality.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
(* to avoid clash with Std.List (i.e. logic list) *)
module List = Stdlib.List

let log fmt =
if false
then Format.kasprintf (Format.printf "%s\n%!") fmt
else Format.ifprintf Format.std_formatter fmt

module Answer =
struct
module S = Set.Make(Term)
Expand Down Expand Up @@ -125,14 +130,15 @@ module Disjunct :
if Term.VarMap.is_empty d then Format.fprintf ppf "<empty>"
else
Format.fprintf ppf "[| ";
Term.VarMap.iter (fun k v ->
Format.fprintf ppf "@[%d =/= %s@], @," k.Term.Var.index (Term.show v)
Term.VarMap.iteri (fun i k v ->
if i<>0 then Format.fprintf ppf ", ";
Format.fprintf ppf " @[%d =/= %s@]" k.Term.Var.index (Term.show v)
) d;
Format.fprintf ppf " |]"

let update t =
ListLabels.fold_left ~init:t
~f:(let open Subst.Binding in fun acc {var; term} ->
let update : t -> _ -> t = fun init ->
ListLabels.fold_left ~init
~f:(fun acc {Subst.Binding.var; term} ->
if Term.VarMap.mem var acc then
(* in this case we have subformula of the form (x =/= t1) \/ (x =/= t2) which is always SAT *)
raise Disequality_fulfilled
Expand Down Expand Up @@ -161,12 +167,32 @@ module Disjunct :
| Fulfiled -> raise Disequality_fulfilled
| Violated -> raise Disequality_violated

let rec recheck env subst t =
let rec recheck env subst (t: t): t =
(* log "Disjunct.recheck: %a" pp t; *)
let var, term = Term.VarMap.max_binding t in
(* log " max bind index = %d" var.Term.Var.index; *)
let unchecked = Term.VarMap.remove var t in
(* log " unchecked: %a" pp unchecked; *)
match refine env subst (Obj.magic var) term with
| Fulfiled -> raise Disequality_fulfilled
| Refined delta -> update unchecked delta
| Fulfiled ->
raise Disequality_fulfilled
| Refined delta -> (
(* When leading terms are reified into something new, we still need to
do whole unification, beacuse other pairs may need walking ---
(we postponed walking, so som einformation may be lost.)
See issue #173
*)
(* log "Refined into: %a" (Format.pp_print_list Subst.Binding.pp) delta; *)
match Subst.unify_map env subst t with
| None ->
(* not unifiable --- always distinct *)
raise Disequality_fulfilled
| Some ([], _) -> raise Disequality_violated
| Some (bnds, _subst) ->
(* TODO(Kakadu): reconstruction of map from binding list could hurt performance *)
let rez = Subst.varmap_of_bindings bnds in
(* log "Disjunct.recheck returns %a" pp rez; *)
rez)
| Violated ->
if Term.VarMap.is_empty unchecked then
raise Disequality_violated
Expand Down Expand Up @@ -254,9 +280,12 @@ module Conjunct :
if M.is_empty map
then Format.fprintf ppf "{}"
else
let idx = ref 0 in
Format.fprintf ppf "{ ";
M.iter (fun k v ->
Format.fprintf ppf "@[%d: %a@],@ " k Disjunct.pp v
if !idx <> 0 then Format.fprintf ppf " ,";
Format.fprintf ppf "@[%d: %a@]" k Disjunct.pp v;
incr idx
) map;
Format.fprintf ppf " }"

Expand All @@ -280,11 +309,14 @@ module Conjunct :
) t Term.VarMap.empty

let recheck env subst t =
M.fold (fun id disj acc ->
log "Conjunct.recheck. %a" pp t;
let rez = M.fold (fun id disj acc ->
try
M.add id (Disjunct.recheck env subst disj) acc
with Disequality_fulfilled -> acc
) t M.empty
) t M.empty in
log "rechecked = %a" pp rez;
rez

let merge_disjoint env subst =
M.union (fun _ _ _ ->
Expand Down Expand Up @@ -375,6 +407,11 @@ type t = Conjunct.t Term.VarMap.t

let empty = Term.VarMap.empty

let pp ppf : t -> unit =
Term.VarMap.iter (fun k v ->
Format.fprintf ppf "@[%d: %a@]@," k.Term.Var.index Conjunct.pp v
)

(* merges all conjuncts (linked to different variables) into one *)
let combine env subst cstore =
Term.VarMap.fold (fun _ -> Conjunct.merge_disjoint env subst) cstore Conjunct.empty
Expand All @@ -394,17 +431,19 @@ let add env subst cstore x y =
| Disequality_violated -> None

let recheck env subst cstore bs =
let helper var cstore =
let helper var cstore : t =
try
let conj = Term.VarMap.find var cstore in
let cstore = Term.VarMap.remove var cstore in
update env subst (Conjunct.recheck env subst conj) cstore

with Not_found -> cstore
in
try
let cstore = ListLabels.fold_left bs ~init:cstore
~f:(let open Subst.Binding in fun cstore {var; term} ->
~f:(fun cstore {Subst.Binding.var; term} ->
let cstore = helper var cstore in
(* log "cstore = %a" pp cstore; *)
match Env.var env term with
| Some u -> helper u cstore
| None -> cstore
Expand All @@ -418,8 +457,3 @@ let project env subst cstore fv =

let reify env subst cstore x =
Conjunct.reify env subst (combine env subst cstore) x

let pp ppf : t -> unit =
Term.VarMap.iter (fun k v ->
Format.fprintf ppf "@[%d: %a@]@," k.Term.Var.index Conjunct.pp v
)
25 changes: 25 additions & 0 deletions src/core/Subst.ml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,18 @@ module Binding =
if res <> 0 then res else Term.compare t p

let hash {var; term} = Hashtbl.hash (Term.Var.hash var, Term.hash term)

let pp ppf {var; term} =
Format.fprintf ppf "{ var.idx = %d; term=%s }" var.Term.Var.index (Term.show term)
end

let varmap_of_bindings : Binding.t list -> Term.t Term.VarMap.t =
Stdlib.List.fold_left (fun (acc: _ Term.VarMap.t) Binding.{var;term} ->
assert (not (Term.VarMap.mem var acc));
Term.VarMap.add var term acc
)
Term.VarMap.empty

type t = Term.t Term.VarMap.t

let empty = Term.VarMap.empty
Expand Down Expand Up @@ -145,13 +155,19 @@ let extend ~scope env subst var term =

exception Unification_failed

let log fmt =
if false
then Format.kasprintf (Format.printf "%s\n%!") fmt
else Format.ifprintf Format.std_formatter fmt

let unify ?(subsume=false) ?(scope=Term.Var.non_local_scope) env subst x y =
(* The idea is to do the unification and collect the unification prefix during the process *)
let extend var term (prefix, subst) =
let subst = extend ~scope env subst var term in
(Binding.({var; term})::prefix, subst)
in
let rec helper x y acc =
log "unify '%s' and '%s'" (Term.show x) (Term.show y);
let open Term in
fold2 x y ~init:acc
~fvar:(fun ((_, subst) as acc) x y ->
Expand Down Expand Up @@ -183,6 +199,15 @@ let apply env subst x = Obj.magic @@
~fvar:(fun v -> Term.repr v)
~fval:(fun x -> Term.repr x)

let unify_map env subst map =
let vars, terms =
Term.VarMap.fold (fun v term acc -> (v :: fst acc, term :: snd acc)) map ([],[])
in
log "var = %s" (Term.show (Obj.magic (apply env subst vars)));
log "terms = %s" (Term.show (Obj.magic (apply env subst terms)));
unify env subst (Obj.magic vars) (Obj.magic terms)


let freevars env subst x =
Env.freevars env @@ apply env subst x

Expand Down
5 changes: 5 additions & 0 deletions src/core/Subst.mli
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ module Binding :
val equal : t -> t -> bool
val compare : t -> t -> int
val hash : t -> int
val pp: Format.formatter -> t -> unit
end

val varmap_of_bindings: Binding.t list -> Term.t Term.VarMap.t

type t

val empty : t
Expand Down Expand Up @@ -64,6 +67,8 @@ val freevars : Env.t -> t -> 'a -> Term.VarSet.t
*)
val unify : ?subsume:bool -> ?scope:Term.Var.scope -> Env.t -> t -> 'a -> 'a -> (Binding.t list * t) option

val unify_map: Env.t -> t -> Term.t Term.VarMap.t -> (Binding.t list * t) option

val merge_disjoint : Env.t -> t -> t -> t

(* [merge env s1 s2] merges two substituions *)
Expand Down
4 changes: 4 additions & 0 deletions src/core/Term.ml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ module VarMap =
match f (try Some (find k m) with Not_found -> None) with
| Some x -> add k x m
| None -> remove k m

let iteri f m =
let i = ref 0 in
iter (fun k v -> f !i k v; incr i) m
end

type t = Obj.t
Expand Down
3 changes: 3 additions & 0 deletions src/core/Term.mli
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ module VarMap :
include Map.S with type key = Var.t

val update : key -> ('a option -> 'a option) -> 'a t -> 'a t

val iteri: (int -> key -> 'a -> unit) -> 'a t -> unit

end

(* [t] type of untyped OCaml term *)
Expand Down

0 comments on commit 62c7b28

Please sign in to comment.