Skip to content

Commit

Permalink
Merge pull request #1301 from cryspen/prop-predicates
Browse files Browse the repository at this point in the history
`hax-lib`: introduce a `Prop` abstraction
  • Loading branch information
karthikbhargavan authored Feb 20, 2025
2 parents ce2ea4b + 908f653 commit b5321c1
Show file tree
Hide file tree
Showing 25 changed files with 833 additions and 406 deletions.
57 changes: 55 additions & 2 deletions engine/backends/fstar/fstar_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,15 @@ module FStarNamePolicy = struct

let anonymous_field_transform index = "_" ^ index

let reserved_words = Hash_set.of_list (module String) ["attributes";"noeq";"unopteq";"and";"assert";"assume";"begin";"by";"calc";"class";"default";"decreases";"effect";"eliminate";"else";"end";"ensures";"exception";"exists";"false";"friend";"forall";"fun";"λ";"function";"if";"in";"include";"inline";"inline_for_extraction";"instance";"introduce";"irreducible";"let";"logic";"match";"returns";"as";"module";"new";"new_effect";"layered_effect";"polymonadic_bind";"polymonadic_subcomp";"noextract";"of";"open";"opaque";"private";"quote";"range_of";"rec";"reifiable";"reify";"reflectable";"requires";"set_range_of";"sub_effect";"synth";"then";"total";"true";"try";"type";"unfold";"unfoldable";"val";"when";"with";"_";"__SOURCE_FILE__";"__LINE__";"match";"if";"let";"and";"string"]
let reserved_words = Hash_set.of_list (module String) ["attributes";"noeq";"unopteq";"and";"assert";"assume";"begin";"by";"calc";"class";"default";"decreases";"b2t";"effect";"eliminate";"else";"end";"ensures";"exception";"exists";"false";"friend";"forall";"fun";"λ";"function";"if";"in";"include";"inline";"inline_for_extraction";"instance";"introduce";"irreducible";"let";"logic";"match";"returns";"as";"module";"new";"new_effect";"layered_effect";"polymonadic_bind";"polymonadic_subcomp";"noextract";"of";"open";"opaque";"private";"quote";"range_of";"rec";"reifiable";"reify";"reflectable";"requires";"set_range_of";"sub_effect";"synth";"then";"total";"true";"try";"type";"unfold";"unfoldable";"val";"when";"with";"_";"__SOURCE_FILE__";"__LINE__";"match";"if";"let";"and";"string"]
end

module RenderId = Concrete_ident.MakeRenderAPI (FStarNamePolicy)
module U = Ast_utils.Make (InputLanguage)
module Visitors = Ast_visitors.Make (InputLanguage)
open AST
module F = Fstar_ast
module Destruct = Ast_destruct.Make (InputLanguage)

module Context = struct
type t = {
Expand Down Expand Up @@ -317,6 +318,12 @@ struct
(c Rust_primitives__hax__int__lt, (2, "<"));
(c Rust_primitives__hax__int__ne, (2, "<>"));
(c Rust_primitives__hax__int__eq, (2, "="));
(c Hax_lib__prop__constructors__and, (2, "/\\"));
(c Hax_lib__prop__constructors__or, (2, "\\/"));
(c Hax_lib__prop__constructors__not, (1, "~"));
(c Hax_lib__prop__constructors__eq, (2, "=="));
(c Hax_lib__prop__constructors__ne, (2, "=!="));
(c Hax_lib__prop__constructors__implies, (2, "==>"));
]
|> Map.of_alist_exn (module Global_ident)

Expand Down Expand Up @@ -511,6 +518,52 @@ struct
F.AST.unit_const F.dummyRange
| GlobalVar global_ident ->
F.term @@ F.AST.Var (pglobal_ident e.span @@ global_ident)
| App { f = { e = GlobalVar f; _ }; args = [ x ] }
when Global_ident.eq_name Hax_lib__prop__constructors__from_bool f ->
let x = pexpr x in
F.mk_e_app (F.term_of_lid [ "b2t" ]) [ x ]
| App
{
f = { e = GlobalVar f; _ };
args = [ { e = Closure { params = [ x ]; body = phi; _ }; _ } ];
}
when Global_ident.eq_name Hax_lib__prop__constructors__forall f ->
let phi = pexpr phi in
let binders =
let b = Destruct.pat_PBinding x |> Option.value_exn in
[
F.AST.
{
b = F.AST.Annotated (plocal_ident b.var, pty x.span b.typ);
brange = F.dummyRange;
blevel = Un;
aqual = None;
battributes = [];
};
]
in
F.term @@ F.AST.QForall (binders, ([], []), phi)
| App
{
f = { e = GlobalVar f; _ };
args = [ { e = Closure { params = [ x ]; body = phi; _ }; _ } ];
}
when Global_ident.eq_name Hax_lib__prop__constructors__exists f ->
let phi = pexpr phi in
let binders =
let b = Destruct.pat_PBinding x |> Option.value_exn in
[
F.AST.
{
b = F.AST.Annotated (plocal_ident b.var, pty x.span b.typ);
brange = F.dummyRange;
blevel = Un;
aqual = None;
battributes = [];
};
]
in
F.term @@ F.AST.QExists (binders, ([], []), phi)
| App
{
f = { e = GlobalVar (`Projector (`TupleField (n, len))) };
Expand All @@ -525,7 +578,7 @@ struct
let arity, op = Map.find_exn operators x in
if List.length args <> arity then
Error.assertion_failure e.span
"pexpr: bad arity for operator application";
("pexpr: bad arity for operator application (" ^ op ^ ")");
F.term @@ F.AST.Op (F.Ident.id_of_text op, List.map ~f:pexpr args)
| App
{
Expand Down
8 changes: 8 additions & 0 deletions engine/lib/ast_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,14 @@ module Make (F : Features.T) = struct
super#visit_expr' ascribe_app e

method! visit_expr (ascribe_app : bool) e =
let ascribe_app =
ascribe_app
&& not
(match e.typ with
| TApp { ident; _ } ->
Global_ident.eq_name Hax_lib__prop__Prop ident
| _ -> false)
in
let e = super#visit_expr ascribe_app e in
let ascribe (e : expr) =
if [%matches? Ascription _] e.e then e
Expand Down
208 changes: 169 additions & 39 deletions engine/lib/phases/phase_specialize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,57 @@ module Make (F : Features.T) =
open struct
open Concrete_ident_generated

module FnReplace = struct
type t =
span:Span.t ->
typ:ty ->
f:expr ->
args:expr list ->
generic_args:generic_value list ->
bounds_impls:impl_expr list ->
trait:(impl_expr * generic_value list) option ->
expr

(** Retype a function application: this concretize the types, using concrete types from arguments. *)
let retype (fn : t) : t =
fun ~span ~typ ~f ~args ~generic_args ~bounds_impls ~trait ->
let f =
let typ =
if List.is_empty args then f.typ
else TArrow (List.map ~f:(fun e -> e.typ) args, typ)
in
{ f with typ }
in
fn ~span ~typ ~f ~args ~generic_args ~bounds_impls ~trait

(** Gets rid of trait and impl informations. *)
let remove_traits (fn : t) : t =
fun ~span ~typ ~f ~args ~generic_args:_ ~bounds_impls:_ ~trait:_ ->
fn ~span ~typ ~f ~args ~generic_args:[] ~bounds_impls:[] ~trait:None

(** Monomorphize a function call: this removes any impl references, and concretize types. *)
let monorphic (fn : t) : t = remove_traits (retype fn)

let name name : t =
fun ~span ~typ ~f ~args ~generic_args ~bounds_impls ~trait ->
let name = Ast.Global_ident.of_name ~value:true name in
let f = { f with e = GlobalVar name } in
let e = App { args; f; generic_args; bounds_impls; trait } in
{ typ; span; e }

let and_then (f1 : t) (f2 : expr -> expr) : t =
fun ~span ~typ ~f ~args ~generic_args ~bounds_impls ~trait ->
f1 ~span ~typ ~f ~args ~generic_args ~bounds_impls ~trait |> f2

let map_args (fn : int -> expr -> expr) : t -> t =
fun g ~span ~typ ~f ~args ~generic_args ~bounds_impls ~trait ->
let args = List.mapi ~f:fn args in
g ~span ~typ ~f ~args ~generic_args ~bounds_impls ~trait
end

type pattern = {
fn : t;
fn_replace : t;
fn_replace : FnReplace.t;
args : (expr -> bool) list;
ret : ty -> bool;
}
Expand All @@ -29,12 +77,16 @@ module Make (F : Features.T) =
work with `_ -> _ option` so that we can chain them *)

(** Constructs a predicate out of predicates and names *)
let mk (args : ('a, 'b) predicate list) (ret : ('c, 'd) predicate)
(fn : t) (fn_replace : t) : pattern =
let mk' (args : ('a, 'b) predicate list) (ret : ('c, 'd) predicate)
(fn : t) (fn_replace : FnReplace.t) : pattern =
let args = List.map ~f:(fun p x -> p x |> Option.is_some) args in
let ret t = ret t |> Option.is_some in
{ fn; fn_replace; args; ret }

let mk (args : ('a, 'b) predicate list) (ret : ('c, 'd) predicate)
(fn : t) (fn_replace : t) : pattern =
mk' args ret fn (FnReplace.name fn_replace |> FnReplace.monorphic)

open struct
let etyp (e : expr) : ty = e.typ
let tref = function TRef { typ; _ } -> Some typ | _ -> None
Expand All @@ -56,9 +108,21 @@ module Make (F : Features.T) =

let erase : 'a. ('a, unit) predicate = fun _ -> Some ()

let ( ||. ) (type a b) (f : (a, b) predicate) (g : (a, b) predicate) :
(a, b) predicate =
fun x ->
match (f x, g x) with Some a, _ | _, Some a -> Some a | _ -> None

let is_int : (ty, unit) predicate =
tapp0 >>& eq_global_ident Hax_lib__int__Int >>& erase

let is_prop : (ty, unit) predicate =
tapp0 >>& eq_global_ident Hax_lib__prop__Prop >>& erase

let is_bool : (ty, unit) predicate = function
| TBool -> Some ()
| _ -> None

let any _ = Some ()
let int_any = mk [ etyp >> is_int ] any
let int_int_any = mk [ etyp >> is_int; etyp >> is_int ] any
Expand All @@ -69,10 +133,24 @@ module Make (F : Features.T) =
mk [ etyp >> (tref >>& is_int); etyp >> (tref >>& is_int) ] any

let any_rint = mk [ any ] (tref >>& is_int)
let bool_prop = mk [ etyp >> is_bool ] is_prop
let prop_bool = mk [ etyp >> is_prop ] is_bool

let arrow : (ty, ty list) predicate = function
| TArrow (ts, t) -> Some (ts @ [ t ])
| _ -> None

let a_to_b a b : _ predicate =
arrow >> fun x ->
let* t, u =
match x with Some [ a; b ] -> Some (a, b) | _ -> None
in
let* a = a t in
let* b = b u in
Some (a, b)
end

(** The list of replacements *)
let patterns =
let int_replacements =
[
int_int_any Core__ops__arith__Add__add
Rust_primitives__hax__int__add;
Expand All @@ -94,11 +172,72 @@ module Make (F : Features.T) =
Rust_primitives__hax__int__le;
rint_rint_any Core__cmp__PartialEq__ne Rust_primitives__hax__int__ne;
rint_rint_any Core__cmp__PartialEq__eq Rust_primitives__hax__int__eq;
any_rint Hax_lib__int__Abstraction__lift
any_int Hax_lib__abstraction__Abstraction__lift
Rust_primitives__hax__int__from_machine;
int_any Hax_lib__int__Concretization__concretize
any_int Hax_lib__int__ToInt__to_int
Rust_primitives__hax__int__from_machine;
int_any Hax_lib__abstraction__Concretization__concretize
Rust_primitives__hax__int__into_machine;
]

let prop_replacements =
let name_from_bool = Hax_lib__prop__constructors__from_bool in
let prop_type =
let ident =
Ast.Global_ident.of_name ~value:false Hax_lib__prop__Prop
in
TApp { ident; args = [] }
in
let bool_prop__from_bool f = bool_prop f name_from_bool in
let poly n f g =
let args =
let prop_or_bool = is_bool ||. is_prop in
List.init n ~f:(fun _ ->
etyp
>> (prop_or_bool
||. (a_to_b prop_or_bool prop_or_bool >> erase)))
in
let promote_bool (e : A.expr) =
match e.typ with
| TBool -> U.call name_from_bool [ e ] e.span prop_type
| _ -> e
in
mk' args is_prop f
(FnReplace.map_args
(fun _ e ->
let e = promote_bool e in
match e.e with
| Closure { params; body; captures } ->
let body = promote_bool body in
{ e with e = Closure { params; body; captures } }
| _ -> e)
(FnReplace.name g |> FnReplace.monorphic))
in
[
bool_prop__from_bool Hax_lib__abstraction__Abstraction__lift;
bool_prop__from_bool Hax_lib__prop__ToProp__to_prop;
bool_prop__from_bool Core__convert__Into__into;
bool_prop__from_bool Core__convert__From__from;
(* Transform inherent methods on Prop *)
poly 2 Hax_lib__prop__Impl__and Hax_lib__prop__constructors__and;
poly 2 Hax_lib__prop__Impl__or Hax_lib__prop__constructors__or;
poly 1 Hax_lib__prop__Impl__not Hax_lib__prop__constructors__not;
poly 2 Hax_lib__prop__Impl__eq Hax_lib__prop__constructors__eq;
poly 2 Hax_lib__prop__Impl__ne Hax_lib__prop__constructors__ne;
poly 2 Hax_lib__prop__Impl__implies
Hax_lib__prop__constructors__implies;
(* Transform standalone functions in `prop` *)
poly 2 Hax_lib__prop__implies Hax_lib__prop__constructors__implies;
poly 1 Hax_lib__prop__forall Hax_lib__prop__constructors__forall;
poly 1 Hax_lib__prop__exists Hax_lib__prop__constructors__exists;
(* Transform core `&`, `|`, `!` on `Prop` *)
poly 2 Core__ops__bit__BitAnd__bitand
Hax_lib__prop__constructors__and;
poly 2 Core__ops__bit__BitOr__bitor Hax_lib__prop__constructors__or;
poly 1 Core__ops__bit__Not__not Hax_lib__prop__constructors__not;
]

let replacements = List.concat [ int_replacements; prop_replacements ]
end

module Error = Phase_utils.MakeError (struct
Expand All @@ -123,45 +262,31 @@ module Make (F : Features.T) =
} -> (
let l = List.map ~f:(self#visit_expr ()) l in
let matching =
List.filter patterns ~f:(fun { fn; args; _ } ->
List.filter
(List.mapi ~f:(fun i x -> (i, x)) replacements)
~f:(fun (_, { fn; args; ret; fn_replace = _ }) ->
Ast.Global_ident.eq_name fn f
&& ret e.typ
&&
match List.for_all2 args l ~f:apply with
| Ok r -> r
| _ -> false)
in
match matching with
| [ { fn_replace; _ } ] ->
let f = Ast.Global_ident.of_name ~value:true fn_replace in
let f = { f' with e = GlobalVar f } in
{
e with
e =
App
{
f;
args = l;
trait = None;
generic_args = [];
bounds_impls = [];
};
}
| [ (_, { fn_replace; _ }) ] ->
let e =
fn_replace ~args:l ~typ:e.typ ~span:e.span ~generic_args
~bounds_impls ~trait ~f:f'
in
self#visit_expr () e
| [] -> (
(* In this case we need to avoid recursing again through the arguments *)
let visited =
super#visit_expr ()
{
e with
e =
App
{
f = f';
args = [];
trait;
generic_args;
bounds_impls;
};
}
let args = [] in
let e' =
App { f = f'; args; trait; generic_args; bounds_impls }
in
super#visit_expr () { e with e = e' }
in
match visited.e with
| App { f; trait; generic_args; bounds_impls; _ } ->
Expand All @@ -172,9 +297,14 @@ module Make (F : Features.T) =
{ f; args = l; trait; generic_args; bounds_impls };
}
| _ -> super#visit_expr () e)
| _ ->
Error.assertion_failure e.span
"Found multiple matching patterns")
| r ->
let msg =
"Found multiple matching patterns: "
^ [%show: int list] (List.map ~f:fst r)
in
Stdio.prerr_endline msg;
U.Debug.expr e;
Error.assertion_failure e.span msg)
| _ -> super#visit_expr () e
end

Expand Down
Loading

0 comments on commit b5321c1

Please sign in to comment.