Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add influence classification for ConnectivityTracer #78

Merged
merged 4 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions src/conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ for TT in (GradientTracer, ConnectivityTracer, HessianTracer)

Base.big(::Type{T}) where {T<:TT} = T
Base.widen(::Type{T}) where {T<:TT} = T
Base.big(t::T) where {T<:TT} = t
Base.widen(t::T) where {T<:TT} = t

Base.convert(::Type{T}, x::Number) where {T<:TT} = empty(T)
Base.convert(::Type{T}, t::T) where {T<:TT} = t
Expand All @@ -26,17 +24,6 @@ for TT in (GradientTracer, ConnectivityTracer, HessianTracer)
Base.floatmax(::Type{T}) where {T<:TT} = empty(T)
Base.maxintfloat(::Type{T}) where {T<:TT} = empty(T)

Base.zero(::T) where {T<:TT} = empty(T)
Base.one(::T) where {T<:TT} = empty(T)
Base.oneunit(::T) where {T<:TT} = empty(T)
Base.typemin(::T) where {T<:TT} = empty(T)
Base.typemax(::T) where {T<:TT} = empty(T)
Base.eps(::T) where {T<:TT} = empty(T)
Base.float(::T) where {T<:TT} = empty(T)
Base.floatmin(::T) where {T<:TT} = empty(T)
Base.floatmax(::T) where {T<:TT} = empty(T)
Base.maxintfloat(::T) where {T<:TT} = empty(T)

## Array constructors
Base.similar(a::Array{T,1}) where {T<:TT} = zeros(T, size(a, 1))
Base.similar(a::Array{T,2}) where {T<:TT} = zeros(T, size(a, 1), size(a, 2))
Expand Down Expand Up @@ -66,8 +53,6 @@ end

Base.big(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{big(P),T}
Base.widen(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{widen(P),T}
Base.big(d::D) where {P,T,D<:Dual{P,T}} = Dual(big(primal(d)), tracer(d))
Base.widen(d::D) where {P,T,D<:Dual{P,T}} = Dual(widen(primal(d)), tracer(d))

Base.convert(::Type{D}, x::Number) where {P,T,D<:Dual{P,T}} = Dual(x, empty(T))
Base.convert(::Type{D}, d::D) where {P,T,D<:Dual{P,T}} = d
Expand All @@ -89,17 +74,6 @@ Base.floatmin(::Type{D}) where {P,T,D<:Dual{P,T}} = D(floatmin(P), empty(T
Base.floatmax(::Type{D}) where {P,T,D<:Dual{P,T}} = D(floatmax(P), empty(T))
Base.maxintfloat(::Type{D}) where {P,T,D<:Dual{P,T}} = D(maxintfloat(P), empty(T))

Base.zero(d::D) where {P,T,D<:Dual{P,T}} = D(zero(primal(d)), empty(T))
Base.one(d::D) where {P,T,D<:Dual{P,T}} = D(one(primal(d)), empty(T))
Base.oneunit(d::D) where {P,T,D<:Dual{P,T}} = D(oneunit(primal(d)), empty(T))
Base.typemin(d::D) where {P,T,D<:Dual{P,T}} = D(typemin(primal(d)), empty(T))
Base.typemax(d::D) where {P,T,D<:Dual{P,T}} = D(typemax(primal(d)), empty(T))
Base.eps(d::D) where {P,T,D<:Dual{P,T}} = D(eps(primal(d)), empty(T))
Base.float(d::D) where {P,T,D<:Dual{P,T}} = D(float(primal(d)), empty(T))
Base.floatmin(d::D) where {P,T,D<:Dual{P,T}} = D(floatmin(primal(d)), empty(T))
Base.floatmax(d::D) where {P,T,D<:Dual{P,T}} = D(floatmax(primal(d)), empty(T))
Base.maxintfloat(d::D) where {P,T,D<:Dual{P,T}} = D(maxintfloat(primal(d)), empty(T))

## Array constructors
function Base.similar(a::Array{D,1}) where {P,T,D<:Dual{P,T}}
p_out = similar(primal.(a))
Expand Down
91 changes: 86 additions & 5 deletions src/operators.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
## Operator definitions

# We use a system of letters to categorize operators:
# z: first- and second-order derivatives (FOD, SOD) are zero
# i: independence - no influence at all
# z: influence but first- and second-order derivatives (FOD, SOD) are zero
# f: FOD ∂f/∂x is non-zero, SOD ∂²f/∂x² is zero
# s: FOD ∂f/∂x is non-zero, SOD ∂²f/∂x² is non-zero
# c: Cross-derivative ∂²f/∂x∂y is non-zero
Expand All @@ -11,14 +12,17 @@
##=================================#
# Operators for functions f: ℝ → ℝ #
#==================================#
function is_influence_zero_global end
function is_firstder_zero_global end
function is_seconder_zero_global end

# Fallbacks for local derivatives:
is_influence_zero_local(f::F, x) where {F} = is_influence_zero_global(f)
is_firstder_zero_local(f::F, x) where {F} = is_firstder_zero_global(f)
is_seconder_zero_local(f::F, x) where {F} = is_seconder_zero_global(f)

# ops_1_to_1_s:
# x -> f != 0
# ∂f/∂x != 0
# ∂²f/∂x² != 0
ops_1_to_1_s = (
Expand Down Expand Up @@ -49,27 +53,34 @@ ops_1_to_1_s = (
)
for op in ops_1_to_1_s
T = typeof(op)
@eval is_influence_zero_global(::$T) = false
@eval is_firstder_zero_global(::$T) = false
@eval is_seconder_zero_global(::$T) = false
end

# ops_1_to_1_f:
# x -> f != 0
# ∂f/∂x != 0
# ∂²f/∂x² == 0
ops_1_to_1_f = (
+, -,
identity,
abs, hypot,
deg2rad, rad2deg,
mod2pi, prevfloat, nextfloat,
# angles
deg2rad, rad2deg, mod2pi,
# floats
float, prevfloat, nextfloat,
adrhill marked this conversation as resolved.
Show resolved Hide resolved
big, widen,
)
for op in ops_1_to_1_f
T = typeof(op)
@eval is_influence_zero_global(::$T) = false
@eval is_firstder_zero_global(::$T) = false
@eval is_seconder_zero_global(::$T) = true
end

# ops_1_to_1_z:
# x -> f != 0
# ∂f/∂x == 0
# ∂²f/∂x² == 0
ops_1_to_1_z = (
Expand All @@ -78,6 +89,23 @@ ops_1_to_1_z = (
)
for op in ops_1_to_1_z
T = typeof(op)
@eval is_influence_zero_global(::$T) = false
@eval is_firstder_zero_global(::$T) = true
@eval is_seconder_zero_global(::$T) = true
end

# ops_1_to_1_i:
# x -> f == 0
# ∂f/∂x == 0
# ∂²f/∂x² == 0
ops_1_to_1_i = (
zero, one, oneunit,
typemin, typemax, eps,
floatmin, floatmax, maxintfloat,
)
for op in ops_1_to_1_i
T = typeof(op)
@eval is_influence_zero_global(::$T) = true
@eval is_firstder_zero_global(::$T) = true
@eval is_seconder_zero_global(::$T) = true
end
Expand All @@ -86,26 +114,31 @@ ops_1_to_1 = union(
ops_1_to_1_s,
ops_1_to_1_f,
ops_1_to_1_z,
ops_1_to_1_i,
)

##==================================#
# Operators for functions f: ℝ² → ℝ #
#===================================#

function is_influence_arg1_zero_global end
function is_influence_arg2_zero_global end
function is_firstder_arg1_zero_global end
function is_seconder_arg1_zero_global end
function is_firstder_arg2_zero_global end
function is_seconder_arg2_zero_global end
function is_crossder_zero_global end

# Fallbacks for local derivatives:
is_influence_arg1_zero_local(f::F, x, y) where {F} = is_influence_arg1_zero_global(f)
is_influence_arg2_zero_local(f::F, x, y) where {F} = is_influence_arg2_zero_global(f)
is_firstder_arg1_zero_local(f::F, x, y) where {F} = is_firstder_arg1_zero_global(f)
is_seconder_arg1_zero_local(f::F, x, y) where {F} = is_seconder_arg1_zero_global(f)
is_firstder_arg2_zero_local(f::F, x, y) where {F} = is_firstder_arg2_zero_global(f)
is_seconder_arg2_zero_local(f::F, x, y) where {F} = is_seconder_arg2_zero_global(f)
is_crossder_zero_local(f::F, x, y) where {F} = is_crossder_zero_global(f)

# ops_2_to_1_ssc:
# ops_2_to_1_ssc:
# ∂f/∂x != 0
# ∂²f/∂x² != 0
# ∂f/∂y != 0
Expand All @@ -116,6 +149,8 @@ ops_2_to_1_ssc = (
)
for op in ops_2_to_1_ssc
T = typeof(op)
@eval is_influence_arg1_zero_global(::$T) = false
@eval is_influence_arg2_zero_global(::$T) = false
@eval is_firstder_arg1_zero_global(::$T) = false
@eval is_seconder_arg1_zero_global(::$T) = false
@eval is_firstder_arg2_zero_global(::$T) = false
Expand All @@ -133,6 +168,8 @@ ops_2_to_1_ssz = ()
#=
for op in ops_2_to_1_ssz
T = typeof(op)
@eval is_influence_arg1_zero_global(::$T) = false
@eval is_influence_arg2_zero_global(::$T) = false
@eval is_seconder_arg1_zero_global(::$T) = false
@eval is_firstder_arg1_zero_global(::$T) = false
@eval is_firstder_arg2_zero_global(::$T) = false
Expand All @@ -151,6 +188,8 @@ ops_2_to_1_sfc = ()
#=
for op in ops_2_to_1_sfc
T = typeof(op)
@eval is_influence_arg1_zero_global(::$T) = false
@eval is_influence_arg2_zero_global(::$T) = false
@eval is_firstder_arg1_zero_global(::$T) = false
@eval is_seconder_arg1_zero_global(::$T) = false
@eval is_firstder_arg2_zero_global(::$T) = false
Expand All @@ -169,6 +208,8 @@ ops_2_to_1_sfz = ()
#=
for op in ops_2_to_1_sfz
T = typeof(op)
@eval is_influence_arg1_zero_global(::$T) = false
@eval is_influence_arg2_zero_global(::$T) = false
@eval is_firstder_arg1_zero_global(::$T) = false
@eval is_seconder_arg1_zero_global(::$T) = false
@eval is_firstder_arg2_zero_global(::$T) = false
Expand All @@ -189,6 +230,8 @@ ops_2_to_1_fsc = (
)
for op in ops_2_to_1_fsc
T = typeof(op)
@eval is_influence_arg1_zero_global(::$T) = false
@eval is_influence_arg2_zero_global(::$T) = false
@eval is_firstder_arg1_zero_global(::$T) = false
@eval is_seconder_arg1_zero_global(::$T) = true
@eval is_firstder_arg2_zero_global(::$T) = false
Expand All @@ -209,6 +252,8 @@ ops_2_to_1_fsz = ()
#=
for op in ops_2_to_1_fsz
T = typeof(op)
@eval is_influence_arg1_zero_global(::$T) = false
@eval is_influence_arg2_zero_global(::$T) = false
@eval is_firstder_arg1_zero_global(::$T) = false
@eval is_seconder_arg1_zero_global(::$T) = true
@eval is_firstder_arg2_zero_global(::$T) = false
Expand All @@ -228,6 +273,8 @@ ops_2_to_1_ffc = (
)
for op in ops_2_to_1_ffc
T = typeof(op)
@eval is_influence_arg1_zero_global(::$T) = false
@eval is_influence_arg2_zero_global(::$T) = false
@eval is_firstder_arg1_zero_global(::$T) = false
@eval is_seconder_arg1_zero_global(::$T) = true
@eval is_firstder_arg2_zero_global(::$T) = false
Expand All @@ -252,6 +299,8 @@ ops_2_to_1_ffz = (
)
for op in ops_2_to_1_ffz
T = typeof(op)
@eval is_influence_arg1_zero_global(::$T) = false
@eval is_influence_arg2_zero_global(::$T) = false
@eval is_firstder_arg1_zero_global(::$T) = false
@eval is_seconder_arg1_zero_global(::$T) = true
@eval is_firstder_arg2_zero_global(::$T) = false
Expand All @@ -277,6 +326,8 @@ ops_2_to_1_szz = ()
#=
for op in ops_2_to_1_szz
T = typeof(op)
@eval is_influence_arg1_zero_global(::$T) = false
@eval is_influence_arg2_zero_global(::$T) = false
@eval is_firstder_arg1_zero_global(::$T) = false
@eval is_seconder_arg1_zero_global(::$T) = false
@eval is_firstder_arg2_zero_global(::$T) = true
Expand All @@ -295,6 +346,8 @@ ops_2_to_1_zsz = ()
#=
for op in ops_2_to_1_zsz
T = typeof(op)
@eval is_influence_arg1_zero_global(::$T) = false
@eval is_influence_arg2_zero_global(::$T) = false
@eval is_firstder_arg1_zero_global(::$T) = true
@eval is_seconder_arg1_zero_global(::$T) = true
@eval is_firstder_arg2_zero_global(::$T) = false
Expand All @@ -314,6 +367,8 @@ ops_2_to_1_fzz = (
)
for op in ops_2_to_1_fzz
T = typeof(op)
@eval is_influence_arg1_zero_global(::$T) = false
@eval is_influence_arg2_zero_global(::$T) = false
@eval is_firstder_arg1_zero_global(::$T) = false
@eval is_seconder_arg1_zero_global(::$T) = true
@eval is_firstder_arg2_zero_global(::$T) = true
Expand All @@ -331,6 +386,8 @@ ops_2_to_1_zfz = ()
#=
for op in ops_2_to_1_zfz
T = typeof(op)
@eval is_influence_arg1_zero_global(::$T) = false
@eval is_influence_arg2_zero_global(::$T) = false
@eval is_firstder_arg1_zero_global(::$T) = true
@eval is_seconder_arg1_zero_global(::$T) = true
@eval is_firstder_arg2_zero_global(::$T) = false
Expand All @@ -351,6 +408,8 @@ ops_2_to_1_zzz = (
)
for op in ops_2_to_1_zzz
T = typeof(op)
@eval is_influence_arg1_zero_global(::$T) = false
@eval is_influence_arg2_zero_global(::$T) = false
@eval is_firstder_arg1_zero_global(::$T) = true
@eval is_seconder_arg1_zero_global(::$T) = true
@eval is_firstder_arg2_zero_global(::$T) = true
Expand Down Expand Up @@ -385,14 +444,18 @@ ops_2_to_1 = union(
# Operators for functions f: ℝ → ℝ² #
#===================================#

function is_influence_out1_zero_global end
function is_influence_out2_zero_global end
function is_firstder_out1_zero_global end
function is_seconder_out1_zero_global end
function is_firstder_out2_zero_global end
function is_seconder_out2_zero_global end

# Fallbacks for local derivatives:
is_seconder_out1_zero_local(f::F, x) where {F} = is_seconder_out1_zero_global(f)
is_influence_out1_zero_local(f::F, x) where {F} = is_influence_out1_zero_global(f)
is_influence_out2_zero_local(f::F, x) where {F} = is_influence_out2_zero_global(f)
is_firstder_out1_zero_local(f::F, x) where {F} = is_firstder_out1_zero_global(f)
is_seconder_out1_zero_local(f::F, x) where {F} = is_seconder_out1_zero_global(f)
gdalle marked this conversation as resolved.
Show resolved Hide resolved
is_firstder_out2_zero_local(f::F, x) where {F} = is_firstder_out2_zero_global(f)
is_seconder_out2_zero_local(f::F, x) where {F} = is_seconder_out2_zero_global(f)

Expand All @@ -408,6 +471,8 @@ ops_1_to_2_ss = (
)
for op in ops_1_to_2_ss
T = typeof(op)
@eval is_influence_out1_zero_global(::$T) = false
@eval is_influence_out2_zero_global(::$T) = false
@eval is_firstder_out1_zero_global(::$T) = false
@eval is_seconder_out1_zero_global(::$T) = false
@eval is_firstder_out2_zero_global(::$T) = false
Expand All @@ -423,6 +488,8 @@ ops_1_to_2_sf = ()
#=
for op in ops_1_to_2_sf
T = typeof(op)
@eval is_influence_out1_zero_global(::$T) = false
@eval is_influence_out2_zero_global(::$T) = false
@eval is_firstder_out1_zero_global(::$T) = false
@eval is_seconder_out1_zero_global(::$T) = false
@eval is_firstder_out2_zero_global(::$T) = false
Expand All @@ -439,6 +506,8 @@ ops_1_to_2_sz = ()
#=
for op in ops_1_to_2_sz
T = typeof(op)
@eval is_influence_out1_zero_global(::$T) = false
@eval is_influence_out2_zero_global(::$T) = false
@eval is_firstder_out1_zero_global(::$T) = false
@eval is_seconder_out1_zero_global(::$T) = false
@eval is_firstder_out2_zero_global(::$T) = true
Expand All @@ -455,6 +524,8 @@ ops_1_to_2_fs = ()
#=
for op in ops_1_to_2_fs
T = typeof(op)
@eval is_influence_out1_zero_global(::$T) = false
@eval is_influence_out2_zero_global(::$T) = false
@eval is_firstder_out1_zero_global(::$T) = false
@eval is_seconder_out1_zero_global(::$T) = true
@eval is_firstder_out2_zero_global(::$T) = false
Expand All @@ -471,6 +542,8 @@ ops_1_to_2_ff = ()
#=
for op in ops_1_to_2_ff
T = typeof(op)
@eval is_influence_out1_zero_global(::$T) = false
@eval is_influence_out2_zero_global(::$T) = false
@eval is_firstder_out1_zero_global(::$T) = false
@eval is_seconder_out1_zero_global(::$T) = true
@eval is_firstder_out2_zero_global(::$T) = false
Expand All @@ -489,6 +562,8 @@ ops_1_to_2_fz = (
#=
for op in ops_1_to_2_fz
T = typeof(op)
@eval is_influence_out1_zero_global(::$T) = false
@eval is_influence_out2_zero_global(::$T) = false
@eval is_firstder_out1_zero_global(::$T) = false
@eval is_seconder_out1_zero_global(::$T) = true
@eval is_firstder_out2_zero_global(::$T) = true
Expand All @@ -505,6 +580,8 @@ ops_1_to_2_zs = ()
#=
for op in ops_1_to_2_zs
T = typeof(op)
@eval is_influence_out1_zero_global(::$T) = false
@eval is_influence_out2_zero_global(::$T) = false
@eval is_firstder_out1_zero_global(::$T) = true
@eval is_seconder_out1_zero_global(::$T) = true
@eval is_firstder_out2_zero_global(::$T) = false
Expand All @@ -521,6 +598,8 @@ ops_1_to_2_zf = ()
#=
for op in ops_1_to_2_zf
T = typeof(op)
@eval is_influence_out1_zero_global(::$T) = false
@eval is_influence_out2_zero_global(::$T) = false
@eval is_firstder_out1_zero_global(::$T) = true
@eval is_seconder_out1_zero_global(::$T) = true
@eval is_firstder_out2_zero_global(::$T) = false
Expand All @@ -537,6 +616,8 @@ ops_1_to_2_zz = ()
#=
for op in ops_1_to_2_zz
T = typeof(op)
@eval is_influence_out1_zero_global(::$T) = false
@eval is_influence_out2_zero_global(::$T) = false
@eval is_firstder_out1_zero_global(::$T) = true
@eval is_seconder_out1_zero_global(::$T) = true
@eval is_firstder_out2_zero_global(::$T) = true
Expand Down
Loading
Loading