Skip to content

Commit

Permalink
Allow mixing CPS and non-CPS functions
Browse files Browse the repository at this point in the history
- An escaping function does not need to be in CPS
- A CPS call site can call non-CPS functions
  • Loading branch information
vouillon authored and hhugo committed Jan 16, 2025
1 parent d170d27 commit 830a2ad
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 81 deletions.
23 changes: 15 additions & 8 deletions compiler/lib/effects.ml
Original file line number Diff line number Diff line change
Expand Up @@ -542,15 +542,22 @@ let rewrite_instr ~st (instr : instr) : instr =
(Extern "caml_alloc_dummy_function", [ size; Pc (Int (Targetint.succ a)) ])
)
| _ -> assert false)
| Let (x, Apply { f; args; _ }) when not (Var.Set.mem x st.cps_needed) ->
(* At the moment, we turn into CPS any function not called with
the right number of parameter *)
assert (
(* If this function is unknown to the global flow analysis, then it was
| Let (x, Apply { f; args; exact }) when not (Var.Set.mem x st.cps_needed) ->
if double_translate ()
then
let exact =
(* If this function is unknown to the global flow analysis, then it was
introduced by the lambda lifting and we don't have exactness info any more. *)
Var.idx f >= Var.Tbl.length st.flow_info.info_approximation
|| Global_flow.exact_call st.flow_info f (List.length args));
Let (x, Apply { f; args; exact = true })
exact
|| Var.idx f < Var.Tbl.length st.flow_info.info_approximation
&& Global_flow.exact_call st.flow_info f (List.length args)
in
Let (x, Apply { f; args; exact })
else (
(* At the moment, we turn into CPS any function not called with
the right number of parameter *)
assert (Global_flow.exact_call st.flow_info f (List.length args));
Let (x, Apply { f; args; exact = true }))
| Let (_, e) when effect_primitive_or_application e ->
(* For the CPS target, applications of CPS functions and effect primitives require
more work (allocating a continuation and/or modifying end-of-block branches) and
Expand Down
19 changes: 16 additions & 3 deletions compiler/lib/generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,6 @@ let parallel_renaming loc back_edge params args continuation queue =
let apply_fun_raw =
let cps_field = Utf8_string.of_string_exn "cps" in
fun ctx f params exact trampolined cps loc ->
let n = List.length params in
let apply_directly f params =
(* Make sure we are performing a regular call, not a (slower)
method call *)
Expand All @@ -928,7 +927,7 @@ let apply_fun_raw =
J.call (J.dot f (Utf8_string.of_string_exn "call")) (s_var "null" :: params) loc
| _ -> J.call f params loc
in
let apply =
let apply ~cps f params =
(* Adapt if [f] is a (direct-style, CPS) closure pair *)
let real_closure =
match Config.effects () with
Expand All @@ -954,7 +953,7 @@ let apply_fun_raw =
( J.Eq
, J.dot real_closure l
, J.dot real_closure (Utf8_string.of_string_exn "length") ) )
, int n )
, int (List.length params) )
, apply_directly real_closure params
, J.call
(* Note: when double translation is enabled, [caml_call_gen*] functions takes a two-version function *)
Expand All @@ -967,6 +966,20 @@ let apply_fun_raw =
[ f; J.array params ]
J.N )
in
let apply =
match Config.effects () with
| `Double_translation when cps ->
let n = List.length params in
J.ECond
( J.EDot (f, J.ANormal, cps_field)
, apply ~cps:true f params
, J.call
(List.nth params (n - 1))
[ apply ~cps:false f (fst (List.take (n - 1) params)) ]
J.N )
| `Double_translation | `Cps | `Disabled -> apply ~cps f params
| `Jspi -> assert false
in
if trampolined
then (
assert (cps_transform ());
Expand Down
10 changes: 9 additions & 1 deletion compiler/lib/partial_cps_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ open! Stdlib

let times = Debug.find "times"

let double_translate () =
match Config.effects () with
| `Disabled | `Jspi -> assert false
| `Cps -> false
| `Double_translation -> true

open Code

let add_var = Var.ISet.add
Expand Down Expand Up @@ -85,7 +91,7 @@ let block_deps ~info ~vars ~tail_deps ~deps ~blocks ~fun_name pc =
add_dep deps x g;
(* Conversally, if a call point is in CPS then all
called functions must be in CPS *)
add_dep deps g x)
if not (double_translate ()) then add_dep deps g x)
known)
| Let (x, Prim (Extern ("%perform" | "%reperform" | "%resume"), _)) -> (
add_var vars x;
Expand Down Expand Up @@ -145,6 +151,8 @@ let cps_needed ~info ~in_mutual_recursion ~rev_deps st x =
| Top -> true
| Values { others; _ } -> others)
| Expr (Closure _) ->
(not (double_translate ()))
&&
(* If a function escapes, it must be in CPS *)
Var.ISet.mem info.Global_flow.info_may_escape x
| Expr (Prim (Extern ("%perform" | "%reperform" | "%resume"), _)) ->
Expand Down
114 changes: 53 additions & 61 deletions compiler/tests-compiler/double-translation/direct_calls.ml
Original file line number Diff line number Diff line change
Expand Up @@ -86,64 +86,66 @@ let%expect_test "direct calls with --effects=double-translation" =
}
function caml_trampoline_cps_call2(f, a0, a1){
return runtime.caml_stack_check_depth()
? (f.cps.l
>= 0
? f.cps.l
: f.cps.l = f.cps.length)
=== 2
? f.cps.call(null, a0, a1)
: runtime.caml_call_gen_cps(f, [a0, a1])
? f.cps
? (f.cps.l
>= 0
? f.cps.l
: f.cps.l = f.cps.length)
=== 2
? f.cps.call(null, a0, a1)
: runtime.caml_call_gen_cps(f, [a0, a1])
: a1
((f.l >= 0 ? f.l : f.l = f.length) === 1
? f(a0)
: runtime.caml_call_gen(f, [a0]))
: runtime.caml_trampoline_return(f, [a0, a1], 0);
}
function caml_exact_trampoline_cps_call(f, a0, a1){
return runtime.caml_stack_check_depth()
? f.cps.call(null, a0, a1)
? f.cps ? f.cps.call(null, a0, a1) : a1(f(a0))
: runtime.caml_trampoline_return(f, [a0, a1], 0);
}
function caml_trampoline_cps_call3(f, a0, a1, a2){
return runtime.caml_stack_check_depth()
? (f.cps.l
>= 0
? f.cps.l
: f.cps.l = f.cps.length)
=== 3
? f.cps.call(null, a0, a1, a2)
: runtime.caml_call_gen_cps(f, [a0, a1, a2])
? f.cps
? (f.cps.l
>= 0
? f.cps.l
: f.cps.l = f.cps.length)
=== 3
? f.cps.call(null, a0, a1, a2)
: runtime.caml_call_gen_cps(f, [a0, a1, a2])
: a2
((f.l >= 0 ? f.l : f.l = f.length) === 2
? f(a0, a1)
: runtime.caml_call_gen(f, [a0, a1]))
: runtime.caml_trampoline_return(f, [a0, a1, a2], 0);
}
function caml_exact_trampoline_cps_call$0(f, a0, a1, a2){
return runtime.caml_stack_check_depth()
? f.cps.call(null, a0, a1, a2)
? f.cps ? f.cps.call(null, a0, a1, a2) : a2(f(a0, a1))
: runtime.caml_trampoline_return(f, [a0, a1, a2], 0);
}
var
dummy = 0,
global_data = runtime.caml_get_global_data(),
_D_ = [0, [4, 0, 0, 0, 0], caml_string_of_jsbytes("%d")],
_z_ = [0, [4, 0, 0, 0, 0], caml_string_of_jsbytes("%d")],
cst_a$0 = caml_string_of_jsbytes("a"),
cst_a = caml_string_of_jsbytes("a"),
Stdlib = global_data.Stdlib,
Stdlib_Printf = global_data.Stdlib__Printf;
function f$1(){
function test1(param){
function f(g, x){
try{caml_call1(g, dummy); return;}
catch(e$0){
var e = caml_wrap_exception(e$0);
throw caml_maybe_attach_backtrace(e, 0);
}
}
return f;
}
function _d_(){return function(x){};}
function _f_(){return function(x){};}
function test1$0(param){var f = f$1(); f(_d_()); f(_f_()); return 0;}
function test1$1(param, cont){
var f = f$1();
f(_d_());
f(_f_());
return cont(0);
f(function(x){});
f(function(x){});
return 0;
}
var test1 = caml_cps_closure(test1$0, test1$1);
function f$0(){
function f$0(g, x){
try{caml_call1(g, x); return;}
Expand All @@ -159,15 +161,13 @@ let%expect_test "direct calls with --effects=double-translation" =
return raise(e);
});
return caml_exact_trampoline_cps_call
(g, x, function(_P_){caml_pop_trap(); return cont();});
(g, x, function(_K_){caml_pop_trap(); return cont();});
}
var f = caml_cps_closure(f$0, f$1);
return f;
}
function _k_(){
return caml_cps_closure(function(x){}, function(x, cont){return cont();});
}
function _m_(){
function _h_(){return function(x){};}
function _j_(){
return caml_cps_closure
(function(x){return caml_call2(Stdlib[28], x, cst_a$0);},
function(x, cont){
Expand All @@ -176,39 +176,31 @@ let%expect_test "direct calls with --effects=double-translation" =
}
function test2$0(param){
var f = f$0();
f(_k_(), 7);
f(_m_(), cst_a);
f(_h_(), 7);
f(_j_(), cst_a);
return 0;
}
function test2$1(param, cont){
var f = f$0();
return caml_exact_trampoline_cps_call$0
(f,
_k_(),
_h_(),
7,
function(_N_){
function(_I_){
return caml_exact_trampoline_cps_call$0
(f, _m_(), cst_a, function(_O_){return cont(0);});
(f, _j_(), cst_a, function(_J_){return cont(0);});
});
}
var test2 = caml_cps_closure(test2$0, test2$1);
function F$0(){
function test3(x){
function F(symbol){function f(x){return x + 1 | 0;} return [0, f];}
return F;
}
function test3$0(x){
var F = F$0(), M1 = F(), M2 = F(), _M_ = caml_call1(M2[1], 2);
return [0, caml_call1(M1[1], 1), _M_];
}
function test3$1(x, cont){
var F = F$0(), M1 = F(), M2 = F(), _L_ = M2[1].call(null, 2);
return cont([0, M1[1].call(null, 1), _L_]);
var M1 = F(), M2 = F(), _H_ = caml_call1(M2[1], 2);
return [0, caml_call1(M1[1], 1), _H_];
}
var test3 = caml_cps_closure(test3$0, test3$1);
function f(){
function f$0(x){return caml_call2(Stdlib_Printf[2], _D_, x);}
function f$0(x){return caml_call2(Stdlib_Printf[2], _z_, x);}
function f$1(x, cont){
return caml_trampoline_cps_call3(Stdlib_Printf[2], _D_, x, cont);
return caml_trampoline_cps_call3(Stdlib_Printf[2], _z_, x, cont);
}
var f = caml_cps_closure(f$0, f$1);
return f;
Expand All @@ -224,7 +216,7 @@ let%expect_test "direct calls with --effects=double-translation" =
return caml_exact_trampoline_cps_call
(M1[1],
1,
function(_K_){
function(_G_){
return caml_exact_trampoline_cps_call(M2[1], 2, cont);
});
}
Expand All @@ -241,18 +233,18 @@ let%expect_test "direct calls with --effects=double-translation" =
tuple = recfuncs(x),
f = tuple[2],
h = tuple[1],
_I_ = h(100),
_J_ = f(12) + _I_ | 0;
return caml_call1(Stdlib[44], _J_);
_E_ = h(100),
_F_ = f(12) + _E_ | 0;
return caml_call1(Stdlib[44], _F_);
}
function g$1(x, cont){
var
tuple = recfuncs(x),
f = tuple[2],
h = tuple[1],
_G_ = h(100),
_H_ = f(12) + _G_ | 0;
return caml_trampoline_cps_call2(Stdlib[44], _H_, cont);
_C_ = h(100),
_D_ = f(12) + _C_ | 0;
return caml_trampoline_cps_call2(Stdlib[44], _D_, cont);
}
var g = caml_cps_closure(g$0, g$1);
return g;
Expand All @@ -263,9 +255,9 @@ let%expect_test "direct calls with --effects=double-translation" =
return caml_exact_trampoline_cps_call
(g$0,
42,
function(_E_){
function(_A_){
return caml_exact_trampoline_cps_call
(g$0, - 5, function(_F_){return cont(0);});
(g$0, - 5, function(_B_){return cont(0);});
});
}
var
Expand Down
19 changes: 12 additions & 7 deletions compiler/tests-compiler/double-translation/effects_toplevel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,18 @@ let%expect_test "test-compiler/lib-effects/test1.ml" =
}
function caml_trampoline_cps_call2(f, a0, a1){
return runtime.caml_stack_check_depth()
? (f.cps.l
>= 0
? f.cps.l
: f.cps.l = f.cps.length)
=== 2
? f.cps.call(null, a0, a1)
: runtime.caml_call_gen_cps(f, [a0, a1])
? f.cps
? (f.cps.l
>= 0
? f.cps.l
: f.cps.l = f.cps.length)
=== 2
? f.cps.call(null, a0, a1)
: runtime.caml_call_gen_cps(f, [a0, a1])
: a1
((f.l >= 0 ? f.l : f.l = f.length) === 1
? f(a0)
: runtime.caml_call_gen(f, [a0]))
: runtime.caml_trampoline_return(f, [a0, a1], 0);
}
var
Expand Down
5 changes: 4 additions & 1 deletion runtime/js/effect.js
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,10 @@ function caml_get_cps_fun(f) {
//If: effects
//If: doubletranslate
function caml_get_cps_fun(f) {
return f.cps;
// This function is only used to get the effect handler. If the
// effect handler has no CPS function, we know that we can directly
// call the direct version instead.
return f.cps ? f.cps : f;
}

//Provides: caml_alloc_stack
Expand Down
4 changes: 4 additions & 0 deletions runtime/js/stdlib.js
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ var caml_call_gen_tuple = (function () {
}
}
function caml_call_gen_cps(f, args) {
if (!f.cps) {
var k = args.pop();
return k(caml_call_gen_direct(f, args));
}
var n = f.cps.l >= 0 ? f.cps.l : (f.cps.l = f.cps.length);
var argsLen = args.length;
var d = n - argsLen;
Expand Down

0 comments on commit 830a2ad

Please sign in to comment.