Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type variables and quantifiers to the AST #775

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions compiler/dcalc/from_scopelang.ml
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ type 'm ctx = {
date_rounding : date_rounding;
}

let mark_tany m pos = Expr.with_ty m (Mark.add pos TAny) ~pos
let mark_tany m pos = Expr.with_ty m (Type.any pos) ~pos

(* Expression argument is used as a type witness, its type and positions aren't
used *)
let pos_mark_mk (type a m) (e : (a, m) gexpr) :
(Pos.t -> m mark) * ((_, Pos.t) Mark.ed -> m mark) =
let pos_mark pos =
Expr.map_mark (fun _ -> pos) (fun _ -> TAny, pos) (Mark.get e)
Expr.map_mark (fun _ -> pos) (fun _ -> Type.any pos) (Mark.get e)
in
let pos_mark_as e = pos_mark (Mark.get e) in
pos_mark, pos_mark_as
Expand Down Expand Up @@ -129,7 +129,7 @@ let tag_with_log_entry

if Global.options.trace <> None then
let pos = Expr.pos e in
Expr.eappop ~op:(Log (l, markings), pos) ~tys:[TAny, pos] ~args:[e] m
Expr.eappop ~op:(Log (l, markings), pos) ~tys:[Type.any pos] ~args:[e] m
else e

(* In a list of exceptions, it is normally an error if more than a single one
Expand Down Expand Up @@ -458,9 +458,9 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm S.expr) : 'm Ast.expr boxed =
https://github.com/CatalaLang/catala/pull/280#discussion_r898851693. *)
let retrieve_out_typ_or_any var vars =
let _, typ, _ = ScopeVar.Map.find (Mark.remove var) vars in
match typ with
| TArrow (_, marked_output_typ) -> Mark.remove marked_output_typ
| _ -> TAny
match Type.unquantify (typ, Expr.pos f) with
| TArrow (_, marked_output_typ), _ -> Mark.remove marked_output_typ
| _, pos -> Mark.remove (Type.any pos)
in
match Mark.remove f with
| ELocation (ScopelangScopeVar { name = var }) ->
Expand All @@ -474,7 +474,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm S.expr) : 'm Ast.expr boxed =
| _ ->
Message.error ~pos:(Expr.pos e)
"Application of non-function toplevel variable")
| _ -> TAny
| _ -> Mark.remove (Type.any (Expr.pos f))
in
(* Message.debug "new_args %d, input_typs: %d, input_typs %a" (List.length
new_args) (List.length input_typs) (Format.pp_print_list Print.typ_debug)
Expand Down
4 changes: 2 additions & 2 deletions compiler/dcalc/invariants.ml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ let rec check_typ_no_default ctx ty =
List.for_all (check_typ_no_default ctx) args && check_typ_no_default ctx res
| TArray ty -> check_typ_no_default ctx ty
| TDefault _t -> false
| TAny ->
| TVar _ | TAny _ ->
Message.error ~internal:true
"Some Dcalc invariants are invalid: TAny was found whereas it should be \
fully resolved."
Expand Down Expand Up @@ -188,7 +188,7 @@ let invariant_typing_defaults () : string * invariant_expr =
fun ctx e ->
if check_type_root ctx (Expr.ty e) then Pass
else (
Message.warning "typing error %a@." (Print.typ ctx) (Expr.ty e);
Message.warning "typing error %a@." Print.typ (Expr.ty e);
Fail) )

let check_all_invariants prgm =
Expand Down
92 changes: 49 additions & 43 deletions compiler/desugared/from_surface.ml
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,15 @@ let translate_binop :
~args:[lhs; rhs]
(Untyped { pos })
in
let tany () = Mark.remove (Type.any op_pos) in
match op with
| S.And -> op_expr And [TLit TBool; TLit TBool]
| S.Or -> op_expr Or [TLit TBool; TLit TBool]
| S.Xor -> op_expr Xor [TLit TBool; TLit TBool]
| S.Add k ->
op_expr Add
(match k with
| S.KPoly -> [TAny; TAny]
| S.KPoly -> [tany (); tany ()]
| S.KInt -> [TLit TInt; TLit TInt]
| S.KDec -> [TLit TRat; TLit TRat]
| S.KMoney -> [TLit TMoney; TLit TMoney]
Expand All @@ -63,7 +64,7 @@ let translate_binop :
| S.Sub k ->
op_expr Sub
(match k with
| S.KPoly -> [TAny; TAny]
| S.KPoly -> [tany (); tany ()]
| S.KInt -> [TLit TInt; TLit TInt]
| S.KDec -> [TLit TRat; TLit TRat]
| S.KMoney -> [TLit TMoney; TLit TMoney]
Expand All @@ -72,7 +73,7 @@ let translate_binop :
| S.Mult k ->
op_expr Mult
(match k with
| S.KPoly -> [TAny; TAny]
| S.KPoly -> [tany (); tany ()]
| S.KInt -> [TLit TInt; TLit TInt]
| S.KDec -> [TLit TRat; TLit TRat]
| S.KMoney -> [TLit TMoney; TLit TRat]
Expand All @@ -83,7 +84,7 @@ let translate_binop :
| S.Div k ->
op_expr Div
(match k with
| S.KPoly -> [TAny; TAny]
| S.KPoly -> [tany (); tany ()]
| S.KInt -> [TLit TInt; TLit TInt]
| S.KDec -> [TLit TRat; TLit TRat]
| S.KMoney -> [TLit TMoney; TLit TMoney]
Expand All @@ -100,17 +101,22 @@ let translate_binop :
| S.Gte _ -> Gte
| _ -> assert false)
(match k with
| S.KPoly -> [TAny; TAny]
| S.KPoly ->
let a = tany () in
[a; a]
| S.KInt -> [TLit TInt; TLit TInt]
| S.KDec -> [TLit TRat; TLit TRat]
| S.KMoney -> [TLit TMoney; TLit TMoney]
| S.KDate -> [TLit TDate; TLit TDate]
| S.KDuration -> [TLit TDuration; TLit TDuration])
| S.Eq ->
op_expr Eq [TAny; TAny]
let a = tany () in
op_expr Eq [a; a]
(* This is a truly polymorphic operator, not an overload *)
| S.Neq -> assert false (* desugared already *)
| S.Concat -> op_expr Concat [TArray (TAny, op_pos); TArray (TAny, op_pos)]
| S.Concat ->
let a = Type.any op_pos in
op_expr Concat [TArray a; TArray a]

let translate_unop ((op, op_pos) : S.unop Mark.pos) pos arg : Ast.expr boxed =
let op_expr op ty =
Expand All @@ -124,7 +130,7 @@ let translate_unop ((op, op_pos) : S.unop Mark.pos) pos arg : Ast.expr boxed =
| S.Minus k ->
op_expr Minus
(match k with
| S.KPoly -> TAny
| S.KPoly -> Mark.remove (Type.any op_pos)
| S.KInt -> TLit TInt
| S.KDec -> TLit TRat
| S.KMoney -> TLit TMoney
Expand Down Expand Up @@ -256,14 +262,14 @@ let rec translate_expr
let rhs = zip names r in
let rtys, explode =
match List.length r with
| 1 -> (TAny, pos), fun e -> [e]
| 1 -> Type.any pos, fun e -> [e]
| size ->
( (TTuple (List.map (fun _ -> TAny, pos) r), pos),
( (TTuple (List.map (fun _ -> Type.any pos) r), pos),
fun e ->
List.init size (fun index ->
Expr.etupleaccess ~e ~size ~index m) )
in
let tys = [TAny, pos; rtys] in
let tys = [Type.any pos; rtys] in
let f_join =
let x1 = Var.make name1 in
let x2 =
Expand All @@ -275,7 +281,7 @@ let rec translate_expr
tys pos
in
Expr.eappop ~op:(Map2, opos) ~args:[f_join; l1; rhs]
~tys:((TAny, pos) :: List.map (fun ty -> TArray ty, pos) tys)
~tys:(Type.any pos :: List.map (fun ty -> TArray ty, pos) tys)
m
in
zip names ls
Expand Down Expand Up @@ -501,11 +507,11 @@ let rec translate_expr
| FunCall ((Builtin b, pos), [arg]) ->
let op, ty =
match b with
| S.ToInteger -> Op.ToInt, TAny
| S.ToDecimal -> Op.ToRat, TAny
| S.ToMoney -> Op.ToMoney, TAny
| S.Round -> Op.Round, TAny
| S.Cardinal -> Op.Length, TArray (TAny, pos)
| S.ToInteger -> Op.ToInt, Mark.remove (Type.any pos)
| S.ToDecimal -> Op.ToRat, Mark.remove (Type.any pos)
| S.ToMoney -> Op.ToMoney, Mark.remove (Type.any pos)
| S.Round -> Op.Round, Mark.remove (Type.any pos)
| S.Cardinal -> Op.Length, TArray (Type.any pos)
| S.GetDay -> Op.GetDay, TLit TDate
| S.GetMonth -> Op.GetMonth, TLit TDate
| S.GetYear -> Op.GetYear, TLit TDate
Expand Down Expand Up @@ -568,7 +574,7 @@ let rec translate_expr
Ident.Map.add (Mark.remove x) (Mark.remove v) local_vars)
local_vars xs m_xs
in
let taus = List.map (fun x -> TAny, Mark.get x) xs in
let taus = List.map (fun x -> Type.any (Mark.get x)) xs in
(* This type will be resolved in Scopelang.Desambiguation *)
let f = Expr.make_abs m_xs (rec_helper ~local_vars e2) taus pos in
Expr.eapp ~f ~args:[rec_helper e1] ~tys:[] emark
Expand Down Expand Up @@ -760,7 +766,7 @@ let rec translate_expr
let f_pred =
Expr.make_abs params
(rec_helper ~local_vars predicate)
(List.map (fun _ -> TAny, pos) params)
(List.map (fun _ -> Type.any pos) params)
pos
in
let f_pred =
Expand All @@ -773,14 +779,14 @@ let rec translate_expr
Var.make (String.concat "_" (List.map Mark.remove param_names))
in
let x = Expr.evar v emark in
let tys = List.map (fun _ -> TAny, pos) param_names in
let tys = List.map (fun _ -> Type.any pos) param_names in
Expr.make_abs
[Mark.add Pos.no_pos v]
(Expr.make_app f_pred
(List.init nb_args (fun i ->
Expr.etupleaccess ~e:x ~index:i ~size:nb_args emark))
tys pos)
[TAny, pos]
[Type.any pos]
pos
in
Expr.eappop
Expand All @@ -789,7 +795,7 @@ let rec translate_expr
| S.Map _, pos -> Map, pos
| S.Filter _, pos -> Filter, pos
| _ -> assert false)
~tys:[TAny, pos; TAny, pos]
~tys:[Type.any pos; Type.any pos]
~args:[f_pred; collection] emark
| CollectionOp ((Fold { f; init }, opos), collection) ->
let acc_names, param_names, fct = f in
Expand All @@ -812,7 +818,7 @@ let rec translate_expr
let f_proc =
Expr.make_abs (accs @ params)
(rec_helper ~local_vars fct)
(List.map (fun _ -> TAny, pos) (accs @ params))
(List.map (fun _ -> Type.any pos) (accs @ params))
pos
in
let f_proc =
Expand All @@ -832,7 +838,7 @@ let rec translate_expr
in
let x_acc = Expr.evar v_acc emark in
let x_param = Expr.evar v_param emark in
let tys = List.init (nb_accs + nb_args) (fun _ -> TAny, pos) in
let tys = List.init (nb_accs + nb_args) (fun _ -> Type.any pos) in
Expr.make_ghost_abs [v_acc; v_param]
(Expr.make_app f_proc
((if nb_accs = 1 then [x_acc]
Expand All @@ -849,11 +855,11 @@ let rec translate_expr
Expr.etupleaccess ~e:x_param ~index ~size:nb_args emark)
params)
tys pos)
[TAny, pos; TAny, pos]
[Type.any pos; Type.any pos]
pos
in
Expr.eappop ~op:(Fold, opos)
~tys:[TAny, pos; TAny, pos; TAny, pos]
~tys:[Type.any pos; Type.any pos; Type.any pos]
~args:[f_proc; init; collection] emark
| CollectionOp
( ( S.AggregateArgExtremum { max; default; f = param_names, predicate },
Expand All @@ -872,7 +878,7 @@ let rec translate_expr
in
let cmp_op = if max then Op.Gt, opos else Op.Lt, opos in
let f_pred =
Expr.make_abs params (rec_helper ~local_vars predicate) [TAny, pos] pos
Expr.make_abs params (rec_helper ~local_vars predicate) [Type.any pos] pos
in
let add_weight_f =
let vs =
Expand All @@ -882,7 +888,7 @@ let rec translate_expr
let x = match xs with [x] -> x | xs -> Expr.etuple xs emark in
Expr.make_ghost_abs vs
(Expr.make_tuple [x; Expr.eapp ~f:f_pred ~args:xs ~tys:[] emark] emark)
[TAny, pos]
[Type.any pos]
pos
in
let reduce_f =
Expand All @@ -892,27 +898,27 @@ let rec translate_expr
Expr.make_ghost_abs [v1; v2]
(Expr.eifthenelse
(Expr.eappop ~op:cmp_op
~tys:[TAny, pos_dft; TAny, pos_dft]
~tys:[Type.any pos_dft; Type.any pos_dft]
~args:
[
Expr.etupleaccess ~e:x1 ~index:1 ~size:2 emark;
Expr.etupleaccess ~e:x2 ~index:1 ~size:2 emark;
]
emark)
x1 x2 emark)
[TAny, pos; TAny, pos]
[Type.any pos; Type.any pos]
pos
in
let weights_var = Var.make "weights" in
let default = Expr.make_app add_weight_f [default] [TAny, pos] pos_dft in
let default = Expr.make_app add_weight_f [default] [Type.any pos] pos_dft in
let weighted_result =
Expr.make_let_in (Mark.ghost weights_var)
(TArray (TTuple [TAny, pos; TAny, pos], pos), pos)
(TArray (TTuple [Type.any pos; Type.any pos], pos), pos)
(Expr.eappop ~op:(Map, opos)
~tys:[TAny, pos; TArray (TAny, pos), pos]
~tys:[Type.any pos; TArray (Type.any pos), pos]
~args:[add_weight_f; collection] emark)
(Expr.eappop ~op:(Reduce, opos)
~tys:[TAny, pos; TAny, pos; TAny, pos]
~tys:[Type.any pos; Type.any pos; Type.any pos]
~args:[reduce_f; default; Expr.evar weights_var emark]
emark)
pos
Expand Down Expand Up @@ -950,10 +956,10 @@ let rec translate_expr
(Array.of_list (List.map Mark.remove vs))
(translate_binop op pos acc (rec_helper ~local_vars predicate))
in
Expr.eabs mvars vs_marks [TAny, pos; TAny, pos] emark
Expr.eabs mvars vs_marks [Type.any pos; Type.any pos] emark
in
Expr.eappop ~op:(Fold, opos)
~tys:[TAny, pos; TAny, pos; TAny, pos]
~tys:[Type.any pos; Type.any pos; Type.any pos]
~args:[f; init; collection] emark
| CollectionOp ((AggregateExtremum { max; default }, opos), collection) ->
let collection = rec_helper collection in
Expand All @@ -967,11 +973,11 @@ let rec translate_expr
let x2 = Expr.make_var v2 emark in
Expr.make_ghost_abs [v1; v2]
(Expr.eifthenelse (translate_binop (op, pos) pos x1 x2) x1 x2 emark)
[TAny, pos; TAny, pos]
[Type.any pos; Type.any pos]
pos
in
Expr.eappop ~op:(Reduce, opos)
~tys:[TAny, pos; TAny, pos; TAny, pos]
~tys:[Type.any pos; Type.any pos; Type.any pos]
~args:[op_f; default; collection]
emark
| CollectionOp ((AggregateSum { typ }, opos), collection) ->
Expand All @@ -997,11 +1003,11 @@ let rec translate_expr
let x2 = Expr.make_var v2 emark in
Expr.make_ghost_abs [v1; v2]
(translate_binop (S.Add KPoly, opos) pos x1 x2)
[TAny, pos; TAny, pos]
[Type.any pos; Type.any pos]
pos
in
Expr.eappop ~op:(Reduce, opos)
~tys:[TAny, pos; TAny, pos; TAny, pos]
~tys:[Type.any pos; Type.any pos; Type.any pos]
~args:[op_f; Expr.elit default_lit emark; collection]
emark
| CollectionOp ((Member { element = member }, opos), collection) ->
Expand All @@ -1018,7 +1024,7 @@ let rec translate_expr
~args:
[
Expr.eappop ~op:(Eq, opos)
~tys:[TAny, pos; TAny, pos]
~tys:[Type.any pos; Type.any pos]
~args:[member; param] emark;
acc;
]
Expand All @@ -1029,11 +1035,11 @@ let rec translate_expr
Expr.eabs
(Expr.bind (Array.of_list (List.map Mark.remove vars)) f_body)
(List.map Mark.get vars)
[TLit TBool, pos; TAny, pos]
[TLit TBool, pos; Type.any pos]
emark
in
Expr.eappop ~op:(Fold, opos)
~tys:[TAny, pos; TAny, pos; TAny, pos]
~tys:[Type.any pos; Type.any pos; Type.any pos]
~args:[f; init; collection] emark

and disambiguate_match_and_build_expression
Expand Down
4 changes: 2 additions & 2 deletions compiler/driver.ml
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ module Commands = struct
match ex_scope_opt with
| Some scope ->
let scope_uid = get_scope_uid prg.program_ctx scope in
Scopelang.Print.scope ~debug:options.Global.debug prg.program_ctx fmt
Scopelang.Print.scope ~debug:options.Global.debug fmt
(scope_uid, ScopeName.Map.find scope_uid prg.program_scopes);
Format.pp_print_newline fmt ()
| None ->
Expand Down Expand Up @@ -895,7 +895,7 @@ module Commands = struct
let prg, type_ordering, _ =
Passes.lcalc options ~includes ~optimize ~check_invariants ~autotest
~typed:Expr.typed ~closure_conversion ~keep_special_ops:true
~monomorphize_types:false ~expand_ops:true
~monomorphize_types:false ~expand_ops:false
~renaming:(Some Lcalc.To_ocaml.renaming)
in
let output_file, with_output =
Expand Down
Loading
Loading