Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempts at specialization transparency, ref #57 #58

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/stabilization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,17 +227,20 @@ function _stabilize_fnc(
print_name = "anonymous function"
end

args, destructurings = let
args_destructurings = map(sanitize_arg_for_stability_check, func[:args])
args, destructurings, typevars = let genwhereparam=true #(codegen_level == "min")
args_destructurings_typevars = map(
arg -> sanitize_arg_for_stability_check(arg; genwhereparam), func[:args])
(
map(first, args_destructurings),
filter(!isnothing, map(last, args_destructurings)),
map(first, args_destructurings_typevars),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error messages might need to be treated since they now show ::var"..." after every arg.

filter(!isnothing, map(adt -> adt[2], args_destructurings_typevars)),
filter(!isnothing, map(last, args_destructurings_typevars)),
)
end
kwargs = func[:kwargs]
where_params = func[:whereparams]

func[:args] = args
func[:whereparams] = (where_params..., typevars...)

arg_symbols = map(extract_symbol, args)
kwarg_symbols = map(extract_symbol, kwargs)
Expand Down
43 changes: 28 additions & 15 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,37 +54,50 @@ Amend args that do not have a symbol or are destructured in the signature. Retur
arg expression and, if needed, an equivalent destructuring assignment for the body.
"""
function sanitize_arg_for_stability_check(
ex::Symbol
)::Tuple{Union{Expr,Symbol},Union{Expr,Nothing}}
return ex, nothing
ex::Symbol; genwhereparam
)::Tuple{Union{Expr,Symbol},Union{Expr,Nothing},Union{Symbol,Nothing}}
genwhereparam || return ex, nothing, nothing
whereparam = gensym("T")
return Expr(:(::), ex, whereparam), nothing, whereparam
end
function sanitize_arg_for_stability_check(
Copy link
Owner

@MilesCranmer MilesCranmer Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that there are 3 returns, it might be nice to return a NamedTuple instead, for robustness. Returning tuples is always a bit risky because of things like

(x, y) = (1, 2, 3)

being valid Julia syntax.

e.g., could be:

return (; arg=Expr(:(::), ex, whereparam), destruct=nothing, whereparam=whereparam)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's a good idea!

ex::Expr
)::Tuple{Union{Expr,Symbol},Union{Expr,Nothing}}
ex::Expr; genwhereparam
)::Tuple{Union{Expr,Symbol},Union{Expr,Nothing},Union{Symbol,Nothing}}
head, args = ex.head, ex.args
if head == :(tuple)
# (Base case)
# matches things like (x,) and (; x)
arg = gensym("arg")
return arg, Expr(:(=), ex, arg)
arg_ex, _, whereparam = sanitize_arg_for_stability_check(arg; genwhereparam)
return arg_ex, Expr(:(=), ex, arg), whereparam
elseif head == :(::) && length(args) == 1
# (Base case)
# matches things like `::T`
arg = gensym("arg")
return Expr(head, arg, only(args)), nothing
return Expr(head, arg, only(args)), nothing, nothing
elseif head == :(...) && length(args) == 1
# (Composite case)
# matches things like `::Int...`
arg_ex, destructure_ex = sanitize_arg_for_stability_check(only(args))
return Expr(head, arg_ex), destructure_ex
elseif head in (:kw, :(::)) && length(args) == 2
arg_ex, destructure_ex = sanitize_arg_for_stability_check(
only(args); genwhereparam=false
)
return Expr(head, arg_ex), destructure_ex, nothing
elseif head == :(::) && length(args) == 2
# (Composite case)
# :(::) => matches things like `(x,)::T` and `(; x)::T`
# :kw => matches things like `::Type{T}=MyType`
arg_ex, destructure_ex = sanitize_arg_for_stability_check(first(args))
return Expr(head, arg_ex, last(args)), destructure_ex
# matches things like `(x,)::T` and `(; x)::T`
arg_ex, destructure_ex = sanitize_arg_for_stability_check(
first(args); genwhereparam=false
)
return Expr(head, arg_ex, last(args)), destructure_ex, nothing
elseif head == :kw && length(args) == 2
# (Composite case)
# matches things like `::Type{T}=MyType`
arg_ex, destructure_ex, whereparam = sanitize_arg_for_stability_check(
first(args); genwhereparam
)
return Expr(head, arg_ex, last(args)), destructure_ex, whereparam
Comment on lines +93 to +98
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this one need the genwhereparam too? It already has the ::T part, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's needed because this is also the branch that matches a regular x=default argument without a type parameter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it happens to be a ::Type{T}=MyType argument, the inner recursive call will land in the branch that sets genwhereparam=false as required

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks!

else
return ex, nothing
return ex, nothing, nothing
end
end

Expand Down
Loading