From adb97b66fbc8652547264fcaf0f00eb940ecf690 Mon Sep 17 00:00:00 2001 From: Lasse Letager Hansen Date: Wed, 1 Nov 2023 17:42:34 +0100 Subject: [PATCH] Adding phase to simplify enum and record matching statements --- engine/backends/coq/coq/coq_backend.ml | 8 +- .../backends/easycrypt/easycrypt_backend.ml | 2 + engine/backends/fstar/fstar_backend.ml | 8 +- engine/lib/ast.ml | 5 +- engine/lib/ast_utils.ml | 2 +- engine/lib/diagnostics.ml | 1 + engine/lib/features.ml | 3 +- .../generic_printer/generic_printer_base.ml | 3 +- engine/lib/import_thir.ml | 7 +- engine/lib/phases.ml | 1 + .../phases/phase_project_instead_of_match.ml | 226 ++++++++++++++++++ .../phases/phase_project_instead_of_match.mli | 18 ++ engine/lib/print_rust.ml | 2 +- engine/lib/subtype.ml | 7 +- 14 files changed, 282 insertions(+), 11 deletions(-) create mode 100644 engine/lib/phases/phase_project_instead_of_match.ml create mode 100644 engine/lib/phases/phase_project_instead_of_match.mli diff --git a/engine/backends/coq/coq/coq_backend.ml b/engine/backends/coq/coq/coq_backend.ml index de0489b66..7db2a51b2 100644 --- a/engine/backends/coq/coq/coq_backend.ml +++ b/engine/backends/coq/coq/coq_backend.ml @@ -12,6 +12,7 @@ include include On.Monadic_binding include On.Macro include On.Construct_base + include On.Project_instead_of_match end) (struct let backend = Diagnostics.Backend.Coq @@ -36,6 +37,8 @@ module SubtypeToInputLanguage and type nontrivial_lhs = Features.Off.nontrivial_lhs and type loop = Features.Off.loop and type block = Features.Off.block + and type project_instead_of_match = + Features.On.project_instead_of_match and type for_loop = Features.Off.for_loop and type for_index_loop = Features.Off.for_index_loop and type state_passing_loop = Features.Off.state_passing_loop) = @@ -52,6 +55,7 @@ struct include Features.SUBTYPE.On.Construct_base include Features.SUBTYPE.On.Slice include Features.SUBTYPE.On.Macro + include Features.SUBTYPE.On.Project_instead_of_match end) let metadata = Phase_utils.Metadata.make (Reject (NotInBackendLang backend)) @@ -216,9 +220,9 @@ struct __TODO_pat__ p.span "tuple 1" | PConstruct { name = `TupleCons n; args } -> C.AST.TuplePat (List.map ~f:(fun { pat } -> ppat pat) args) - | PConstruct { name; args; is_record = true } -> + | PConstruct { name; args; is_record = Some _ } -> C.AST.RecordPat (pglobal_ident name, pfield_pats args) - | PConstruct { name; args; is_record = false } -> + | PConstruct { name; args; is_record = None } -> C.AST.ConstructorPat (pglobal_ident name, List.map ~f:(fun p -> ppat p.pat) args) | PConstant { lit } -> C.AST.Lit (pliteral p.span lit) diff --git a/engine/backends/easycrypt/easycrypt_backend.ml b/engine/backends/easycrypt/easycrypt_backend.ml index 6b609021b..8d4acdc5a 100644 --- a/engine/backends/easycrypt/easycrypt_backend.ml +++ b/engine/backends/easycrypt/easycrypt_backend.ml @@ -15,6 +15,7 @@ include include On.Mutable_variable include On.Macro include On.Construct_base + include On.Project_instead_of_match end) (struct let backend = Diagnostics.Backend.EasyCrypt @@ -60,6 +61,7 @@ module RejectNotEC (FA : Features.T) = struct let state_passing_loop = reject let nontrivial_lhs = reject let block = reject + let project_instead_of_match _ _ = Features.On.project_instead_of_match let for_loop = reject let construct_base _ _ = Features.On.construct_base let for_index_loop _ _ = Features.On.for_index_loop diff --git a/engine/backends/fstar/fstar_backend.ml b/engine/backends/fstar/fstar_backend.ml index 59dc83917..9c26865d1 100644 --- a/engine/backends/fstar/fstar_backend.ml +++ b/engine/backends/fstar/fstar_backend.ml @@ -11,6 +11,7 @@ include include On.Slice include On.Macro include On.Construct_base + include On.Project_instead_of_match end) (struct let backend = Diagnostics.Backend.FStar @@ -35,6 +36,8 @@ module SubtypeToInputLanguage and type nontrivial_lhs = Features.Off.nontrivial_lhs and type loop = Features.Off.loop and type block = Features.Off.block + and type project_instead_of_match = + Features.On.project_instead_of_match and type for_loop = Features.Off.for_loop and type for_index_loop = Features.Off.for_index_loop and type state_passing_loop = Features.Off.state_passing_loop) = @@ -51,6 +54,7 @@ struct include Features.SUBTYPE.On.Construct_base include Features.SUBTYPE.On.Slice include Features.SUBTYPE.On.Macro + include Features.SUBTYPE.On.Project_instead_of_match end) let metadata = Phase_utils.Metadata.make (Reject (NotInBackendLang backend)) @@ -354,12 +358,12 @@ struct let pat_rec () = F.pat @@ F.AST.PatRecord (List.map ~f:pfield_pat args) in - if is_struct && is_record then pat_rec () + if is_struct && Option.is_some is_record then pat_rec () else let pat_name = F.pat @@ F.AST.PatName (pglobal_ident p.span name) in F.pat_app pat_name @@ - if is_record then [ pat_rec () ] + if Option.is_some is_record then [ pat_rec () ] else List.map ~f:(fun { field; pat } -> ppat pat) args | PConstant { lit } -> F.pat @@ F.AST.PatConst (pliteral p.span lit) | _ -> . diff --git a/engine/lib/ast.ml b/engine/lib/ast.ml index 75f7ca2ff..b6f351869 100644 --- a/engine/lib/ast.ml +++ b/engine/lib/ast.ml @@ -270,13 +270,16 @@ functor and trait_ref = { trait : concrete_ident; args : generic_value list } + (* and is_record_construct = { witness : F.project_instead_of_match } *) and pat' = | PWild | PAscription of { typ : ty; typ_span : span; pat : pat } | PConstruct of { name : global_ident; args : field_pat list; - is_record : bool; (* are fields named? *) + is_record : F.project_instead_of_match option; + (* are fields named? *) + (* F.project_instead_of_match *) is_struct : bool; (* a struct has one constructor *) } (* An or-pattern, e.g. `p | q`. diff --git a/engine/lib/ast_utils.ml b/engine/lib/ast_utils.ml index 0d3437cbc..b651885e4 100644 --- a/engine/lib/ast_utils.ml +++ b/engine/lib/ast_utils.ml @@ -564,7 +564,7 @@ module Make (F : Features.T) = struct { name = `TupleCons len; args = tuple; - is_record = false; + is_record = None; is_struct = true; }; typ = make_tuple_typ @@ List.map ~f:(fun { pat; _ } -> pat.typ) tuple; diff --git a/engine/lib/diagnostics.ml b/engine/lib/diagnostics.ml index 58a99c66e..27d36b4af 100644 --- a/engine/lib/diagnostics.ml +++ b/engine/lib/diagnostics.ml @@ -37,6 +37,7 @@ module Phase = struct | TrivializeAssignLhs | CfIntoMonads | FunctionalizeLoops + | ProjectInsteadOfMatch | DummyA | DummyB | DummyC diff --git a/engine/lib/features.ml b/engine/lib/features.ml index 5ae6001a8..fdc1b90b5 100644 --- a/engine/lib/features.ml +++ b/engine/lib/features.ml @@ -21,7 +21,8 @@ loop, construct_base, monadic_action, monadic_binding, - block] + block, + project_instead_of_match] module Full = On diff --git a/engine/lib/generic_printer/generic_printer_base.ml b/engine/lib/generic_printer/generic_printer_base.ml index b42a8f9ac..a267802f9 100644 --- a/engine/lib/generic_printer/generic_printer_base.ml +++ b/engine/lib/generic_printer/generic_printer_base.ml @@ -210,7 +210,8 @@ module Make (F : Features.T) = struct | PConstruct { name; args; is_record; is_struct } -> ( match name with | `Concrete constructor -> - print#doc_construct_inductive ~is_record ~is_struct + print#doc_construct_inductive + ~is_record:(Option.is_some is_record) ~is_struct ~constructor ~base:None (List.map ~f:(fun fp -> diff --git a/engine/lib/import_thir.ml b/engine/lib/import_thir.ml index c02a8d498..83c2d9531 100644 --- a/engine/lib/import_thir.ml +++ b/engine/lib/import_thir.ml @@ -734,7 +734,12 @@ end) : EXPR = struct { name; args; - is_record = info.variant_is_record; + is_record = + (if info.variant_is_record then Some W.project_instead_of_match + else None (* W.project_instead_of_match *)); + (* if info.variant_is_record *) + (* then Some W.project_instead_of_match *) + (* else None; *) is_struct = info.typ_is_struct; } | Tuple { subpatterns } -> diff --git a/engine/lib/phases.ml b/engine/lib/phases.ml index 07a971f21..cb0a28805 100644 --- a/engine/lib/phases.ml +++ b/engine/lib/phases.ml @@ -9,3 +9,4 @@ module Cf_into_monads = Phase_cf_into_monads.Make module Functionalize_loops = Phase_functionalize_loops.Make module Reject = Phase_reject module Local_mutation = Phase_local_mutation.Make +module Project_instead_of_match = Phase_project_instead_of_match.Make diff --git a/engine/lib/phases/phase_project_instead_of_match.ml b/engine/lib/phases/phase_project_instead_of_match.ml new file mode 100644 index 000000000..da1fbd7d3 --- /dev/null +++ b/engine/lib/phases/phase_project_instead_of_match.ml @@ -0,0 +1,226 @@ +open! Prelude + +module%inlined_contents Make (F : Features.T) = struct + open Ast + module FA = F + + module FB = struct + include F + include Features.Off.Project_instead_of_match + end + + include + Phase_utils.MakeBase (F) (FB) + (struct + let phase_id = Diagnostics.Phase.ProjectInsteadOfMatch + end) + + module UA = Ast_utils.Make (F) + module UB = Ast_utils.Make (FB) + + module Implem : ImplemT.T = struct + let metadata = metadata + + module S = struct + include Features.SUBTYPE.Id + include Features.Off.Project_instead_of_match + end + + [%%inline_defs dmutability] + + let rec dty (span : span) (ty : A.ty) : B.ty = + match ty with [%inline_arms "dty.*"] -> auto + + and dpat' (span : span) (pat : A.pat') : B.pat' = + match pat with + | [%inline_arms "dpat'.*" - PConstruct] -> auto + | PConstruct { name; args; is_record = _; is_struct } -> + PConstruct + { + name; + args = List.map ~f:(dfield_pat span) args; + is_record = None; + is_struct; + } + + and project_pat (p : A.pat) : B.pat * (B.pat * B.expr) list = + let simple_pat, remaining_pats = project_pat' p.span p.p in + ({ p = simple_pat; span = p.span; typ = dty p.span p.typ }, remaining_pats) + + and project_field_pat (_span : span) (p : A.field_pat) : + B.field_pat * (B.pat * B.expr) list = + let pat, pat_list = project_pat p.pat in + ({ field = p.field; pat }, pat_list) + + and project_pat' (span : span) (pat : A.pat') : + B.pat' * (B.pat * B.expr) list = + match pat with + | PWild -> (PWild, []) + | PAscription { typ; typ_span; pat } -> + let simple_pat, remaining_pats = project_pat pat in + ( PAscription { typ = dty span typ; pat = simple_pat; typ_span }, + remaining_pats ) + | PConstruct { name; args; is_record = Some _; is_struct } -> + let update_args = List.map ~f:(project_field_pat span) args in + let new_id = UA.fresh_local_ident_in_expr [] "some_name" in + ( PConstruct + { + name; + args = + [ + { + field = name; + pat = + { + p = + PBinding + { + mut = Immutable; + mode = ByValue; + var = new_id (* name *); + typ = TApp { ident = name; args = [] }; + (* TODO? *) + subpat = None; + }; + span; + typ = TApp { ident = name; args = [] }; + }; + }; + ]; + is_record = (None : FB.project_instead_of_match option); + is_struct; + }, + List.map + ~f:(fun ({ field; pat }, _) -> + ( pat, + ({ + e = + App + { + f = + { + e = GlobalVar field; + typ = TApp { ident = field; args = [] }; + (* TODO *) + span = pat.span; + }; + args = + [ + { + e = LocalVar new_id; + typ = TApp { ident = name; args = [] }; + span = pat.span; + }; + ]; + generic_args = []; + (* TODO *) + }; + typ = pat.typ; + span = pat.span; + } + : B.expr) )) + update_args + @ List.concat_map ~f:snd update_args ) + | PConstruct { name; args; is_record = None; is_struct } -> + let update_args = List.map ~f:(project_field_pat span) args in + ( PConstruct + { + name; + args = List.map ~f:fst update_args; + is_record = None; + is_struct; + }, + List.concat_map ~f:snd update_args ) + | PArray { args } -> + let update_args = List.map ~f:project_pat args in + ( PArray { args = List.map ~f:fst update_args }, + List.concat_map ~f:snd update_args ) + | PConstant { lit } -> (PConstant { lit }, []) + | PBinding { mut; mode; var : Local_ident.t; typ; subpat } -> + let simple_pat, remaining_pats = + match subpat with + | Some (subpat, as_pat) -> + let simple_pat, remaining_pats = project_pat subpat in + (Some (simple_pat, S.as_pattern span as_pat), remaining_pats) + | None -> (None, []) + in + ( PBinding + { + mut = dmutability span S.mutable_variable mut; + mode = dbinding_mode span mode; + var; + typ = dty span typ; + subpat = simple_pat; + (* TODO *) + (* Option.map ~f:(dpat *** S.as_pattern) subpat; *) + }, + remaining_pats ) + | PDeref { subpat; witness } -> + let simple_pat, remaining_pats = project_pat subpat in + ( PDeref { subpat = simple_pat; witness = S.reference span witness }, + remaining_pats ) + | POr { subpats } -> + let updated_subpats = List.map ~f:project_pat subpats in + ( POr { subpats = List.map ~f:fst updated_subpats }, + List.concat_map ~f:snd updated_subpats ) + + and let_of_pat_binding ((p, rhs) : B.pat * B.expr) (body : B.expr) : B.expr + = + UB.make_let p rhs body + + and lets_of_pat_bindings (bindings : (B.pat * B.expr) list) (body : B.expr) + : B.expr = + List.fold_right ~init:body ~f:let_of_pat_binding bindings + + and dexpr' (span : span) (e : A.expr') : B.expr' = + match (UA.unbox_underef_expr { e; span; typ = UA.never_typ }).e with + | [%inline_arms "dexpr'.*" - Let - Closure - Loop - Match] -> auto + | Match { scrutinee; arms } -> + Match { scrutinee = dexpr scrutinee; arms = List.map ~f:darm arms } + | Let { monadic; lhs; rhs; body } -> + let simple_pat, remaining_pats = project_pat lhs in + Let + { + monadic = + Option.map + ~f:(dsupported_monads span *** S.monadic_binding span) + monadic; + lhs = simple_pat; + rhs = dexpr rhs; + body = lets_of_pat_bindings remaining_pats (dexpr body); + } + | Loop { body; kind; state; label; witness } -> + Loop + { + body = dexpr body; + kind = dloop_kind span kind; + state = Option.map ~f:(dloop_state span) state; + label; + witness = S.loop span witness; + } + | Closure { params; body; captures } -> + let projected_params = List.map ~f:project_pat params in + Closure + { + params = List.map ~f:fst projected_params; + body = + lets_of_pat_bindings + (List.concat_map ~f:snd projected_params) + (dexpr body); + captures = List.map ~f:dexpr captures; + } + + and darm' (_span : span) (a : A.arm') : B.arm' = + let simple_pat, remaining_pats = project_pat a.arm_pat in + { + arm_pat = simple_pat; + body = lets_of_pat_bindings remaining_pats (dexpr a.body); + } + [@@inline_ands bindings_of dexpr] + + [%%inline_defs "Item.*"] + end + + include Implem +end +[@@add "subtype.ml"] diff --git a/engine/lib/phases/phase_project_instead_of_match.mli b/engine/lib/phases/phase_project_instead_of_match.mli new file mode 100644 index 000000000..ba51d7968 --- /dev/null +++ b/engine/lib/phases/phase_project_instead_of_match.mli @@ -0,0 +1,18 @@ +open! Prelude + +module Make (F : Features.T) : sig + include module type of struct + module FA = F + + module FB = struct + include F + include Features.Off.Project_instead_of_match + end + + module A = Ast.Make (F) + module B = Ast.Make (FB) + module ImplemT = Phase_utils.MakePhaseImplemT (A) (B) + end + + include ImplemT.T +end diff --git a/engine/lib/print_rust.ml b/engine/lib/print_rust.ml index 87b76cb5c..738df8fae 100644 --- a/engine/lib/print_rust.ml +++ b/engine/lib/print_rust.ml @@ -173,7 +173,7 @@ module Raw = struct pglobal_ident e.span name & if List.is_empty args then !"" - else if is_record then + else if Option.is_some is_record then !"{" & concat ~sep:!", " (List.map diff --git a/engine/lib/subtype.ml b/engine/lib/subtype.ml index 5907cfed4..7b09c7b2d 100644 --- a/engine/lib/subtype.ml +++ b/engine/lib/subtype.ml @@ -96,7 +96,12 @@ struct { name; args = List.map ~f:(dfield_pat span) args; - is_record; + is_record = + Option.map + ~f:(fun x : FB.project_instead_of_match -> + S.project_instead_of_match span x) + is_record; + (* S.project_instead_of_match *) is_struct; } | POr { subpats } -> POr { subpats = List.map ~f:dpat subpats }