diff --git a/.utils/expand.sh b/.utils/expand.sh new file mode 100755 index 000000000..6f0e5ea96 --- /dev/null +++ b/.utils/expand.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +# This script expands a crate so that one can inspect macro expansion +# by hax. It is a wrapper around `cargo expand` that inject the +# required rustc flags. + +RUSTFLAGS='-Zcrate-attr=register_tool(_hax) -Zcrate-attr=feature(register_tool) --cfg hax_compilation --cfg _hax --cfg hax --cfg hax_backend_fstar --cfg hax' cargo expand "$@" + diff --git a/engine/lib/attr_payloads.ml b/engine/lib/attr_payloads.ml index 61cbcd855..a13d8bd85 100644 --- a/engine/lib/attr_payloads.ml +++ b/engine/lib/attr_payloads.ml @@ -161,10 +161,6 @@ module Make (F : Features.T) (Error : Phase_utils.ERROR) = struct val expect_expr : ?keep_last_args:int -> generics * param list * expr -> expr - val associated_expr_rebinding : - span -> pat list -> AssocRole.t -> attrs -> expr option - (** Looks up an expression but takes care of rebinding free variables. *) - val associated_refinement_in_type : span -> string list -> attrs -> expr option (** For type, there is a special treatment. The name of fields are @@ -277,55 +273,6 @@ module Make (F : Features.T) (Error : Phase_utils.ERROR) = struct attrs -> expr list = associated_fns role >> List.map ~f:(expect_expr ~keep_last_args) - let associated_expr_rebinding span (params : pat list) (role : AssocRole.t) - (attrs : attrs) : expr option = - let* _, original_params, body = associated_fn role attrs in - let original_params = - List.map ~f:(fun param -> param.pat) original_params - in - let vars_of_pat = - U.Reducers.collect_local_idents#visit_pat () >> Set.to_list - in - let original_vars = List.concat_map ~f:vars_of_pat original_params in - let target_vars = List.concat_map ~f:vars_of_pat params in - let mk_error_message prefix = - prefix ^ "\n" ^ "\n - original_vars: " - ^ [%show: local_ident list] original_vars - ^ "\n - target_vars: " - ^ [%show: local_ident list] target_vars - ^ "\n\n - original_params: " - ^ [%show: pat list] original_params - ^ "\n - params: " - ^ [%show: pat list] params - in - let replacements = - List.zip_opt original_vars target_vars - |> Option.value_or_thunk ~default:(fun _ -> - let details = - mk_error_message - "associated_expr_rebinding: zip two lists of different \ - lengths (original_vars and target_vars)" - in - Error.unimplemented ~details span) - in - let replacements = - match Map.of_alist (module Local_ident) replacements with - | `Ok replacements -> replacements - | `Duplicate_key key -> - let details = - mk_error_message - "associated_expr_rebinding: of_alist failed because `" - ^ [%show: local_ident] key - ^ "` is a duplicate key. Context: " - in - Error.unimplemented ~details span - in - Some - ((U.Mappers.rename_local_idents (fun v -> - Map.find replacements v |> Option.value ~default:v)) - #visit_expr - () body) - let associated_refinement_in_type span (free_variables : string list) : attrs -> expr option = associated_fn Refine diff --git a/engine/lib/phases/phase_traits_specs.ml b/engine/lib/phases/phase_traits_specs.ml index 05de6178e..2d2aa087f 100644 --- a/engine/lib/phases/phase_traits_specs.ml +++ b/engine/lib/phases/phase_traits_specs.ml @@ -100,44 +100,43 @@ module Make (F : Features.T) = match item.ii_v with | IIFn { params = []; _ } -> [] | IIFn { body; params } -> - let out_ident = - U.fresh_local_ident_in - (U.Reducers.collect_local_idents#visit_impl_item () - item - |> Set.to_list) - "out" - in - let params_pat = - List.map ~f:(fun param -> param.pat) params - in - let pat = U.make_var_pat out_ident body.typ body.span in - let typ = body.typ in - let out = { pat; typ; typ_span = None; attrs = [] } in + (* We always need to produce a pre and a post + condition implementation for each method in + the impl. *) [ - { - (mk "pre") with - ii_v = - IIFn - { - body = - Attrs.associated_expr_rebinding item.ii_span - params_pat Requires item.ii_attrs - |> Option.value ~default; - params; - }; - }; - { - (mk "post") with - ii_v = - IIFn - { - body = - Attrs.associated_expr_rebinding item.ii_span - (params_pat @ [ pat ]) Ensures item.ii_attrs - |> Option.value ~default; - params = params @ [ out ]; - }; - }; + (let params, body = + match Attrs.associated_fn Requires item.ii_attrs with + | Some (_, params, body) -> (params, body) + | None -> (params, default) + in + { (mk "pre") with ii_v = IIFn { body; params } }); + (let params, body = + match Attrs.associated_fn Ensures item.ii_attrs with + | Some (_, params, body) -> (params, body) + | None -> + (* There is no explicit post-condition + on this method. We need to define a + trivial one. *) + (* Post-condition *always* an extra + argument in final position for the + output. *) + let out_ident = + U.fresh_local_ident_in + (U.Reducers.collect_local_idents + #visit_impl_item () item + |> Set.to_list) + "out" + in + let pat = + U.make_var_pat out_ident body.typ body.span + in + let typ = body.typ in + let out = + { pat; typ; typ_span = None; attrs = [] } + in + (params @ [ out ], default) + in + { (mk "post") with ii_v = IIFn { body; params } }); ] | IIType _ -> [] in diff --git a/flake.nix b/flake.nix index 19791054d..f51767ff4 100644 --- a/flake.nix +++ b/flake.nix @@ -159,6 +159,7 @@ mkdir -p $out/bin cp ${./.utils/rebuild.sh} $out/bin/rebuild cp ${./.utils/list-names.sh} $out/bin/list-names + cp ${./.utils/expand.sh} $out/bin/expand-hax-macros ''; }; packages = [ diff --git a/test-harness/src/snapshots/toolchain__attributes into-fstar.snap b/test-harness/src/snapshots/toolchain__attributes into-fstar.snap index 6c30e8a40..d81d020ab 100644 --- a/test-harness/src/snapshots/toolchain__attributes into-fstar.snap +++ b/test-harness/src/snapshots/toolchain__attributes into-fstar.snap @@ -186,11 +186,11 @@ let impl: t_Operation t_ViaAdd = (Rust_primitives.Hax.Int.from_machine x <: Hax_lib.Int.t_Int) <= (127 <: Hax_lib.Int.t_Int)); f_double_post = - (fun (x: u8) (out: u8) -> + (fun (x: u8) (result: u8) -> ((Rust_primitives.Hax.Int.from_machine x <: Hax_lib.Int.t_Int) * (2 <: Hax_lib.Int.t_Int) <: Hax_lib.Int.t_Int) = - (Rust_primitives.Hax.Int.from_machine out <: Hax_lib.Int.t_Int)); + (Rust_primitives.Hax.Int.from_machine result <: Hax_lib.Int.t_Int)); f_double = fun (x: u8) -> x +! x } @@ -205,11 +205,11 @@ let impl_1: t_Operation t_ViaMul = (Rust_primitives.Hax.Int.from_machine x <: Hax_lib.Int.t_Int) <= (127 <: Hax_lib.Int.t_Int)); f_double_post = - (fun (x: u8) (out: u8) -> + (fun (x: u8) (result: u8) -> ((Rust_primitives.Hax.Int.from_machine x <: Hax_lib.Int.t_Int) * (2 <: Hax_lib.Int.t_Int) <: Hax_lib.Int.t_Int) = - (Rust_primitives.Hax.Int.from_machine out <: Hax_lib.Int.t_Int)); + (Rust_primitives.Hax.Int.from_machine result <: Hax_lib.Int.t_Int)); f_double = fun (x: u8) -> x *! 2uy } ''' @@ -225,7 +225,7 @@ type t_Foo = | Foo : u8 -> t_Foo let impl: Core.Ops.Arith.t_Add t_Foo t_Foo = { f_Output = t_Foo; - f_add_pre = (fun (self: t_Foo) (rhs: t_Foo) -> self._0 <. (255uy -! rhs._0 <: u8)); + f_add_pre = (fun (self___: t_Foo) (rhs: t_Foo) -> self___._0 <. (255uy -! rhs._0 <: u8)); f_add_post = (fun (self: t_Foo) (rhs: t_Foo) (out: t_Foo) -> true); f_add = fun (self: t_Foo) (rhs: t_Foo) -> Foo (self._0 +! rhs._0) <: t_Foo } @@ -236,7 +236,7 @@ let impl_1: Core.Ops.Arith.t_Mul t_Foo t_Foo = f_Output = t_Foo; f_mul_pre = - (fun (self: t_Foo) (rhs: t_Foo) -> rhs._0 =. 0uy || self._0 <. (255uy /! rhs._0 <: u8)); + (fun (self___: t_Foo) (rhs: t_Foo) -> rhs._0 =. 0uy || self___._0 <. (255uy /! rhs._0 <: u8)); f_mul_post = (fun (self: t_Foo) (rhs: t_Foo) (out: t_Foo) -> true); f_mul = fun (self: t_Foo) (rhs: t_Foo) -> Foo (self._0 *! rhs._0) <: t_Foo } @@ -277,7 +277,7 @@ let mutation_example let impl: Core.Ops.Index.t_Index t_MyArray usize = { f_Output = u8; - f_index_pre = (fun (self: t_MyArray) (index: usize) -> index <. v_MAX); + f_index_pre = (fun (self___: t_MyArray) (index: usize) -> index <. v_MAX); f_index_post = (fun (self: t_MyArray) (index: usize) (out: u8) -> true); f_index = fun (self: t_MyArray) (index: usize) -> self.[ index ] } @@ -339,6 +339,78 @@ let double (x: u8) : Prims.Pure t_Even (requires x <. 127uy) (fun _ -> Prims.l_T let double_refine (x: u8) : Prims.Pure t_Even (requires x <. 127uy) (fun _ -> Prims.l_True) = x +! x <: t_Even ''' +"Attributes.Requires_mut.fst" = ''' +module Attributes.Requires_mut +#set-options "--fuel 0 --ifuel 1 --z3rlimit 15" +open Core +open FStar.Mul + +class t_Foo (v_Self: Type0) = { + f_f_pre:x: u8 -> y: u8 + -> pred: + Type0 + { ((Rust_primitives.Hax.Int.from_machine x <: Hax_lib.Int.t_Int) + + (Rust_primitives.Hax.Int.from_machine y <: Hax_lib.Int.t_Int) + <: + Hax_lib.Int.t_Int) < + (254 <: Hax_lib.Int.t_Int) ==> + pred }; + f_f_post:x: u8 -> y: u8 -> x1: (u8 & u8) + -> pred: + Type0 + { pred ==> + (let y_future, output_variable:(u8 & u8) = x1 in + output_variable =. y_future) }; + f_f:x0: u8 -> x1: u8 -> Prims.Pure (u8 & u8) (f_f_pre x0 x1) (fun result -> f_f_post x0 x1 result); + f_g_pre:u8 -> u8 -> Type0; + f_g_post:u8 -> u8 -> u8 -> Type0; + f_g:x0: u8 -> x1: u8 -> Prims.Pure u8 (f_g_pre x0 x1) (fun result -> f_g_post x0 x1 result); + f_h_pre:u8 -> u8 -> Type0; + f_h_post:u8 -> u8 -> Prims.unit -> Type0; + f_h:x0: u8 -> x1: u8 + -> Prims.Pure Prims.unit (f_h_pre x0 x1) (fun result -> f_h_post x0 x1 result); + f_i_pre:u8 -> u8 -> Type0; + f_i_post:u8 -> u8 -> u8 -> Type0; + f_i:x0: u8 -> x1: u8 -> Prims.Pure u8 (f_i_pre x0 x1) (fun result -> f_i_post x0 x1 result) +} + +[@@ FStar.Tactics.Typeclasses.tcinstance] +let impl: t_Foo Prims.unit = + { + f_f_pre + = + (fun (x: u8) (y: u8) -> + ((Rust_primitives.Hax.Int.from_machine x <: Hax_lib.Int.t_Int) + + (Rust_primitives.Hax.Int.from_machine y <: Hax_lib.Int.t_Int) + <: + Hax_lib.Int.t_Int) < + (254 <: Hax_lib.Int.t_Int)); + f_f_post + = + (fun (x: u8) (y: u8) (y_future, output_variable: (u8 & u8)) -> output_variable =. y_future); + f_f + = + (fun (x: u8) (y: u8) -> + let y:u8 = y +! x in + let hax_temp_output:u8 = y in + y, hax_temp_output <: (u8 & u8)); + f_g_pre = (fun (x: u8) (y: u8) -> true); + f_g_post = (fun (x: u8) (y: u8) (output_variable: u8) -> output_variable =. y); + f_g = (fun (x: u8) (y: u8) -> y); + f_h_pre = (fun (x: u8) (y: u8) -> true); + f_h_post + = + (fun (x: u8) (y: u8) (output_variable: Prims.unit) -> output_variable =. (() <: Prims.unit)); + f_h = (fun (x: u8) (y: u8) -> () <: Prims.unit); + f_i_pre = (fun (x: u8) (y: u8) -> true); + f_i_post = (fun (x: u8) (y: u8) (y_future: u8) -> y_future =. y); + f_i + = + fun (x: u8) (y: u8) -> + let hax_temp_output:Prims.unit = () <: Prims.unit in + y + } +''' "Attributes.Verifcation_status.fst" = ''' module Attributes.Verifcation_status #set-options "--fuel 0 --ifuel 1 --z3rlimit 15" diff --git a/tests/attributes/src/lib.rs b/tests/attributes/src/lib.rs index 181057d2f..c0deae940 100644 --- a/tests/attributes/src/lib.rs +++ b/tests/attributes/src/lib.rs @@ -338,3 +338,46 @@ mod verifcation_status { let still_not_much = not_much + nothing; } } + +mod requires_mut { + use hax_lib::int::*; + + #[hax_lib::attributes] + trait Foo { + #[hax_lib::requires(x.lift() + y.lift() < int!(254))] + #[hax_lib::ensures(|output_variable| output_variable == *future(y))] + fn f(x: u8, y: &mut u8) -> u8; + + fn g(x: u8, y: u8) -> u8; + fn h(x: u8, y: u8); + fn i(x: u8, y: &mut u8); + } + + #[hax_lib::attributes] + impl Foo for () { + #[hax_lib::requires(x.lift() + y.lift() < int!(254))] + #[hax_lib::ensures(|output_variable| output_variable == *future(y))] + fn f(x: u8, y: &mut u8) -> u8 { + *y += x; + *y + } + + #[hax_lib::requires(true)] + #[hax_lib::ensures(|output_variable| output_variable == y)] + fn g(x: u8, y: u8) -> u8 { + y + } + + #[hax_lib::requires(true)] + #[hax_lib::ensures(|output_variable| output_variable == ())] + fn h(x: u8, y: u8) { + () + } + + #[hax_lib::requires(true)] + #[hax_lib::ensures(|out| *future(y) == *y)] + fn i(x: u8, y: &mut u8) { + () + } + } +}