diff --git a/compiler+runtime/src/cpp/jank/codegen/llvm_processor.cpp b/compiler+runtime/src/cpp/jank/codegen/llvm_processor.cpp index 9258fdea3..df9e1950d 100644 --- a/compiler+runtime/src/cpp/jank/codegen/llvm_processor.cpp +++ b/compiler+runtime/src/cpp/jank/codegen/llvm_processor.cpp @@ -628,6 +628,8 @@ namespace jank::codegen expr::function_arity const &arity) { auto const fn_expr(boost::get>(expr.recursion_ref.fn_ctx->fn->data)); + auto const &captures(fn_expr.captures()); + /* Named recursion is a special kind of call. We can't go always through a var, since there * may not be one. We can't just use the fn's name, since we could be recursing into a * different arity. Finally, we need to keep in account whether or not this fn is a closure. @@ -635,7 +637,29 @@ namespace jank::codegen * For named recursion calls, we don't use dynamic_call. We just call the generated C fn * directly. This doesn't impede interactivity, since the whole thing will be redefined * if a new fn is created. */ - auto const is_closure(!fn_expr.captures().empty()); + auto const is_closure(!captures.empty()); + + /* We may have a named recursion in a closure which crosses another function in order to + * recurse. For example: + * + * ```clojure + * (let [a 1] + * (fn foo [] + * (fn bar [] + * (println a) + * (foo)))) + * ``` + * + * Here, the `(foo)` call is a named recursion, but we're not actually in the `foo` fn. + * We need to "cross" `bar` in order to get back into `foo`. This is an important + * distinction, since the closure context for `foo` and `bar` may be different, such + * as if `bar` closes over more data than `foo` does. + * + * In this case of a named recursion which crosses a fn, we can't use the current fn's + * closure context. We need to build a new one. */ + auto const crosses_fn( + boost::get>(&expr.recursion_ref.fn_ctx->fn->data) + != boost::get>(&arity.fn_ctx->fn->data)); llvm::SmallVector arg_handles; llvm::SmallVector arg_types; @@ -651,8 +675,43 @@ namespace jank::codegen } else if(is_closure) { - arg_handles.emplace_back(ctx->builder->GetInsertBlock()->getParent()->getArg(0)); - arg_types.emplace_back(ctx->builder->getPtrTy()); + /* TODO: If nested closures all build their contexts on their parents, we can always + * pass a nested closure upward for a named recursion. This would require sorted captures + * based on lexical scope though, which is a big jump from what we currently have. */ + if(crosses_fn) + { + auto const &fn(boost::get>(expr.recursion_ref.fn_ctx->fn->data)); + std::vector const capture_types{ captures.size(), ctx->builder->getPtrTy() }; + auto const closure_ctx_type( + get_or_insert_struct_type(fmt::format("{}_context", munge(fn.unique_name)), + capture_types)); + + auto const malloc_fn_type( + llvm::FunctionType::get(ctx->builder->getPtrTy(), { ctx->builder->getInt64Ty() }, false)); + auto const malloc_fn(ctx->module->getOrInsertFunction("GC_malloc", malloc_fn_type)); + auto const closure_obj( + ctx->builder->CreateCall(malloc_fn, { llvm::ConstantExpr::getSizeOf(closure_ctx_type) })); + + size_t index{}; + for(auto const &capture : captures) + { + auto const field_ptr( + ctx->builder->CreateStructGEP(closure_ctx_type, closure_obj, index++)); + expr::local_reference const local_ref{ + expression_base{ {}, expression_position::value, fn.frame }, + capture.first, + *capture.second + }; + ctx->builder->CreateStore(gen(local_ref, arity), field_ptr); + } + arg_handles.emplace_back(closure_obj); + arg_types.emplace_back(ctx->builder->getPtrTy()); + } + else + { + arg_handles.emplace_back(ctx->builder->GetInsertBlock()->getParent()->getArg(0)); + arg_types.emplace_back(ctx->builder->getPtrTy()); + } } for(auto const &arg_expr : expr.arg_exprs) diff --git a/compiler+runtime/test/jank/form/fn/named-recur/pass-recur-across-fn.jank b/compiler+runtime/test/jank/form/fn/named-recur/pass-recur-across-fn.jank new file mode 100644 index 000000000..38c330085 --- /dev/null +++ b/compiler+runtime/test/jank/form/fn/named-recur/pass-recur-across-fn.jank @@ -0,0 +1,7 @@ +(let* [result :success + foo (fn* foo [loop?] + ((fn* bar [] + (if loop? + (foo false) + result))))] + (foo true))