From e335db0f5ddad94c38274efb60259d75e1d8babf Mon Sep 17 00:00:00 2001 From: Cameron Low Date: Tue, 14 Jan 2025 12:28:46 +0000 Subject: [PATCH] Module Tweaks: Allow for fine grain editing of existing modules. This commit introduces a new mechanism that permits the user to create a new module by slightly tweaking an existing module definition. It has the following operations: - Introduce new module variables. - Introduce new local variables. - Delete/Modify/Add statements at particular code positions - Delete branches (match support is not currently working fully) - Modify branch conditions - Insert new branches around a chunk of code - Modify the return expression Syntax: ``` module N = M with { var x : t (* add new module variable *) proc f [ var y : s (* add new local variable *) cp +/-/~ { s } (* insert after/insert before/modify a statement *) cp - (* delete a statement *) cp + ( e ) (* insert new if statement with condition `e` surrounding the suffix code block *) cp - ./?/#cstr (* delete all other branches except true/false/cstr *) ] res ~ ( e ) (* change the return expression *) } ``` --- examples/br93.ec | 63 +++---- src/ecLowPhlGoal.ml | 2 +- src/ecMatching.ml | 48 +++-- src/ecMatching.mli | 9 +- src/ecParser.mly | 27 +++ src/ecParsetree.ml | 64 ++++--- src/ecTyping.ml | 284 +++++++++++++++++++++++++++++- src/ecTyping.mli | 9 + src/ecUserMessages.ml | 18 ++ tests/fine-grained-module-defs.ec | 50 ++++++ 10 files changed, 479 insertions(+), 95 deletions(-) create mode 100644 tests/fine-grained-module-defs.ec diff --git a/examples/br93.ec b/examples/br93.ec index 91057cd134..8a7c0fb851 100644 --- a/examples/br93.ec +++ b/examples/br93.ec @@ -83,7 +83,7 @@ import H.Lazy. (* BR93 is a module that, given access to an oracle H from type *) (* `from` to type `rand` (see `print Oracle.`), implements procedures *) (* `keygen`, `enc` and `dec` as follows described below. *) -module BR93 (H:Oracle) = { +module BR93 (H : Oracle) = { (* `keygen` simply samples a key pair in `dkeys` *) proc keygen() = { var kp; @@ -183,14 +183,14 @@ qed. (* But we can't do it (yet) for IND-CPA because of the random oracle *) (* Instead, we define CPA for BR93 with that particular RO. *) -module type Adv (ARO: POracle) = { +module type Adv (ARO : POracle) = { proc a1(p:pkey): (ptxt * ptxt) proc a2(c:ctxt): bool }. (* We need to log the random oracle queries made to the adversary *) (* in order to express the final theorem. *) -module Log (H:Oracle) = { +module Log (H : Oracle) = { var qs: rand list proc init() = { @@ -251,23 +251,17 @@ declare axiom A_a1_ll (O <: POracle {-A}): islossless O.o => islossless A(O).a1. declare axiom A_a2_ll (O <: POracle {-A}): islossless O.o => islossless A(O).a2. (* Step 1: replace RO call with random sampling *) -local module Game1 = { - var r: rand - - proc main() = { - var pk, sk, m0, m1, b, h, c, b'; - Log(LRO).init(); - (pk,sk) <$ dkeys; - (m0,m1) <@ A(Log(LRO)).a1(pk); - b <$ {0,1}; - - r <$ drand; - h <$ dptxt; - c <- ((f pk r),h +^ (b?m0:m1)); - - b' <@ A(Log(LRO)).a2(c); - return b' = b; - } +local module Game1 = BR93_CPA(A) with { + var r : rand + + proc main [ + (* new local variable to store the sampled ptxt *) + var h : ptxt + (* inline key generation *) + ^ <@ {2} ~ { (pk, sk) <$ dkeys; } + (* inline challenge encryption and idealize RO call *) + ^ c<@ ~ { r <$ drand; h <$ dptxt; c <- (f pk r, h +^ (b ? m0 : m1)); } + ] }. local lemma pr_Game0_Game1 &m: @@ -327,23 +321,11 @@ by move=> _ rR aL mL aR qsR mR h /h [] ->. qed. (* Step 2: replace h ^ m with h in the challenge encryption *) -local module Game2 = { - var r: rand - - proc main() = { - var pk, sk, m0, m1, b, h, c, b'; - Log(LRO).init(); - (pk,sk) <$ dkeys; - (m0,m1) <@ A(Log(LRO)).a1(pk); - b <$ {0,1}; - - r <$ drand; - h <$ dptxt; - c <- ((f pk r),h); - - b' <@ A(Log(LRO)).a2(c); - return b' = b; - } +local module Game2 = Game1 with { + proc main [ + (* Challenge ciphertext is now produced uniformly at random *) + ^ c<- ~ { c <- (f pk r, h); } + ] }. local equiv eq_Game1_Game2: Game1.main ~ Game2.main: @@ -402,12 +384,12 @@ local module OWr (I : Inverter) = { (* We can easily prove that it is strictly equivalent to OW *) local lemma OW_OWr &m (I <: Inverter {-OWr}): - Pr[OW(I).main() @ &m: res] + Pr[OW(I).main() @ &m: res] = Pr[OWr(I).main() @ &m: res]. proof. by byequiv=> //=; sim. qed. local lemma pr_Game2_OW &m: - Pr[Game2.main() @ &m: Game2.r \in Log.qs] + Pr[Game2.main() @ &m: Game2.r \in Log.qs] <= Pr[OW(I(A)).main() @ &m: res]. proof. rewrite (OW_OWr &m (I(A))). (* Note: we proved it forall (abstract) I *) @@ -431,7 +413,7 @@ by auto=> /> [pk sk] ->. qed. lemma Reduction &m: - Pr[BR93_CPA(A).main() @ &m : res] - 1%r/2%r + Pr[BR93_CPA(A).main() @ &m : res] - 1%r/2%r <= Pr[OW(I(A)).main() @ &m: res]. proof. smt(pr_Game0_Game1 pr_Game1_Game2 pr_bad_Game1_Game2 pr_Game2 pr_Game2_OW). @@ -675,4 +657,3 @@ by move=> O O_o_ll; proc; call (A_a2_ll O O_o_ll). qed. end section. - diff --git a/src/ecLowPhlGoal.ml b/src/ecLowPhlGoal.ml index 683dd5e185..79244af03e 100644 --- a/src/ecLowPhlGoal.ml +++ b/src/ecLowPhlGoal.ml @@ -578,7 +578,7 @@ type 'a zip_t = let t_fold f (cenv : code_txenv) (cpos : codepos) (_ : form * form) (state, s) = try let env = EcEnv.LDecl.toenv (snd cenv) in - let (me, f) = Zpr.fold env cenv cpos f state s in + let (me, f) = Zpr.fold env cenv cpos (fun _ -> f) state s in ((me, f, []) : memenv * _ * form list) with Zpr.InvalidCPos -> tc_error (fst cenv) "invalid code position" diff --git a/src/ecMatching.ml b/src/ecMatching.ml index a2f5fce026..015a2c08ba 100644 --- a/src/ecMatching.ml +++ b/src/ecMatching.ml @@ -57,7 +57,10 @@ module Zipper = struct module P = EcPath type ('a, 'state) folder = - 'a -> 'state -> instr -> 'state * instr list + env -> 'a -> 'state -> instr -> 'state * instr list + + type ('a, 'state) folder_tl = + env -> 'a -> 'state -> instr -> instr list -> 'state * instr list type spath_match_ctxt = { locals : (EcIdent.t * ty) list; @@ -71,18 +74,19 @@ module Zipper = struct | ZIfThen of expr * spath * stmt | ZIfElse of expr * stmt * spath | ZMatch of expr * spath * spath_match_ctxt - + and spath = (instr list * instr list) * ipath type zipper = { z_head : instr list; (* instructions on my left (rev) *) z_tail : instr list; (* instructions on my right (me incl.) *) z_path : ipath; (* path (zipper) leading to me *) + z_env : env option; } let cpos (i : int) : codepos1 = (0, `ByPos i) - let zipper hd tl zpr = { z_head = hd; z_tail = tl; z_path = zpr; } + let zipper ?env hd tl zpr = { z_head = hd; z_tail = tl; z_path = zpr; z_env = env; } let find_by_cp_match (env : EcEnv.env) @@ -193,19 +197,19 @@ module Zipper = struct ((cp1, sub) : codepos1 * codepos_brsel) (s : stmt) (zpr : ipath) - : (ipath * stmt) * (codepos1 * codepos_brsel) + : (ipath * stmt) * (codepos1 * codepos_brsel) * env = let (s1, i, s2) = find_by_cpos1 env cp1 s in - let zpr = + let zpr, env = match i.i_node, sub with | Swhile (e, sw), `Cond true -> - (ZWhile (e, ((s1, s2), zpr)), sw) + (ZWhile (e, ((s1, s2), zpr)), sw), env | Sif (e, ifs1, ifs2), `Cond true -> - (ZIfThen (e, ((s1, s2), zpr), ifs2), ifs1) + (ZIfThen (e, ((s1, s2), zpr), ifs2), ifs1), env | Sif (e, ifs1, ifs2), `Cond false -> - (ZIfElse (e, ifs1, ((s1, s2), zpr)), ifs2) + (ZIfElse (e, ifs1, ((s1, s2), zpr)), ifs2), env | Smatch (e, bs), `Match cn -> let _, indt, _ = oget (EcEnv.Ty.get_top_decl e.e_ty env) in @@ -216,19 +220,20 @@ module Zipper = struct with Not_found -> raise InvalidCPos in let prebr, (locals, body), postbr = List.pivot_at ix bs in - (ZMatch (e, ((s1, s2), zpr), { locals; prebr; postbr; }), body) + let env = EcEnv.Var.bind_locals locals env in + (ZMatch (e, ((s1, s2), zpr), { locals; prebr; postbr; }), body), env | _ -> raise InvalidCPos - in zpr, ((0, `ByPos (1 + List.length s1)), sub) + in zpr, ((0, `ByPos (1 + List.length s1)), sub), env let zipper_of_cpos_r (env : EcEnv.env) ((nm, cp1) : codepos) (s : stmt) = - let (zpr, s), nm = + let ((zpr, s), env), nm = List.fold_left_map - (fun (zpr, s) nm1 -> zipper_at_nm_cpos1 env nm1 s zpr) - (ZTop, s) nm in + (fun ((zpr, s), env) nm1 -> let zpr, s, env = zipper_at_nm_cpos1 env nm1 s zpr in (zpr, env), s) + ((ZTop, s), env) nm in let s1, i, s2 = find_by_cpos1 env cp1 s in - let zpr = zipper s1 (i :: s2) zpr in + let zpr = zipper ~env s1 (i :: s2) zpr in (zpr, (nm, (0, `ByPos (1 + List.length s1)))) @@ -274,21 +279,28 @@ module Zipper = struct in List.rev after - let fold env cenv cpos f state s = + let fold_tl env cenv cpos f state s = let zpr = zipper_of_cpos env cpos s in match zpr.z_tail with | [] -> raise InvalidCPos | i :: tl -> begin - match f cenv state i with + match f (odfl env zpr.z_env) cenv state i tl with | (state', [i']) when i == i' && state == state' -> (state, s) - | (state', si ) -> (state', zip { zpr with z_tail = si @ tl }) + | (state', si ) -> (state', zip { zpr with z_tail = si }) end + let fold env cenv cpos f state s = + let f e ce st i tl = + let state', si = f e ce st i in + state', si @ tl + in + fold_tl env cenv cpos f state s + let map env cpos f s = fst_map Option.get - (fold env () cpos (fun () _ i -> fst_map some (f i)) None s) + (fold env () cpos (fun _ () _ i -> fst_map some (f i)) None s) end (* -------------------------------------------------------------------- *) diff --git a/src/ecMatching.mli b/src/ecMatching.mli index 9961f1c24e..1751637654 100644 --- a/src/ecMatching.mli +++ b/src/ecMatching.mli @@ -61,6 +61,7 @@ module Zipper : sig z_head : instr list; (* instructions on my left (rev) *) z_tail : instr list; (* instructions on my right (me incl.) *) z_path : ipath ; (* path (zipper) leading to me *) + z_env : env option; (* env with local vars from previous instructions *) } exception InvalidCPos @@ -79,7 +80,7 @@ module Zipper : sig val offset_of_position : env -> codepos1 -> stmt -> int (* [zipper] soft constructor *) - val zipper : instr list -> instr list -> ipath -> zipper + val zipper : ?env : env -> instr list -> instr list -> ipath -> zipper (* Return the zipper for the stmt [stmt] at code position [codepos]. * Raise [InvalidCPos] if [codepos] is not valid for [stmt]. It also @@ -101,7 +102,8 @@ module Zipper : sig *) val after : strict:bool -> zipper -> instr list list - type ('a, 'state) folder = 'a -> 'state -> instr -> 'state * instr list + type ('a, 'state) folder = env -> 'a -> 'state -> instr -> 'state * instr list + type ('a, 'state) folder_tl = env -> 'a -> 'state -> instr -> instr list -> 'state * instr list (* [fold env v cpos f state s] create the zipper for [s] at [cpos], and apply * [f] to it, along with [v] and the state [state]. [f] must return the @@ -112,6 +114,9 @@ module Zipper : sig *) val fold : env -> 'a -> codepos -> ('a, 'state) folder -> 'state -> stmt -> 'state * stmt + (* Same as above but using [folder_tl]. *) + val fold_tl : env -> 'a -> codepos -> ('a, 'state) folder_tl -> 'state -> stmt -> 'state * stmt + (* [map cpos env f s] is a special case of [fold] where the state and the * out-of-band data are absent *) diff --git a/src/ecParser.mly b/src/ecParser.mly index e07035a281..9f72e6abff 100644 --- a/src/ecParser.mly +++ b/src/ecParser.mly @@ -1443,6 +1443,30 @@ mod_item: | IMPORT VAR ms=loc(mod_qident)+ { Pst_import ms } +mod_update_var: +| v=var_decl { v } + +mod_update_fun: +| PROC x=lident LBRACKET lvs=var_decl* fups=fun_update+ RBRACKET res_up=option(RES TILD e=sexpr {e}) + { (x, lvs, (List.flatten fups, res_up)) } + +update_stmt: +| PLUS s=brace(stmt){ [Pups_add (s, true)] } +| MINUS s=brace(stmt){ [Pups_add (s, false)] } +| TILD s=brace(stmt) { [Pups_del; Pups_add (s, true)] } +| MINUS { [Pups_del] } + +update_cond: +| PLUS e=sexpr { Pupc_add e } +| TILD e=sexpr { Pupc_mod e } +| MINUS bs=branch_select { Pupc_del bs } + +fun_update: +| cp=loc(codepos) sup=update_stmt + { List.map (fun v -> (cp, Pup_stmt v)) sup } +| cp=loc(codepos) cup=update_cond + { [(cp, Pup_cond cup)] } + (* -------------------------------------------------------------------- *) (* Modules *) @@ -1453,6 +1477,9 @@ mod_body: | LBRACE stt=loc(mod_item)* RBRACE { Pm_struct stt } +| m=mod_qident WITH LBRACE vs=mod_update_var* fs=mod_update_fun* RBRACE + { Pm_update (m, vs, fs) } + mod_def_or_decl: | locality=locality MODULE header=mod_header c=mod_cast? EQ ptm_body=loc(mod_body) { let ptm_header = match c with None -> header | Some c -> Pmh_cast(header,c) in diff --git a/src/ecParsetree.ml b/src/ecParsetree.ml index 3fd5c16f93..67eba5e110 100644 --- a/src/ecParsetree.ml +++ b/src/ecParsetree.ml @@ -36,6 +36,30 @@ type osymbol_r = psymbol option type osymbol = osymbol_r located (* -------------------------------------------------------------------- *) +type pcp_match = [ + | `If + | `While + | `Match + | `Assign of plvmatch + | `Sample of plvmatch + | `Call of plvmatch +] + +and plvmatch = [ `LvmNone | `LvmVar of pqsymbol ] + +type pcp_base = [ `ByPos of int | `ByMatch of int option * pcp_match ] + +type pbranch_select = [`Cond of bool | `Match of psymbol] +type pcodepos1 = int * pcp_base +type pcodepos = (pcodepos1 * pbranch_select) list * pcodepos1 +type pdocodepos1 = pcodepos1 doption option + +type pcodeoffset1 = [ + | `ByOffset of int + | `ByPosition of pcodepos1 +] +(* -------------------------------------------------------------------- *) + type pty_r = | PTunivar | PTtuple of pty list @@ -305,6 +329,7 @@ and pmodule_params = (psymbol * pmodule_type) list and pmodule_expr_r = | Pm_ident of pmsymbol | Pm_struct of pstructure + | Pm_update of pmsymbol * pupdate_var list * pupdate_fun list and pmodule_expr = pmodule_expr_r located @@ -318,6 +343,21 @@ and pstructure_item = | Pst_include of (pmsymbol located * bool * minclude_proc option) | Pst_import of (pmsymbol located) list +and pupdate_var = psymbol list * pty +and pupdate_fun = psymbol * (psymbol list * pty) list * ((pcodepos located * pupdate_item) list * pexpr option) + +and pupdate_item = + | Pup_stmt of pupdate_stmt + | Pup_cond of pupdate_cond + +and pupdate_stmt = + | Pups_add of (pstmt * bool) + | Pups_del + +and pupdate_cond = + | Pupc_add of pexpr + | Pupc_mod of pexpr + | Pupc_del of pbranch_select and pfunction_body = { pfb_locals : pfunction_local list; @@ -468,30 +508,6 @@ type preduction = { puser : bool; (* user reduction *) } -(* -------------------------------------------------------------------- *) -type pcp_match = [ - | `If - | `While - | `Match - | `Assign of plvmatch - | `Sample of plvmatch - | `Call of plvmatch -] - -and plvmatch = [ `LvmNone | `LvmVar of pqsymbol ] - -type pcp_base = [ `ByPos of int | `ByMatch of int option * pcp_match ] - -type pbranch_select = [`Cond of bool | `Match of psymbol] -type pcodepos1 = int * pcp_base -type pcodepos = (pcodepos1 * pbranch_select) list * pcodepos1 -type pdocodepos1 = pcodepos1 doption option - -type pcodeoffset1 = [ - | `ByOffset of int - | `ByPosition of pcodepos1 -] - (* -------------------------------------------------------------------- *) type pswap_kind = { interval : (pcodepos1 * pcodepos1 option) option; diff --git a/src/ecTyping.ml b/src/ecTyping.ml index dc94ff5579..3e9951c574 100644 --- a/src/ecTyping.ml +++ b/src/ecTyping.ml @@ -89,6 +89,14 @@ type modsig_error = | MTS_DupProcName of symbol | MTS_DupArgName of symbol * symbol +type modupd_error = +| MUE_Functor +| MUE_AbstractFun +| MUE_AbstractModule +| MUE_InvalidFun +| MUE_InvalidCodePos +| MUE_InvalidTargetCond + type funapp_error = | FAE_WrongArgCount @@ -155,6 +163,7 @@ type tyerror = | InvalidModAppl of modapp_error | InvalidModType of modtyp_error | InvalidModSig of modsig_error +| InvalidModUpdate of modupd_error | InvalidMem of symbol * mem_error | InvalidMatch of fxerror | InvalidFilter of filter_error @@ -2016,6 +2025,263 @@ and transmod_body ~attop (env : EcEnv.env) x params (me:pmodule_expr) = me | Pm_struct ps -> transstruct ~attop env x.pl_desc stparams (mk_loc me.pl_loc ps) + | Pm_update (m, vars, funs) -> + let loc = me.pl_loc in + let (mp, sig_) = trans_msymbol env {pl_desc = m; pl_loc = loc} in + + (* Prohibit functor updates *) + if not (List.is_empty sig_.miss_params) then + tyerror loc env (InvalidModUpdate MUE_Functor); + + (* Construct the set of new module variables *) + let items = + List.concat_map + (fun (xs, ty) -> + let ty = transty_for_decl env ty in + let items = List.map + (fun { pl_desc = x } -> MI_Variable { v_name = x; v_type = ty; }) + xs + in + items) + vars + in + + let me, _ = EcEnv.Mod.by_mpath mp env in + let p = match mp.m_top with | `Concrete (p, _) -> p | _ -> assert false in + let subst = EcSubst.add_moddef EcSubst.empty ~src:p ~dst:(EcEnv.mroot env) in + let me = EcSubst.subst_module subst me in + + let update_fun env fn plocals pupdates pupdate_res = + (* Extract the function body and load the memory *) + let fun_ = EcEnv.Fun.by_xpath (xpath mp fn) env in + + (* Follow a function alias until we get to the concrete definition *) + let rec resolve_alias f = + match f.f_def with + | FBabs _ -> + tyerror loc env (InvalidModUpdate MUE_AbstractModule); + | FBalias xp -> resolve_alias (EcEnv.Fun.by_xpath xp env) + | FBdef _ -> f + in + + let target_fun = EcSubst.subst_function subst (resolve_alias fun_) in + let (_fs, fd), memenv = EcEnv.Fun.actmem_body mhr target_fun in + + let fun_ = EcSubst.subst_function subst fun_ in + + (* Introduce the new local variables *) + let locals = List.concat_map (fun (vs, pty) -> + let ty = transty_for_decl env pty in + List.map (fun x -> { v_name = x.pl_desc; v_type = ty; }, x.pl_loc) vs + ) + plocals + in + + let memenv = fundef_add_symbol env memenv locals in + let env = EcEnv.Memory.push_active memenv env in + + let locals = ref locals in + let memenv = ref memenv in + + (* Semantics for stmt updating, `i` is the target of the update. *) + let eval_supdate env sup i = + match sup with + | Pups_add (s, after) -> + let ue = UE.create (Some []) in + let s = transstmt env ue s in + let ts = Tuni.subst (UE.close ue) in + if after then + i :: (s_subst ts s).s_node + else + (s_subst ts s).s_node @ [i] + | Pups_del -> [] + in + + (* Semantics for condition updating *) + (* `i` is the target of the update, and `tl` is the instr suffix. *) + let eval_cupdate cp_loc env cup i tl = + match cup with + (* Insert an if with condition `e` with body `tl` *) + | Pupc_add e -> + let loc = e.pl_loc in + let ue = UE.create (Some []) in + let e, ty = transexp env `InProc ue e in + let ts = Tuni.subst (UE.close ue) in + let ty = ty_subst ts ty in + unify_or_fail env ue loc ~expct:tbool ty; + i :: [i_if (e_subst ts e, stmt tl, s_empty)] + + (* Change the condition expression to `e` for a conditional instr `i` *) + | Pupc_mod e -> begin + let loc = e.pl_loc in + let ue = UE.create (Some []) in + let e, ty = transexp env `InProc ue e in + let ts = Tuni.subst (UE.close ue) in + let ty = ty_subst ts ty in + match i.i_node with + | Sif (_, t, f) -> + unify_or_fail env ue loc ~expct:tbool ty; + [i_if (e_subst ts e, t, f)] @ tl + | Smatch (p, bs) -> + unify_or_fail env ue loc ~expct:p.e_ty ty; + [i_match (e_subst ts e, bs)] @ tl + | Swhile (_, t) -> + unify_or_fail env ue loc ~expct:tbool ty; + [i_while (e_subst ts e, t)] @ tl + | _ -> + tyerror cp_loc env (InvalidModUpdate MUE_InvalidTargetCond); + end + + (* Collapse a conditional `i` to a specific branch `bs` *) + | Pupc_del bs -> begin + let bs = trans_codepos_brsel bs in + match i.i_node, bs with + | Sif (_, t, _), `Cond true -> t.s_node + | Sif (_, _, f), `Cond false -> f.s_node + | Swhile (_, t), `Cond true -> t.s_node + | Smatch (e, bs), `Match cn -> begin + (* match e with | C a b c => b | ... ---> (a, b, c) <- oget (get_as_C e); b *) + + let typ, tydc, tyinst = oget (EcEnv.Ty.get_top_decl e.e_ty env) in + let tyinst = List.combine tydc.tyd_params tyinst in + let indt = oget (EcDecl.tydecl_as_datatype tydc) in + let cnames = List.fst indt.tydt_ctors in + let r = List.assoc_opt cn (List.combine cnames bs) in + match r with + | None -> + tyerror cp_loc env (InvalidModUpdate MUE_InvalidTargetCond) + | Some (p, b) -> begin + (* TODO: Factorize. This is mostly just a copy/paste from EcPhlRCond.gen_rcond_full. *) + let cvars = List.map (fun (x, xty) -> { ov_name = Some (EcIdent.name x); ov_type = xty; }) p in + let me, cvars = EcMemory.bindall_fresh cvars !memenv in + + let subst, pvs = + let s = Fsubst.f_subst_id in + let s, pvs = List.fold_left_map (fun s ((x, xty), name) -> + let pv = pv_loc (oget name.ov_name) in + let s = bind_elocal s x (e_var pv xty) in + (s, (pv, xty))) + s (List.combine p cvars) + in + (s, pvs) + in + + let asgn = EcModules.lv_of_list pvs |> omap (fun lv -> + let rty = ttuple (List.snd p) in + let proj = EcInductive.datatype_proj_path typ cn in + let proj = e_op proj (List.snd tyinst) (tfun e.e_ty (toption rty)) in + let proj = e_app proj [e] (toption rty) in + let proj = e_oget proj rty in + i_asgn (lv, proj)) + in + + memenv := me; + locals := !locals @ (List.map (fun ov -> {v_name = oget ov.ov_name; v_type = ov.ov_type; }, cp_loc) cvars); + + match asgn with + | None -> b.s_node @ tl + | Some a -> a :: (s_subst subst b).s_node @ tl + end + end + | _ -> + tyerror cp_loc env (InvalidModUpdate MUE_InvalidTargetCond); + end + in + + (* Apply each of updates in reverse *) + (* NOTE: This is with the expectation that the user entered them in chronological order. *) + let body = + List.fold_right (fun (cp, up) bd -> + let {pl_desc = cp; pl_loc = loc} = cp in + let cp = trans_codepos env cp in + let change env _ _ i tl = (), + match up with + | Pup_stmt sup -> + eval_supdate env sup i @ tl + | Pup_cond cup -> + eval_cupdate loc env cup i tl + in + let env = EcEnv.Memory.push_active !memenv env in + try + let _, s = EcMatching.Zipper.fold_tl env () cp change () bd in + s + with + | EcMatching.Zipper.InvalidCPos -> + tyerror loc env (InvalidModUpdate MUE_InvalidCodePos); + ) + pupdates + fd.f_body + in + + (* Apply the result update if given *) + let ret = match fd.f_ret, pupdate_res with + | Some e, Some e' -> + let loc = e'.pl_loc in + let ue = UE.create (Some []) in + let e', ty = transexp env `InProc ue e' in + unify_or_fail env ue loc ~expct:e.e_ty ty; + let ts = Tuni.subst (UE.close ue) in + Some (e_subst ts e') + | _ -> fd.f_ret + in + + (* Reconstruct the function def *) + let uses = ret |> ofold ((^~) se_inuse) (s_inuse body) in + let fd = {f_locals = fd.f_locals @ (List.fst !locals); f_body = body; f_ret = ret; f_uses = uses; } in + let fun_ = {fun_ with f_def = FBdef fd} in + fun_ + in + + let allowed_funs = List.map (fun (Tys_function f) -> f.fs_name) me.me_sig_body in + let funs = List.map (fun ({pl_loc = loc; pl_desc = fn}, lvs, v) -> + if List.mem fn allowed_funs then + fn, (lvs, v) + else + tyerror loc env (InvalidModUpdate MUE_InvalidFun) + ) funs + in + + (* Update all module items *) + let env, items = + match me.me_body with + | ME_Structure mb -> + let doit (env, items) mi = + match mi with + | MI_Variable v -> + let env = EcEnv.Var.bind_pvglob v.v_name v.v_type env in + env, items @ [mi] + | MI_Function f -> begin + match List.assoc_opt f.f_name funs with + | None -> + let env = EcEnv.bind1 (f.f_name, `Function f) env in + env, items @ [mi] + | Some (lvs, (upsc, rup)) -> + let f = update_fun env f.f_name lvs upsc rup in + let env = EcEnv.bind1 (f.f_name, `Function f) env in + env, items @ [MI_Function f] + end + | MI_Module me -> + let env = EcEnv.bind1 (me.me_name, `Module me) env in + env, items @ [mi] + in + List.fold_left doit (env, []) (items @ mb.ms_body) + + | _ -> + tyerror loc env (InvalidModUpdate MUE_AbstractModule); + in + + let ois = get_oi_calls env (stparams, items) in + + (* Construct structure representation *) + let me = + { me_name = x.pl_desc; + me_body = ME_Structure { ms_body = items; }; + me_comps = items; + me_params = stparams; + me_sig_body = me.me_sig_body; + me_oinfos = ois; } + in + me (* -------------------------------------------------------------------- *) (* Module parameters must have been added to the environment *) @@ -3300,7 +3566,7 @@ and transexpcast_opt (env : EcEnv.env) mode ue oty e = match oty with | None -> fst (transexp env mode ue e) | Some t -> transexpcast env mode ue t e - + (* -------------------------------------------------------------------- *) and trans_form_opt env ?mv ue pf oty = trans_form_or_pattern env `Form ?mv ue pf oty @@ -3318,10 +3584,10 @@ and trans_pattern env ps ue pf = trans_form_or_pattern env `Form ~ps ue pf None (* -------------------------------------------------------------------- *) -let trans_args env ue = transcall (transexp env `InProc ue) env ue +and trans_args env ue = transcall (transexp env `InProc ue) env ue (* -------------------------------------------------------------------- *) -let trans_lv_match ?(memory : memory option) (env : EcEnv.env) (p : plvmatch) : lvmatch = +and trans_lv_match ?(memory : memory option) (env : EcEnv.env) (p : plvmatch) : lvmatch = match p with | `LvmNone as p -> (p :> lvmatch) | `LvmVar pv -> begin @@ -3332,7 +3598,7 @@ let trans_lv_match ?(memory : memory option) (env : EcEnv.env) (p : plvmatch) : `LvmVar (transpvar env m pv) end (* -------------------------------------------------------------------- *) -let trans_cp_match ?(memory : memory option) (env : EcEnv.env) (p : pcp_match) : cp_match = +and trans_cp_match ?(memory : memory option) (env : EcEnv.env) (p : pcp_match) : cp_match = match p with | (`While | `If | `Match) as p -> (p :> cp_match) @@ -3343,29 +3609,29 @@ let trans_cp_match ?(memory : memory option) (env : EcEnv.env) (p : pcp_match) : | `Assign lv -> `Assign (trans_lv_match ?memory env lv) (* -------------------------------------------------------------------- *) -let trans_cp_base ?(memory : memory option) (env : EcEnv.env) (p : pcp_base) : cp_base = +and trans_cp_base ?(memory : memory option) (env : EcEnv.env) (p : pcp_base) : cp_base = match p with | `ByPos _ as p -> (p :> cp_base) | `ByMatch (i, p) -> `ByMatch (i, trans_cp_match ?memory env p) (* -------------------------------------------------------------------- *) -let trans_codepos1 ?(memory : memory option) (env : EcEnv.env) (p : pcodepos1) : codepos1 = +and trans_codepos1 ?(memory : memory option) (env : EcEnv.env) (p : pcodepos1) : codepos1 = snd_map (trans_cp_base ?memory env) p (* -------------------------------------------------------------------- *) -let trans_codepos_brsel (bs : pbranch_select) : codepos_brsel = +and trans_codepos_brsel (bs : pbranch_select) : codepos_brsel = match bs with | `Cond b -> `Cond b | `Match { pl_desc = x } -> `Match x (* -------------------------------------------------------------------- *) -let trans_codepos ?(memory : memory option) (env : EcEnv.env) ((nm, p) : pcodepos) : codepos = +and trans_codepos ?(memory : memory option) (env : EcEnv.env) ((nm, p) : pcodepos) : codepos = let nm = List.map (fun (cp1, bs) -> (trans_codepos1 ?memory env cp1, trans_codepos_brsel bs)) nm in let p = trans_codepos1 ?memory env p in (nm, p) (* -------------------------------------------------------------------- *) -let trans_dcodepos1 ?(memory : memory option) (env : EcEnv.env) (p : pcodepos1 doption) : codepos1 doption = +and trans_dcodepos1 ?(memory : memory option) (env : EcEnv.env) (p : pcodepos1 doption) : codepos1 doption = DOption.map (trans_codepos1 ?memory env) p (* -------------------------------------------------------------------- *) diff --git a/src/ecTyping.mli b/src/ecTyping.mli index eb3e48f9f1..a09265998a 100644 --- a/src/ecTyping.mli +++ b/src/ecTyping.mli @@ -81,6 +81,14 @@ type modsig_error = | MTS_DupProcName of symbol | MTS_DupArgName of symbol * symbol +type modupd_error = +| MUE_Functor +| MUE_AbstractFun +| MUE_AbstractModule +| MUE_InvalidFun +| MUE_InvalidCodePos +| MUE_InvalidTargetCond + type funapp_error = | FAE_WrongArgCount @@ -147,6 +155,7 @@ type tyerror = | InvalidModAppl of modapp_error | InvalidModType of modtyp_error | InvalidModSig of modsig_error +| InvalidModUpdate of modupd_error | InvalidMem of symbol * mem_error | InvalidMatch of fxerror | InvalidFilter of filter_error diff --git a/src/ecUserMessages.ml b/src/ecUserMessages.ml index 6973f029ee..9c947c1b6c 100644 --- a/src/ecUserMessages.ml +++ b/src/ecUserMessages.ml @@ -452,6 +452,24 @@ end = struct | InvalidModSig (MTS_DupArgName (f, x)) -> msg "duplicated proc. arg. name in signature: `%s.%s'" f x + | InvalidModUpdate MUE_Functor -> + msg "cannot update a functor" + + | InvalidModUpdate MUE_AbstractFun -> + msg "cannot update an abstract function" + + | InvalidModUpdate MUE_AbstractModule -> + msg "cannot update an abstract module" + + | InvalidModUpdate MUE_InvalidFun -> + msg "unknown function" + + | InvalidModUpdate MUE_InvalidCodePos-> + msg "invalid code position" + + | InvalidModUpdate MUE_InvalidTargetCond -> + msg "target instruction is not a conditional" + | InvalidMem (name, MAE_IsConcrete) -> msg "the memory %s must be abstract" name diff --git a/tests/fine-grained-module-defs.ec b/tests/fine-grained-module-defs.ec new file mode 100644 index 0000000000..09b8acadf8 --- /dev/null +++ b/tests/fine-grained-module-defs.ec @@ -0,0 +1,50 @@ +require import AllCore. + +module type T = { + proc run() : unit +}. + +module A (B : T) = { + var x : int + + proc f(y: int) = { + x <- x + y; + B.run(); + return x; + } + proc g(y: int) = { + x <- x - y; + B.run(); + return x; + } + proc h(x: int option) = { + var r <- 2; + match x with + | None => {} + | Some v => { + r <- v; + } +end; + return r; + } +}. + +module A_count (B : T) = A(B) with { + var c : int + proc f [1 - { c <- c + 1;}] + proc g [1 ~ { c <- c - 1;} 2 -] res ~ (x + 1) + proc h [^match - #Some.] +}. +print A_count. + +equiv A_A_count (B <: T{-A_count, -A}) : A(B).f ~ A_count(B).f: ={arg, glob B} /\ ={x}(A, A_count) ==> ={res, glob B} /\ ={x}(A, A_count). +proof. +proc. +by call (: true); auto. +qed. + +lemma Check_Delete_Branch (B <: T): hoare[A_count(B).h: arg = Some 4 ==> res = 4]. +proof. +proc. +by auto. +qed.