diff --git a/src/stabilization.jl b/src/stabilization.jl index 5f21a0c..b35b37e 100644 --- a/src/stabilization.jl +++ b/src/stabilization.jl @@ -227,17 +227,20 @@ function _stabilize_fnc( print_name = "anonymous function" end - args, destructurings = let - args_destructurings = map(sanitize_arg_for_stability_check, func[:args]) + args, destructurings, typevars = let genwhereparam=true #(codegen_level == "min") + args_destructurings_typevars = map( + arg -> sanitize_arg_for_stability_check(arg; genwhereparam), func[:args]) ( - map(first, args_destructurings), - filter(!isnothing, map(last, args_destructurings)), + map(first, args_destructurings_typevars), + filter(!isnothing, map(adt -> adt[2], args_destructurings_typevars)), + filter(!isnothing, map(last, args_destructurings_typevars)), ) end kwargs = func[:kwargs] where_params = func[:whereparams] func[:args] = args + func[:whereparams] = (where_params..., typevars...) arg_symbols = map(extract_symbol, args) kwarg_symbols = map(extract_symbol, kwargs) diff --git a/src/utils.jl b/src/utils.jl index 715dba6..59919c5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -54,37 +54,50 @@ Amend args that do not have a symbol or are destructured in the signature. Retur arg expression and, if needed, an equivalent destructuring assignment for the body. """ function sanitize_arg_for_stability_check( - ex::Symbol -)::Tuple{Union{Expr,Symbol},Union{Expr,Nothing}} - return ex, nothing + ex::Symbol; genwhereparam +)::Tuple{Union{Expr,Symbol},Union{Expr,Nothing},Union{Symbol,Nothing}} + genwhereparam || return ex, nothing, nothing + whereparam = gensym("T") + return Expr(:(::), ex, whereparam), nothing, whereparam end function sanitize_arg_for_stability_check( - ex::Expr -)::Tuple{Union{Expr,Symbol},Union{Expr,Nothing}} + ex::Expr; genwhereparam +)::Tuple{Union{Expr,Symbol},Union{Expr,Nothing},Union{Symbol,Nothing}} head, args = ex.head, ex.args if head == :(tuple) # (Base case) # matches things like (x,) and (; x) arg = gensym("arg") - return arg, Expr(:(=), ex, arg) + arg_ex, _, whereparam = sanitize_arg_for_stability_check(arg; genwhereparam) + return arg_ex, Expr(:(=), ex, arg), whereparam elseif head == :(::) && length(args) == 1 # (Base case) # matches things like `::T` arg = gensym("arg") - return Expr(head, arg, only(args)), nothing + return Expr(head, arg, only(args)), nothing, nothing elseif head == :(...) && length(args) == 1 # (Composite case) # matches things like `::Int...` - arg_ex, destructure_ex = sanitize_arg_for_stability_check(only(args)) - return Expr(head, arg_ex), destructure_ex - elseif head in (:kw, :(::)) && length(args) == 2 + arg_ex, destructure_ex = sanitize_arg_for_stability_check( + only(args); genwhereparam=false + ) + return Expr(head, arg_ex), destructure_ex, nothing + elseif head == :(::) && length(args) == 2 # (Composite case) - # :(::) => matches things like `(x,)::T` and `(; x)::T` - # :kw => matches things like `::Type{T}=MyType` - arg_ex, destructure_ex = sanitize_arg_for_stability_check(first(args)) - return Expr(head, arg_ex, last(args)), destructure_ex + # matches things like `(x,)::T` and `(; x)::T` + arg_ex, destructure_ex = sanitize_arg_for_stability_check( + first(args); genwhereparam=false + ) + return Expr(head, arg_ex, last(args)), destructure_ex, nothing + elseif head == :kw && length(args) == 2 + # (Composite case) + # matches things like `::Type{T}=MyType` + arg_ex, destructure_ex, whereparam = sanitize_arg_for_stability_check( + first(args); genwhereparam + ) + return Expr(head, arg_ex, last(args)), destructure_ex, whereparam else - return ex, nothing + return ex, nothing, nothing end end