Skip to content

Commit

Permalink
Move @noinline into generated code
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Oct 9, 2024
1 parent 622b736 commit ccef7f6
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 28 deletions.
30 changes: 15 additions & 15 deletions src/overloads/gradient_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ SCT = SparseConnectivityTracer

## 1-to-1

@noinline function gradient_tracer_1_to_1(
t::T, is_der1_zero::Bool
) where {T<:GradientTracer}
function gradient_tracer_1_to_1(t::T, is_der1_zero::Bool) where {T<:GradientTracer}
if is_der1_zero && !isemptytracer(t)
return myempty(T)
else
Expand Down Expand Up @@ -36,7 +34,7 @@ function generate_code_gradient_1_to_1(M::Symbol, f::Function)

expr_gradienttracer = quote
function $M.$fname(t::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_1(t, $is_der1_zero_g)
return @noinline $SCT.gradient_tracer_1_to_1(t, $is_der1_zero_g)
end
end

Expand All @@ -55,7 +53,7 @@ function generate_code_gradient_1_to_1(M::Symbol, f::Function)

t = $SCT.tracer(d)
is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x)
t_out = $SCT.gradient_tracer_1_to_1(t, is_der1_zero)
t_out = @noinline $SCT.gradient_tracer_1_to_1(t, is_der1_zero)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -65,7 +63,7 @@ end

## 2-to-1

@noinline function gradient_tracer_2_to_1(
function gradient_tracer_2_to_1(
tx::T, ty::T, is_der1_arg1_zero::Bool, is_der1_arg2_zero::Bool
) where {T<:GradientTracer}
# TODO: add tests for isempty
Expand Down Expand Up @@ -116,7 +114,7 @@ function generate_code_gradient_2_to_1(M::Symbol, f::Function)

expr_tracer_tracer = quote
function $M.$fname(tx::T, ty::T) where {T<:$SCT.GradientTracer}
return $SCT.gradient_tracer_2_to_1(
return @noinline $SCT.gradient_tracer_2_to_1(
tx, ty, $is_der1_arg1_zero_g, $is_der1_arg2_zero_g
)
end
Expand All @@ -141,7 +139,7 @@ function generate_code_gradient_2_to_1(M::Symbol, f::Function)
ty = $SCT.tracer(dy)
is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y)
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
t_out = $SCT.gradient_tracer_2_to_1(
t_out = @noinline $SCT.gradient_tracer_2_to_1(
tx, ty, is_der1_arg1_zero, is_der1_arg2_zero
)
return $SCT.Dual(p_out, t_out)
Expand All @@ -164,12 +162,12 @@ function generate_code_gradient_2_to_1_typed(

expr_tracer_type = quote
function $M.$fname(tx::$SCT.GradientTracer, ::$Z)
return $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g)
return @noinline $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g)
end
end
expr_type_tracer = quote
function $M.$fname(::$Z, ty::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g)
return @noinline $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g)
end
end

Expand All @@ -188,7 +186,7 @@ function generate_code_gradient_2_to_1_typed(

tx = $SCT.tracer(dx)
is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y)
t_out = $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero)
t_out = @noinline $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -208,7 +206,7 @@ function generate_code_gradient_2_to_1_typed(

ty = $SCT.tracer(dy)
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
t_out = $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero)
t_out = @noinline $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -218,7 +216,7 @@ end

## 1-to-2

@noinline function gradient_tracer_1_to_2(
function gradient_tracer_1_to_2(
t::T, is_der1_out1_zero::Bool, is_der1_out2_zero::Bool
) where {T<:GradientTracer}
if isemptytracer(t) # TODO: add test
Expand All @@ -237,7 +235,9 @@ function generate_code_gradient_1_to_2(M::Symbol, f::Function)

expr_gradienttracer = quote
function $M.$fname(t::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_2(t, $is_der1_out1_zero_g, $is_der1_out2_zero_g)
return @noinline $SCT.gradient_tracer_1_to_2(
t, $is_der1_out1_zero_g, $is_der1_out2_zero_g
)
end
end

Expand All @@ -257,7 +257,7 @@ function generate_code_gradient_1_to_2(M::Symbol, f::Function)
t = $SCT.tracer(d)
is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x)
is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$fname, x)
t_out1, t_out2 = $SCT.gradient_tracer_1_to_2(
t_out1, t_out2 = @noinline $SCT.gradient_tracer_1_to_2(
t, is_der1_out1_zero, is_der1_out2_zero
)
return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) # TODO: this was wrong, add test
Expand Down
34 changes: 21 additions & 13 deletions src/overloads/hessian_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ SCT = SparseConnectivityTracer
# 𝟙[∇γ] = 𝟙[∂φ]⋅𝟙[∇α]
# 𝟙[∇²γ] = 𝟙[∂φ]⋅𝟙[∇²α] ∨ 𝟙[∂²φ]⋅(𝟙[∇α] ∨ 𝟙[∇α]ᵀ)

@noinline function hessian_tracer_1_to_1(
function hessian_tracer_1_to_1(
t::T, is_der1_zero::Bool, is_der2_zero::Bool
) where {P<:AbstractHessianPattern,T<:HessianTracer{P}}
if isemptytracer(t) # TODO: add test
Expand Down Expand Up @@ -65,7 +65,7 @@ function generate_code_hessian_1_to_1(M::Symbol, f::Function)
expr_hessiantracer = quote
## HessianTracer
function $M.$fname(t::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_1(t, $is_der1_zero_g, $is_der2_zero_g)
return @noinline $SCT.hessian_tracer_1_to_1(t, $is_der1_zero_g, $is_der2_zero_g)
end
end

Expand All @@ -85,7 +85,7 @@ function generate_code_hessian_1_to_1(M::Symbol, f::Function)
t = $SCT.tracer(d)
is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x)
is_der2_zero = $SCT.is_der2_zero_local($M.$fname, x)
t_out = $SCT.hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero)
t_out = @noinline $SCT.hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -96,7 +96,7 @@ end

## 2-to-1

@noinline function hessian_tracer_2_to_1(
function hessian_tracer_2_to_1(
tx::T,
ty::T,
is_der1_arg1_zero::Bool,
Expand Down Expand Up @@ -189,7 +189,7 @@ function generate_code_hessian_2_to_1(

expr_tracer_tracer = quote
function $M.$fname(tx::T, ty::T) where {T<:$SCT.HessianTracer}
return $SCT.hessian_tracer_2_to_1(
return @noinline $SCT.hessian_tracer_2_to_1(
tx,
ty,
$is_der1_arg1_zero_g,
Expand Down Expand Up @@ -232,7 +232,7 @@ function generate_code_hessian_2_to_1(
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y)
is_der_cross_zero = $SCT.is_der_cross_zero_local($M.$fname, x, y)
t_out = $SCT.hessian_tracer_2_to_1(
t_out = @noinline $SCT.hessian_tracer_2_to_1(
tx,
ty,
is_der1_arg1_zero,
Expand Down Expand Up @@ -263,12 +263,16 @@ function generate_code_hessian_2_to_1_typed(

expr_tracer_type = quote
function $M.$fname(tx::$SCT.HessianTracer, y::$Z)
return $SCT.hessian_tracer_1_to_1(tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g)
return @noinline $SCT.hessian_tracer_1_to_1(
tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g
)
end
end
expr_type_tracer = quote
function $M.$fname(x::$Z, ty::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_1(ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g)
return @noinline $SCT.hessian_tracer_1_to_1(
ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g
)
end
end

Expand All @@ -288,7 +292,9 @@ function generate_code_hessian_2_to_1_typed(
tx = $SCT.tracer(dx)
is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y)
is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$fname, x, y)
t_out = $SCT.hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_der2_arg1_zero)
t_out = @noinline $SCT.hessian_tracer_1_to_1(
tx, is_der1_arg1_zero, is_der2_arg1_zero
)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -309,7 +315,9 @@ function generate_code_hessian_2_to_1_typed(
ty = $SCT.tracer(dy)
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y)
t_out = $SCT.hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_der2_arg2_zero)
t_out = @noinline $SCT.hessian_tracer_1_to_1(
ty, is_der1_arg2_zero, is_der2_arg2_zero
)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -319,7 +327,7 @@ end

## 1-to-2

@noinline function hessian_tracer_1_to_2(
function hessian_tracer_1_to_2(
t::T,
is_der1_out1_zero::Bool,
is_der2_out1_zero::Bool,
Expand All @@ -344,7 +352,7 @@ function generate_code_hessian_1_to_2(M::Symbol, f::Function)

expr_hessiantracer = quote
function $M.$fname(t::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_2(
return @noinline $SCT.hessian_tracer_1_to_2(
t,
$is_der1_out1_zero_g,
$is_der2_out1_zero_g,
Expand Down Expand Up @@ -375,7 +383,7 @@ function generate_code_hessian_1_to_2(M::Symbol, f::Function)
is_der2_out1_zero = $SCT.is_der2_out1_zero_local($M.$fname, x)
is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x)
is_der2_out2_zero = $SCT.is_der2_out2_zero_local($M.$fname, x)
t_out1, t_out2 = $SCT.hessian_tracer_1_to_2(
t_out1, t_out2 = @noinline $SCT.hessian_tracer_1_to_2(
d,
is_der1_out1_zero,
is_der2_out1_zero,
Expand Down

0 comments on commit ccef7f6

Please sign in to comment.