Skip to content

Commit

Permalink
Merge pull request cryspen#872 from hacspec/pre_post_impl_blocks
Browse files Browse the repository at this point in the history
Pre post impl blocks
  • Loading branch information
W95Psp authored Sep 3, 2024
2 parents c051292 + 8923a70 commit fd62ddd
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 97 deletions.
8 changes: 8 additions & 0 deletions .utils/expand.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/usr/bin/env bash

# This script expands a crate so that one can inspect macro expansion
# by hax. It is a wrapper around `cargo expand` that inject the
# required rustc flags.

RUSTFLAGS='-Zcrate-attr=register_tool(_hax) -Zcrate-attr=feature(register_tool) --cfg hax_compilation --cfg _hax --cfg hax --cfg hax_backend_fstar --cfg hax' cargo expand "$@"

53 changes: 0 additions & 53 deletions engine/lib/attr_payloads.ml
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,6 @@ module Make (F : Features.T) (Error : Phase_utils.ERROR) = struct
val expect_expr :
?keep_last_args:int -> generics * param list * expr -> expr

val associated_expr_rebinding :
span -> pat list -> AssocRole.t -> attrs -> expr option
(** Looks up an expression but takes care of rebinding free variables. *)

val associated_refinement_in_type :
span -> string list -> attrs -> expr option
(** For type, there is a special treatment. The name of fields are
Expand Down Expand Up @@ -277,55 +273,6 @@ module Make (F : Features.T) (Error : Phase_utils.ERROR) = struct
attrs -> expr list =
associated_fns role >> List.map ~f:(expect_expr ~keep_last_args)

let associated_expr_rebinding span (params : pat list) (role : AssocRole.t)
(attrs : attrs) : expr option =
let* _, original_params, body = associated_fn role attrs in
let original_params =
List.map ~f:(fun param -> param.pat) original_params
in
let vars_of_pat =
U.Reducers.collect_local_idents#visit_pat () >> Set.to_list
in
let original_vars = List.concat_map ~f:vars_of_pat original_params in
let target_vars = List.concat_map ~f:vars_of_pat params in
let mk_error_message prefix =
prefix ^ "\n" ^ "\n - original_vars: "
^ [%show: local_ident list] original_vars
^ "\n - target_vars: "
^ [%show: local_ident list] target_vars
^ "\n\n - original_params: "
^ [%show: pat list] original_params
^ "\n - params: "
^ [%show: pat list] params
in
let replacements =
List.zip_opt original_vars target_vars
|> Option.value_or_thunk ~default:(fun _ ->
let details =
mk_error_message
"associated_expr_rebinding: zip two lists of different \
lengths (original_vars and target_vars)"
in
Error.unimplemented ~details span)
in
let replacements =
match Map.of_alist (module Local_ident) replacements with
| `Ok replacements -> replacements
| `Duplicate_key key ->
let details =
mk_error_message
"associated_expr_rebinding: of_alist failed because `"
^ [%show: local_ident] key
^ "` is a duplicate key. Context: "
in
Error.unimplemented ~details span
in
Some
((U.Mappers.rename_local_idents (fun v ->
Map.find replacements v |> Option.value ~default:v))
#visit_expr
() body)

let associated_refinement_in_type span (free_variables : string list) :
attrs -> expr option =
associated_fn Refine
Expand Down
73 changes: 36 additions & 37 deletions engine/lib/phases/phase_traits_specs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -100,44 +100,43 @@ module Make (F : Features.T) =
match item.ii_v with
| IIFn { params = []; _ } -> []
| IIFn { body; params } ->
let out_ident =
U.fresh_local_ident_in
(U.Reducers.collect_local_idents#visit_impl_item ()
item
|> Set.to_list)
"out"
in
let params_pat =
List.map ~f:(fun param -> param.pat) params
in
let pat = U.make_var_pat out_ident body.typ body.span in
let typ = body.typ in
let out = { pat; typ; typ_span = None; attrs = [] } in
(* We always need to produce a pre and a post
condition implementation for each method in
the impl. *)
[
{
(mk "pre") with
ii_v =
IIFn
{
body =
Attrs.associated_expr_rebinding item.ii_span
params_pat Requires item.ii_attrs
|> Option.value ~default;
params;
};
};
{
(mk "post") with
ii_v =
IIFn
{
body =
Attrs.associated_expr_rebinding item.ii_span
(params_pat @ [ pat ]) Ensures item.ii_attrs
|> Option.value ~default;
params = params @ [ out ];
};
};
(let params, body =
match Attrs.associated_fn Requires item.ii_attrs with
| Some (_, params, body) -> (params, body)
| None -> (params, default)
in
{ (mk "pre") with ii_v = IIFn { body; params } });
(let params, body =
match Attrs.associated_fn Ensures item.ii_attrs with
| Some (_, params, body) -> (params, body)
| None ->
(* There is no explicit post-condition
on this method. We need to define a
trivial one. *)
(* Post-condition *always* an extra
argument in final position for the
output. *)
let out_ident =
U.fresh_local_ident_in
(U.Reducers.collect_local_idents
#visit_impl_item () item
|> Set.to_list)
"out"
in
let pat =
U.make_var_pat out_ident body.typ body.span
in
let typ = body.typ in
let out =
{ pat; typ; typ_span = None; attrs = [] }
in
(params @ [ out ], default)
in
{ (mk "post") with ii_v = IIFn { body; params } });
]
| IIType _ -> []
in
Expand Down
1 change: 1 addition & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
mkdir -p $out/bin
cp ${./.utils/rebuild.sh} $out/bin/rebuild
cp ${./.utils/list-names.sh} $out/bin/list-names
cp ${./.utils/expand.sh} $out/bin/expand-hax-macros
'';
};
packages = [
Expand Down
86 changes: 79 additions & 7 deletions test-harness/src/snapshots/toolchain__attributes into-fstar.snap
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ let impl: t_Operation t_ViaAdd =
(Rust_primitives.Hax.Int.from_machine x <: Hax_lib.Int.t_Int) <= (127 <: Hax_lib.Int.t_Int));
f_double_post
=
(fun (x: u8) (out: u8) ->
(fun (x: u8) (result: u8) ->
((Rust_primitives.Hax.Int.from_machine x <: Hax_lib.Int.t_Int) * (2 <: Hax_lib.Int.t_Int)
<:
Hax_lib.Int.t_Int) =
(Rust_primitives.Hax.Int.from_machine out <: Hax_lib.Int.t_Int));
(Rust_primitives.Hax.Int.from_machine result <: Hax_lib.Int.t_Int));
f_double = fun (x: u8) -> x +! x
}

Expand All @@ -205,11 +205,11 @@ let impl_1: t_Operation t_ViaMul =
(Rust_primitives.Hax.Int.from_machine x <: Hax_lib.Int.t_Int) <= (127 <: Hax_lib.Int.t_Int));
f_double_post
=
(fun (x: u8) (out: u8) ->
(fun (x: u8) (result: u8) ->
((Rust_primitives.Hax.Int.from_machine x <: Hax_lib.Int.t_Int) * (2 <: Hax_lib.Int.t_Int)
<:
Hax_lib.Int.t_Int) =
(Rust_primitives.Hax.Int.from_machine out <: Hax_lib.Int.t_Int));
(Rust_primitives.Hax.Int.from_machine result <: Hax_lib.Int.t_Int));
f_double = fun (x: u8) -> x *! 2uy
}
'''
Expand All @@ -225,7 +225,7 @@ type t_Foo = | Foo : u8 -> t_Foo
let impl: Core.Ops.Arith.t_Add t_Foo t_Foo =
{
f_Output = t_Foo;
f_add_pre = (fun (self: t_Foo) (rhs: t_Foo) -> self._0 <. (255uy -! rhs._0 <: u8));
f_add_pre = (fun (self___: t_Foo) (rhs: t_Foo) -> self___._0 <. (255uy -! rhs._0 <: u8));
f_add_post = (fun (self: t_Foo) (rhs: t_Foo) (out: t_Foo) -> true);
f_add = fun (self: t_Foo) (rhs: t_Foo) -> Foo (self._0 +! rhs._0) <: t_Foo
}
Expand All @@ -236,7 +236,7 @@ let impl_1: Core.Ops.Arith.t_Mul t_Foo t_Foo =
f_Output = t_Foo;
f_mul_pre
=
(fun (self: t_Foo) (rhs: t_Foo) -> rhs._0 =. 0uy || self._0 <. (255uy /! rhs._0 <: u8));
(fun (self___: t_Foo) (rhs: t_Foo) -> rhs._0 =. 0uy || self___._0 <. (255uy /! rhs._0 <: u8));
f_mul_post = (fun (self: t_Foo) (rhs: t_Foo) (out: t_Foo) -> true);
f_mul = fun (self: t_Foo) (rhs: t_Foo) -> Foo (self._0 *! rhs._0) <: t_Foo
}
Expand Down Expand Up @@ -277,7 +277,7 @@ let mutation_example
let impl: Core.Ops.Index.t_Index t_MyArray usize =
{
f_Output = u8;
f_index_pre = (fun (self: t_MyArray) (index: usize) -> index <. v_MAX);
f_index_pre = (fun (self___: t_MyArray) (index: usize) -> index <. v_MAX);
f_index_post = (fun (self: t_MyArray) (index: usize) (out: u8) -> true);
f_index = fun (self: t_MyArray) (index: usize) -> self.[ index ]
}
Expand Down Expand Up @@ -339,6 +339,78 @@ let double (x: u8) : Prims.Pure t_Even (requires x <. 127uy) (fun _ -> Prims.l_T
let double_refine (x: u8) : Prims.Pure t_Even (requires x <. 127uy) (fun _ -> Prims.l_True) =
x +! x <: t_Even
'''
"Attributes.Requires_mut.fst" = '''
module Attributes.Requires_mut
#set-options "--fuel 0 --ifuel 1 --z3rlimit 15"
open Core
open FStar.Mul

class t_Foo (v_Self: Type0) = {
f_f_pre:x: u8 -> y: u8
-> pred:
Type0
{ ((Rust_primitives.Hax.Int.from_machine x <: Hax_lib.Int.t_Int) +
(Rust_primitives.Hax.Int.from_machine y <: Hax_lib.Int.t_Int)
<:
Hax_lib.Int.t_Int) <
(254 <: Hax_lib.Int.t_Int) ==>
pred };
f_f_post:x: u8 -> y: u8 -> x1: (u8 & u8)
-> pred:
Type0
{ pred ==>
(let y_future, output_variable:(u8 & u8) = x1 in
output_variable =. y_future) };
f_f:x0: u8 -> x1: u8 -> Prims.Pure (u8 & u8) (f_f_pre x0 x1) (fun result -> f_f_post x0 x1 result);
f_g_pre:u8 -> u8 -> Type0;
f_g_post:u8 -> u8 -> u8 -> Type0;
f_g:x0: u8 -> x1: u8 -> Prims.Pure u8 (f_g_pre x0 x1) (fun result -> f_g_post x0 x1 result);
f_h_pre:u8 -> u8 -> Type0;
f_h_post:u8 -> u8 -> Prims.unit -> Type0;
f_h:x0: u8 -> x1: u8
-> Prims.Pure Prims.unit (f_h_pre x0 x1) (fun result -> f_h_post x0 x1 result);
f_i_pre:u8 -> u8 -> Type0;
f_i_post:u8 -> u8 -> u8 -> Type0;
f_i:x0: u8 -> x1: u8 -> Prims.Pure u8 (f_i_pre x0 x1) (fun result -> f_i_post x0 x1 result)
}

[@@ FStar.Tactics.Typeclasses.tcinstance]
let impl: t_Foo Prims.unit =
{
f_f_pre
=
(fun (x: u8) (y: u8) ->
((Rust_primitives.Hax.Int.from_machine x <: Hax_lib.Int.t_Int) +
(Rust_primitives.Hax.Int.from_machine y <: Hax_lib.Int.t_Int)
<:
Hax_lib.Int.t_Int) <
(254 <: Hax_lib.Int.t_Int));
f_f_post
=
(fun (x: u8) (y: u8) (y_future, output_variable: (u8 & u8)) -> output_variable =. y_future);
f_f
=
(fun (x: u8) (y: u8) ->
let y:u8 = y +! x in
let hax_temp_output:u8 = y in
y, hax_temp_output <: (u8 & u8));
f_g_pre = (fun (x: u8) (y: u8) -> true);
f_g_post = (fun (x: u8) (y: u8) (output_variable: u8) -> output_variable =. y);
f_g = (fun (x: u8) (y: u8) -> y);
f_h_pre = (fun (x: u8) (y: u8) -> true);
f_h_post
=
(fun (x: u8) (y: u8) (output_variable: Prims.unit) -> output_variable =. (() <: Prims.unit));
f_h = (fun (x: u8) (y: u8) -> () <: Prims.unit);
f_i_pre = (fun (x: u8) (y: u8) -> true);
f_i_post = (fun (x: u8) (y: u8) (y_future: u8) -> y_future =. y);
f_i
=
fun (x: u8) (y: u8) ->
let hax_temp_output:Prims.unit = () <: Prims.unit in
y
}
'''
"Attributes.Verifcation_status.fst" = '''
module Attributes.Verifcation_status
#set-options "--fuel 0 --ifuel 1 --z3rlimit 15"
Expand Down
43 changes: 43 additions & 0 deletions tests/attributes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,46 @@ mod verifcation_status {
let still_not_much = not_much + nothing;
}
}

mod requires_mut {
use hax_lib::int::*;

#[hax_lib::attributes]
trait Foo {
#[hax_lib::requires(x.lift() + y.lift() < int!(254))]
#[hax_lib::ensures(|output_variable| output_variable == *future(y))]
fn f(x: u8, y: &mut u8) -> u8;

fn g(x: u8, y: u8) -> u8;
fn h(x: u8, y: u8);
fn i(x: u8, y: &mut u8);
}

#[hax_lib::attributes]
impl Foo for () {
#[hax_lib::requires(x.lift() + y.lift() < int!(254))]
#[hax_lib::ensures(|output_variable| output_variable == *future(y))]
fn f(x: u8, y: &mut u8) -> u8 {
*y += x;
*y
}

#[hax_lib::requires(true)]
#[hax_lib::ensures(|output_variable| output_variable == y)]
fn g(x: u8, y: u8) -> u8 {
y
}

#[hax_lib::requires(true)]
#[hax_lib::ensures(|output_variable| output_variable == ())]
fn h(x: u8, y: u8) {
()
}

#[hax_lib::requires(true)]
#[hax_lib::ensures(|out| *future(y) == *y)]
fn i(x: u8, y: &mut u8) {
()
}
}
}

0 comments on commit fd62ddd

Please sign in to comment.