Skip to content

Commit

Permalink
refactor(BV): Use Z.t in BV solver
Browse files Browse the repository at this point in the history
The old implementation of the solver relied on the fact that each
constant fragment of a bit-vector was either all ones or all zeroes
because it defined a constant zero as the smallest bit-vector, a
constant one as the biggest bit-vector, and then computed the min and
max elements of a set of bit-vectors to check for consistency. With the
new solver based on Tarjan's union-find, this limitation is no longer
necessary, and it causes the solver to needlessly split bit-vector
variables involved in equalities with non-uniform constants [^1].

This patch is a simple refactoring of the bit-vector solver to use
integers rather than booleans to represent the constant parts of
bit-vectors, hopefully improving performance for bit-vectors with
non-uniform constant parts.

[^1]: For instance, solving `x = #b0000` can be done in a single swoop,
but solving `x = #b0101` currently first slices `x` into `a @ b @ c @ d`
and then assigns values to each of `a`, `b`, `c` and `d`.
  • Loading branch information
bclement-ocp committed Nov 16, 2023
1 parent e96d3d4 commit c9865f4
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 82 deletions.
163 changes: 81 additions & 82 deletions src/lib/reasoners/bitv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ let pp_sort ppf = function
Note that [tvar] and [defn] values must not be created directly: instead,
call the [s_var] and [s_cte] helpers. This is important due to the use of
physical equality in [union] (in particular, [union] assumes that there is a
single [tvar] for the boolean constants [true] and [false]). *)
physical equality in [union]. *)
type tvar = { mutable defn : defn }

and defn =
Expand All @@ -87,9 +86,9 @@ and defn =
(** The sort of this variable. See the type definition for {!sort_var} for
details. *)
}
| Dcte of bool
(** A [Dcte] is a variable that is forced to be equal to either all ones or
all zeroes. *)
| Dcte of Z.t * int
(** A [Dcte] is a variable that is forced to be equal to the specified integer
representation in bits. The second argument is the bit width. *)
| Dlink of tvar
(** A [Dlink] is a defined variable. All the [Dlink] should be followed to
arrive either at an unconstrained variable ([Droot]) or a constant
Expand All @@ -101,14 +100,15 @@ and defn =

let rec pp_defn ppf = function
| Droot { id; sorte; _ } -> Fmt.pf ppf "%a_%d" pp_sort sorte id
| Dcte b -> Fmt.pf ppf "%d" (if b then 1 else 0)
| Dcte (b, w) -> Fmt.pf ppf "%s" (Z.format ("%0" ^ string_of_int w ^ "b") b)
| Dlink tv -> pp_tvar ppf tv

and pp_tvar ppf { defn } = pp_defn ppf defn

let equal_tvar v1 v2 =
match v1.defn, v2.defn with
| (Droot _ | Dcte _), (Droot _ | Dcte _ ) -> v1 == v2
| Droot _, _ | _, Droot _ -> v1 == v2
| Dcte (n1, w1), Dcte (n2, w2) -> w1 = w2 && Z.equal n1 n2
| _ ->
(* [equal_tvar] should only be used before solving, i.e. before any unions
are made, and so there should be no [Dlink]. *)
Expand All @@ -124,17 +124,11 @@ let s_var =
and neg = { defn = Droot { id = -id; sorte; neg = x }} in
x

let s_cte =
(* Ensure that there is a single [tvar] for each boolean constant. *)
let v_true = { defn = Dcte true } in
let v_false = { defn = Dcte false } in
fun b -> if b then v_true else v_false
let s_cte n w = { defn = Dcte (Z.extract n 0 w, w) }

let negate_tvar = function
| { defn = Droot { neg; _ }; _ } -> neg
| { defn = Dcte b; _ } ->
(* Maintain the invariant that there is a single [tvar] for each constant *)
s_cte (not b)
| { defn = Dcte (n, w); _ } -> s_cte (Z.lognot n) w
| { defn = Dlink _; _ } -> assert false

(** Follow the defined variables [Dlink] and return the class representative as
Expand All @@ -159,22 +153,25 @@ let union v1 v2 =
| Dlink _, _ | _, Dlink _ ->
(* [find] invariant *)
assert false
| Dcte b1, Dcte b2 ->
(* Must be different because of [v1 == v2] check above *)
assert (not (Bool.equal b1 b2));
raise Util.Unsolvable
| Dcte (n1, w1), Dcte (n2, w2) ->
if w1 = w2 && Z.equal n1 n2 then (
(* We don't require physical equality of [Dcte] constructors, but we
still merge the corresponding nodes. *)
v1.defn <- Dlink v2;
) else
raise Util.Unsolvable
| Droot r1, Droot r2 ->
if r1.neg == v2 then raise Util.Unsolvable
else (
v1.defn <- Dlink v2;
r1.neg.defn <- Dlink r2.neg;
)
| Droot r1, Dcte b ->
| Droot r1, Dcte (n, w) ->
v1.defn <- Dlink v2;
r1.neg.defn <- Dlink (s_cte (not b))
| Dcte b, Droot r2 ->
r1.neg.defn <- Dlink (s_cte (Z.lognot n) w)
| Dcte (n, w), Droot r2 ->
v2.defn <- Dlink v1;
r2.neg.defn <- Dlink (s_cte (not b))
r2.neg.defn <- Dlink (s_cte (Z.lognot n) w)

type 'a alpha_term = {
bv : 'a;
Expand Down Expand Up @@ -214,23 +211,23 @@ let positive value = { value; negated = false }
let negative value = { value; negated = true }

type 'a simple_term_aux =
| Cte of bool
| Cte of Z.t
| Other of 'a signed
| Ext of 'a signed * int * int * int (*// id * size * i * j //*)

let equal_simple_term_aux eq l r =
match l, r with
| Cte b1, Cte b2 -> Bool.equal b1 b2
| Cte b1, Cte b2 -> Z.equal b1 b2
| Other o1, Other o2 -> equal_signed eq o1 o2
| Ext (o1, s1, i1, j1), Ext (o2, s2, i2, j2) ->
i1 = i2 && j1 = j2 && s1 = s2 && equal_signed eq o1 o2
| _, _ -> false

let compare_simple_term_aux cmp st1 st2 =
match st1, st2 with
| Cte b1, Cte b2 -> Bool.compare b1 b2
| Cte false , _ | _ , Cte true -> -1
| _ , Cte false | Cte true,_ -> 1
| Cte b1, Cte b2 -> Z.compare b1 b2
| Cte _, _ -> -1
| _, Cte _ -> 1

| Other t1 , Other t2 -> compare_signed cmp t1 t2
| _ , Other _ -> -1
Expand All @@ -243,13 +240,13 @@ let compare_simple_term_aux cmp st1 st2 =
if c2 <> 0 then c2 else compare_signed cmp t1 t2

let hash_simple_term_aux hash = function
| Cte b -> 11 * Hashtbl.hash b
| Cte b -> 11 * Z.hash b
| Other x -> 17 * hash_signed hash x
| Ext (x, a, b, c) ->
hash_signed hash x + 19 * (a + b + c)

let negate_simple_term_aux = function
| Cte b -> Cte (not b)
let negate_simple_term_aux sz = function
| Cte b -> Cte (Z.extract (Z.lognot b) 0 sz)
| Other o -> Other (negate_signed o)
| Ext (o, sz, i, j) -> Ext (negate_signed o, sz, i, j)

Expand All @@ -261,37 +258,20 @@ let compare_simple_term cmp = compare_alpha_term (compare_simple_term_aux cmp)

let hash_simple_term hash = hash_alpha_term (hash_simple_term_aux hash)

let negate_simple_term st = { st with bv = negate_simple_term_aux st.bv }
let negate_simple_term st = { st with bv = negate_simple_term_aux st.sz st.bv }

type 'a abstract = 'a simple_term list

let rec to_Z_opt_aux acc = function
| [] -> Some acc
| { bv = Cte false; sz } :: sts ->
to_Z_opt_aux Z.(acc lsl sz) sts
| { bv = Cte true; sz } :: sts ->
to_Z_opt_aux Z.((acc lsl sz) + (~$1 lsl sz) - ~$1) sts
| { bv = Cte n; sz } :: sts ->
to_Z_opt_aux Z.((acc lsl sz) + n) sts
| _ -> None

let to_Z_opt r = to_Z_opt_aux Z.zero r

let int2bv_const n z =
(* If [z] is out of the [0 .. 2^n] range (including if [z] is negative),
considering only the first [n] bits is equivalent to computing [z mod 2^n],
so we just do that and don't bother computing the modulus. *)
let acc = ref [] in
for i = 0 to n - 1 do
match Z.testbit z i, !acc with
| false, { bv = Cte false; sz } :: rst ->
acc := { bv = Cte false; sz = sz + 1 } :: rst
| false, rst ->
acc := { bv = Cte false; sz = 1 } :: rst
| true, { bv = Cte true; sz } :: rst ->
acc := { bv = Cte true; sz = sz + 1 } :: rst
| true, rst ->
acc := { bv = Cte true; sz = 1 } :: rst
done;
!acc
[ { bv = Cte (Z.extract z 0 n) ; sz = n } ]

let equal_abstract eq = Stdcompat.List.equal (equal_simple_term eq)

Expand Down Expand Up @@ -354,7 +334,8 @@ module Shostak(X : ALIEN) = struct

let view t =
match E.term_view t with
| { f = Bitv (_, s); ty = Tbitv size; _ } -> { descr = Vcte s; size }
| { f = Bitv (_, s); ty = Tbitv size; _ } ->
{ descr = Vcte s; size }
| { f = Op Concat; xs = [ t1; t2 ]; ty = Tbitv size; _ } ->
{ descr = Vconcat (t1, t2); size }
| { f = Op Extract (i, j); xs = [ t' ]; ty = Tbitv size; _ } ->
Expand Down Expand Up @@ -386,9 +367,9 @@ module Shostak(X : ALIEN) = struct
let bv = embed r in
if neg then negate_abstract bv, ctx else bv, ctx

let extract_st i j ({ bv; sz } as st) =
let extract_st i j { bv; sz } =
match bv with
| Cte _ -> [{ st with sz = j - i + 1 }]
| Cte b -> [{ bv = Cte (Z.extract b i (j - i + 1)); sz = j - i + 1 }]
| Other r -> [{ bv = Ext (r, sz, i, j) ; sz = j - i + 1 }]
| Ext (r, sz, k, _) ->
[{ bv = Ext (r, sz, i + k, j + k) ; sz = j - i + 1 }]
Expand Down Expand Up @@ -424,8 +405,9 @@ module Shostak(X : ALIEN) = struct
| [s] -> [normalize_st s]
| s :: (t :: ts as tts) ->
begin match s.bv, t.bv with
| Cte bs, Cte bt when Bool.equal bs bt ->
normalize ({ bv = Cte bs; sz = s.sz + t.sz } :: ts)
| Cte bs, Cte bt ->
normalize @@
{ bv = Cte Z.(bs lsl t.sz + bt); sz = s.sz + t.sz } :: ts
| Ext (d1, ds, i, j), Ext (d2, _, k, l)
when equal_signed X.equal d1 d2 && l = i - 1 ->
let d = { bv = Ext (d1, ds, k, j); sz = s.sz + t.sz } in
Expand Down Expand Up @@ -487,7 +469,10 @@ module Shostak(X : ALIEN) = struct
let print fmt ast =
let open Format in
match ast.bv with
| Cte b -> fprintf fmt "%d[%d]" (if b then 1 else 0) ast.sz
| Cte b ->
fprintf fmt "%s[%d]"
(Z.format ("%0" ^ string_of_int ast.sz ^ "b") b)
ast.sz
| Other t -> fprintf fmt "%a[%d]" (pp_signed X.print) t ast.sz
| Ext (t,sz,i,j) ->
fprintf fmt "%a[%d]" (pp_signed X.print) t sz;
Expand Down Expand Up @@ -570,7 +555,9 @@ module Shostak(X : ALIEN) = struct
- [y] is [[b(siz); ..; bn]] *)
let st_slice st siz =
let siz_bis = st.sz - siz in match st.bv with
|Cte _ -> {st with sz = siz},{st with sz = siz_bis}
|Cte b ->
{bv = Cte (Z.extract b siz_bis siz); sz = siz},
{bv = Cte (Z.extract b 0 siz_bis) ; sz = siz_bis}
|Other x ->
let s1 = Ext(x,st.sz, siz_bis, st.sz - 1) in
let s2 = Ext(x,st.sz, 0, siz_bis - 1) in
Expand Down Expand Up @@ -610,14 +597,14 @@ module Shostak(X : ALIEN) = struct
end
in f_rec [] (t,u)

(* Orient the equality [b = r] where [b] is a boolean constant and [r] is
(* Orient the equality [b = r] where [b] is a bitvector constant and [r] is
an uninterpreted ("Other") term, possibly negated. *)
let cte_vs_other bol r sz =
let bol = if r.negated then not bol else bol in
{ bv = r.value; sz } , [{bv = s_cte bol ; sz }]
let bol = if r.negated then Z.lognot bol else bol in
{ bv = r.value; sz } , [{bv = s_cte bol sz ; sz }]

(* Orient the equality [b = xt[s_xt]^{i,j}] where [b] is a boolean constant
and [xt] is uninterpreted of size [s_xt], possibly negated.
(* Orient the equality [b = xt[s_xt]^{i,j}] where [b] is a bitvector
constant and [xt] is uninterpreted of size [s_xt], possibly negated.
We introduce two A-variables [a1[i]] and [a2[s_xt-1-j]] and orient:
Expand All @@ -636,8 +623,9 @@ module Shostak(X : ALIEN) = struct
let cte_vs_ext bol xt s_xt i j =
let a1 = fresh_bitv A i in
let a2 = fresh_bitv A (s_xt - 1 - j) in
let bol = if xt.negated then not bol else bol in
let cte = [ {bv = s_cte bol ; sz =j - i + 1 } ] in
let b_sz = j - i + 1 in
let bol = if xt.negated then Z.lognot bol else bol in
let cte = [ {bv = s_cte bol b_sz ; sz = b_sz } ] in
let var = { bv = xt.value ; sz = s_xt }
in var, a2@cte@a1

Expand Down Expand Up @@ -820,7 +808,9 @@ module Shostak(X : ALIEN) = struct
*)
let sys_solve sys =
let c_solve (st1,st2) = match st1.bv,st2.bv with
|Cte _, Cte _ -> raise Util.Unsolvable (* forcement un 1 et un 0 *)
|Cte b1, Cte b2 ->
assert (not (Z.equal b1 b2));
raise Util.Unsolvable (* forcement distincts *)

|Cte b, Other r -> [cte_vs_other b r st2.sz]
|Other r, Cte b -> [cte_vs_other b r st1.sz]
Expand Down Expand Up @@ -879,21 +869,24 @@ module Shostak(X : ALIEN) = struct
let slice_var var pat_hd pat_tl =
let mk, tr =
match var.bv.defn with
| Dcte _ -> (fun sz -> { var with sz }), None
| Dcte (n, _) ->
(fun ofs sz ->
{ bv = s_cte (Z.extract n (ofs - sz) sz) sz ; sz }
), None
| Droot { sorte; _ } ->
(fun sz -> { bv = s_var sorte; sz }), Some sorte
(fun _ sz -> { bv = s_var sorte; sz }), Some sorte
| Dlink _ -> assert false
in
let rec aux cnt plist =
match plist with
| [] -> [], []
| h :: t when cnt < h -> [ mk cnt ], (h - cnt) :: t
| h :: t when cnt = h -> [ mk cnt ], t
| h :: t when cnt < h -> [ mk cnt cnt ], (h - cnt) :: t
| h :: t when cnt = h -> [ mk cnt cnt ], t
| h :: t ->
let vl, ptail = aux (cnt - h) t in
mk h :: vl, ptail
mk cnt h :: vl, ptail
in
let fst_v = mk pat_hd in
let fst_v = mk var.sz pat_hd in
let cnt = var.sz - pat_hd in
let vl, pat_tail = aux cnt pat_tl in
fst_v :: vl, pat_tail, tr
Expand Down Expand Up @@ -1176,7 +1169,7 @@ module Shostak(X : ALIEN) = struct
let get_rep var =
match (find var.bv).defn with
| Dlink _ -> assert false
| Dcte b -> Cte b
| Dcte (n, _) -> Cte n
| Droot { id; _ } ->
assert (id <> 0);
match Hashtbl.find vars (abs id) with
Expand All @@ -1198,8 +1191,8 @@ module Shostak(X : ALIEN) = struct
|a::b::r ->
begin
match a.bv,b.bv with
| Cte b1, Cte b2 when Bool.equal b1 b2 ->
cnf_max ({ b with sz = a.sz + b.sz }::r)
| Cte b1, Cte b2 ->
cnf_max ({ bv = Cte Z.(b1 lsl b.sz + b2) ; sz = a.sz + b.sz }::r)
| _, Cte _ -> a::(cnf_max (b::r))
| _ -> a::b::(cnf_max r)
end
Expand Down Expand Up @@ -1296,8 +1289,7 @@ module Shostak(X : ALIEN) = struct
expression. *)
let simple_term_to_nat acc st =
match st.bv with
| Cte false -> E.Ints.(acc * ~$$Z.(~$1 lsl st.sz))
| Cte true -> E.Ints.((acc + ~$1) * ~$$Z.(~$1 lsl st.sz) - ~$1)
| Cte n -> E.Ints.(acc * ~$$Z.(~$1 lsl st.sz) + ~$$n)
| Other r ->
let t = term_extract r.value in
let t = if r.negated then E.BV.bvnot t else t in
Expand Down Expand Up @@ -1451,11 +1443,17 @@ module Shostak(X : ALIEN) = struct
let solve r1 r2 pb =
Sig.{pb with sbt = List.rev_append (solve_bis r1 r2) pb.sbt}

(* Pop the first bit, raises [Not_found] if there is no first bit *)
(* Pop the first bit, raises [Not_found] if there is no first bit.
Note that the returned bv has an incorrect size. *)
let pop_bit = function
| [] -> raise Not_found
| { bv = Cte b; sz } as st :: rst ->
Some b, if sz > 1 then { st with sz = sz - 1 } :: rst else rst
| ({ bv = Cte n; sz } as st) :: rst ->
Some (Z.testbit n (sz - 1)),
if sz > 1 then
{ st with sz = sz - 1 } :: rst
else
rst
| { bv = Other _ | Ext _ ; sz } as st :: rst ->
None, if sz > 1 then { st with sz = sz - 1 } :: rst else rst

Expand Down Expand Up @@ -1552,7 +1550,8 @@ module Shostak(X : ALIEN) = struct
in
let s =
List.map (function
| { bv = Cte b; sz } -> String.make sz (if b then '1' else '0')
| { bv = Cte b; sz } ->
Z.format ("%0" ^ string_of_int sz ^ "b") b
| _ ->
(* Cannot happen because [a] must satisfy [is_cte_abstract] at this
point. *)
Expand Down
Loading

0 comments on commit c9865f4

Please sign in to comment.