Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement (if-let) guards in pattern matchings #821

Merged
merged 6 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion engine/backends/coq/coq/coq_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ module SubtypeToInputLanguage
and type for_index_loop = Features.Off.for_index_loop
and type quote = Features.Off.quote
and type state_passing_loop = Features.Off.state_passing_loop
and type dyn = Features.Off.dyn) =
and type dyn = Features.Off.dyn
and type match_guard = Features.Off.match_guard) =
struct
module FB = InputLanguage

Expand Down Expand Up @@ -705,6 +706,7 @@ module TransformToInputLanguage =
|> Phases.Direct_and_mut
|> Phases.Reject.Arbitrary_lhs
|> Phases.Drop_blocks
|> Phases.Drop_match_guards
|> Phases.Reject.Continue
|> Phases.Drop_references
|> Phases.Trivialize_assign_lhs
Expand Down
4 changes: 3 additions & 1 deletion engine/backends/coq/ssprove/ssprove_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ module SubtypeToInputLanguage
and type nontrivial_lhs = Features.Off.nontrivial_lhs
and type quote = Features.Off.quote
and type block = Features.Off.block
and type dyn = Features.Off.dyn) =
and type dyn = Features.Off.dyn
and type match_guard = Features.Off.match_guard) =
struct
module FB = InputLanguage

Expand Down Expand Up @@ -570,6 +571,7 @@ module TransformToInputLanguage (* : PHASE *) =
|> Phases.Direct_and_mut
|> Phases.Reject.Arbitrary_lhs
|> Phases.Drop_blocks
|> Phases.Drop_match_guards
(* |> Phases.Reject.Continue *)
|> Phases.Drop_references
|> Phases.Trivialize_assign_lhs
Expand Down
1 change: 1 addition & 0 deletions engine/backends/easycrypt/easycrypt_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ module RejectNotEC (FA : Features.T) = struct
let while_loop = reject
let quote = reject
let dyn = reject
let match_guard = reject
let construct_base _ _ = Features.On.construct_base
let for_index_loop _ _ = Features.On.for_index_loop

Expand Down
4 changes: 3 additions & 1 deletion engine/backends/fstar/fstar_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ module SubtypeToInputLanguage
and type for_loop = Features.Off.for_loop
and type while_loop = Features.Off.while_loop
and type for_index_loop = Features.Off.for_index_loop
and type state_passing_loop = Features.Off.state_passing_loop) =
and type state_passing_loop = Features.Off.state_passing_loop
and type match_guard = Features.Off.match_guard) =
struct
module FB = InputLanguage

Expand Down Expand Up @@ -1670,6 +1671,7 @@ module TransformToInputLanguage =
|> Phases.Direct_and_mut
|> Phases.Reject.Arbitrary_lhs
|> Phases.Drop_blocks
|> Phases.Drop_match_guards
|> Phases.Drop_references
|> Phases.Trivialize_assign_lhs
|> Side_effect_utils.Hoist
Expand Down
1 change: 1 addition & 0 deletions engine/backends/proverif/proverif_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ struct
let monadic_binding = reject
let block = reject
let dyn = reject
let match_guard = reject
let metadata = Phase_reject.make_metadata (NotInBackendLang ProVerif)
end)

Expand Down
10 changes: 9 additions & 1 deletion engine/lib/ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,16 @@ functor
witness : F.nontrivial_lhs;
}

(* A guard is a condition on a pattern like: *)
(* match x {.. if guard => .., ..}*)
and guard = { guard : guard'; span : span }

(* Only if-let guards are supported for now but other variants like regular if *)
(* could be added later (regular if guards are for now desugared as IfLet) *)
and guard' = IfLet of { lhs : pat; rhs : expr; witness : F.match_guard }

(* OCaml + visitors is not happy with `pat`... hence `arm_pat`... *)
and arm' = { arm_pat : pat; body : expr }
and arm' = { arm_pat : pat; body : expr; guard : guard option }
and arm = { arm : arm'; span : span } [@@deriving show, yojson, hash, eq]

type generic_param = {
Expand Down
27 changes: 23 additions & 4 deletions engine/lib/ast_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,12 @@ module Make (F : Features.T) = struct
inherit [_] Visitors.reduce as super
inherit [_] Sets.Local_ident.monoid as _m

method! visit_arm' env { arm_pat; body } =
shadows ~env [ arm_pat ] body super#visit_expr
method! visit_arm' env { arm_pat; body; guard } =
match guard with
| None -> shadows ~env [ arm_pat ] body super#visit_expr
| Some { guard = IfLet { lhs; rhs; _ }; _ } ->
shadows ~env [ arm_pat ] rhs super#visit_expr
++ shadows ~env [ arm_pat; lhs ] body super#visit_expr

method! visit_expr' env e =
match e with
Expand Down Expand Up @@ -466,6 +470,8 @@ module Make (F : Features.T) = struct
(module Local_ident)
end

(* This removes "fake" shadowing introduced by macros.
See PR #368 *)
let disambiguate_local_idents (item : item) =
let ambiguous = collect_ambiguous_local_idents#visit_item [] item in
let local_vars = collect_local_idents#visit_item () item |> ref in
Expand Down Expand Up @@ -601,8 +607,17 @@ module Make (F : Features.T) = struct
(without_vars (self#visit_expr () body) vars))
| _ -> super#visit_expr' () e

method! visit_arm' () { arm_pat; body } =
without_pat_vars (self#visit_expr () body) arm_pat
method! visit_arm' () { arm_pat; body; guard } =
match guard with
| Some { guard = IfLet { lhs; rhs; _ }; _ } ->
let rhs_vars =
without_pat_vars (self#visit_expr () rhs) arm_pat
in
let body_vars =
without_pats_vars (self#visit_expr () body) [ arm_pat; lhs ]
in
Set.union rhs_vars body_vars
| None -> without_pat_vars (self#visit_expr () body) arm_pat
end

class ['s] expr_list_monoid =
Expand Down Expand Up @@ -777,6 +792,10 @@ module Make (F : Features.T) = struct

let make_wild_pat (typ : ty) (span : span) : pat = { p = PWild; span; typ }

let make_arm (arm_pat : pat) (body : expr) ?(guard : guard option = None)
(span : span) : arm =
{ arm = { arm_pat; body; guard }; span }

let make_unit_param (span : span) : param =
let typ = unit_typ in
let pat = make_wild_pat typ span in
Expand Down
1 change: 1 addition & 0 deletions engine/lib/diagnostics.ml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ module Phase = struct
| DropReferences
| DropBlocks
| DropSizedTrait
| DropMatchGuards
| RefMut
| ResugarAsserts
| ResugarForLoops
Expand Down
3 changes: 2 additions & 1 deletion engine/lib/features.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ loop,
monadic_binding,
quote,
block,
dyn]
dyn,
match_guard]

module Full = On

Expand Down
15 changes: 12 additions & 3 deletions engine/lib/generic_printer/generic_printer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ module Make (F : Features.T) (View : Concrete_ident.VIEW_API) = struct
| TAssociatedType _ -> string "assoc_type!()"
| TOpaque _ -> string "opaque_type!()"
| TApp _ -> super#ty ctx ty
| TDyn _ -> string "" (* TODO *)
| TDyn _ -> empty (* TODO *)

method! expr' : par_state -> expr' fn =
fun ctx e ->
Expand Down Expand Up @@ -429,11 +429,20 @@ module Make (F : Features.T) (View : Concrete_ident.VIEW_API) = struct
method generic_params : generic_param list fn =
separate_map comma print#generic_param >> group >> angles

(*Option.map ~f:(...) guard |> Option.value ~default:empty*)
method arm' : arm' fn =
fun { arm_pat; body } ->
fun { arm_pat; body; guard } ->
let pat = print#pat_at Arm_pat arm_pat |> group in
let body = print#expr_at Arm_body body in
pat ^^ string " => " ^^ body ^^ comma
let guard =
Option.map
~f:(fun { guard = IfLet { lhs; rhs; _ }; _ } ->
string " if let " ^^ print#pat_at Arm_pat lhs ^^ string " = "
^^ print#expr_at Arm_body rhs)
guard
|> Option.value ~default:empty
in
pat ^^ guard ^^ string " => " ^^ body ^^ comma
end
end

Expand Down
54 changes: 34 additions & 20 deletions engine/lib/import_thir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -418,20 +418,11 @@ end) : EXPR = struct
{
arms =
[
{ arm = { arm_pat = lhs; body }; span = lhs_body_span };
{
arm =
{
arm_pat =
{
p = PWild;
span = else_block.span;
typ = lhs.typ;
};
body = { else_block with typ = body.typ };
};
span = else_block.span;
};
U.make_arm lhs body lhs_body_span;
U.make_arm
{ p = PWild; span = else_block.span; typ = lhs.typ }
{ else_block with typ = body.typ }
else_block.span;
];
scrutinee = rhs;
}
Expand Down Expand Up @@ -487,12 +478,10 @@ end) : EXPR = struct
Option.value ~default:(U.unit_expr span)
@@ Option.map ~f:c_expr else_opt
in
let arm_then =
{ arm = { arm_pat; body = then_ }; span = then_.span }
in
let arm_then = U.make_arm arm_pat then_ then_.span in
let arm_else =
let arm_pat = { arm_pat with p = PWild } in
{ arm = { arm_pat; body = else_ }; span = else_.span }
U.make_arm arm_pat else_ else_.span
in
Match { scrutinee; arms = [ arm_then; arm_else ] }
| If { cond; else_opt; then'; _ } ->
Expand Down Expand Up @@ -1091,7 +1080,31 @@ end) : EXPR = struct
let arm_pat = c_pat arm.pattern in
let body = c_expr arm.body in
let span = Span.of_thir arm.span in
{ arm = { arm_pat; body }; span }
let guard =
Option.map
~f:(fun (e : Thir.decorated_for__expr_kind) ->
let guard =
match e.contents with
| Let { expr; pat } ->
IfLet
{
lhs = c_pat pat;
rhs = c_expr expr;
witness = W.match_guard;
}
| _ ->
IfLet
{
lhs =
{ p = PConstant { lit = Bool true }; span; typ = TBool };
rhs = c_expr e;
witness = W.match_guard;
}
in
{ guard; span = Span.of_thir e.span })
arm.guard
in
{ arm = { arm_pat; body; guard }; span }

and c_param span (param : Thir.param) : param =
{
Expand Down Expand Up @@ -1327,7 +1340,8 @@ let cast_of_enum typ_name generics typ thir_span
in
(Exp e, (pat, acc)))
|> List.map ~f:(Fn.id *** function Exp e -> e | Lit n -> to_expr n)
|> List.map ~f:(fun (arm_pat, body) -> { arm = { arm_pat; body }; span })
|> List.map ~f:(fun (arm_pat, body) ->
{ arm = { arm_pat; body; guard = None }; span })
in
let scrutinee_var =
Local_ident.{ name = "x"; id = Local_ident.mk_id Expr (-1) }
Expand Down
9 changes: 5 additions & 4 deletions engine/lib/phases/phase_cf_into_monads.ml
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,11 @@ struct
| Match { scrutinee; arms } ->
let arms =
List.map
~f:(fun { arm = { arm_pat; body = a }; span } ->
~f:(fun { arm = { arm_pat; body = a; guard }; span } ->
let b = dexpr a in
let m = KnownMonads.from_typ dty a.typ b.typ in
(m, (dpat arm_pat, span, b)))
let g = Option.map ~f:dguard guard in
(m, (dpat arm_pat, span, b, g)))
arms
in
let arms =
Expand All @@ -177,10 +178,10 @@ struct
|> List.reduce_exn ~f:(KnownMonads.lub span)
in
List.map
~f:(fun (mself, (arm_pat, span, body)) ->
~f:(fun (mself, (arm_pat, span, body, guard)) ->
let body = KnownMonads.lift "Match" body mself.monad m in
let arm_pat = { arm_pat with typ = body.typ } in
({ arm = { arm_pat; body }; span } : B.arm))
({ arm = { arm_pat; body; guard }; span } : B.arm))
arms
in
let typ =
Expand Down
Loading
Loading