From 957d99774ed10f0a51bb250207c14fd15f32cb44 Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 15 May 2024 18:52:23 +0200 Subject: [PATCH 01/47] Add operators --- src/operators.jl | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/operators.jl b/src/operators.jl index 6463b1fe..14fabe24 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -14,6 +14,9 @@ function is_firstder_zero_global end function is_seconder_zero_global end +is_firstder_zero_local(f::Type, x, y) = is_firstder_zero_global(f) +is_seconder_zero_local(f::Type, x, y) = is_seconder_zero_global(f) + # ops_1_to_1_s: # ∂f/∂x != 0 # ∂²f/∂x² != 0 @@ -111,6 +114,12 @@ function is_firstder_arg2_zero_global end function is_seconder_arg2_zero_global end function is_crossder_zero_global end +is_firstder_arg1_zero_local(f::Type, x, y) = is_firstder_arg1_zero_global(f) +is_seconder_arg1_zero_local(f::Type, x, y) = is_firstder_arg1_zero_global(f) +is_firstder_arg2_zero_local(f::Type, x, y) = is_firstder_arg1_zero_global(f) +is_seconder_arg2_zero_local(f::Type, x, y) = is_firstder_arg1_zero_global(f) +is_crossder_zero_local(f::Type, x, y) = is_firstder_arg1_zero_global(f) + # ops_2_to_1_ssc: # ∂f/∂x != 0 # ∂²f/∂x² != 0 @@ -196,6 +205,9 @@ for op in ops_2_to_1_fsc SparseConnectivityTracer.is_crossder_zero_global(::T) = false end +is_firstder_arg1_zero_local(::Type{typeof(:/)}, x, y) = iszero(x) +is_crossder_zero_local(::Type{typeof(:/)}, x, y) = iszero(x) + # ops_2_to_1_fsz: # ∂f/∂x != 0 # ∂²f/∂x² == 0 @@ -230,6 +242,12 @@ for op in ops_2_to_1_ffc SparseConnectivityTracer.is_crossder_zero_global(::T) = false end +is_firstder_arg1_zero_local(::Type{typeof(:*)}, x, y) = iszero(x) +is_seconder_arg1_zero_local(::Type{typeof(:*)}, x, y) = iszero(x) +is_firstder_arg2_zero_local(::Type{typeof(:*)}, x, y) = iszero(y) +is_seconder_arg2_zero_local(::Type{typeof(:*)}, x, y) = iszero(y) +is_crossder_zero_local(::Type{typeof(:*)}, x, y) = iszero(x) || iszero(y) + # ops_2_to_1_ffz: # ∂f/∂x != 0 # ∂²f/∂x² == 0 @@ -239,6 +257,7 @@ end ops_2_to_1_ffz = ( :+, :-, :mod, :rem, + :min, :max, ) for op in ops_2_to_1_ffz T = typeof(eval(op)) @@ -249,6 +268,16 @@ for op in ops_2_to_1_ffz SparseConnectivityTracer.is_crossder_zero_global(::T) = true end + +is_firstder_arg2_zero_local(::Type{typeof(mod)}, x, y) = ifelse(y > 0, y>x, x>y) +is_firstder_arg2_zero_local(::Type{typeof(rem)}, x, y) = ifelse(y > 0, y>x, x>y) + +is_firstder_arg1_zero_local(::Type{typeof(max)}, x, y) = x < y +is_firstder_arg2_zero_local(::Type{typeof(max)}, x, y) = y < x + +is_firstder_arg1_zero_local(::Type{typeof(min)}, x, y) = x > y +is_firstder_arg2_zero_local(::Type{typeof(min)}, x, y) = y > x + # ops_2_to_1_szz: # ∂f/∂x != 0 # ∂²f/∂x² != 0 @@ -369,6 +398,11 @@ function is_seconder_out1_zero_global end function is_firstder_out2_zero_global end function is_seconder_out2_zero_global end +is_firstder_out1_zero_local(f::Type, x, y) = is_firstder_out1_zero_global(f) +is_seconder_out1_zero_local(f::Type, x, y) = is_seconder_out1_zero_global(f) +is_firstder_out2_zero_local(f::Type, x, y) = is_firstder_out2_zero_global(f) +is_seconder_out2_zero_local(f::Type, x, y) = is_seconder_out2_zero_global(f) + # ops_1_to_2_ss: # ∂f₁/∂x != 0 # ∂²f₁/∂x² != 0 From 194adab775266f2b2e5f940f3214eca02e5270bb Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 15 May 2024 18:54:15 +0200 Subject: [PATCH 02/47] Rename tracers --- docs/src/api.md | 4 +-- src/conversion.jl | 12 ++++----- src/overload_gradient.jl | 48 +++++++++++++++++------------------ src/overload_hessian.jl | 20 +++++++-------- src/pattern.jl | 12 ++++----- src/tracers.jl | 55 ++++++++++++++++++++-------------------- test/first_order.jl | 6 ++--- test/second_order.jl | 4 +-- 8 files changed, 79 insertions(+), 82 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index f94f1c20..54e89b65 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -24,8 +24,8 @@ Currently, three tracer types are provided: ```@docs SparseConnectivityTracer.ConnectivityTracer -SparseConnectivityTracer.GlobalGradientTracer -SparseConnectivityTracer.GlobalHessianTracer +SparseConnectivityTracer.GradientTracer +SparseConnectivityTracer.HessianTracer ``` We also define alternative pseudo-set types that can deliver faster `union`: diff --git a/src/conversion.jl b/src/conversion.jl index 1190b451..1e9e419b 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -1,5 +1,5 @@ ## Type conversions -for TT in (:GlobalGradientTracer, :ConnectivityTracer, :GlobalHessianTracer) +for TT in (:GradientTracer, :ConnectivityTracer, :HessianTracer) @eval Base.promote_rule(::Type{T}, ::Type{N}) where {T<:$TT,N<:Number} = T @eval Base.promote_rule(::Type{N}, ::Type{T}) where {T<:$TT,N<:Number} = T @@ -29,11 +29,9 @@ end function Base.similar(::Array, ::Type{ConnectivityTracer{C}}, dims::Dims{N}) where {C,N} return zeros(ConnectivityTracer{C}, dims) end -function Base.similar(::Array, ::Type{GlobalGradientTracer{G}}, dims::Dims{N}) where {G,N} - return zeros(GlobalGradientTracer{G}, dims) +function Base.similar(::Array, ::Type{GradientTracer{G}}, dims::Dims{N}) where {G,N} + return zeros(GradientTracer{G}, dims) end -function Base.similar( - ::Array, ::Type{GlobalHessianTracer{G,H}}, dims::Dims{N} -) where {G,H,N} - return zeros(GlobalHessianTracer{G,H}, dims) +function Base.similar(::Array, ::Type{HessianTracer{G,H}}, dims::Dims{N}) where {G,H,N} + return zeros(HessianTracer{G,H}, dims) end diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index 909a8029..7bd8e9ae 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -1,9 +1,9 @@ for fn in union(ops_1_to_1_s, ops_1_to_1_f) - @eval Base.$fn(t::GlobalGradientTracer) = t + @eval Base.$fn(t::GradientTracer) = t end for fn in union(ops_1_to_1_z, ops_1_to_1_const) - @eval Base.$fn(::T) where {T<:GlobalGradientTracer} = empty(T) + @eval Base.$fn(::T) where {T<:GradientTracer} = empty(T) end for fn in union( @@ -16,52 +16,52 @@ for fn in union( ops_2_to_1_ffc, ops_2_to_1_ffz, ) - @eval Base.$fn(a::T, b::T) where {T<:GlobalGradientTracer} = T(a.grad ∪ b.grad) - @eval Base.$fn(t::GlobalGradientTracer, ::Number) = t - @eval Base.$fn(::Number, t::GlobalGradientTracer) = t + @eval Base.$fn(a::T, b::T) where {T<:GradientTracer} = T(a.grad ∪ b.grad) + @eval Base.$fn(t::GradientTracer, ::Number) = t + @eval Base.$fn(::Number, t::GradientTracer) = t end for fn in union(ops_2_to_1_zsz, ops_2_to_1_zfz) - @eval Base.$fn(::GlobalGradientTracer, t::GlobalGradientTracer) = t - @eval Base.$fn(::T, ::Number) where {T<:GlobalGradientTracer} = empty(T) - @eval Base.$fn(::Number, t::GlobalGradientTracer) = t + @eval Base.$fn(::GradientTracer, t::GradientTracer) = t + @eval Base.$fn(::T, ::Number) where {T<:GradientTracer} = empty(T) + @eval Base.$fn(::Number, t::GradientTracer) = t end for fn in union(ops_2_to_1_szz, ops_2_to_1_fzz) - @eval Base.$fn(t::GlobalGradientTracer, ::GlobalGradientTracer) = t - @eval Base.$fn(t::GlobalGradientTracer, ::Number) = t - @eval Base.$fn(::Number, ::T) where {T<:GlobalGradientTracer} = empty(T) + @eval Base.$fn(t::GradientTracer, ::GradientTracer) = t + @eval Base.$fn(t::GradientTracer, ::Number) = t + @eval Base.$fn(::Number, ::T) where {T<:GradientTracer} = empty(T) end for fn in ops_2_to_1_zzz - @eval Base.$fn(::T, ::T) where {T<:GlobalGradientTracer} = empty(T) - @eval Base.$fn(::T, ::Number) where {T<:GlobalGradientTracer} = empty(T) - @eval Base.$fn(::Number, ::T) where {T<:GlobalGradientTracer} = empty(T) + @eval Base.$fn(::T, ::T) where {T<:GradientTracer} = empty(T) + @eval Base.$fn(::T, ::Number) where {T<:GradientTracer} = empty(T) + @eval Base.$fn(::Number, ::T) where {T<:GradientTracer} = empty(T) end for fn in union(ops_1_to_2_ss, ops_1_to_2_sf, ops_1_to_2_fs, ops_1_to_2_ff) - @eval Base.$fn(t::GlobalGradientTracer) = (t, t) + @eval Base.$fn(t::GradientTracer) = (t, t) end for fn in union(ops_1_to_2_sz, ops_1_to_2_fz) - @eval Base.$fn(t::T) where {T<:GlobalGradientTracer} = (t, empty(T)) + @eval Base.$fn(t::T) where {T<:GradientTracer} = (t, empty(T)) end for fn in union(ops_1_to_2_zs, ops_1_to_2_zf) - @eval Base.$fn(t::T) where {T<:GlobalGradientTracer} = (empty(T), t) + @eval Base.$fn(t::T) where {T<:GradientTracer} = (empty(T), t) end for fn in ops_1_to_2_zz - @eval Base.$fn(::T) where {T<:GlobalGradientTracer} = (empty(T), empty(T)) + @eval Base.$fn(::T) where {T<:GradientTracer} = (empty(T), empty(T)) end # Extra types required for exponent for T in (:Real, :Integer, :Rational) - @eval Base.:^(t::GlobalGradientTracer, ::$T) = t - @eval Base.:^(::$T, t::GlobalGradientTracer) = t + @eval Base.:^(t::GradientTracer, ::$T) = t + @eval Base.:^(::$T, t::GradientTracer) = t end -Base.:^(t::GlobalGradientTracer, ::Irrational{:ℯ}) = t -Base.:^(::Irrational{:ℯ}, t::GlobalGradientTracer) = t +Base.:^(t::GradientTracer, ::Irrational{:ℯ}) = t +Base.:^(::Irrational{:ℯ}, t::GradientTracer) = t ## Rounding -Base.round(t::T, ::RoundingMode; kwargs...) where {T<:GlobalGradientTracer} = empty(T) +Base.round(t::T, ::RoundingMode; kwargs...) where {T<:GradientTracer} = empty(T) ## Random numbers -rand(::AbstractRNG, ::SamplerType{T}) where {T<:GlobalGradientTracer} = empty(T) +rand(::AbstractRNG, ::SamplerType{T}) where {T<:GradientTracer} = empty(T) diff --git a/src/overload_hessian.jl b/src/overload_hessian.jl index dd0f8012..b9868d65 100644 --- a/src/overload_hessian.jl +++ b/src/overload_hessian.jl @@ -1,6 +1,6 @@ ## 1-to-1 for fn in ops_1_to_1 - @eval function Base.$fn(t::T) where {T<:GlobalHessianTracer} + @eval function Base.$fn(t::T) where {T<:HessianTracer} if is_seconder_zero_global($fn) if is_firstder_zero_global($fn) return empty(T) @@ -15,7 +15,7 @@ end ## 2-to-1 for fn in ops_2_to_1 - @eval function Base.$fn(a::T, b::T) where {G,H,T<:GlobalHessianTracer{G,H}} + @eval function Base.$fn(a::T, b::T) where {G,H,T<:HessianTracer{G,H}} grad = empty(G) hess = empty(H) if !is_firstder_arg1_zero_global($fn) @@ -38,7 +38,7 @@ for fn in ops_2_to_1 return T(grad, hess) end - @eval function Base.$fn(t::T, ::Number) where {G,H,T<:GlobalHessianTracer{G,H}} + @eval function Base.$fn(t::T, ::Number) where {G,H,T<:HessianTracer{G,H}} if is_seconder_arg1_zero_global($fn) if is_firstder_arg1_zero_global($fn) return empty(T) @@ -49,7 +49,7 @@ for fn in ops_2_to_1 return T(t.grad, t.hess ∪ (t.grad × t.grad)) end end - @eval function Base.$fn(::Number, t::T) where {G,H,T<:GlobalHessianTracer{G,H}} + @eval function Base.$fn(::Number, t::T) where {G,H,T<:HessianTracer{G,H}} if is_seconder_arg2_zero_global($fn) if is_firstder_arg2_zero_global($fn) return empty(T) @@ -64,22 +64,22 @@ end # Extra types required for exponent for T in (:Real, :Integer, :Rational) - @eval function Base.:^(t::T, ::$T) where {T<:GlobalHessianTracer} + @eval function Base.:^(t::T, ::$T) where {T<:HessianTracer} return T(t.grad, t.hess ∪ (t.grad × t.grad)) end - @eval function Base.:^(::$T, t::T) where {T<:GlobalHessianTracer} + @eval function Base.:^(::$T, t::T) where {T<:HessianTracer} return T(t.grad, t.hess ∪ (t.grad × t.grad)) end end -function Base.:^(t::T, ::Irrational{:ℯ}) where {T<:GlobalHessianTracer} +function Base.:^(t::T, ::Irrational{:ℯ}) where {T<:HessianTracer} return T(t.grad, t.hess ∪ (t.grad × t.grad)) end -function Base.:^(::Irrational{:ℯ}, t::T) where {T<:GlobalHessianTracer} +function Base.:^(::Irrational{:ℯ}, t::T) where {T<:HessianTracer} return T(t.grad, t.hess ∪ (t.grad × t.grad)) end ## Rounding -Base.round(t::T, ::RoundingMode; kwargs...) where {T<:GlobalHessianTracer} = empty(T) +Base.round(t::T, ::RoundingMode; kwargs...) where {T<:HessianTracer} = empty(T) ## Random numbers -rand(::AbstractRNG, ::SamplerType{T}) where {T<:GlobalHessianTracer} = empty(T) +rand(::AbstractRNG, ::SamplerType{T}) where {T<:HessianTracer} = empty(T) diff --git a/src/pattern.jl b/src/pattern.jl index e80fe0d7..b309a741 100644 --- a/src/pattern.jl +++ b/src/pattern.jl @@ -8,7 +8,7 @@ const DEFAULT_MATRIX_TYPE = Set{Tuple{Int,Int}} Enumerates input indices and constructs the specified type `T` of tracer. -Supports [`ConnectivityTracer`](@ref), [`GlobalGradientTracer`](@ref) and [`GlobalHessianTracer`](@ref). +Supports [`ConnectivityTracer`](@ref), [`GradientTracer`](@ref) and [`HessianTracer`](@ref). """ trace_input(::Type{T}, x) where {T<:AbstractTracer} = trace_input(T, x, 1) trace_input(::Type{T}, ::Number, i) where {T<:AbstractTracer} = tracer(T, i) @@ -119,7 +119,7 @@ julia> jacobian_pattern(f, x) ``` """ function jacobian_pattern(f, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {G} - xt, yt = trace_function(GlobalGradientTracer{G}, f, x) + xt, yt = trace_function(GradientTracer{G}, f, x) return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) end @@ -132,13 +132,13 @@ Compute the sparsity pattern of the Jacobian of `f!(y, x)`. The type of index set `S` can be specified as an optional argument and defaults to `BitSet`. """ function jacobian_pattern(f!, y, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {G} - xt, yt = trace_function(GlobalGradientTracer{G}, f!, y, x) + xt, yt = trace_function(GradientTracer{G}, f!, y, x) return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) end function jacobian_pattern_to_mat( xt::AbstractArray{T}, yt::AbstractArray{<:Number} -) where {T<:GlobalGradientTracer} +) where {T<:GradientTracer} n, m = length(xt), length(yt) I = Int[] # row indices J = Int[] # column indices @@ -192,13 +192,13 @@ julia> hessian_pattern(g, x) function hessian_pattern( f, x, ::Type{G}=DEFAULT_VECTOR_TYPE, ::Type{H}=DEFAULT_MATRIX_TYPE ) where {G,H} - xt, yt = trace_function(GlobalHessianTracer{G,H}, f, x) + xt, yt = trace_function(HessianTracer{G,H}, f, x) return hessian_pattern_to_mat(to_array(xt), yt) end function hessian_pattern_to_mat( xt::AbstractArray{T}, yt::T -) where {G,H<:AbstractSet,T<:GlobalHessianTracer{G,H}} +) where {G,H<:AbstractSet,T<:HessianTracer{G,H}} # Allocate Hessian matrix n = length(xt) I = Int[] # row indices diff --git a/src/tracers.jl b/src/tracers.jl index fb2829a2..ab88f9d5 100644 --- a/src/tracers.jl +++ b/src/tracers.jl @@ -84,31 +84,31 @@ Set{Int64} with 2 elements: 3 1 -julia> SparseConnectivityTracer.GlobalGradientTracer(grad) -SparseConnectivityTracer.GlobalGradientTracer{Set{Int64}}(1, 3) +julia> SparseConnectivityTracer.GradientTracer(grad) +SparseConnectivityTracer.GradientTracer{Set{Int64}}(1, 3) ``` """ -struct GlobalGradientTracer{G<:AbstractSet{<:Integer}} <: AbstractTracer +struct GradientTracer{G<:AbstractSet{<:Integer}} <: AbstractTracer "Sparse binary vector representing non-zero entries in the gradient." grad::G end -function Base.show(io::IO, t::GlobalGradientTracer) +function Base.show(io::IO, t::GradientTracer) return Base.show_delim_array( io, convert.(Int, sort(collect(t.grad))), "$(typeof(t))(", ',', ')', true ) end -function empty(::Type{GlobalGradientTracer{G}}) where {G} - return GlobalGradientTracer{G}(empty(G)) +function empty(::Type{GradientTracer{G}}) where {G} + return GradientTracer{G}(empty(G)) end -function GlobalGradientTracer{G}(::Number) where {G<:AbstractSet{<:Integer}} - return empty(GlobalGradientTracer{G}) +function GradientTracer{G}(::Number) where {G<:AbstractSet{<:Integer}} + return empty(GradientTracer{G}) end -GlobalGradientTracer{G}(t::GlobalGradientTracer{G}) where {G<:AbstractSet{<:Integer}} = t -GlobalGradientTracer(t::GlobalGradientTracer) = t +GradientTracer{G}(t::GradientTracer{G}) where {G<:AbstractSet{<:Integer}} = t +GradientTracer(t::GradientTracer) = t #=========# # Hessian # @@ -137,23 +137,22 @@ Set{Tuple{Int64, Int64}} with 3 elements: (1, 1) (2, 3) -julia> SparseConnectivityTracer.GlobalHessianTracer(grad, hess) -SparseConnectivityTracer.GlobalHessianTracer{Set{Int64}, Set{Tuple{Int64, Int64}}}( +julia> SparseConnectivityTracer.HessianTracer(grad, hess) +SparseConnectivityTracer.HessianTracer{Set{Int64}, Set{Tuple{Int64, Int64}}}( Gradient: Set([3, 1]), Hessian: Set([(3, 2), (1, 1), (2, 3)]) ) ``` """ -struct GlobalHessianTracer{ - G<:AbstractSet{<:Integer},H<:AbstractSet{<:Tuple{Integer,Integer}} -} <: AbstractTracer +struct HessianTracer{G<:AbstractSet{<:Integer},H<:AbstractSet{<:Tuple{Integer,Integer}}} <: + AbstractTracer "Sparse binary vector representation of non-zero entries in the gradient." grad::G "Sparse binary matrix representation of non-zero entries in the Hessian." hess::H end -function Base.show(io::IO, t::GlobalHessianTracer) +function Base.show(io::IO, t::HessianTracer) println(io, "$(eltype(t))(") println(io, " Gradient: ", t.grad, ",") println(io, " Hessian: ", t.hess) @@ -161,22 +160,22 @@ function Base.show(io::IO, t::GlobalHessianTracer) return nothing end -function empty(::Type{GlobalHessianTracer{G,H}}) where {G,H} - return GlobalHessianTracer{G,H}(empty(G), empty(H)) +function empty(::Type{HessianTracer{G,H}}) where {G,H} + return HessianTracer{G,H}(empty(G), empty(H)) end -function GlobalHessianTracer{G,H}( +function HessianTracer{G,H}( ::Number ) where {G<:AbstractSet{<:Integer},H<:AbstractSet{<:Tuple{Integer,Integer}}} - return empty(GlobalHessianTracer{G,H}) + return empty(HessianTracer{G,H}) end -function GlobalHessianTracer{G,H}( - t::GlobalHessianTracer{G,H} +function HessianTracer{G,H}( + t::HessianTracer{G,H} ) where {G<:AbstractSet{<:Integer},H<:AbstractSet{<:Tuple{Integer,Integer}}} return t end -GlobalHessianTracer(t::GlobalHessianTracer) = t +HessianTracer(t::HessianTracer) = t #===========# # Utilities # @@ -185,14 +184,14 @@ GlobalHessianTracer(t::GlobalHessianTracer) = t """ tracer(T, index) where {T<:AbstractTracer} -Convenience constructor for [`ConnectivityTracer`](@ref), [`GlobalGradientTracer`](@ref) and [`GlobalHessianTracer`](@ref) from input indices. +Convenience constructor for [`ConnectivityTracer`](@ref), [`GradientTracer`](@ref) and [`HessianTracer`](@ref) from input indices. """ -function tracer(::Type{GlobalGradientTracer{G}}, index::Integer) where {G} - return GlobalGradientTracer{G}(sparse_vector(G, index)) +function tracer(::Type{GradientTracer{G}}, index::Integer) where {G} + return GradientTracer{G}(sparse_vector(G, index)) end function tracer(::Type{ConnectivityTracer{C}}, index::Integer) where {C} return ConnectivityTracer{C}(sparse_vector(C, index)) end -function tracer(::Type{GlobalHessianTracer{G,H}}, index::Integer) where {G,H} - return GlobalHessianTracer{G,H}(sparse_vector(G, index), empty(H)) +function tracer(::Type{HessianTracer{G,H}}, index::Integer) where {G,H} + return HessianTracer{G,H}(sparse_vector(G, index), empty(H)) end diff --git a/test/first_order.jl b/test/first_order.jl index baf53054..71b3aff1 100644 --- a/test/first_order.jl +++ b/test/first_order.jl @@ -1,6 +1,6 @@ using SparseConnectivityTracer using SparseConnectivityTracer: - ConnectivityTracer, GlobalGradientTracer, tracer, trace_input, empty + ConnectivityTracer, GradientTracer, tracer, trace_input, empty using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using Test @@ -8,7 +8,7 @@ using Test BitSet, Set{UInt64}, DuplicateVector{UInt64}, RecursiveSet{UInt64}, SortedVector{UInt64} ) CT = ConnectivityTracer{G} - JT = GlobalGradientTracer{G} + JT = GradientTracer{G} x = rand(3) xt = trace_input(CT, x) @@ -30,7 +30,7 @@ using Test @test connectivity_pattern(Returns(1), 1, G) ≈ [0;;] @test jacobian_pattern(Returns(1), 1, G) ≈ [0;;] - # Test GlobalGradientTracer on functions with zero derivatives + # Test GradientTracer on functions with zero derivatives x = rand(2) g(x) = [x[1] * x[2], ceil(x[1] * x[2]), x[1] * round(x[2])] @test connectivity_pattern(g, x, G) ≈ [1 1; 1 1; 1 1] diff --git a/test/second_order.jl b/test/second_order.jl index 6a318416..79f2e4c1 100644 --- a/test/second_order.jl +++ b/test/second_order.jl @@ -1,5 +1,5 @@ using SparseConnectivityTracer -using SparseConnectivityTracer: GlobalHessianTracer, tracer, trace_input, empty +using SparseConnectivityTracer: HessianTracer, tracer, trace_input, empty using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using Test @@ -18,7 +18,7 @@ end ) I = eltype(G) H = Set{Tuple{I,I}} - HT = GlobalHessianTracer{G,H} + HT = HessianTracer{G,H} @test hessian_pattern(identity, rand(), G, H) ≈ [0;;] @test hessian_pattern(sqrt, rand(), G, H) ≈ [1;;] From dad5e94ff8d9e3eae3e23c9d14da976e33175eee Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 15 May 2024 18:54:38 +0200 Subject: [PATCH 03/47] Add `Dual` and accessor functions --- src/tracers.jl | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/src/tracers.jl b/src/tracers.jl index ab88f9d5..c284bc12 100644 --- a/src/tracers.jl +++ b/src/tracers.jl @@ -42,9 +42,11 @@ struct ConnectivityTracer{C<:AbstractSet{<:Integer}} <: AbstractTracer inputs::C end +inputs(t::ConnectivityTracer) = t.inputs + function Base.show(io::IO, t::ConnectivityTracer) return Base.show_delim_array( - io, convert.(Int, sort(collect(t.inputs))), "$(typeof(t))(", ',', ')', true + io, convert.(Int, sort(collect(inputs(t)))), "$(typeof(t))(", ',', ')', true ) end @@ -93,9 +95,11 @@ struct GradientTracer{G<:AbstractSet{<:Integer}} <: AbstractTracer grad::G end +gradient(t::GradientTracer) = t.grad + function Base.show(io::IO, t::GradientTracer) return Base.show_delim_array( - io, convert.(Int, sort(collect(t.grad))), "$(typeof(t))(", ',', ')', true + io, convert.(Int, sort(collect(gradient(t)))), "$(typeof(t))(", ',', ')', true ) end @@ -152,10 +156,13 @@ struct HessianTracer{G<:AbstractSet{<:Integer},H<:AbstractSet{<:Tuple{Integer,In hess::H end +gradient(t::HessianTracer) = t.grad +hessian(t::HessianTracer) = t.hess + function Base.show(io::IO, t::HessianTracer) println(io, "$(eltype(t))(") - println(io, " Gradient: ", t.grad, ",") - println(io, " Hessian: ", t.hess) + println(io, " Gradient: ", gradient(t), ",") + println(io, " Hessian: ", hessian(t)) print(io, ")") return nothing end @@ -177,6 +184,30 @@ function HessianTracer{G,H}( end HessianTracer(t::HessianTracer) = t +#================================# +# Dual numbers for local tracing # +#================================# + +""" +$(TYPEDEF) + +Dual number type keeping track of the results of a primal computation as well as a tracer. + +## Fields +$(TYPEDFIELDS) +""" +struct Dual{P<:Number,T<:AbstractTracer} <: AbstractTracer + primal::P + tracer::T +end + +primal(d::Dual) = d.primal + +input(d::Dual{P,T}) where {P,T<:ConnectivityTracer} = input(d.tracer) +gradient(d::Dual{P,T}) where {P,T<:GradientTracer} = gradient(d.tracer) +gradient(d::Dual{P,T}) where {P,T<:HessianTracer} = gradient(d.tracer) +hessian(d::Dual{P,T}) where {P,T<:HessianTracer} = hessian(d.tracer) + #===========# # Utilities # #===========# From c29e0351c16570520cf85cab901aa677425256f9 Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 15 May 2024 19:20:25 +0200 Subject: [PATCH 04/47] Fix ops --- src/operators.jl | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/src/operators.jl b/src/operators.jl index 14fabe24..40a7bebc 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -14,8 +14,9 @@ function is_firstder_zero_global end function is_seconder_zero_global end -is_firstder_zero_local(f::Type, x, y) = is_firstder_zero_global(f) -is_seconder_zero_local(f::Type, x, y) = is_seconder_zero_global(f) +# Fallbacks for local derivatives: +is_firstder_zero_local(f, x) = is_firstder_zero_global(f) +is_seconder_zero_local(f, x) = is_seconder_zero_global(f) # ops_1_to_1_s: # ∂f/∂x != 0 @@ -104,6 +105,7 @@ ops_1_to_1 = union( ops_1_to_1_const, ) + ##==================================# # Operators for functions f: ℝ² → ℝ # #===================================# @@ -114,11 +116,12 @@ function is_firstder_arg2_zero_global end function is_seconder_arg2_zero_global end function is_crossder_zero_global end -is_firstder_arg1_zero_local(f::Type, x, y) = is_firstder_arg1_zero_global(f) -is_seconder_arg1_zero_local(f::Type, x, y) = is_firstder_arg1_zero_global(f) -is_firstder_arg2_zero_local(f::Type, x, y) = is_firstder_arg1_zero_global(f) -is_seconder_arg2_zero_local(f::Type, x, y) = is_firstder_arg1_zero_global(f) -is_crossder_zero_local(f::Type, x, y) = is_firstder_arg1_zero_global(f) +# Fallbacks for local derivatives: +is_firstder_arg1_zero_local(f, x, y) = is_firstder_arg1_zero_global(f) +is_seconder_arg1_zero_local(f, x, y) = is_firstder_arg1_zero_global(f) +is_firstder_arg2_zero_local(f, x, y) = is_firstder_arg1_zero_global(f) +is_seconder_arg2_zero_local(f, x, y) = is_firstder_arg1_zero_global(f) +is_crossder_zero_local(f, x, y) = is_firstder_arg1_zero_global(f) # ops_2_to_1_ssc: # ∂f/∂x != 0 @@ -205,9 +208,6 @@ for op in ops_2_to_1_fsc SparseConnectivityTracer.is_crossder_zero_global(::T) = false end -is_firstder_arg1_zero_local(::Type{typeof(:/)}, x, y) = iszero(x) -is_crossder_zero_local(::Type{typeof(:/)}, x, y) = iszero(x) - # ops_2_to_1_fsz: # ∂f/∂x != 0 # ∂²f/∂x² == 0 @@ -242,12 +242,6 @@ for op in ops_2_to_1_ffc SparseConnectivityTracer.is_crossder_zero_global(::T) = false end -is_firstder_arg1_zero_local(::Type{typeof(:*)}, x, y) = iszero(x) -is_seconder_arg1_zero_local(::Type{typeof(:*)}, x, y) = iszero(x) -is_firstder_arg2_zero_local(::Type{typeof(:*)}, x, y) = iszero(y) -is_seconder_arg2_zero_local(::Type{typeof(:*)}, x, y) = iszero(y) -is_crossder_zero_local(::Type{typeof(:*)}, x, y) = iszero(x) || iszero(y) - # ops_2_to_1_ffz: # ∂f/∂x != 0 # ∂²f/∂x² == 0 @@ -398,10 +392,12 @@ function is_seconder_out1_zero_global end function is_firstder_out2_zero_global end function is_seconder_out2_zero_global end -is_firstder_out1_zero_local(f::Type, x, y) = is_firstder_out1_zero_global(f) -is_seconder_out1_zero_local(f::Type, x, y) = is_seconder_out1_zero_global(f) -is_firstder_out2_zero_local(f::Type, x, y) = is_firstder_out2_zero_global(f) -is_seconder_out2_zero_local(f::Type, x, y) = is_seconder_out2_zero_global(f) +# Fallbacks for local derivatives: +is_seconder_out1_zero_local(f, x) = is_seconder_out1_zero_global(f) +is_firstder_out1_zero_local(f, x) = is_firstder_out1_zero_global(f) +is_firstder_out2_zero_local(f, x) = is_firstder_out2_zero_global(f) +is_seconder_out2_zero_local(f, x) = is_seconder_out2_zero_global(f) + # ops_1_to_2_ss: # ∂f₁/∂x != 0 From 15a2ccf42478531d184510bb4798481c4a93a211 Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 15 May 2024 20:38:39 +0200 Subject: [PATCH 05/47] Update ops --- src/operators.jl | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/src/operators.jl b/src/operators.jl index 40a7bebc..d28b0afb 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -15,7 +15,7 @@ function is_firstder_zero_global end function is_seconder_zero_global end # Fallbacks for local derivatives: -is_firstder_zero_local(f, x) = is_firstder_zero_global(f) +is_firstder_zero_local(f, x) = is_firstder_zero_global(f) is_seconder_zero_local(f, x) = is_seconder_zero_global(f) # ops_1_to_1_s: @@ -208,6 +208,9 @@ for op in ops_2_to_1_fsc SparseConnectivityTracer.is_crossder_zero_global(::T) = false end +# gradient of x/y: [1/y -x/y²] +SparseConnectivityTracer.is_firstder_arg2_zero_local(::typeof(Base.:/)) = iszero(x) + # ops_2_to_1_fsz: # ∂f/∂x != 0 # ∂²f/∂x² == 0 @@ -242,6 +245,11 @@ for op in ops_2_to_1_ffc SparseConnectivityTracer.is_crossder_zero_global(::T) = false end +# gradient of x*y: [y x] +SparseConnectivityTracer.is_firstder_arg1_zero_local(::typeof(Base.:*)) = iszero(y) +SparseConnectivityTracer.is_firstder_arg2_zero_local(::typeof(Base.:*)) = iszero(x) +SparseConnectivityTracer.is_crossder_zero_local(::typeof(Base.:*)) = iszero(x) || iszero(y) + # ops_2_to_1_ffz: # ∂f/∂x != 0 # ∂²f/∂x² == 0 @@ -262,15 +270,14 @@ for op in ops_2_to_1_ffz SparseConnectivityTracer.is_crossder_zero_global(::T) = true end +is_firstder_arg2_zero_local(::typeof(mod), x, y) = ifelse(y > 0, y > x, x > y) +is_firstder_arg2_zero_local(::typeof(rem), x, y) = ifelse(y > 0, y > x, x > y) -is_firstder_arg2_zero_local(::Type{typeof(mod)}, x, y) = ifelse(y > 0, y>x, x>y) -is_firstder_arg2_zero_local(::Type{typeof(rem)}, x, y) = ifelse(y > 0, y>x, x>y) - -is_firstder_arg1_zero_local(::Type{typeof(max)}, x, y) = x < y -is_firstder_arg2_zero_local(::Type{typeof(max)}, x, y) = y < x +is_firstder_arg1_zero_local(::typeof(max), x, y) = x < y +is_firstder_arg2_zero_local(::typeof(max), x, y) = y < x -is_firstder_arg1_zero_local(::Type{typeof(min)}, x, y) = x > y -is_firstder_arg2_zero_local(::Type{typeof(min)}, x, y) = y > x +is_firstder_arg1_zero_local(::typeof(min), x, y) = x > y +is_firstder_arg2_zero_local(::typeof(min), x, y) = y > x # ops_2_to_1_szz: # ∂f/∂x != 0 @@ -365,10 +372,9 @@ ops_2_to_1 = union( # Including second- and first-order ops_2_to_1_sfc, ops_2_to_1_sfz, - ops_2_to_1_fsc, ops_2_to_1_fsz, - + # Including first-order only ops_2_to_1_ffc, ops_2_to_1_ffz, @@ -376,11 +382,9 @@ ops_2_to_1 = union( # Including zero-order ops_2_to_1_szz, ops_2_to_1_zsz, - ops_2_to_1_fzz, ops_2_to_1_zfz, - - ops_2_to_1_zzz, + ops_2_to_1_zzz, ) ##==================================# From 5205fb344762c05b7c93762e74b0a719357704a9 Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 15 May 2024 20:41:00 +0200 Subject: [PATCH 06/47] Please the code coverage docs --- src/operators.jl | 225 +++++++++++++++++++++++------------------------ 1 file changed, 112 insertions(+), 113 deletions(-) diff --git a/src/operators.jl b/src/operators.jl index d28b0afb..139e31af 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -105,7 +105,6 @@ ops_1_to_1 = union( ops_1_to_1_const, ) - ##==================================# # Operators for functions f: ℝ² → ℝ # #===================================# @@ -148,14 +147,14 @@ end # ∂²f/∂y² != 0 # ∂²f/∂x∂y == 0 ops_2_to_1_ssz = () -for op in ops_2_to_1_ssz - T = typeof(eval(op)) - SparseConnectivityTracer.is_seconder_arg1_zero_global(::T) = false - SparseConnectivityTracer.is_firstder_arg1_zero_global(::T) = false - SparseConnectivityTracer.is_firstder_arg2_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_arg2_zero_global(::T) = false - SparseConnectivityTracer.is_crossder_zero_global(::T) = true -end +# for op in ops_2_to_1_ssz +# T = typeof(eval(op)) +# SparseConnectivityTracer.is_seconder_arg1_zero_global(::T) = false +# SparseConnectivityTracer.is_firstder_arg1_zero_global(::T) = false +# SparseConnectivityTracer.is_firstder_arg2_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_arg2_zero_global(::T) = false +# SparseConnectivityTracer.is_crossder_zero_global(::T) = true +# end # ops_2_to_1_sfc: # ∂f/∂x != 0 @@ -164,14 +163,14 @@ end # ∂²f/∂y² == 0 # ∂²f/∂x∂y != 0 ops_2_to_1_sfc = () -for op in ops_2_to_1_sfc - T = typeof(eval(op)) - SparseConnectivityTracer.is_firstder_arg1_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_arg1_zero_global(::T) = false - SparseConnectivityTracer.is_firstder_arg2_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_arg2_zero_global(::T) = true - SparseConnectivityTracer.is_crossder_zero_global(::T) = false -end +# for op in ops_2_to_1_sfc +# T = typeof(eval(op)) +# SparseConnectivityTracer.is_firstder_arg1_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_arg1_zero_global(::T) = false +# SparseConnectivityTracer.is_firstder_arg2_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_arg2_zero_global(::T) = true +# SparseConnectivityTracer.is_crossder_zero_global(::T) = false +# end # ops_2_to_1_sfz: # ∂f/∂x != 0 @@ -180,14 +179,14 @@ end # ∂²f/∂y² == 0 # ∂²f/∂x∂y == 0 ops_2_to_1_sfz = () -for op in ops_2_to_1_sfz - T = typeof(eval(op)) - SparseConnectivityTracer.is_firstder_arg1_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_arg1_zero_global(::T) = false - SparseConnectivityTracer.is_firstder_arg2_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_arg2_zero_global(::T) = true - SparseConnectivityTracer.is_crossder_zero_global(::T) = true -end +# for op in ops_2_to_1_sfz +# T = typeof(eval(op)) +# SparseConnectivityTracer.is_firstder_arg1_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_arg1_zero_global(::T) = false +# SparseConnectivityTracer.is_firstder_arg2_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_arg2_zero_global(::T) = true +# SparseConnectivityTracer.is_crossder_zero_global(::T) = true +# end # ops_2_to_1_fsc: # ∂f/∂x != 0 @@ -218,14 +217,14 @@ SparseConnectivityTracer.is_firstder_arg2_zero_local(::typeof(Base.:/)) = iszero # ∂²f/∂y² != 0 # ∂²f/∂x∂y == 0 ops_2_to_1_fsz = () -for op in ops_2_to_1_fsz - T = typeof(eval(op)) - SparseConnectivityTracer.is_firstder_arg1_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_arg1_zero_global(::T) = true - SparseConnectivityTracer.is_firstder_arg2_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_arg2_zero_global(::T) = false - SparseConnectivityTracer.is_crossder_zero_global(::T) = true -end +# for op in ops_2_to_1_fsz +# T = typeof(eval(op)) +# SparseConnectivityTracer.is_firstder_arg1_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_arg1_zero_global(::T) = true +# SparseConnectivityTracer.is_firstder_arg2_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_arg2_zero_global(::T) = false +# SparseConnectivityTracer.is_crossder_zero_global(::T) = true +# end # ops_2_to_1_ffc: # ∂f/∂x != 0 @@ -286,14 +285,14 @@ is_firstder_arg2_zero_local(::typeof(min), x, y) = y > x # ∂²f/∂y² == 0 # ∂²f/∂x∂y == 0 ops_2_to_1_szz = () -for op in ops_2_to_1_szz - T = typeof(eval(op)) - SparseConnectivityTracer.is_firstder_arg1_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_arg1_zero_global(::T) = false - SparseConnectivityTracer.is_firstder_arg2_zero_global(::T) = true - SparseConnectivityTracer.is_seconder_arg2_zero_global(::T) = true - SparseConnectivityTracer.is_crossder_zero_global(::T) = true -end +# for op in ops_2_to_1_szz +# T = typeof(eval(op)) +# SparseConnectivityTracer.is_firstder_arg1_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_arg1_zero_global(::T) = false +# SparseConnectivityTracer.is_firstder_arg2_zero_global(::T) = true +# SparseConnectivityTracer.is_seconder_arg2_zero_global(::T) = true +# SparseConnectivityTracer.is_crossder_zero_global(::T) = true +# end # ops_2_to_1_zsz: # ∂f/∂x == 0 @@ -302,14 +301,14 @@ end # ∂²f/∂y² != 0 # ∂²f/∂x∂y == 0 ops_2_to_1_zsz = () -for op in ops_2_to_1_zsz - T = typeof(eval(op)) - SparseConnectivityTracer.is_firstder_arg1_zero_global(::T) = true - SparseConnectivityTracer.is_seconder_arg1_zero_global(::T) = true - SparseConnectivityTracer.is_firstder_arg2_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_arg2_zero_global(::T) = false - SparseConnectivityTracer.is_crossder_zero_global(::T) = true -end +# for op in ops_2_to_1_zsz +# T = typeof(eval(op)) +# SparseConnectivityTracer.is_firstder_arg1_zero_global(::T) = true +# SparseConnectivityTracer.is_seconder_arg1_zero_global(::T) = true +# SparseConnectivityTracer.is_firstder_arg2_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_arg2_zero_global(::T) = false +# SparseConnectivityTracer.is_crossder_zero_global(::T) = true +# end # ops_2_to_1_fzz: # ∂f/∂x != 0 @@ -336,14 +335,14 @@ end # ∂²f/∂y² == 0 # ∂²f/∂x∂y == 0 ops_2_to_1_zfz = () -for op in ops_2_to_1_zfz - T = typeof(eval(op)) - SparseConnectivityTracer.is_firstder_arg1_zero_global(::T) = true - SparseConnectivityTracer.is_seconder_arg1_zero_global(::T) = true - SparseConnectivityTracer.is_firstder_arg2_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_arg2_zero_global(::T) = true - SparseConnectivityTracer.is_crossder_zero_global(::T) = true -end +# for op in ops_2_to_1_zfz +# T = typeof(eval(op)) +# SparseConnectivityTracer.is_firstder_arg1_zero_global(::T) = true +# SparseConnectivityTracer.is_seconder_arg1_zero_global(::T) = true +# SparseConnectivityTracer.is_firstder_arg2_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_arg2_zero_global(::T) = true +# SparseConnectivityTracer.is_crossder_zero_global(::T) = true +# end # ops_2_to_1_zfz: # ∂f/∂x == 0 @@ -427,13 +426,13 @@ end # ∂f₂/∂x != 0 # ∂²f₂/∂x² == 0 ops_1_to_2_sf = () -for op in ops_1_to_2_sf - T = typeof(eval(op)) - SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = false - SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = true -end +# for op in ops_1_to_2_sf +# T = typeof(eval(op)) +# SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = false +# SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = true +# end # ops_1_to_2_sz: # ∂f₁/∂x != 0 @@ -441,13 +440,13 @@ end # ∂f₂/∂x == 0 # ∂²f₂/∂x² == 0 ops_1_to_2_sz = () -for op in ops_1_to_2_sz - T = typeof(eval(op)) - SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = false - SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = true - SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = true -end +# for op in ops_1_to_2_sz +# T = typeof(eval(op)) +# SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = false +# SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = true +# SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = true +# end # ops_1_to_2_fs: # ∂f₁/∂x != 0 @@ -455,13 +454,13 @@ end # ∂f₂/∂x != 0 # ∂²f₂/∂x² != 0 ops_1_to_2_fs = () -for op in ops_1_to_2_fs - T = typeof(eval(op)) - SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = true - SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = false -end +# for op in ops_1_to_2_fs +# T = typeof(eval(op)) +# SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = true +# SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = false +# end # ops_1_to_2_ff: # ∂f₁/∂x != 0 @@ -469,13 +468,13 @@ end # ∂f₂/∂x != 0 # ∂²f₂/∂x² == 0 ops_1_to_2_ff = () -for op in ops_1_to_2_ff - T = typeof(eval(op)) - SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = true - SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = true -end +# for op in ops_1_to_2_ff +# T = typeof(eval(op)) +# SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = true +# SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = true +# end # ops_1_to_2_fz: # ∂f₁/∂x != 0 @@ -485,13 +484,13 @@ end ops_1_to_2_fz = ( # :frexp, # TODO: removed for now ) -for op in ops_1_to_2_fz - T = typeof(eval(op)) - SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = true - SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = true - SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = true -end +# for op in ops_1_to_2_fz +# T = typeof(eval(op)) +# SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = true +# SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = true +# SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = true +# end # ops_1_to_2_zs: # ∂f₁/∂x == 0 @@ -499,13 +498,13 @@ end # ∂f₂/∂x != 0 # ∂²f₂/∂x² != 0 ops_1_to_2_zs = () -for op in ops_1_to_2_zs - T = typeof(eval(op)) - SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = true - SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = true - SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = false -end +# for op in ops_1_to_2_zs +# T = typeof(eval(op)) +# SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = true +# SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = true +# SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = false +# end # ops_1_to_2_zf: # ∂f₁/∂x == 0 @@ -513,13 +512,13 @@ end # ∂f₂/∂x != 0 # ∂²f₂/∂x² == 0 ops_1_to_2_zf = () -for op in ops_1_to_2_zf - T = typeof(eval(op)) - SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = true - SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = true - SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = false - SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = true -end +# for op in ops_1_to_2_zf +# T = typeof(eval(op)) +# SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = true +# SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = true +# SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = false +# SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = true +# end # ops_1_to_2_zz: # ∂f₁/∂x == 0 @@ -527,13 +526,13 @@ end # ∂f₂/∂x == 0 # ∂²f₂/∂x² == 0 ops_1_to_2_zz = () -for op in ops_1_to_2_zz - T = typeof(eval(op)) - SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = true - SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = true - SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = true - SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = true -end +# for op in ops_1_to_2_zz +# T = typeof(eval(op)) +# SparseConnectivityTracer.is_firstder_out1_zero_global(::T) = true +# SparseConnectivityTracer.is_seconder_out1_zero_global(::T) = true +# SparseConnectivityTracer.is_firstder_out2_zero_global(::T) = true +# SparseConnectivityTracer.is_seconder_out2_zero_global(::T) = true +# end ops_1_to_2 = union( ops_1_to_2_ss, From 1124a842a07d4e0cb29d863d0a13958fd182e99a Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 15 May 2024 20:43:34 +0200 Subject: [PATCH 07/47] Fix typos --- src/operators.jl | 8 ++++---- src/tracers.jl | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/operators.jl b/src/operators.jl index 139e31af..731d421a 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -208,7 +208,7 @@ for op in ops_2_to_1_fsc end # gradient of x/y: [1/y -x/y²] -SparseConnectivityTracer.is_firstder_arg2_zero_local(::typeof(Base.:/)) = iszero(x) +SparseConnectivityTracer.is_firstder_arg2_zero_local(::typeof(Base.:/), x, y) = iszero(x) # ops_2_to_1_fsz: # ∂f/∂x != 0 @@ -245,9 +245,9 @@ for op in ops_2_to_1_ffc end # gradient of x*y: [y x] -SparseConnectivityTracer.is_firstder_arg1_zero_local(::typeof(Base.:*)) = iszero(y) -SparseConnectivityTracer.is_firstder_arg2_zero_local(::typeof(Base.:*)) = iszero(x) -SparseConnectivityTracer.is_crossder_zero_local(::typeof(Base.:*)) = iszero(x) || iszero(y) +SparseConnectivityTracer.is_firstder_arg1_zero_local(::typeof(Base.:*), x, y) = iszero(y) +SparseConnectivityTracer.is_firstder_arg2_zero_local(::typeof(Base.:*), x, y) = iszero(x) +SparseConnectivityTracer.is_crossder_zero_local(::typeof(Base.:*), x, y) = iszero(x) || iszero(y) # ops_2_to_1_ffz: # ∂f/∂x != 0 diff --git a/src/tracers.jl b/src/tracers.jl index c284bc12..31eaded9 100644 --- a/src/tracers.jl +++ b/src/tracers.jl @@ -203,7 +203,7 @@ end primal(d::Dual) = d.primal -input(d::Dual{P,T}) where {P,T<:ConnectivityTracer} = input(d.tracer) +inputs(d::Dual{P,T}) where {P,T<:ConnectivityTracer} = inputs(d.tracer) gradient(d::Dual{P,T}) where {P,T<:GradientTracer} = gradient(d.tracer) gradient(d::Dual{P,T}) where {P,T<:HessianTracer} = gradient(d.tracer) hessian(d::Dual{P,T}) where {P,T<:HessianTracer} = hessian(d.tracer) From a4db8724c0b5f4cb6b8a0be8e5a3615eee6ac2b0 Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 15 May 2024 20:44:14 +0200 Subject: [PATCH 08/47] Monkey-patch operator classific. on `min`/`max` --- test/classification.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/classification.jl b/test/classification.jl index 1c79dae8..0c95b3bc 100644 --- a/test/classification.jl +++ b/test/classification.jl @@ -132,6 +132,10 @@ function classify_2_to_1(f, x, y; atol) return (first_arg, second_arg, cross) end +# Some exceptions have to be manually specified +classify_2_to_1(::typeof(max), x, y; atol) = (first_order, first_order, zero_order) +classify_2_to_1(::typeof(min), x, y; atol) = (first_order, first_order, zero_order) + function classify_2_to_1(op::Symbol; atol=1e-5, trials=100) f = sym2fn(op) try From ce5c1723262ba9972527050588c454df4011bf00 Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 15 May 2024 20:44:26 +0200 Subject: [PATCH 09/47] Refactor `GradientTracer` overload --- src/overload_gradient.jl | 104 +++++++++++++++++++++------------------ 1 file changed, 57 insertions(+), 47 deletions(-) diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index 7bd8e9ae..2ce27b35 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -1,55 +1,65 @@ -for fn in union(ops_1_to_1_s, ops_1_to_1_f) - @eval Base.$fn(t::GradientTracer) = t +## 1-to-1 +for fn in ops_1_to_1 + @eval function Base.$fn(t::T) where {T<:GradientTracer} + if is_firstder_zero_global($fn) + return empty(T) + else + return t + end + end end -for fn in union(ops_1_to_1_z, ops_1_to_1_const) - @eval Base.$fn(::T) where {T<:GradientTracer} = empty(T) -end - -for fn in union( - ops_2_to_1_ssc, - ops_2_to_1_ssz, - ops_2_to_1_sfc, - ops_2_to_1_sfz, - ops_2_to_1_fsc, - ops_2_to_1_fsz, - ops_2_to_1_ffc, - ops_2_to_1_ffz, -) - @eval Base.$fn(a::T, b::T) where {T<:GradientTracer} = T(a.grad ∪ b.grad) - @eval Base.$fn(t::GradientTracer, ::Number) = t - @eval Base.$fn(::Number, t::GradientTracer) = t -end - -for fn in union(ops_2_to_1_zsz, ops_2_to_1_zfz) - @eval Base.$fn(::GradientTracer, t::GradientTracer) = t - @eval Base.$fn(::T, ::Number) where {T<:GradientTracer} = empty(T) - @eval Base.$fn(::Number, t::GradientTracer) = t -end -for fn in union(ops_2_to_1_szz, ops_2_to_1_fzz) - @eval Base.$fn(t::GradientTracer, ::GradientTracer) = t - @eval Base.$fn(t::GradientTracer, ::Number) = t - @eval Base.$fn(::Number, ::T) where {T<:GradientTracer} = empty(T) -end -for fn in ops_2_to_1_zzz - @eval Base.$fn(::T, ::T) where {T<:GradientTracer} = empty(T) - @eval Base.$fn(::T, ::Number) where {T<:GradientTracer} = empty(T) - @eval Base.$fn(::Number, ::T) where {T<:GradientTracer} = empty(T) -end +## 2-to-1 +for fn in ops_2_to_1 + @eval function Base.$fn(tx::T, ty::T) where {T<:GradientTracer} + ∂f∂x0 = is_firstder_arg1_zero_global($fn) + ∂f∂y0 = is_firstder_arg2_zero_global($fn) + if ∂f∂x0 + if ∂f∂y0 + return empty(T) + else # ∂f∂y ≠ 0 + return ty + end + else # ∂f∂x ≠ 0 + if ∂f∂y0 + return tx + else # ∂f∂y ≠ 0 + return T(gradient(tx) ∪ gradient(ty)) + end + end + end -for fn in union(ops_1_to_2_ss, ops_1_to_2_sf, ops_1_to_2_fs, ops_1_to_2_ff) - @eval Base.$fn(t::GradientTracer) = (t, t) + @eval function Base.$fn(t::T, ::Number) where {T<:GradientTracer} + if is_firstder_arg1_zero_global($fn) + return empty(T) + else + return t + end + end + @eval function Base.$fn(::Number, t::T) where {T<:GradientTracer} + if is_firstder_arg2_zero_global($fn) + return empty(T) + else + return t + end + end end -for fn in union(ops_1_to_2_sz, ops_1_to_2_fz) - @eval Base.$fn(t::T) where {T<:GradientTracer} = (t, empty(T)) -end - -for fn in union(ops_1_to_2_zs, ops_1_to_2_zf) - @eval Base.$fn(t::T) where {T<:GradientTracer} = (empty(T), t) -end -for fn in ops_1_to_2_zz - @eval Base.$fn(::T) where {T<:GradientTracer} = (empty(T), empty(T)) +## 1-to-2 +for fn in ops_1_to_2 + @eval function Base.$fn(t::T) where {T<:GradientTracer} + g1 = if is_firstder_out1_zero_global($fn) + empty(T) + else + t + end + g2 = if is_firstder_out2_zero_global($fn) + empty(T) + else + t + end + return (g1, g2) + end end # Extra types required for exponent From 6cb875be40bb3bdbd16bf35398ae3b5df6a56e1e Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 15 May 2024 21:29:10 +0200 Subject: [PATCH 10/47] First draft of Dual `GradientTracer` --- src/overload_gradient.jl | 75 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 71 insertions(+), 4 deletions(-) diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index 2ce27b35..8ebc7bb1 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -7,6 +7,15 @@ for fn in ops_1_to_1 return t end end + @eval function Base.$fn(t::D) where {P,T<:GradientTracer,D<:Dual{P,T}} + x = primal(t) + out = Base.$fn(x) + if is_firstder_zero_local($fn, x) + return Dual(out, empty(T)) + else + return Dual(out, t) + end + end end ## 2-to-1 @@ -28,6 +37,27 @@ for fn in ops_2_to_1 end end end + @eval function Base.$fn(tx::D, ty::D) where {P,T<:GradientTracer,D<:Dual{P,T}} + x = primal(tx) + y = primal(ty) + out = Base.$fn(x, y) + + ∂f∂x0 = is_firstder_arg1_zero_local($fn, x, y) + ∂f∂y0 = is_firstder_arg2_zero_local($fn, x, y) + if ∂f∂x0 + if ∂f∂y0 + return Dual(out, empty(T)) + else # ∂f∂y ≠ 0 + return Dual(out, ty) + end + else # ∂f∂x ≠ 0 + if ∂f∂y0 + return Dual(out, tx) + else # ∂f∂y ≠ 0 + return Dual(out, T(gradient(tx) ∪ gradient(ty))) + end + end + end @eval function Base.$fn(t::T, ::Number) where {T<:GradientTracer} if is_firstder_arg1_zero_global($fn) @@ -36,6 +66,16 @@ for fn in ops_2_to_1 return t end end + @eval function Base.$fn(tx::D, y::Number) where {P,T<:GradientTracer,D<:Dual{P,T}} + x = primal(tx) + out = Base.$fn(x, y) + if is_firstder_arg1_zero_local($fn, x, y) + return Dual(out, empty(T)) + else + return Dual(out, t) + end + end + @eval function Base.$fn(::Number, t::T) where {T<:GradientTracer} if is_firstder_arg2_zero_global($fn) return empty(T) @@ -43,26 +83,53 @@ for fn in ops_2_to_1 return t end end + @eval function Base.$fn(x::Number, ty::D) where {P,T<:GradientTracer,D<:Dual{P,T}} + y = primal(ty) + out = Base.$fn(x, y) + if is_firstder_arg2_zero_local($fn, x, y) + return Dual(out, empty(T)) + else + return Dual(out, t) + end + end end ## 1-to-2 for fn in ops_1_to_2 @eval function Base.$fn(t::T) where {T<:GradientTracer} - g1 = if is_firstder_out1_zero_global($fn) + tracer1 = if is_firstder_out1_zero_global($fn) + empty(T) + else + t + end + tracer2 = if is_firstder_out2_zero_global($fn) + empty(T) + else + t + end + return (tracer1,) + end + + @eval function Base.$fn(tx::D) where {P,T<:GradientTracer,D<:Dual{P,T}} + x = primal(tx) + out1, out2 = Base.$fn(x) + + tracer1 = if is_firstder_out1_zero_global($fn) empty(T) else t end - g2 = if is_firstder_out2_zero_global($fn) + tracer2 = if is_firstder_out2_zero_global($fn) empty(T) else t end - return (g1, g2) + return (Dual(out1, tracer1), Dual(out2, tracer2)) end end -# Extra types required for exponent +# TODO: support Dual tracers for these. +# Extra types required for exponent for T in (:Real, :Integer, :Rational) @eval Base.:^(t::GradientTracer, ::$T) = t @eval Base.:^(::$T, t::GradientTracer) = t From b1610ca2d586d1ae2f54f48f6b99bbf3d924fe83 Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 15 May 2024 21:42:06 +0200 Subject: [PATCH 11/47] add draft for local `HessianTracer` --- src/overload_hessian.jl | 101 ++++++++++++++++++++++++++++++++++------ src/tracers.jl | 1 + 2 files changed, 87 insertions(+), 15 deletions(-) diff --git a/src/overload_hessian.jl b/src/overload_hessian.jl index b9868d65..2f750ea1 100644 --- a/src/overload_hessian.jl +++ b/src/overload_hessian.jl @@ -8,7 +8,20 @@ for fn in ops_1_to_1 return t end else - return T(t.grad, t.hess ∪ (t.grad × t.grad)) + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) + end + end + @eval function Base.$fn(t::D) where {P,T<:HessianTracer,D<:Dual{P,T}} + x = primal(t) + out = Base.$fn(x) + if is_seconder_zero_local($fn, x) + if is_firstder_zero_local($fn, x) + return Dual(out, empty(T)) + else + return Dual(out, t) + end + else + return Dual(out, T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t)))) end end end @@ -19,24 +32,50 @@ for fn in ops_2_to_1 grad = empty(G) hess = empty(H) if !is_firstder_arg1_zero_global($fn) - grad = union(grad, a.grad) # TODO: use union! - union!(hess, a.hess) + grad = union(grad, gradient(a)) # TODO: use union! + union!(hess, hessian(a)) end if !is_firstder_arg2_zero_global($fn) - grad = union(grad, b.grad) # TODO: use union! - union!(hess, b.hess) + grad = union(grad, gradient(b)) # TODO: use union! + union!(hess, hessian(b)) end if !is_seconder_arg1_zero_global($fn) - union!(hess, a.grad × a.grad) + union!(hess, gradient(a) × gradient(a)) end if !is_seconder_arg2_zero_global($fn) - union!(hess, b.grad × b.grad) + union!(hess, gradient(b) × gradient(b)) end if !is_crossder_zero_global($fn) - union!(hess, (a.grad × b.grad) ∪ (b.grad × a.grad)) + union!(hess, (gradient(a) × gradient(b)) ∪ (gradient(b) × gradient(a))) end return T(grad, hess) end + @eval function Base.$fn(a::D, b::D) where {P,G,H,T<:HessianTracer{G,H},D<:Dual{P,T}} + x = primal(a) + y = primal(b) + out = Base.$fn(x, y) + + grad = empty(G) + hess = empty(H) + if !is_firstder_arg1_zero_local($fn, x, y) + grad = union(grad, gradient(a)) # TODO: use union! + union!(hess, hessian(a)) + end + if !is_firstder_arg2_zero_local($fn, x, y) + grad = union(grad, gradient(b)) # TODO: use union! + union!(hess, hessian(b)) + end + if !is_seconder_arg1_zero_local($fn, x, y) + union!(hess, gradient(a) × gradient(a)) + end + if !is_seconder_arg2_zero_local($fn, x, y) + union!(hess, gradient(b) × gradient(b)) + end + if !is_crossder_zero_local($fn, x, y) + union!(hess, (gradient(a) × gradient(b)) ∪ (gradient(b) × gradient(a))) + end + return Dual(out, T(grad, hess)) + end @eval function Base.$fn(t::T, ::Number) where {G,H,T<:HessianTracer{G,H}} if is_seconder_arg1_zero_global($fn) @@ -46,10 +85,26 @@ for fn in ops_2_to_1 return t end else - return T(t.grad, t.hess ∪ (t.grad × t.grad)) + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) + end + end + @eval function Base.$fn( + t::D, y::Number + ) where {P,G,H,T<:HessianTracer{G,H},D<:Dual{P,T}} + x = primal(t) + out = Base.$fn(x, y) + if is_seconder_arg1_zero_local($fn, x, y) + if is_firstder_arg1_zero_local($fn, x, y) + return Dual(out, empty(T)) + else + return Dual(out, t) + end + else + return Dual(out, T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t)))) end end - @eval function Base.$fn(::Number, t::T) where {G,H,T<:HessianTracer{G,H}} + + @eval function Base.$fn(x::Number, t::T) where {G,H,T<:HessianTracer{G,H}} if is_seconder_arg2_zero_global($fn) if is_firstder_arg2_zero_global($fn) return empty(T) @@ -57,25 +112,41 @@ for fn in ops_2_to_1 return t end else - return T(t.grad, t.hess ∪ (t.grad × t.grad)) + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) + end + end + @eval function Base.$fn( + x::Number, t::D + ) where {P,G,H,T<:HessianTracer{G,H},D<:Dual{P,T}} + y = primal(t) + out = Base.$fn(x, y) + if is_seconder_arg2_zero_local($fn, x, y) + if is_firstder_arg2_zero_local($fn, x, y) + return Dual(out, empty(T)) + else + return Dual(out, t) + end + else + return Dual(out, T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t)))) end end end +# TODO: support Dual tracers for these. # Extra types required for exponent for T in (:Real, :Integer, :Rational) @eval function Base.:^(t::T, ::$T) where {T<:HessianTracer} - return T(t.grad, t.hess ∪ (t.grad × t.grad)) + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) end @eval function Base.:^(::$T, t::T) where {T<:HessianTracer} - return T(t.grad, t.hess ∪ (t.grad × t.grad)) + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) end end function Base.:^(t::T, ::Irrational{:ℯ}) where {T<:HessianTracer} - return T(t.grad, t.hess ∪ (t.grad × t.grad)) + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) end function Base.:^(::Irrational{:ℯ}, t::T) where {T<:HessianTracer} - return T(t.grad, t.hess ∪ (t.grad × t.grad)) + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) end ## Rounding diff --git a/src/tracers.jl b/src/tracers.jl index 31eaded9..65e26486 100644 --- a/src/tracers.jl +++ b/src/tracers.jl @@ -202,6 +202,7 @@ struct Dual{P<:Number,T<:AbstractTracer} <: AbstractTracer end primal(d::Dual) = d.primal +tracer(d::Dual) = d.tracer inputs(d::Dual{P,T}) where {P,T<:ConnectivityTracer} = inputs(d.tracer) gradient(d::Dual{P,T}) where {P,T<:GradientTracer} = gradient(d.tracer) From 2a0e77de0c556f45e1a28d662a0b3b8c32f2ed50 Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 15 May 2024 21:46:30 +0200 Subject: [PATCH 12/47] Add Hessian overloads on functions 1-to-2 --- src/overload_gradient.jl | 2 +- src/overload_hessian.jl | 50 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index 8ebc7bb1..2fd8cecf 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -107,7 +107,7 @@ for fn in ops_1_to_2 else t end - return (tracer1,) + return (tracer1, tracer2) end @eval function Base.$fn(tx::D) where {P,T<:GradientTracer,D<:Dual{P,T}} diff --git a/src/overload_hessian.jl b/src/overload_hessian.jl index 2f750ea1..83058706 100644 --- a/src/overload_hessian.jl +++ b/src/overload_hessian.jl @@ -132,6 +132,56 @@ for fn in ops_2_to_1 end end +## 1-to-2 +for fn in ops_1_to_2 + @eval function Base.$fn(t::T) where {T<:HessianTracer} + tracer1 = if is_seconder_out1_zero_global($fn) + if is_firstder_out1_zero_global($fn) + return empty(T) + else + return t + end + else + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) + end + tracer2 = if is_seconder_out2_zero_global($fn) + if is_firstder_out2_zero_global($fn) + return empty(T) + else + return t + end + else + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) + end + return (tracer1, tracer2) + end + + @eval function Base.$fn(tx::D) where {P,T<:HessianTracer,D<:Dual{P,T}} + x = primal(tx) + out1, out2 = Base.$fn(x) + + tracer1 = if is_seconder_out1_zero_local($fn, x) + if is_firstder_out1_zero_local($fn, x) + return empty(T) + else + return t + end + else + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) + end + tracer2 = if is_seconder_out2_zero_local($fn, x) + if is_firstder_out2_zero_local($fn, x) + return empty(T) + else + return t + end + else + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) + end + return (Dual(out1, tracer1), Dual(out2, tracer2)) + end +end + # TODO: support Dual tracers for these. # Extra types required for exponent for T in (:Real, :Integer, :Rational) From 2c19273d92f3d83b492db21936229dcae1372e7e Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 15 May 2024 21:54:29 +0200 Subject: [PATCH 13/47] Update pattern --- src/pattern.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pattern.jl b/src/pattern.jl index b309a741..f9dcb3ed 100644 --- a/src/pattern.jl +++ b/src/pattern.jl @@ -86,7 +86,7 @@ function connectivity_pattern_to_mat( V = Bool[] # values for (i, y) in enumerate(yt) if y isa T - for j in y.inputs + for j in inputs(y) push!(I, i) push!(J, j) push!(V, true) @@ -145,7 +145,7 @@ function jacobian_pattern_to_mat( V = Bool[] # values for (i, y) in enumerate(yt) if y isa T - for j in y.grad + for j in gradient(y) push!(I, i) push!(J, j) push!(V, true) @@ -205,7 +205,7 @@ function hessian_pattern_to_mat( J = Int[] # column indices V = Bool[] # values - for (i, j) in yt.hess + for (i, j) in hessian(yt) push!(I, i) push!(J, j) push!(V, true) From cd5c3abb9ee0add6cd7cb343a93f3e64e1550a7c Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 15 May 2024 22:41:42 +0200 Subject: [PATCH 14/47] Get local `GradientTracer` to work --- src/SparseConnectivityTracer.jl | 4 +-- src/operators.jl | 1 - src/overload_connectivity.jl | 2 ++ src/overload_gradient.jl | 14 ++++----- src/overload_hessian.jl | 8 ++--- src/pattern.jl | 55 ++++++++++++++++++++++++++++++--- src/tracers.jl | 15 ++++++--- 7 files changed, 75 insertions(+), 24 deletions(-) diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index d8c8e38c..c5d3a657 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -20,8 +20,8 @@ include("pattern.jl") include("adtypes.jl") export connectivity_pattern -export jacobian_pattern -export hessian_pattern +export jacobian_pattern, local_jacobian_pattern +export hessian_pattern, local_hessian_pattern # ADTypes interface export TracerSparsityDetector diff --git a/src/operators.jl b/src/operators.jl index 731d421a..7210483e 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -247,7 +247,6 @@ end # gradient of x*y: [y x] SparseConnectivityTracer.is_firstder_arg1_zero_local(::typeof(Base.:*), x, y) = iszero(y) SparseConnectivityTracer.is_firstder_arg2_zero_local(::typeof(Base.:*), x, y) = iszero(x) -SparseConnectivityTracer.is_crossder_zero_local(::typeof(Base.:*), x, y) = iszero(x) || iszero(y) # ops_2_to_1_ffz: # ∂f/∂x != 0 diff --git a/src/overload_connectivity.jl b/src/overload_connectivity.jl index e213c293..8ba7ccb6 100644 --- a/src/overload_connectivity.jl +++ b/src/overload_connectivity.jl @@ -1,3 +1,5 @@ +# TODO: support Duals + for fn in union(ops_1_to_1_s, ops_1_to_1_f, ops_1_to_1_z) @eval Base.$fn(t::ConnectivityTracer) = t end diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index 2fd8cecf..8194658a 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -13,7 +13,7 @@ for fn in ops_1_to_1 if is_firstder_zero_local($fn, x) return Dual(out, empty(T)) else - return Dual(out, t) + return Dual(out, tracer(t)) end end end @@ -48,11 +48,11 @@ for fn in ops_2_to_1 if ∂f∂y0 return Dual(out, empty(T)) else # ∂f∂y ≠ 0 - return Dual(out, ty) + return Dual(out, tracer(ty)) end else # ∂f∂x ≠ 0 if ∂f∂y0 - return Dual(out, tx) + return Dual(out, tracer(tx)) else # ∂f∂y ≠ 0 return Dual(out, T(gradient(tx) ∪ gradient(ty))) end @@ -72,7 +72,7 @@ for fn in ops_2_to_1 if is_firstder_arg1_zero_local($fn, x, y) return Dual(out, empty(T)) else - return Dual(out, t) + return Dual(out, tracer(tx)) end end @@ -89,7 +89,7 @@ for fn in ops_2_to_1 if is_firstder_arg2_zero_local($fn, x, y) return Dual(out, empty(T)) else - return Dual(out, t) + return Dual(out, tracer(ty)) end end end @@ -117,12 +117,12 @@ for fn in ops_1_to_2 tracer1 = if is_firstder_out1_zero_global($fn) empty(T) else - t + tracer(tx) end tracer2 = if is_firstder_out2_zero_global($fn) empty(T) else - t + tracer(tx) end return (Dual(out1, tracer1), Dual(out2, tracer2)) end diff --git a/src/overload_hessian.jl b/src/overload_hessian.jl index 83058706..e15ee829 100644 --- a/src/overload_hessian.jl +++ b/src/overload_hessian.jl @@ -97,7 +97,7 @@ for fn in ops_2_to_1 if is_firstder_arg1_zero_local($fn, x, y) return Dual(out, empty(T)) else - return Dual(out, t) + return Dual(out, tracer(t)) end else return Dual(out, T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t)))) @@ -124,7 +124,7 @@ for fn in ops_2_to_1 if is_firstder_arg2_zero_local($fn, x, y) return Dual(out, empty(T)) else - return Dual(out, t) + return Dual(out, tracer(t)) end else return Dual(out, T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t)))) @@ -164,7 +164,7 @@ for fn in ops_1_to_2 if is_firstder_out1_zero_local($fn, x) return empty(T) else - return t + return tracer(t) end else return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) @@ -173,7 +173,7 @@ for fn in ops_1_to_2 if is_firstder_out2_zero_local($fn, x) return empty(T) else - return t + return tracer(t) end else return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) diff --git a/src/pattern.jl b/src/pattern.jl index f9dcb3ed..12e2632c 100644 --- a/src/pattern.jl +++ b/src/pattern.jl @@ -11,10 +11,13 @@ Enumerates input indices and constructs the specified type `T` of tracer. Supports [`ConnectivityTracer`](@ref), [`GradientTracer`](@ref) and [`HessianTracer`](@ref). """ trace_input(::Type{T}, x) where {T<:AbstractTracer} = trace_input(T, x, 1) -trace_input(::Type{T}, ::Number, i) where {T<:AbstractTracer} = tracer(T, i) -function trace_input(::Type{T}, x::AbstractArray, i) where {T<:AbstractTracer} - indices = (i - 1) .+ reshape(1:length(x), size(x)) - return tracer.(T, indices) + +function trace_input(::Type{T}, x::Number, i::Integer) where {T<:AbstractTracer} + return create_tracer(T, x, i) +end +function trace_input(::Type{T}, xs::AbstractArray, i) where {T<:AbstractTracer} + indices = reshape(1:length(xs), size(xs)) .+ (i - 1) + return create_tracer.(T, xs, indices) end ## Trace function @@ -123,6 +126,34 @@ function jacobian_pattern(f, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {G} return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) end +""" + local_jacobian_pattern(f, x) + local_jacobian_pattern(f, x, T) + +Compute the local sparsity pattern of the Jacobian of `y = f(x)` at `x`. + +The type of index set `S` can be specified as an optional argument and defaults to `BitSet`. + +## Example + +```jldoctest +julia> x = rand(3); + +julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sign(x[3])]; + +julia> local_jacobian_pattern(f, x) +3×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: + 1 ⋅ ⋅ + 1 1 ⋅ + ⋅ ⋅ ⋅ +``` +""" +function local_jacobian_pattern(f, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {G} + D = Dual{eltype(x),GradientTracer{G}} + xt, yt = trace_function(D, f, x) + return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) +end + """ jacobian_pattern(f!, y, x) jacobian_pattern(f!, y, x, T) @@ -136,9 +167,23 @@ function jacobian_pattern(f!, y, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {G} return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) end +""" + local_jacobian_pattern(f!, y, x) + local_jacobian_pattern(f!, y, x, T) + +Compute the local sparsity pattern of the Jacobian of `f!(y, x)` at `x`. + +The type of index set `S` can be specified as an optional argument and defaults to `BitSet`. +""" +function local_jacobian_pattern(f!, y, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {G} + D = Dual{eltype(x),GradientTracer{G}} + xt, yt = trace_function(D, f!, y, x) + return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) +end + function jacobian_pattern_to_mat( xt::AbstractArray{T}, yt::AbstractArray{<:Number} -) where {T<:GradientTracer} +) where {P,G<:GradientTracer,T<:Union{G,Dual{P,G}}} n, m = length(xt), length(yt) I = Int[] # row indices J = Int[] # column indices diff --git a/src/tracers.jl b/src/tracers.jl index 65e26486..6ba9062c 100644 --- a/src/tracers.jl +++ b/src/tracers.jl @@ -196,10 +196,11 @@ Dual number type keeping track of the results of a primal computation as well as ## Fields $(TYPEDFIELDS) """ -struct Dual{P<:Number,T<:AbstractTracer} <: AbstractTracer +struct Dual{P<:Number,T<:Union{GradientTracer,HessianTracer}} <: AbstractTracer primal::P tracer::T end +# TODO: support ConnectivityTracer primal(d::Dual) = d.primal tracer(d::Dual) = d.tracer @@ -214,16 +215,20 @@ hessian(d::Dual{P,T}) where {P,T<:HessianTracer} = hessian(d.tracer) #===========# """ - tracer(T, index) where {T<:AbstractTracer} + create_tracer(T, index) where {T<:AbstractTracer} Convenience constructor for [`ConnectivityTracer`](@ref), [`GradientTracer`](@ref) and [`HessianTracer`](@ref) from input indices. """ -function tracer(::Type{GradientTracer{G}}, index::Integer) where {G} +function create_tracer(::Type{Dual{P,T}}, primal::Number, index::Integer) where {P,T} + return Dual(primal, create_tracer(T, primal, index)) +end + +function create_tracer(::Type{GradientTracer{G}}, ::Number, index::Integer) where {G} return GradientTracer{G}(sparse_vector(G, index)) end -function tracer(::Type{ConnectivityTracer{C}}, index::Integer) where {C} +function create_tracer(::Type{ConnectivityTracer{C}}, ::Number, index::Integer) where {C} return ConnectivityTracer{C}(sparse_vector(C, index)) end -function tracer(::Type{HessianTracer{G,H}}, index::Integer) where {G,H} +function create_tracer(::Type{HessianTracer{G,H}}, ::Number, index::Integer) where {G,H} return HessianTracer{G,H}(sparse_vector(G, index), empty(H)) end From 6c462ae28ca04ed8b4722f3c308bab0a7eae2e75 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 12:03:06 +0200 Subject: [PATCH 15/47] Remove overload on `rem` --- src/operators.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operators.jl b/src/operators.jl index 7210483e..1f52deb0 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -269,7 +269,6 @@ for op in ops_2_to_1_ffz end is_firstder_arg2_zero_local(::typeof(mod), x, y) = ifelse(y > 0, y > x, x > y) -is_firstder_arg2_zero_local(::typeof(rem), x, y) = ifelse(y > 0, y > x, x > y) is_firstder_arg1_zero_local(::typeof(max), x, y) = x < y is_firstder_arg2_zero_local(::typeof(max), x, y) = y < x From 2f12b8f32098b659584785db9350128768e18608 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 12:03:56 +0200 Subject: [PATCH 16/47] Help Julia specialize --- src/operators.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/operators.jl b/src/operators.jl index 1f52deb0..39f6084b 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -15,8 +15,8 @@ function is_firstder_zero_global end function is_seconder_zero_global end # Fallbacks for local derivatives: -is_firstder_zero_local(f, x) = is_firstder_zero_global(f) -is_seconder_zero_local(f, x) = is_seconder_zero_global(f) +is_firstder_zero_local(f::F, x) where {F} = is_firstder_zero_global(f) +is_seconder_zero_local(f::F, x) where {F} = is_seconder_zero_global(f) # ops_1_to_1_s: # ∂f/∂x != 0 @@ -116,11 +116,11 @@ function is_seconder_arg2_zero_global end function is_crossder_zero_global end # Fallbacks for local derivatives: -is_firstder_arg1_zero_local(f, x, y) = is_firstder_arg1_zero_global(f) -is_seconder_arg1_zero_local(f, x, y) = is_firstder_arg1_zero_global(f) -is_firstder_arg2_zero_local(f, x, y) = is_firstder_arg1_zero_global(f) -is_seconder_arg2_zero_local(f, x, y) = is_firstder_arg1_zero_global(f) -is_crossder_zero_local(f, x, y) = is_firstder_arg1_zero_global(f) +is_firstder_arg1_zero_local(f::F, x, y) where {F} = is_firstder_arg1_zero_global(f) +is_seconder_arg1_zero_local(f::F, x, y) where {F} = is_firstder_arg1_zero_global(f) +is_firstder_arg2_zero_local(f::F, x, y) where {F} = is_firstder_arg1_zero_global(f) +is_seconder_arg2_zero_local(f::F, x, y) where {F} = is_firstder_arg1_zero_global(f) +is_crossder_zero_local(f::F, x, y) where {F} = is_firstder_arg1_zero_global(f) # ops_2_to_1_ssc: # ∂f/∂x != 0 @@ -394,10 +394,10 @@ function is_firstder_out2_zero_global end function is_seconder_out2_zero_global end # Fallbacks for local derivatives: -is_seconder_out1_zero_local(f, x) = is_seconder_out1_zero_global(f) -is_firstder_out1_zero_local(f, x) = is_firstder_out1_zero_global(f) -is_firstder_out2_zero_local(f, x) = is_firstder_out2_zero_global(f) -is_seconder_out2_zero_local(f, x) = is_seconder_out2_zero_global(f) +is_seconder_out1_zero_local(f::F, x) where {F} = is_seconder_out1_zero_global(f) +is_firstder_out1_zero_local(f::F, x) where {F} = is_firstder_out1_zero_global(f) +is_firstder_out2_zero_local(f::F, x) where {F} = is_firstder_out2_zero_global(f) +is_seconder_out2_zero_local(f::F, x) where {F} = is_seconder_out2_zero_global(f) # ops_1_to_2_ss: From c3a14aaad896ef1eca2eed0136d6427637e910aa Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 13:46:43 +0200 Subject: [PATCH 17/47] Refactor `GradientTracer` overloads --- src/overload_gradient.jl | 187 +++++++++++++++++++-------------------- 1 file changed, 90 insertions(+), 97 deletions(-) diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index 8194658a..763e561e 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -1,130 +1,123 @@ ## 1-to-1 +function gradient_tracer_1_to_1(t::T, is_firstder_zero::Bool) where {T<:GradientTracer} + if is_firstder_zero + return empty(T) + else + return t + end +end + for fn in ops_1_to_1 @eval function Base.$fn(t::T) where {T<:GradientTracer} - if is_firstder_zero_global($fn) - return empty(T) - else - return t - end + return gradient_tracer_1_to_1(t, is_firstder_zero_global($fn)) end @eval function Base.$fn(t::D) where {P,T<:GradientTracer,D<:Dual{P,T}} x = primal(t) - out = Base.$fn(x) - if is_firstder_zero_local($fn, x) - return Dual(out, empty(T)) + p_out = Base.$fn(x) + t_out = gradient_tracer_1_to_1(tracer(t), is_firstder_zero_local($fn, x)) + return Dual(p_out, t_out) + end +end + +## 2-to-1 +function gradient_tracer_2_to_1( + tx::T, + ty::T, + is_firstder_arg1_zero_or_number::Bool, + is_firstder_arg2_zero_or_number::Bool, +) where {T<:GradientTracer} + if is_firstder_arg1_zero_or_number + if is_firstder_arg2_zero_or_number + return empty(T) else - return Dual(out, tracer(t)) + return ty + end + else # ∂f∂x ≠ 0 + if is_firstder_arg2_zero_or_number + return tx + else + return T(gradient(tx) ∪ gradient(ty)) end end end -## 2-to-1 +function gradient_tracer_2_to_1_one_tracer( + t::T, is_firstder_zero::Bool +) where {T<:GradientTracer} + if is_firstder_zero + return empty(T) + else + return t + end +end + for fn in ops_2_to_1 @eval function Base.$fn(tx::T, ty::T) where {T<:GradientTracer} - ∂f∂x0 = is_firstder_arg1_zero_global($fn) - ∂f∂y0 = is_firstder_arg2_zero_global($fn) - if ∂f∂x0 - if ∂f∂y0 - return empty(T) - else # ∂f∂y ≠ 0 - return ty - end - else # ∂f∂x ≠ 0 - if ∂f∂y0 - return tx - else # ∂f∂y ≠ 0 - return T(gradient(tx) ∪ gradient(ty)) - end - end + return gradient_tracer_2_to_1( + tx, ty, is_firstder_arg1_zero_global($fn), is_firstder_arg2_zero_global($fn) + ) end - @eval function Base.$fn(tx::D, ty::D) where {P,T<:GradientTracer,D<:Dual{P,T}} - x = primal(tx) - y = primal(ty) - out = Base.$fn(x, y) - - ∂f∂x0 = is_firstder_arg1_zero_local($fn, x, y) - ∂f∂y0 = is_firstder_arg2_zero_local($fn, x, y) - if ∂f∂x0 - if ∂f∂y0 - return Dual(out, empty(T)) - else # ∂f∂y ≠ 0 - return Dual(out, tracer(ty)) - end - else # ∂f∂x ≠ 0 - if ∂f∂y0 - return Dual(out, tracer(tx)) - else # ∂f∂y ≠ 0 - return Dual(out, T(gradient(tx) ∪ gradient(ty))) - end - end + @eval function Base.$fn(dx::D, dy::D) where {P,T<:GradientTracer,D<:Dual{P,T}} + x = primal(dx) + y = primal(dy) + p_out = Base.$fn(x, y) + t_out = gradient_tracer_2_to_1( + tracer(dx), + tracer(dy), + is_firstder_arg1_zero_local($fn, x, y), + is_firstder_arg2_zero_local($fn, x, y), + ) + return Dual(p_out, t_out) end @eval function Base.$fn(t::T, ::Number) where {T<:GradientTracer} - if is_firstder_arg1_zero_global($fn) - return empty(T) - else - return t - end + return gradient_tracer_2_to_1_one_tracer(t, is_firstder_arg1_zero_global($fn)) end - @eval function Base.$fn(tx::D, y::Number) where {P,T<:GradientTracer,D<:Dual{P,T}} - x = primal(tx) - out = Base.$fn(x, y) - if is_firstder_arg1_zero_local($fn, x, y) - return Dual(out, empty(T)) - else - return Dual(out, tracer(tx)) - end + @eval function Base.$fn(dx::D, y::Number) where {P,T<:GradientTracer,D<:Dual{P,T}} + x = primal(dx) + p_out = Base.$fn(x, y) + t_out = gradient_tracer_2_to_1_one_tracer( + tracer(dx), is_firstder_arg1_zero_local($fn, x, y) + ) + return Dual(p_out, t_out) end @eval function Base.$fn(::Number, t::T) where {T<:GradientTracer} - if is_firstder_arg2_zero_global($fn) - return empty(T) - else - return t - end + return gradient_tracer_2_to_1_one_tracer(t, is_firstder_arg2_zero_global($fn)) end - @eval function Base.$fn(x::Number, ty::D) where {P,T<:GradientTracer,D<:Dual{P,T}} - y = primal(ty) - out = Base.$fn(x, y) - if is_firstder_arg2_zero_local($fn, x, y) - return Dual(out, empty(T)) - else - return Dual(out, tracer(ty)) - end + @eval function Base.$fn(x::Number, dy::D) where {P,T<:GradientTracer,D<:Dual{P,T}} + y = primal(dy) + p_out = Base.$fn(x, y) + t_out = gradient_tracer_2_to_1_one_tracer( + tracer(dx), is_firstder_arg2_zero_local($fn, x, y) + ) + return Dual(p_out, t_out) end end ## 1-to-2 +function gradient_tracer_1_to_2( + t::T, is_firstder_out1_zero::Bool, is_firstder_out2_zero::Bool +) where {T<:GradientTracer} + t1 = gradient_tracer_1_to_1(t, is_firstder_out1_zero) + t2 = gradient_tracer_1_to_1(t, is_firstder_out2_zero) + return (t1, t2) +end + for fn in ops_1_to_2 @eval function Base.$fn(t::T) where {T<:GradientTracer} - tracer1 = if is_firstder_out1_zero_global($fn) - empty(T) - else - t - end - tracer2 = if is_firstder_out2_zero_global($fn) - empty(T) - else - t - end - return (tracer1, tracer2) + return gradient_tracer_1_to_2( + t, is_firstder_out1_zero_global($fn), is_firstder_out2_zero_global($fn) + ) end - @eval function Base.$fn(tx::D) where {P,T<:GradientTracer,D<:Dual{P,T}} - x = primal(tx) - out1, out2 = Base.$fn(x) - - tracer1 = if is_firstder_out1_zero_global($fn) - empty(T) - else - tracer(tx) - end - tracer2 = if is_firstder_out2_zero_global($fn) - empty(T) - else - tracer(tx) - end - return (Dual(out1, tracer1), Dual(out2, tracer2)) + @eval function Base.$fn(dx::D) where {P,T<:GradientTracer,D<:Dual{P,T}} + x = primal(dx) + p1_out, p2_out = Base.$fn(x) + t1_out, t2_out = gradient_tracer_1_to_2( + t, is_firstder_out1_zero_local($fn, x), is_firstder_out2_zero_local($fn, x) + ) + return (Dual(p1_out, t1_out), Dual(p1_out, t1_out)) end end From 5fc8218b41ebd23a2442383936088fe873c41819 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 14:18:18 +0200 Subject: [PATCH 18/47] Add local first order tests --- test/first_order.jl | 120 +++++++++++++++++++++++++++----------------- 1 file changed, 74 insertions(+), 46 deletions(-) diff --git a/test/first_order.jl b/test/first_order.jl index 71b3aff1..560afcf9 100644 --- a/test/first_order.jl +++ b/test/first_order.jl @@ -4,52 +4,80 @@ using SparseConnectivityTracer: using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using Test -@testset "Set type $G" for G in ( +const FIRST_ORDER_SET_TYPES = ( BitSet, Set{UInt64}, DuplicateVector{UInt64}, RecursiveSet{UInt64}, SortedVector{UInt64} ) - CT = ConnectivityTracer{G} - JT = GradientTracer{G} - - x = rand(3) - xt = trace_input(CT, x) - - # Matrix multiplication - A = rand(1, 3) - yt = only(A * xt) - @test connectivity_pattern(x -> only(A * x), x, G) ≈ [1 1 1] - - # Custom functions - f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])] - yt = f(xt) - - @test connectivity_pattern(f, x, G) ≈ [1 0 0; 1 1 0; 0 0 1] - @test jacobian_pattern(f, x, G) ≈ [1 0 0; 1 1 0; 0 0 1] - - @test connectivity_pattern(identity, rand(), G) ≈ [1;;] - @test jacobian_pattern(identity, rand(), G) ≈ [1;;] - @test connectivity_pattern(Returns(1), 1, G) ≈ [0;;] - @test jacobian_pattern(Returns(1), 1, G) ≈ [0;;] - - # Test GradientTracer on functions with zero derivatives - x = rand(2) - g(x) = [x[1] * x[2], ceil(x[1] * x[2]), x[1] * round(x[2])] - @test connectivity_pattern(g, x, G) ≈ [1 1; 1 1; 1 1] - @test jacobian_pattern(g, x, G) ≈ [1 1; 0 0; 1 0] - - # Code coverage - @test connectivity_pattern(x -> [sincos(x)...], 1, G) ≈ [1; 1] - @test connectivity_pattern(typemax, 1, G) ≈ [0;;] - @test connectivity_pattern(x -> x^(2//3), 1, G) ≈ [1;;] - @test connectivity_pattern(x -> (2//3)^x, 1, G) ≈ [1;;] - @test connectivity_pattern(x -> x^ℯ, 1, G) ≈ [1;;] - @test connectivity_pattern(x -> ℯ^x, 1, G) ≈ [1;;] - @test connectivity_pattern(x -> round(x, RoundNearestTiesUp), 1, G) ≈ [1;;] - - @test jacobian_pattern(x -> [sincos(x)...], 1, G) ≈ [1; 1] - @test jacobian_pattern(typemax, 1, G) ≈ [0;;] - @test jacobian_pattern(x -> x^(2//3), 1, G) ≈ [1;;] - @test jacobian_pattern(x -> (2//3)^x, 1, G) ≈ [1;;] - @test jacobian_pattern(x -> x^ℯ, 1, G) ≈ [1;;] - @test jacobian_pattern(x -> ℯ^x, 1, G) ≈ [1;;] - @test jacobian_pattern(x -> round(x, RoundNearestTiesUp), 1, G) ≈ [0;;] + +@testset "Global" begin + @testset "Set type $G" for G in FIRST_ORDER_SET_TYPES + CT = ConnectivityTracer{G} + JT = GradientTracer{G} + + x = rand(3) + xt = trace_input(CT, x) + + # Matrix multiplication + A = rand(1, 3) + yt = only(A * xt) + @test connectivity_pattern(x -> only(A * x), x, G) ≈ [1 1 1] + + # Custom functions + f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])] + yt = f(xt) + + @test connectivity_pattern(f, x, G) ≈ [1 0 0; 1 1 0; 0 0 1] + @test jacobian_pattern(f, x, G) ≈ [1 0 0; 1 1 0; 0 0 1] + + @test connectivity_pattern(identity, rand(), G) ≈ [1;;] + @test jacobian_pattern(identity, rand(), G) ≈ [1;;] + @test connectivity_pattern(Returns(1), 1, G) ≈ [0;;] + @test jacobian_pattern(Returns(1), 1, G) ≈ [0;;] + + # Test GradientTracer on functions with zero derivatives + x = rand(2) + g(x) = [x[1] * x[2], ceil(x[1] * x[2]), x[1] * round(x[2])] + @test connectivity_pattern(g, x, G) ≈ [1 1; 1 1; 1 1] + @test jacobian_pattern(g, x, G) ≈ [1 1; 0 0; 1 0] + + # Code coverage + @test connectivity_pattern(x -> [sincos(x)...], 1, G) ≈ [1; 1] + @test connectivity_pattern(typemax, 1, G) ≈ [0;;] + @test connectivity_pattern(x -> x^(2//3), 1, G) ≈ [1;;] + @test connectivity_pattern(x -> (2//3)^x, 1, G) ≈ [1;;] + @test connectivity_pattern(x -> x^ℯ, 1, G) ≈ [1;;] + @test connectivity_pattern(x -> ℯ^x, 1, G) ≈ [1;;] + @test connectivity_pattern(x -> round(x, RoundNearestTiesUp), 1, G) ≈ [1;;] + + @test jacobian_pattern(x -> [sincos(x)...], 1, G) ≈ [1; 1] + @test jacobian_pattern(typemax, 1, G) ≈ [0;;] + @test jacobian_pattern(x -> x^(2//3), 1, G) ≈ [1;;] + @test jacobian_pattern(x -> (2//3)^x, 1, G) ≈ [1;;] + @test jacobian_pattern(x -> x^ℯ, 1, G) ≈ [1;;] + @test jacobian_pattern(x -> ℯ^x, 1, G) ≈ [1;;] + @test jacobian_pattern(x -> round(x, RoundNearestTiesUp), 1, G) ≈ [0;;] + end +end + +@testset "Local" verbose = true begin + @testset "Set type $G" for G in FIRST_ORDER_SET_TYPES + # Multiplication + @test local_jacobian_pattern(x -> x[1] * x[2], [1.0, 1.0], G) ≈ [1 1;] + @test local_jacobian_pattern(x -> x[1] * x[2], [1.0, 0.0], G) ≈ [0 1;] + @test local_jacobian_pattern(x -> x[1] * x[2], [0.0, 1.0], G) ≈ [1 0;] + @test local_jacobian_pattern(x -> x[1] * x[2], [0.0, 0.0], G) ≈ [0 0;] + + # Division + @test local_jacobian_pattern(x -> x[1] / x[2], [1.0, 1.0], G) ≈ [1 1;] + @test local_jacobian_pattern(x -> x[1] / x[2], [0.0, 0.0], G) ≈ [1 0;] + + # Maximum + @test local_jacobian_pattern(x -> max(x[1], x[2]), [1.0, 2.0], G) ≈ [0 1;] + @test local_jacobian_pattern(x -> max(x[1], x[2]), [2.0, 1.0], G) ≈ [1 0;] + @test local_jacobian_pattern(x -> max(x[1], x[2]), [1.0, 1.0], G) ≈ [1 1;] + + # Minimum + @test local_jacobian_pattern(x -> min(x[1], x[2]), [1.0, 2.0], G) ≈ [1 0;] + @test local_jacobian_pattern(x -> min(x[1], x[2]), [2.0, 1.0], G) ≈ [0 1;] + @test local_jacobian_pattern(x -> min(x[1], x[2]), [1.0, 1.0], G) ≈ [1 1;] + end end From a9cffc837956e71889cb7a7189766eca10b213a0 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 14:53:54 +0200 Subject: [PATCH 19/47] Minor fixes --- src/overload_gradient.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index 763e561e..0444e344 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -11,10 +11,10 @@ for fn in ops_1_to_1 @eval function Base.$fn(t::T) where {T<:GradientTracer} return gradient_tracer_1_to_1(t, is_firstder_zero_global($fn)) end - @eval function Base.$fn(t::D) where {P,T<:GradientTracer,D<:Dual{P,T}} - x = primal(t) + @eval function Base.$fn(d::D) where {P,T<:GradientTracer,D<:Dual{P,T}} + x = primal(d) p_out = Base.$fn(x) - t_out = gradient_tracer_1_to_1(tracer(t), is_firstder_zero_local($fn, x)) + t_out = gradient_tracer_1_to_1(tracer(d), is_firstder_zero_local($fn, x)) return Dual(p_out, t_out) end end @@ -70,8 +70,8 @@ for fn in ops_2_to_1 return Dual(p_out, t_out) end - @eval function Base.$fn(t::T, ::Number) where {T<:GradientTracer} - return gradient_tracer_2_to_1_one_tracer(t, is_firstder_arg1_zero_global($fn)) + @eval function Base.$fn(tx::T, ::Number) where {T<:GradientTracer} + return gradient_tracer_2_to_1_one_tracer(tx, is_firstder_arg1_zero_global($fn)) end @eval function Base.$fn(dx::D, y::Number) where {P,T<:GradientTracer,D<:Dual{P,T}} x = primal(dx) @@ -82,8 +82,8 @@ for fn in ops_2_to_1 return Dual(p_out, t_out) end - @eval function Base.$fn(::Number, t::T) where {T<:GradientTracer} - return gradient_tracer_2_to_1_one_tracer(t, is_firstder_arg2_zero_global($fn)) + @eval function Base.$fn(::Number, ty::T) where {T<:GradientTracer} + return gradient_tracer_2_to_1_one_tracer(ty, is_firstder_arg2_zero_global($fn)) end @eval function Base.$fn(x::Number, dy::D) where {P,T<:GradientTracer,D<:Dual{P,T}} y = primal(dy) @@ -111,8 +111,8 @@ for fn in ops_1_to_2 ) end - @eval function Base.$fn(dx::D) where {P,T<:GradientTracer,D<:Dual{P,T}} - x = primal(dx) + @eval function Base.$fn(d::D) where {P,T<:GradientTracer,D<:Dual{P,T}} + x = primal(d) p1_out, p2_out = Base.$fn(x) t1_out, t2_out = gradient_tracer_1_to_2( t, is_firstder_out1_zero_local($fn, x), is_firstder_out2_zero_local($fn, x) From efcd18ea1653ae0f6b50eb6038a2d69c250e05db Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 14:54:06 +0200 Subject: [PATCH 20/47] Test on `logdet` (#68) --- test/first_order.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/first_order.jl b/test/first_order.jl index 560afcf9..4d6e1f09 100644 --- a/test/first_order.jl +++ b/test/first_order.jl @@ -2,6 +2,8 @@ using SparseConnectivityTracer using SparseConnectivityTracer: ConnectivityTracer, GradientTracer, tracer, trace_input, empty using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector + +using LinearAlgebra: det, logdet using Test const FIRST_ORDER_SET_TYPES = ( @@ -79,5 +81,9 @@ end @test local_jacobian_pattern(x -> min(x[1], x[2]), [1.0, 2.0], G) ≈ [1 0;] @test local_jacobian_pattern(x -> min(x[1], x[2]), [2.0, 1.0], G) ≈ [0 1;] @test local_jacobian_pattern(x -> min(x[1], x[2]), [1.0, 1.0], G) ≈ [1 1;] + + # Linear algebra + @test local_jacobian_pattern(logdet, [1.0 0.0; 2.0 2.0], G) ≈ [1 1; 1 1] # (#68) + @test local_jacobian_pattern(x -> log(det(x)), [1.0 0.0; 2.0 2.0], G) ≈ [1 1; 1 1] end end From d9566bbd39f16a889abf6677bbc2ff895637d9fb Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 14:54:13 +0200 Subject: [PATCH 21/47] Overload comparisons --- src/overload_comparisons.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 src/overload_comparisons.jl diff --git a/src/overload_comparisons.jl b/src/overload_comparisons.jl new file mode 100644 index 00000000..78ce1b20 --- /dev/null +++ b/src/overload_comparisons.jl @@ -0,0 +1,20 @@ +# Overload comparisons on Dual numbers +for fn in ( + :iseven, + :isfinite, + :isinf, + :isinteger, + :ismissing, + :isnan, + :isnothing, + :isodd, + :isone, + :isreal, + :iszero, +) + @eval Base.$fn(d::D) where {D<:Dual} = Base.$fn(primal(d)) +end + +for fn in (:isequal, :isapprox, :isless, :(==), :(<), :(>), :(<=), :(>=)) + @eval Base.$fn(dx::D, dy::D) where {D<:Dual} = Base.$fn(primal(dx), primal(dy)) +end From 86c6a440e41039fa67ace0eb8129f06a4c84b54a Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 15:29:16 +0200 Subject: [PATCH 22/47] Remove constant functions from `operators.jl` --- src/conversion.jl | 23 ++++++++++++++++++----- src/operators.jl | 17 ----------------- src/overload_connectivity.jl | 6 +----- 3 files changed, 19 insertions(+), 27 deletions(-) diff --git a/src/conversion.jl b/src/conversion.jl index 1e9e419b..574f132a 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -1,9 +1,10 @@ -## Type conversions +## Type conversions (non-dual) for TT in (:GradientTracer, :ConnectivityTracer, :HessianTracer) @eval Base.promote_rule(::Type{T}, ::Type{N}) where {T<:$TT,N<:Number} = T @eval Base.promote_rule(::Type{N}, ::Type{T}) where {T<:$TT,N<:Number} = T @eval Base.big(::Type{T}) where {T<:$TT} = T + @eval Base.big(t::T) where {T<:$TT} = t @eval Base.widen(::Type{T}) where {T<:$TT} = T @eval Base.widen(t::T) where {T<:$TT} = t @@ -12,10 +13,22 @@ for TT in (:GradientTracer, :ConnectivityTracer, :HessianTracer) @eval Base.convert(::Type{<:Number}, t::T) where {T<:$TT} = t ## Constants - @eval Base.zero(::Type{T}) where {T<:$TT} = empty(T) - @eval Base.one(::Type{T}) where {T<:$TT} = empty(T) - @eval Base.typemin(::Type{T}) where {T<:$TT} = empty(T) - @eval Base.typemax(::Type{T}) where {T<:$TT} = empty(T) + @eval Base.zero(::Type{T}) where {T<:$TT} = empty(T) + @eval Base.zero(::T) where {T<:$TT} = empty(T) + @eval Base.one(::Type{T}) where {T<:$TT} = empty(T) + @eval Base.one(::T) where {T<:$TT} = empty(T) + @eval Base.typemin(::Type{T}) where {T<:$TT} = empty(T) + @eval Base.typemin(::T) where {T<:$TT} = empty(T) + @eval Base.typemax(::Type{T}) where {T<:$TT} = empty(T) + @eval Base.typemax(::T) where {T<:$TT} = empty(T) + @eval Base.eps(::Type{T}) where {T<:$TT} = empty(T) + @eval Base.eps(::T) where {T<:$TT} = empty(T) + @eval Base.floatmin(::Type{T}) where {T<:$TT} = empty(T) + @eval Base.floatmin(::T) where {T<:$TT} = empty(T) + @eval Base.floatmax(::Type{T}) where {T<:$TT} = empty(T) + @eval Base.floatmax(::T) where {T<:$TT} = empty(T) + @eval Base.maxintfloat(::Type{T}) where {T<:$TT} = empty(T) + @eval Base.maxintfloat(::T) where {T<:$TT} = empty(T) ## Array constructors @eval Base.similar(a::Array{T,1}) where {T<:$TT} = zeros(T, size(a, 1)) diff --git a/src/operators.jl b/src/operators.jl index 39f6084b..0a6c51cc 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -82,27 +82,10 @@ for op in ops_1_to_1_z SparseConnectivityTracer.is_seconder_zero_global(::T) = true end -# Functions returning constant output -# that only depends on the input type. -# For the purpose of operator overloading, -# these are kept separate from ops_1_to_1_z. -ops_1_to_1_const = ( - :zero, :one, - :eps, - :typemin, :typemax, - :floatmin, :floatmax, :maxintfloat, -) -for op in ops_1_to_1_const - T = typeof(eval(op)) - SparseConnectivityTracer.is_firstder_zero_global(::T) = true - SparseConnectivityTracer.is_seconder_zero_global(::T) = true -end - ops_1_to_1 = union( ops_1_to_1_s, ops_1_to_1_f, ops_1_to_1_z, - ops_1_to_1_const, ) ##==================================# diff --git a/src/overload_connectivity.jl b/src/overload_connectivity.jl index 8ba7ccb6..8ed583f8 100644 --- a/src/overload_connectivity.jl +++ b/src/overload_connectivity.jl @@ -1,13 +1,9 @@ # TODO: support Duals -for fn in union(ops_1_to_1_s, ops_1_to_1_f, ops_1_to_1_z) +for fn in ops_1_to_1 @eval Base.$fn(t::ConnectivityTracer) = t end -for fn in ops_1_to_1_const - @eval Base.$fn(::T) where {T<:ConnectivityTracer} = empty(T) -end - for fn in ops_1_to_2 @eval Base.$fn(t::ConnectivityTracer) = (t, t) end From 424f5a06f69711acbb8a77bd8908351291364d5a Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 15:32:59 +0200 Subject: [PATCH 23/47] Start adding conversions on duals --- src/conversion.jl | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/conversion.jl b/src/conversion.jl index 574f132a..7eea2473 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -48,3 +48,42 @@ end function Base.similar(::Array, ::Type{HessianTracer{G,H}}, dims::Dims{N}) where {G,H,N} return zeros(HessianTracer{G,H}, dims) end + +## Duals +function Base.promote_rule(::Type{D}, ::Type{N}) where {P,T,D<:Dual{P,T},N<:Number} + PP = Base.promote_rule(P, N) # TODO: possible method call error? + return D{PP,T} +end +function Base.promote_rule(::Type{N}, ::Type{D}) where {P,T,D<:Dual{P,T},N<:Number} + PP = Base.promote_rule(P, N) # TODO: possible method call error? + return D{PP,T} +end + +Base.big(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{big(P),T} +Base.big(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual(big(primal(d)), tracer(d)) +Base.widen(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{widen(P),T} +Base.widen(d::D) where {P,T,D<:Dual{P,T}} = Dual(widen(primal(d)), tracer(d)) + +Base.convert(::Type{D}, x::Number) where {P,T,D<:Dual{P,T}} = Dual(x, empty(T)) +Base.convert(::Type{D}, d::D) where {D<:Dual} = d +function Base.convert(::Type{T}, d::D) where {T<:Number,D<:Dual} + return Dual(convert(T, primal(d)), tracer(d)) +end + +## Constants +Base.zero(::Type{D}) where {P,T,D<:Dual{P,T}} = D(zero(P), empty(T)) +Base.zero(d::D) where {P,T,D<:Dual{P,T}} = D(zero(primal(d)), empty(T)) +Base.one(::Type{D}) where {P,T,D<:Dual{P,T}} = D(one(P), empty(T)) +Base.one(d::D) where {P,T,D<:Dual{P,T}} = D(one(primal(d)), empty(T)) +Base.typemin(::Type{D}) where {P,T,D<:Dual{P,T}} = D(typemin(P), empty(T)) +Base.typemin(d::D) where {P,T,D<:Dual{P,T}} = D(typemin(primal(d)), empty(T)) +Base.typemax(::Type{D}) where {P,T,D<:Dual{P,T}} = D(typemax(P), empty(T)) +Base.typemax(d::D) where {P,T,D<:Dual{P,T}} = D(typemax(primal(d)), empty(T)) +Base.eps(::Type{D}) where {P,T,D<:Dual{P,T}} = D(eps(P), empty(T)) +Base.eps(d::D) where {P,T,D<:Dual{P,T}} = D(eps(primal(d)), empty(T)) +Base.floatmin(::Type{D}) where {P,T,D<:Dual{P,T}} = D(floatmin(P), empty(T)) +Base.floatmin(d::D) where {P,T,D<:Dual{P,T}} = D(floatmin(primal(d)), empty(T)) +Base.floatmax(::Type{D}) where {P,T,D<:Dual{P,T}} = D(floatmax(P), empty(T)) +Base.floatmax(d::D) where {P,T,D<:Dual{P,T}} = D(floatmax(primal(d)), empty(T)) +Base.maxintfloat(::Type{D}) where {P,T,D<:Dual{P,T}} = D(maxintfloat(P), empty(T)) +Base.maxintfloat(d::D) where {P,T,D<:Dual{P,T}} = D(maxintfloat(primal(d)), empty(T)) From 40b27d1e2c67c887b3d98ee6457d89d59f85fd49 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 16:08:18 +0200 Subject: [PATCH 24/47] Support Dual `GradientTracer` on `^` and `rand` --- src/SparseConnectivityTracer.jl | 1 + src/overload_comparisons.jl | 4 ++-- src/overload_gradient.jl | 24 +++++++++++++++++------- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index c5d3a657..46fbd61d 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -16,6 +16,7 @@ include("operators.jl") include("overload_connectivity.jl") include("overload_gradient.jl") include("overload_hessian.jl") +include("overload_comparisons.jl") include("pattern.jl") include("adtypes.jl") diff --git a/src/overload_comparisons.jl b/src/overload_comparisons.jl index 78ce1b20..5df719d5 100644 --- a/src/overload_comparisons.jl +++ b/src/overload_comparisons.jl @@ -12,9 +12,9 @@ for fn in ( :isreal, :iszero, ) - @eval Base.$fn(d::D) where {D<:Dual} = Base.$fn(primal(d)) + @eval Base.$fn(d::D) where {D<:Dual} = $fn(primal(d)) end for fn in (:isequal, :isapprox, :isless, :(==), :(<), :(>), :(<=), :(>=)) - @eval Base.$fn(dx::D, dy::D) where {D<:Dual} = Base.$fn(primal(dx), primal(dy)) + @eval Base.$fn(dx::D, dy::D) where {D<:Dual} = $fn(primal(dx), primal(dy)) end diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index 0444e344..7ae42de6 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -121,17 +121,27 @@ for fn in ops_1_to_2 end end -# TODO: support Dual tracers for these. # Extra types required for exponent -for T in (:Real, :Integer, :Rational) - @eval Base.:^(t::GradientTracer, ::$T) = t - @eval Base.:^(::$T, t::GradientTracer) = t +for T in (Real, Integer, Rational, Irrational{:ℯ}) + Base.:^(t::GradientTracer, ::T) = t + Base.:^(::T, t::GradientTracer) = t + + function Base.:^(dx::D, y::T) where {P,T<:GradientTracer,D<:Dual{P,T}} + return Dual(primal(dx)^y, tracer(dx)) + end + function Base.:^(x::T, dy::D) where {P,T<:GradientTracer,D<:Dual{P,T}} + return Dual(x^primal(dy), tracer(dy)) + end end -Base.:^(t::GradientTracer, ::Irrational{:ℯ}) = t -Base.:^(::Irrational{:ℯ}, t::GradientTracer) = t ## Rounding Base.round(t::T, ::RoundingMode; kwargs...) where {T<:GradientTracer} = empty(T) +function Base.round( + d::D, mode::RoundingMode; kwargs... +) where {P,T<:GradientTracer,D<:Dual{P,T}} + return Dual(round(primal(d), mode; kwargs...), empty(T)) +end -## Random numbers +## Random numbers +# TODO: support random numbers on Duals rand(::AbstractRNG, ::SamplerType{T}) where {T<:GradientTracer} = empty(T) From 770a129b818a3c99133088bdb065cd8cfc800c39 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 16:08:44 +0200 Subject: [PATCH 25/47] First draft of `similar` on Duals --- src/conversion.jl | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/conversion.jl b/src/conversion.jl index 7eea2473..38b975f7 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -87,3 +87,33 @@ Base.floatmax(::Type{D}) where {P,T,D<:Dual{P,T}} = D(floatmax(P), empty(T)) Base.floatmax(d::D) where {P,T,D<:Dual{P,T}} = D(floatmax(primal(d)), empty(T)) Base.maxintfloat(::Type{D}) where {P,T,D<:Dual{P,T}} = D(maxintfloat(P), empty(T)) Base.maxintfloat(d::D) where {P,T,D<:Dual{P,T}} = D(maxintfloat(primal(d)), empty(T)) + +## Array constructors +function Base.similar(a::Array{D,1}) where {P,T,D<:Dual{P,T}} + p_out = similar(primal.(a)) + return Dual.(p_out, empty(T)) +end +function Base.similar(a::Array{D,2}) where {P,T,D<:Dual{P,T}} + p_out = similar(primal.(a)) + return Dual.(p_out, empty(T)) +end +function Base.similar(a::Array{A,1}, ::Type{D}) where {A,P,T,D<:Dual{P,T}} + p_out = similar(a, P) + return Dual.(p_out, empty(T)) +end +function Base.similar(a::Array{A,2}, ::Type{D}) where {A,P,T,D<:Dual{P,T}} + p_out = similar(a, P) + return Dual.(p_out, empty(T)) +end +function Base.similar(a::Array{D}, m::Int) where {P,T,D<:Dual{P,T}} + p_out = similar(primal.(a), m) + return Dual.(p_out, empty(T)) +end +function Base.similar(a::Array{D}, dims::Dims{N}) where {N,D<:Dual} + p_out = similar(primal.(a), dims) + return Dual.(p_out, empty(T)) +end +function Base.similar(a::Array, ::Type{D}, dims::Dims{N}) where {P,T,D<:Dual{P,T},N} + p_out = similar(primal.(a), P, dims) + return Dual.(p_out, empty(T)) +end From 951b157d93bfd1ea50b4f8cf175607d8255eaed5 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 16:20:26 +0200 Subject: [PATCH 26/47] Remove `@eval` from `conversion.jl` --- src/conversion.jl | 125 ++++++++++++++++++++++------------------------ 1 file changed, 60 insertions(+), 65 deletions(-) diff --git a/src/conversion.jl b/src/conversion.jl index 38b975f7..21616758 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -1,52 +1,46 @@ +#! format: off + ## Type conversions (non-dual) -for TT in (:GradientTracer, :ConnectivityTracer, :HessianTracer) - @eval Base.promote_rule(::Type{T}, ::Type{N}) where {T<:$TT,N<:Number} = T - @eval Base.promote_rule(::Type{N}, ::Type{T}) where {T<:$TT,N<:Number} = T +for TT in (GradientTracer, ConnectivityTracer, HessianTracer) + Base.promote_rule(::Type{T}, ::Type{N}) where {T<:TT,N<:Number} = T + Base.promote_rule(::Type{N}, ::Type{T}) where {T<:TT,N<:Number} = T - @eval Base.big(::Type{T}) where {T<:$TT} = T - @eval Base.big(t::T) where {T<:$TT} = t - @eval Base.widen(::Type{T}) where {T<:$TT} = T - @eval Base.widen(t::T) where {T<:$TT} = t + Base.big(::Type{T}) where {T<:TT} = T + Base.widen(::Type{T}) where {T<:TT} = T + Base.big(t::T) where {T<:TT} = t + Base.widen(t::T) where {T<:TT} = t - @eval Base.convert(::Type{T}, x::Number) where {T<:$TT} = empty(T) - @eval Base.convert(::Type{T}, t::T) where {T<:$TT} = t - @eval Base.convert(::Type{<:Number}, t::T) where {T<:$TT} = t + Base.convert(::Type{T}, x::Number) where {T<:TT} = empty(T) + Base.convert(::Type{T}, t::T) where {T<:TT} = t + Base.convert(::Type{<:Number}, t::T) where {T<:TT} = t ## Constants - @eval Base.zero(::Type{T}) where {T<:$TT} = empty(T) - @eval Base.zero(::T) where {T<:$TT} = empty(T) - @eval Base.one(::Type{T}) where {T<:$TT} = empty(T) - @eval Base.one(::T) where {T<:$TT} = empty(T) - @eval Base.typemin(::Type{T}) where {T<:$TT} = empty(T) - @eval Base.typemin(::T) where {T<:$TT} = empty(T) - @eval Base.typemax(::Type{T}) where {T<:$TT} = empty(T) - @eval Base.typemax(::T) where {T<:$TT} = empty(T) - @eval Base.eps(::Type{T}) where {T<:$TT} = empty(T) - @eval Base.eps(::T) where {T<:$TT} = empty(T) - @eval Base.floatmin(::Type{T}) where {T<:$TT} = empty(T) - @eval Base.floatmin(::T) where {T<:$TT} = empty(T) - @eval Base.floatmax(::Type{T}) where {T<:$TT} = empty(T) - @eval Base.floatmax(::T) where {T<:$TT} = empty(T) - @eval Base.maxintfloat(::Type{T}) where {T<:$TT} = empty(T) - @eval Base.maxintfloat(::T) where {T<:$TT} = empty(T) + Base.zero(::Type{T}) where {T<:TT} = empty(T) + Base.one(::Type{T}) where {T<:TT} = empty(T) + Base.typemin(::Type{T}) where {T<:TT} = empty(T) + Base.typemax(::Type{T}) where {T<:TT} = empty(T) + Base.eps(::Type{T}) where {T<:TT} = empty(T) + Base.floatmin(::Type{T}) where {T<:TT} = empty(T) + Base.floatmax(::Type{T}) where {T<:TT} = empty(T) + Base.maxintfloat(::Type{T}) where {T<:TT} = empty(T) + + Base.zero(::T) where {T<:TT} = empty(T) + Base.one(::T) where {T<:TT} = empty(T) + Base.typemin(::T) where {T<:TT} = empty(T) + Base.typemax(::T) where {T<:TT} = empty(T) + Base.eps(::T) where {T<:TT} = empty(T) + Base.floatmin(::T) where {T<:TT} = empty(T) + Base.floatmax(::T) where {T<:TT} = empty(T) + Base.maxintfloat(::T) where {T<:TT} = empty(T) ## Array constructors - @eval Base.similar(a::Array{T,1}) where {T<:$TT} = zeros(T, size(a, 1)) - @eval Base.similar(a::Array{T,2}) where {T<:$TT} = zeros(T, size(a, 1), size(a, 2)) - @eval Base.similar(a::Array{A,1}, ::Type{T}) where {A,T<:$TT} = zeros(T, size(a, 1)) - @eval Base.similar(a::Array{A,2}, ::Type{T}) where {A,T<:$TT} = zeros(T, size(a, 1), size(a, 2)) - @eval Base.similar(::Array{T}, m::Int) where {T<:$TT} = zeros(T, m) - @eval Base.similar(::Array{T}, dims::Dims{N}) where {N,T<:$TT} = zeros(T, dims) -end - -function Base.similar(::Array, ::Type{ConnectivityTracer{C}}, dims::Dims{N}) where {C,N} - return zeros(ConnectivityTracer{C}, dims) -end -function Base.similar(::Array, ::Type{GradientTracer{G}}, dims::Dims{N}) where {G,N} - return zeros(GradientTracer{G}, dims) -end -function Base.similar(::Array, ::Type{HessianTracer{G,H}}, dims::Dims{N}) where {G,H,N} - return zeros(HessianTracer{G,H}, dims) + Base.similar(a::Array{T,1}) where {T<:TT} = zeros(T, size(a, 1)) + Base.similar(a::Array{T,2}) where {T<:TT} = zeros(T, size(a, 1), size(a, 2)) + Base.similar(a::Array{A,1}, ::Type{T}) where {T<:TT,A} = zeros(T, size(a, 1)) + Base.similar(a::Array{A,2}, ::Type{T}) where {T<:TT,A} = zeros(T, size(a, 1), size(a, 2)) + Base.similar(::Array{T}, m::Int) where {T<:TT} = zeros(T, m) + Base.similar(::Array{T}, dims::Dims{N}) where {T<:TT,N} = zeros(T, dims) + Base.similar(::Array, ::Type{T}, dims::Dims{N}) where {T<:TT,N} = zeros(T, dims) end ## Duals @@ -59,34 +53,33 @@ function Base.promote_rule(::Type{N}, ::Type{D}) where {P,T,D<:Dual{P,T},N<:Numb return D{PP,T} end -Base.big(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{big(P),T} -Base.big(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual(big(primal(d)), tracer(d)) +Base.big(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{big(P),T} Base.widen(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{widen(P),T} -Base.widen(d::D) where {P,T,D<:Dual{P,T}} = Dual(widen(primal(d)), tracer(d)) +Base.big(d::D) where {P,T,D<:Dual{P,T}} = Dual(big(primal(d)), tracer(d)) +Base.widen(d::D) where {P,T,D<:Dual{P,T}} = Dual(widen(primal(d)), tracer(d)) -Base.convert(::Type{D}, x::Number) where {P,T,D<:Dual{P,T}} = Dual(x, empty(T)) -Base.convert(::Type{D}, d::D) where {D<:Dual} = d -function Base.convert(::Type{T}, d::D) where {T<:Number,D<:Dual} - return Dual(convert(T, primal(d)), tracer(d)) -end +Base.convert(::Type{D}, x::Number) where {P,T,D<:Dual{P,T}} = Dual(x, empty(T)) +Base.convert(::Type{D}, d::D) where {D<:Dual} = d +Base.convert(::Type{T}, d::D) where {T<:Number,D<:Dual} = Dual(convert(T, primal(d)), tracer(d)) ## Constants -Base.zero(::Type{D}) where {P,T,D<:Dual{P,T}} = D(zero(P), empty(T)) -Base.zero(d::D) where {P,T,D<:Dual{P,T}} = D(zero(primal(d)), empty(T)) -Base.one(::Type{D}) where {P,T,D<:Dual{P,T}} = D(one(P), empty(T)) -Base.one(d::D) where {P,T,D<:Dual{P,T}} = D(one(primal(d)), empty(T)) -Base.typemin(::Type{D}) where {P,T,D<:Dual{P,T}} = D(typemin(P), empty(T)) -Base.typemin(d::D) where {P,T,D<:Dual{P,T}} = D(typemin(primal(d)), empty(T)) -Base.typemax(::Type{D}) where {P,T,D<:Dual{P,T}} = D(typemax(P), empty(T)) -Base.typemax(d::D) where {P,T,D<:Dual{P,T}} = D(typemax(primal(d)), empty(T)) -Base.eps(::Type{D}) where {P,T,D<:Dual{P,T}} = D(eps(P), empty(T)) -Base.eps(d::D) where {P,T,D<:Dual{P,T}} = D(eps(primal(d)), empty(T)) -Base.floatmin(::Type{D}) where {P,T,D<:Dual{P,T}} = D(floatmin(P), empty(T)) -Base.floatmin(d::D) where {P,T,D<:Dual{P,T}} = D(floatmin(primal(d)), empty(T)) -Base.floatmax(::Type{D}) where {P,T,D<:Dual{P,T}} = D(floatmax(P), empty(T)) -Base.floatmax(d::D) where {P,T,D<:Dual{P,T}} = D(floatmax(primal(d)), empty(T)) +Base.zero(::Type{D}) where {P,T,D<:Dual{P,T}} = D(zero(P), empty(T)) +Base.one(::Type{D}) where {P,T,D<:Dual{P,T}} = D(one(P), empty(T)) +Base.typemin(::Type{D}) where {P,T,D<:Dual{P,T}} = D(typemin(P), empty(T)) +Base.typemax(::Type{D}) where {P,T,D<:Dual{P,T}} = D(typemax(P), empty(T)) +Base.eps(::Type{D}) where {P,T,D<:Dual{P,T}} = D(eps(P), empty(T)) +Base.floatmin(::Type{D}) where {P,T,D<:Dual{P,T}} = D(floatmin(P), empty(T)) +Base.floatmax(::Type{D}) where {P,T,D<:Dual{P,T}} = D(floatmax(P), empty(T)) Base.maxintfloat(::Type{D}) where {P,T,D<:Dual{P,T}} = D(maxintfloat(P), empty(T)) -Base.maxintfloat(d::D) where {P,T,D<:Dual{P,T}} = D(maxintfloat(primal(d)), empty(T)) + +Base.zero(d::D) where {P,T,D<:Dual{P,T}} = D(zero(primal(d)), empty(T)) +Base.one(d::D) where {P,T,D<:Dual{P,T}} = D(one(primal(d)), empty(T)) +Base.typemin(d::D) where {P,T,D<:Dual{P,T}} = D(typemin(primal(d)), empty(T)) +Base.typemax(d::D) where {P,T,D<:Dual{P,T}} = D(typemax(primal(d)), empty(T)) +Base.eps(d::D) where {P,T,D<:Dual{P,T}} = D(eps(primal(d)), empty(T)) +Base.floatmin(d::D) where {P,T,D<:Dual{P,T}} = D(floatmin(primal(d)), empty(T)) +Base.floatmax(d::D) where {P,T,D<:Dual{P,T}} = D(floatmax(primal(d)), empty(T)) +Base.maxintfloat(d::D) where {P,T,D<:Dual{P,T}} = D(maxintfloat(primal(d)), empty(T)) ## Array constructors function Base.similar(a::Array{D,1}) where {P,T,D<:Dual{P,T}} @@ -117,3 +110,5 @@ function Base.similar(a::Array, ::Type{D}, dims::Dims{N}) where {P,T,D<:Dual{P,T p_out = similar(primal.(a), P, dims) return Dual.(p_out, empty(T)) end + +#! format: on From 3426f9ceb7786d6f3baf335173ae350fd682e4a3 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 16:36:15 +0200 Subject: [PATCH 27/47] Fixes --- src/SparseConnectivityTracer.jl | 2 +- src/{overload_comparisons.jl => overload_dual.jl} | 3 ++- src/overload_gradient.jl | 6 +++--- test/first_order.jl | 4 ++-- 4 files changed, 8 insertions(+), 7 deletions(-) rename src/{overload_comparisons.jl => overload_dual.jl} (88%) diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 46fbd61d..b4bc31a5 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -16,7 +16,7 @@ include("operators.jl") include("overload_connectivity.jl") include("overload_gradient.jl") include("overload_hessian.jl") -include("overload_comparisons.jl") +include("overload_dual.jl") include("pattern.jl") include("adtypes.jl") diff --git a/src/overload_comparisons.jl b/src/overload_dual.jl similarity index 88% rename from src/overload_comparisons.jl rename to src/overload_dual.jl index 5df719d5..113a234b 100644 --- a/src/overload_comparisons.jl +++ b/src/overload_dual.jl @@ -1,4 +1,4 @@ -# Overload comparisons on Dual numbers +# Special overloads for Dual numbers for fn in ( :iseven, :isfinite, @@ -11,6 +11,7 @@ for fn in ( :isone, :isreal, :iszero, + :real, ) @eval Base.$fn(d::D) where {D<:Dual} = $fn(primal(d)) end diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index 7ae42de6..baaf7507 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -89,7 +89,7 @@ for fn in ops_2_to_1 y = primal(dy) p_out = Base.$fn(x, y) t_out = gradient_tracer_2_to_1_one_tracer( - tracer(dx), is_firstder_arg2_zero_local($fn, x, y) + tracer(dy), is_firstder_arg2_zero_local($fn, x, y) ) return Dual(p_out, t_out) end @@ -126,10 +126,10 @@ for T in (Real, Integer, Rational, Irrational{:ℯ}) Base.:^(t::GradientTracer, ::T) = t Base.:^(::T, t::GradientTracer) = t - function Base.:^(dx::D, y::T) where {P,T<:GradientTracer,D<:Dual{P,T}} + function Base.:^(dx::D, y::T) where {P,GT<:GradientTracer,D<:Dual{P,GT}} return Dual(primal(dx)^y, tracer(dx)) end - function Base.:^(x::T, dy::D) where {P,T<:GradientTracer,D<:Dual{P,T}} + function Base.:^(x::T, dy::D) where {P,GT<:GradientTracer,D<:Dual{P,GT}} return Dual(x^primal(dy), tracer(dy)) end end diff --git a/test/first_order.jl b/test/first_order.jl index 4d6e1f09..a17e1e4a 100644 --- a/test/first_order.jl +++ b/test/first_order.jl @@ -83,7 +83,7 @@ end @test local_jacobian_pattern(x -> min(x[1], x[2]), [1.0, 1.0], G) ≈ [1 1;] # Linear algebra - @test local_jacobian_pattern(logdet, [1.0 0.0; 2.0 2.0], G) ≈ [1 1; 1 1] # (#68) - @test local_jacobian_pattern(x -> log(det(x)), [1.0 0.0; 2.0 2.0], G) ≈ [1 1; 1 1] + @test local_jacobian_pattern(logdet, [1.0 -1.0; 2.0 2.0], G) ≈ [1 1 1 1] # (#68) + @test local_jacobian_pattern(x -> log(det(x)), [1.0 -1.0; 2.0 2.0], G) ≈ [1 1 1 1] end end From 24004c207b3477ca7b81b332e7545c1d7062fbcc Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 16:36:35 +0200 Subject: [PATCH 28/47] Add `local_*_pattern` functions --- src/pattern.jl | 51 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/src/pattern.jl b/src/pattern.jl index 12e2632c..4f3ad688 100644 --- a/src/pattern.jl +++ b/src/pattern.jl @@ -137,15 +137,15 @@ The type of index set `S` can be specified as an optional argument and defaults ## Example ```jldoctest -julia> x = rand(3); +julia> x = [1.0, 2.0, 3.0]; -julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sign(x[3])]; +julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, max(x[2],x[3])]; julia> local_jacobian_pattern(f, x) 3×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: 1 ⋅ ⋅ 1 1 ⋅ - ⋅ ⋅ ⋅ + ⋅ ⋅ 1 ``` """ function local_jacobian_pattern(f, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {G} @@ -241,9 +241,52 @@ function hessian_pattern( return hessian_pattern_to_mat(to_array(xt), yt) end +""" + local_hessian_pattern(f, x) + local_hessian_pattern(f, x, T) + +Computes the local sparsity pattern of the Hessian of a scalar function `y = f(x)` at `x`. + +The type of index set `S` can be specified as an optional argument and defaults to `BitSet`. + +## Example + +```jldoctest +julia> x = [1.0 3.0 5.0 1.0 2.0]; + +julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + x[5]; + +julia> hessian_pattern(f, x) +5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: + ⋅ ⋅ ⋅ ⋅ ⋅ + ⋅ ⋅ 1 ⋅ ⋅ + ⋅ 1 ⋅ ⋅ ⋅ + ⋅ ⋅ ⋅ 1 ⋅ + ⋅ ⋅ ⋅ ⋅ ⋅ + +julia> g(x) = x[2] * max(x[1], x[5]); + +julia> local_hessian_pattern(f, x) +5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: + ⋅ ⋅ ⋅ ⋅ ⋅ + ⋅ ⋅ 1 ⋅ 1 + ⋅ 1 ⋅ ⋅ ⋅ + ⋅ ⋅ ⋅ 1 ⋅ + ⋅ 1 ⋅ ⋅ ⋅ + +``` +""" +function local_hessian_pattern( + f, x, ::Type{G}=DEFAULT_VECTOR_TYPE, ::Type{H}=DEFAULT_MATRIX_TYPE +) where {G,H} + D = Dual{eltype(x),HessianTracer{G,H}} + xt, yt = trace_function(D, f, x) + return hessian_pattern_to_mat(to_array(xt), yt) +end + function hessian_pattern_to_mat( xt::AbstractArray{T}, yt::T -) where {G,H<:AbstractSet,T<:HessianTracer{G,H}} +) where {P,G,H<:AbstractSet,HT<:HessianTracer{G,H},T<:Union{HT,Dual{P,HT}}} # Allocate Hessian matrix n = length(xt) I = Int[] # row indices From 8459bad162eb4388068389c35babbc945fc93cdb Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 16:52:01 +0200 Subject: [PATCH 29/47] Support duals on `ConnectivityTracer` --- src/overload_connectivity.jl | 46 ++++++++++++++++++++++++++---------- test/classification.jl | 1 - 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/src/overload_connectivity.jl b/src/overload_connectivity.jl index 8ed583f8..376169b3 100644 --- a/src/overload_connectivity.jl +++ b/src/overload_connectivity.jl @@ -1,26 +1,48 @@ -# TODO: support Duals - +## 1-to-1 for fn in ops_1_to_1 @eval Base.$fn(t::ConnectivityTracer) = t + @eval function Base.$fn(d::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} + return Dual($fn(primal(d)), tracer(d)) + end end -for fn in ops_1_to_2 - @eval Base.$fn(t::ConnectivityTracer) = (t, t) -end - +## 2-to-1 for fn in ops_2_to_1 - @eval Base.$fn(a::T, b::T) where {T<:ConnectivityTracer} = T(a.inputs ∪ b.inputs) + @eval Base.$fn(a::T, b::T) where {T<:ConnectivityTracer} = T(inputs(a) ∪ inputs(b)) + @eval function Base.$fn(da::D, db::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} + return Dual($fn(primal(da), primal(db)), $fn(tracer(da), tracer(db))) + end + @eval Base.$fn(t::ConnectivityTracer, ::Number) = t + @eval Base.$fn(dx::D, y::Number) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} = + Dual($fn(primal(dx), y), tracer(dx)) + @eval Base.$fn(::Number, t::ConnectivityTracer) = t + @eval Base.$fn(x::Number, dy::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} = + Dual($fn(x, primal(dy)), tracer(dy)) +end + +## 1-to-2 +for fn in ops_1_to_2 + @eval Base.$fn(t::ConnectivityTracer) = (t, t) + @eval function Base.$fn(d::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} + p1, p2 = $fn(primal(d)) + return (Dual(p1, tracer(d)), Dual(p2, tracer(d))) + end end # Extra types required for exponent -for T in (:Real, :Integer, :Rational) - @eval Base.:^(t::ConnectivityTracer, ::$T) = t - @eval Base.:^(::$T, t::ConnectivityTracer) = t +for T in (Real, Integer, Rational, Irrational{:ℯ}) + Base.:^(t::ConnectivityTracer, ::T) = t + function Base.:^(dx::D, y::Number) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} + return Dual(primal(dx)^y, tracer(dx)) + end + + Base.:^(::T, t::ConnectivityTracer) = t + function Base.:^(x::Number, dy::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} + return Dual(x^primal(dy), tracer(dy)) + end end -Base.:^(t::ConnectivityTracer, ::Irrational{:ℯ}) = t -Base.:^(::Irrational{:ℯ}, t::ConnectivityTracer) = t ## Rounding Base.round(t::ConnectivityTracer, ::RoundingMode; kwargs...) = t diff --git a/test/classification.jl b/test/classification.jl index 0c95b3bc..9ea8bd19 100644 --- a/test/classification.jl +++ b/test/classification.jl @@ -3,7 +3,6 @@ using SparseConnectivityTracer: ops_1_to_1_s, ops_1_to_1_f, ops_1_to_1_z, - ops_1_to_1_const, ops_2_to_1, ops_2_to_1_ssc, ops_2_to_1_ssz, From 839821179a679c866ff6a2981cdcd54c618efec7 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 17:26:39 +0200 Subject: [PATCH 30/47] Refactor `HessianTracer` overloads --- src/overload_hessian.jl | 323 +++++++++++++++++++--------------------- 1 file changed, 155 insertions(+), 168 deletions(-) diff --git a/src/overload_hessian.jl b/src/overload_hessian.jl index e15ee829..e70206a5 100644 --- a/src/overload_hessian.jl +++ b/src/overload_hessian.jl @@ -1,203 +1,190 @@ ## 1-to-1 -for fn in ops_1_to_1 - @eval function Base.$fn(t::T) where {T<:HessianTracer} - if is_seconder_zero_global($fn) - if is_firstder_zero_global($fn) - return empty(T) - else - return t - end +function hessian_tracer_1_to_1( + t::T, is_firstder_zero::Bool, is_seconder_zero::Bool +) where {T<:HessianTracer} + if is_seconder_zero + if is_firstder_zero + return empty(T) else - return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) + return t end + else + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) end - @eval function Base.$fn(t::D) where {P,T<:HessianTracer,D<:Dual{P,T}} - x = primal(t) - out = Base.$fn(x) - if is_seconder_zero_local($fn, x) - if is_firstder_zero_local($fn, x) - return Dual(out, empty(T)) - else - return Dual(out, t) - end - else - return Dual(out, T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t)))) - end +end + +for fn in ops_1_to_1 + @eval function Base.$fn(t::HessianTracer) + return hessian_tracer_1_to_1( + t, is_firstder_zero_global($fn), is_seconder_zero_global($fn) + ) + end + @eval function Base.$fn(d::D) where {P,T<:HessianTracer,D<:Dual{P,T}} + x = primal(d) + p_out = Base.$fn(x) + t_out = hessian_tracer_1_to_1( + tracer(d), is_firstder_zero_local($fn, x), is_seconder_zero_local($fn, x) + ) + return Dual(p_out, t_out) end end ## 2-to-1 -for fn in ops_2_to_1 - @eval function Base.$fn(a::T, b::T) where {G,H,T<:HessianTracer{G,H}} - grad = empty(G) - hess = empty(H) - if !is_firstder_arg1_zero_global($fn) - grad = union(grad, gradient(a)) # TODO: use union! - union!(hess, hessian(a)) - end - if !is_firstder_arg2_zero_global($fn) - grad = union(grad, gradient(b)) # TODO: use union! - union!(hess, hessian(b)) - end - if !is_seconder_arg1_zero_global($fn) - union!(hess, gradient(a) × gradient(a)) - end - if !is_seconder_arg2_zero_global($fn) - union!(hess, gradient(b) × gradient(b)) - end - if !is_crossder_zero_global($fn) - union!(hess, (gradient(a) × gradient(b)) ∪ (gradient(b) × gradient(a))) - end - return T(grad, hess) +function hessian_tracer_2_to_1( + a::T, + b::T, + is_firstder_arg1_zero::Bool, + is_seconder_arg1_zero::Bool, + is_firstder_arg2_zero::Bool, + is_seconder_arg2_zero::Bool, + is_crossder_zero::Bool, +) where {T<:HessianTracer} + grad = empty(G) + hess = empty(H) + if !is_firstder_arg1_zero + grad = union(grad, gradient(a)) # TODO: use union! + union!(hess, hessian(a)) end - @eval function Base.$fn(a::D, b::D) where {P,G,H,T<:HessianTracer{G,H},D<:Dual{P,T}} - x = primal(a) - y = primal(b) - out = Base.$fn(x, y) - - grad = empty(G) - hess = empty(H) - if !is_firstder_arg1_zero_local($fn, x, y) - grad = union(grad, gradient(a)) # TODO: use union! - union!(hess, hessian(a)) - end - if !is_firstder_arg2_zero_local($fn, x, y) - grad = union(grad, gradient(b)) # TODO: use union! - union!(hess, hessian(b)) - end - if !is_seconder_arg1_zero_local($fn, x, y) - union!(hess, gradient(a) × gradient(a)) - end - if !is_seconder_arg2_zero_local($fn, x, y) - union!(hess, gradient(b) × gradient(b)) - end - if !is_crossder_zero_local($fn, x, y) - union!(hess, (gradient(a) × gradient(b)) ∪ (gradient(b) × gradient(a))) - end - return Dual(out, T(grad, hess)) + if !is_firstder_arg2_zero + grad = union(grad, gradient(b)) # TODO: use union! + union!(hess, hessian(b)) end + if !is_seconder_arg1_zero + union!(hess, gradient(a) × gradient(a)) + end + if !is_seconder_arg2_zero + union!(hess, gradient(b) × gradient(b)) + end + if !is_crossder_zero + union!(hess, (gradient(a) × gradient(b)) ∪ (gradient(b) × gradient(a))) + end + return T(grad, hess) +end - @eval function Base.$fn(t::T, ::Number) where {G,H,T<:HessianTracer{G,H}} - if is_seconder_arg1_zero_global($fn) - if is_firstder_arg1_zero_global($fn) - return empty(T) - else - return t - end +function hessian_tracer_2_to_1_one_tracer( + t::T, is_firstder_zero::Bool, is_seconder_zero::Bool +) where {T<:GradientTracer} + # NOTE: this is identical to hessian_tracer_1_to_1 due to ignored second argument having empty set + # TODO: remove once gdalle agrees + if is_seconder_zero + if is_firstder_zero + return empty(T) else - return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) + return t end + else + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) end - @eval function Base.$fn( - t::D, y::Number - ) where {P,G,H,T<:HessianTracer{G,H},D<:Dual{P,T}} - x = primal(t) - out = Base.$fn(x, y) - if is_seconder_arg1_zero_local($fn, x, y) - if is_firstder_arg1_zero_local($fn, x, y) - return Dual(out, empty(T)) - else - return Dual(out, tracer(t)) - end - else - return Dual(out, T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t)))) - end +end + +for fn in ops_2_to_1 + @eval function Base.$fn(tx::T, ty::T) where {T<:HessianTracer} + return hessian_tracer_2_to_1( + tx, + ty, + is_firstder_arg1_zero_global($fn), + is_seconder_arg1_zero_global($fn), + is_firstder_arg2_zero_global($fn), + is_seconder_arg2_zero_global($fn), + is_crossder_zero_global($fn), + ) + end + @eval function Base.$fn(dx::D, dy::D) where {P,T<:HessianTracer,D<:Dual{P,T}} + x = primal(dx) + y = primal(dy) + p_out = Base.$fn(x, y) + t_out = hessian_tracer_2_to_1( + tracer(dx), + tracer(dy), + is_firstder_arg1_zero_local($fn, x, y), + is_seconder_arg1_zero_local($fn, x, y), + is_firstder_arg2_zero_local($fn, x, y), + is_seconder_arg2_zero_local($fn, x, y), + is_crossder_zero_local($fn, x, y), + ) + return Dual(p_out, t_out) end - @eval function Base.$fn(x::Number, t::T) where {G,H,T<:HessianTracer{G,H}} - if is_seconder_arg2_zero_global($fn) - if is_firstder_arg2_zero_global($fn) - return empty(T) - else - return t - end - else - return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) - end + @eval function Base.$fn(tx::HessianTracer, y::Number) + return hessian_tracer_2_to_1_one_tracer( + tx, is_firstder_arg1_zero_global($fn), is_seconder_arg1_zero_global($fn) + ) end - @eval function Base.$fn( - x::Number, t::D - ) where {P,G,H,T<:HessianTracer{G,H},D<:Dual{P,T}} - y = primal(t) - out = Base.$fn(x, y) - if is_seconder_arg2_zero_local($fn, x, y) - if is_firstder_arg2_zero_local($fn, x, y) - return Dual(out, empty(T)) - else - return Dual(out, tracer(t)) - end - else - return Dual(out, T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t)))) - end + @eval function Base.$fn(x::Number, ty::HessianTracer) + return hessian_tracer_2_to_1_one_tracer( + ty, is_firstder_arg2_zero_global($fn), is_seconder_arg2_zero_global($fn) + ) + end + + @eval function Base.$fn(dx::D, y::Number) where {P,T<:HessianTracer,D<:Dual{P,T}} + x = primal(dx) + p_out = Base.$fn(x, y) + t_out = hessian_tracer_2_to_1_one_tracer( + tracer(dx), + is_firstder_arg1_zero_local($fn, x, y), + is_seconder_arg1_zero_local($fn, x, y), + ) + return Dual(p_out, t_out) + end + @eval function Base.$fn(x::Number, dy::D) where {P,T<:HessianTracer,D<:Dual{P,T}} + y = primal(dy) + p_out = Base.$fn(x, y) + t_out = hessian_tracer_2_to_1_one_tracer( + tracer(dy), + is_firstder_arg2_zero_local($fn, x, y), + is_seconder_arg2_zero_local($fn, x, y), + ) + return Dual(p_out, t_out) end end ## 1-to-2 +function hessian_tracer_1_to_2( + t::T, + is_firstder_out1_zero::Bool, + is_seconder_out1_zero::Bool, + is_firstder_out2_zero::Bool, + is_seconder_out2_zero::Bool, +) where {T<:HessianTracer} + t1 = hessian_tracer_1_to_1(t, is_firstder_out1_zero, is_seconder_out1_zero) + t2 = hessian_tracer_1_to_1(t, is_firstder_out2_zero, is_seconder_out2_zero) + return (t1, t2) +end for fn in ops_1_to_2 - @eval function Base.$fn(t::T) where {T<:HessianTracer} - tracer1 = if is_seconder_out1_zero_global($fn) - if is_firstder_out1_zero_global($fn) - return empty(T) - else - return t - end - else - return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) - end - tracer2 = if is_seconder_out2_zero_global($fn) - if is_firstder_out2_zero_global($fn) - return empty(T) - else - return t - end - else - return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) - end - return (tracer1, tracer2) + @eval function Base.$fn(t::HessianTracer) + return hessian_tracer_1_to_2( + t, + is_firstder_out1_zero_global($fn), + is_seconder_out1_zero_global($fn), + is_firstder_out2_zero_global($fn), + is_seconder_out2_zero_global($fn), + ) end - @eval function Base.$fn(tx::D) where {P,T<:HessianTracer,D<:Dual{P,T}} - x = primal(tx) - out1, out2 = Base.$fn(x) - - tracer1 = if is_seconder_out1_zero_local($fn, x) - if is_firstder_out1_zero_local($fn, x) - return empty(T) - else - return tracer(t) - end - else - return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) - end - tracer2 = if is_seconder_out2_zero_local($fn, x) - if is_firstder_out2_zero_local($fn, x) - return empty(T) - else - return tracer(t) - end - else - return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) - end - return (Dual(out1, tracer1), Dual(out2, tracer2)) + @eval function Base.$fn(d::D) where {P,T<:HessianTracer,D<:Dual{P,T}} + x = primal(d) + p1_out, p2_out = Base.$fn(x) + t1_out, t2_out = hessian_tracer_1_to_2( + d, + is_firstder_out1_zero_local($fn, x), + is_seconder_out1_zero_local($fn, x), + is_firstder_out2_zero_local($fn, x), + is_seconder_out2_zero_local($fn, x), + ) + return (Dual(p1_out, t1_out), Dual(p2_out, t2_out)) end end # TODO: support Dual tracers for these. # Extra types required for exponent -for T in (:Real, :Integer, :Rational) - @eval function Base.:^(t::T, ::$T) where {T<:HessianTracer} - return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) +for T in (Real, Integer, Rational, Irrational{:ℯ}) + function Base.:^(t::HT, ::T) where {HT<:HessianTracer} + return HT(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) end - @eval function Base.:^(::$T, t::T) where {T<:HessianTracer} - return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) + function Base.:^(::T, t::HT) where {HT<:HessianTracer} + return HT(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) end end -function Base.:^(t::T, ::Irrational{:ℯ}) where {T<:HessianTracer} - return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) -end -function Base.:^(::Irrational{:ℯ}, t::T) where {T<:HessianTracer} - return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) -end ## Rounding Base.round(t::T, ::RoundingMode; kwargs...) where {T<:HessianTracer} = empty(T) From 52cf2e58da04e3cc7ca17f2f4c32974e001036a1 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 17:26:54 +0200 Subject: [PATCH 31/47] Minor changes to `GradientTracer` overloads --- src/overload_gradient.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index baaf7507..3a01915c 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -8,7 +8,7 @@ function gradient_tracer_1_to_1(t::T, is_firstder_zero::Bool) where {T<:Gradient end for fn in ops_1_to_1 - @eval function Base.$fn(t::T) where {T<:GradientTracer} + @eval function Base.$fn(t::GradientTracer) return gradient_tracer_1_to_1(t, is_firstder_zero_global($fn)) end @eval function Base.$fn(d::D) where {P,T<:GradientTracer,D<:Dual{P,T}} @@ -44,6 +44,8 @@ end function gradient_tracer_2_to_1_one_tracer( t::T, is_firstder_zero::Bool ) where {T<:GradientTracer} + # NOTE: this is identical to gradient_tracer_1_to_1 due to ignored second argument having empty set + # TODO: remove once gdalle agrees if is_firstder_zero return empty(T) else @@ -70,7 +72,7 @@ for fn in ops_2_to_1 return Dual(p_out, t_out) end - @eval function Base.$fn(tx::T, ::Number) where {T<:GradientTracer} + @eval function Base.$fn(tx::GradientTracer, ::Number) return gradient_tracer_2_to_1_one_tracer(tx, is_firstder_arg1_zero_global($fn)) end @eval function Base.$fn(dx::D, y::Number) where {P,T<:GradientTracer,D<:Dual{P,T}} @@ -82,7 +84,7 @@ for fn in ops_2_to_1 return Dual(p_out, t_out) end - @eval function Base.$fn(::Number, ty::T) where {T<:GradientTracer} + @eval function Base.$fn(::Number, ty::GradientTracer) return gradient_tracer_2_to_1_one_tracer(ty, is_firstder_arg2_zero_global($fn)) end @eval function Base.$fn(x::Number, dy::D) where {P,T<:GradientTracer,D<:Dual{P,T}} @@ -105,7 +107,7 @@ function gradient_tracer_1_to_2( end for fn in ops_1_to_2 - @eval function Base.$fn(t::T) where {T<:GradientTracer} + @eval function Base.$fn(t::GradientTracer) return gradient_tracer_1_to_2( t, is_firstder_out1_zero_global($fn), is_firstder_out2_zero_global($fn) ) From 103cf56469b964ee31c4f93abce52ec4b72ba8ef Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 17:38:30 +0200 Subject: [PATCH 32/47] Fixes --- src/conversion.jl | 6 +++--- src/overload_connectivity.jl | 10 +++++----- src/overload_gradient.jl | 10 +++++----- src/overload_hessian.jl | 14 +++++++------- src/pattern.jl | 11 ++++++----- 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/conversion.jl b/src/conversion.jl index 21616758..6d712e28 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -90,11 +90,11 @@ function Base.similar(a::Array{D,2}) where {P,T,D<:Dual{P,T}} p_out = similar(primal.(a)) return Dual.(p_out, empty(T)) end -function Base.similar(a::Array{A,1}, ::Type{D}) where {A,P,T,D<:Dual{P,T}} +function Base.similar(a::Array{A,1}, ::Type{D}) where {P,T,D<:Dual{P,T},A} p_out = similar(a, P) return Dual.(p_out, empty(T)) end -function Base.similar(a::Array{A,2}, ::Type{D}) where {A,P,T,D<:Dual{P,T}} +function Base.similar(a::Array{A,2}, ::Type{D}) where {P,T,D<:Dual{P,T},A} p_out = similar(a, P) return Dual.(p_out, empty(T)) end @@ -102,7 +102,7 @@ function Base.similar(a::Array{D}, m::Int) where {P,T,D<:Dual{P,T}} p_out = similar(primal.(a), m) return Dual.(p_out, empty(T)) end -function Base.similar(a::Array{D}, dims::Dims{N}) where {N,D<:Dual} +function Base.similar(a::Array{D}, dims::Dims{N}) where {P,T,D<:Dual{P,T}, N} p_out = similar(primal.(a), dims) return Dual.(p_out, empty(T)) end diff --git a/src/overload_connectivity.jl b/src/overload_connectivity.jl index 376169b3..78fdbdc3 100644 --- a/src/overload_connectivity.jl +++ b/src/overload_connectivity.jl @@ -32,14 +32,14 @@ for fn in ops_1_to_2 end # Extra types required for exponent -for T in (Real, Integer, Rational, Irrational{:ℯ}) - Base.:^(t::ConnectivityTracer, ::T) = t - function Base.:^(dx::D, y::Number) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} +for S in (Real, Integer, Rational, Irrational{:ℯ}) + Base.:^(t::ConnectivityTracer, ::S) = t + function Base.:^(dx::D, y::S) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} return Dual(primal(dx)^y, tracer(dx)) end - Base.:^(::T, t::ConnectivityTracer) = t - function Base.:^(x::Number, dy::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} + Base.:^(::S, t::ConnectivityTracer) = t + function Base.:^(x::S, dy::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} return Dual(x^primal(dy), tracer(dy)) end end diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index 3a01915c..183eee05 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -124,14 +124,14 @@ for fn in ops_1_to_2 end # Extra types required for exponent -for T in (Real, Integer, Rational, Irrational{:ℯ}) - Base.:^(t::GradientTracer, ::T) = t - Base.:^(::T, t::GradientTracer) = t +for S in (Real, Integer, Rational, Irrational{:ℯ}) + Base.:^(t::GradientTracer, ::S) = t + Base.:^(::S, t::GradientTracer) = t - function Base.:^(dx::D, y::T) where {P,GT<:GradientTracer,D<:Dual{P,GT}} + function Base.:^(dx::D, y::S) where {P,T<:GradientTracer,D<:Dual{P,T}} return Dual(primal(dx)^y, tracer(dx)) end - function Base.:^(x::T, dy::D) where {P,GT<:GradientTracer,D<:Dual{P,GT}} + function Base.:^(x::S, dy::D) where {P,T<:GradientTracer,D<:Dual{P,T}} return Dual(x^primal(dy), tracer(dy)) end end diff --git a/src/overload_hessian.jl b/src/overload_hessian.jl index e70206a5..bfbb29d6 100644 --- a/src/overload_hessian.jl +++ b/src/overload_hessian.jl @@ -38,7 +38,7 @@ function hessian_tracer_2_to_1( is_firstder_arg2_zero::Bool, is_seconder_arg2_zero::Bool, is_crossder_zero::Bool, -) where {T<:HessianTracer} +) where {G,H,T<:HessianTracer{G,H}} grad = empty(G) hess = empty(H) if !is_firstder_arg1_zero @@ -63,7 +63,7 @@ end function hessian_tracer_2_to_1_one_tracer( t::T, is_firstder_zero::Bool, is_seconder_zero::Bool -) where {T<:GradientTracer} +) where {T<:HessianTracer} # NOTE: this is identical to hessian_tracer_1_to_1 due to ignored second argument having empty set # TODO: remove once gdalle agrees if is_seconder_zero @@ -177,12 +177,12 @@ end # TODO: support Dual tracers for these. # Extra types required for exponent -for T in (Real, Integer, Rational, Irrational{:ℯ}) - function Base.:^(t::HT, ::T) where {HT<:HessianTracer} - return HT(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) +for S in (Real, Integer, Rational, Irrational{:ℯ}) + function Base.:^(t::T, ::S) where {T<:HessianTracer} + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) end - function Base.:^(::T, t::HT) where {HT<:HessianTracer} - return HT(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) + function Base.:^(::S, t::T) where {T<:HessianTracer} + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) end end diff --git a/src/pattern.jl b/src/pattern.jl index 4f3ad688..6c413696 100644 --- a/src/pattern.jl +++ b/src/pattern.jl @@ -182,14 +182,14 @@ function local_jacobian_pattern(f!, y, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where { end function jacobian_pattern_to_mat( - xt::AbstractArray{T}, yt::AbstractArray{<:Number} -) where {P,G<:GradientTracer,T<:Union{G,Dual{P,G}}} + xt::AbstractArray{TT}, yt::AbstractArray{<:Number} +) where {P,T<:GradientTracer,D<:Dual{P,T},TT<:Union{T,D}} n, m = length(xt), length(yt) I = Int[] # row indices J = Int[] # column indices V = Bool[] # values for (i, y) in enumerate(yt) - if y isa T + if y isa TT for j in gradient(y) push!(I, i) push!(J, j) @@ -285,8 +285,9 @@ function local_hessian_pattern( end function hessian_pattern_to_mat( - xt::AbstractArray{T}, yt::T -) where {P,G,H<:AbstractSet,HT<:HessianTracer{G,H},T<:Union{HT,Dual{P,HT}}} + xt::AbstractArray{TT}, yt::TT +) where {P,T<:HessianTracer,D<:Dual{P,T},TT<:Union{T,D}} + # Allocate Hessian matrix n = length(xt) I = Int[] # row indices From 9f0b516e731780757bbdfa69d2f5d6bcca1f5dde Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 17:48:17 +0200 Subject: [PATCH 33/47] omg --- src/operators.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/operators.jl b/src/operators.jl index 0a6c51cc..719c30a1 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -100,10 +100,10 @@ function is_crossder_zero_global end # Fallbacks for local derivatives: is_firstder_arg1_zero_local(f::F, x, y) where {F} = is_firstder_arg1_zero_global(f) -is_seconder_arg1_zero_local(f::F, x, y) where {F} = is_firstder_arg1_zero_global(f) +is_seconder_arg1_zero_local(f::F, x, y) where {F} = is_seconder_arg1_zero_global(f) is_firstder_arg2_zero_local(f::F, x, y) where {F} = is_firstder_arg1_zero_global(f) -is_seconder_arg2_zero_local(f::F, x, y) where {F} = is_firstder_arg1_zero_global(f) -is_crossder_zero_local(f::F, x, y) where {F} = is_firstder_arg1_zero_global(f) +is_seconder_arg2_zero_local(f::F, x, y) where {F} = is_seconder_arg1_zero_global(f) +is_crossder_zero_local(f::F, x, y) where {F} = is_crossder_zero_global(f) # ops_2_to_1_ssc: # ∂f/∂x != 0 From f72d31083649a214d8d955d716128d81bdecc617 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 17:49:21 +0200 Subject: [PATCH 34/47] omgv2 --- src/operators.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operators.jl b/src/operators.jl index 719c30a1..9507ba99 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -101,8 +101,8 @@ function is_crossder_zero_global end # Fallbacks for local derivatives: is_firstder_arg1_zero_local(f::F, x, y) where {F} = is_firstder_arg1_zero_global(f) is_seconder_arg1_zero_local(f::F, x, y) where {F} = is_seconder_arg1_zero_global(f) -is_firstder_arg2_zero_local(f::F, x, y) where {F} = is_firstder_arg1_zero_global(f) -is_seconder_arg2_zero_local(f::F, x, y) where {F} = is_seconder_arg1_zero_global(f) +is_firstder_arg2_zero_local(f::F, x, y) where {F} = is_firstder_arg2_zero_global(f) +is_seconder_arg2_zero_local(f::F, x, y) where {F} = is_seconder_arg2_zero_global(f) is_crossder_zero_local(f::F, x, y) where {F} = is_crossder_zero_global(f) # ops_2_to_1_ssc: From e92dcd4d669b45e2534678172ce82b44ea2f10bf Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 17:51:12 +0200 Subject: [PATCH 35/47] Fix docstring test --- src/pattern.jl | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/src/pattern.jl b/src/pattern.jl index 6c413696..b7b1fbad 100644 --- a/src/pattern.jl +++ b/src/pattern.jl @@ -142,7 +142,7 @@ julia> x = [1.0, 2.0, 3.0]; julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, max(x[2],x[3])]; julia> local_jacobian_pattern(f, x) -3×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: +3×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 4 stored entries: 1 ⋅ ⋅ 1 1 ⋅ ⋅ ⋅ 1 @@ -254,26 +254,15 @@ The type of index set `S` can be specified as an optional argument and defaults ```jldoctest julia> x = [1.0 3.0 5.0 1.0 2.0]; -julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + x[5]; - -julia> hessian_pattern(f, x) -5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: - ⋅ ⋅ ⋅ ⋅ ⋅ - ⋅ ⋅ 1 ⋅ ⋅ - ⋅ 1 ⋅ ⋅ ⋅ - ⋅ ⋅ ⋅ 1 ⋅ - ⋅ ⋅ ⋅ ⋅ ⋅ - -julia> g(x) = x[2] * max(x[1], x[5]); +julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + x[2] * max(x[1], x[5]); julia> local_hessian_pattern(f, x) -5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: +5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 5 stored entries: ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ ⋅ - ``` """ function local_hessian_pattern( From 7ce3d80df3bde49cc593461d8056d4a645553763 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 17:54:25 +0200 Subject: [PATCH 36/47] More tests --- src/pattern.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/pattern.jl b/src/pattern.jl index b7b1fbad..bcbfff84 100644 --- a/src/pattern.jl +++ b/src/pattern.jl @@ -263,6 +263,16 @@ julia> local_hessian_pattern(f, x) ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ ⋅ + +julia> x = [4.0 3.0 5.0 1.0 2.0]; + +julia> local_hessian_pattern(f, x) +5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 5 stored entries: + ⋅ 1 ⋅ ⋅ ⋅ + 1 ⋅ 1 ⋅ ⋅ + ⋅ 1 ⋅ ⋅ ⋅ + ⋅ ⋅ ⋅ 1 ⋅ + ⋅ ⋅ ⋅ ⋅ ⋅ ``` """ function local_hessian_pattern( From 68a2310cc20d21ff96f4da0189251ab708efd86a Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 22:25:32 +0200 Subject: [PATCH 37/47] Update tests --- test/first_order.jl | 14 ++ test/second_order.jl | 296 +++++++++++++++++++++++-------------------- 2 files changed, 176 insertions(+), 134 deletions(-) diff --git a/test/first_order.jl b/test/first_order.jl index a17e1e4a..9d4584df 100644 --- a/test/first_order.jl +++ b/test/first_order.jl @@ -82,6 +82,20 @@ end @test local_jacobian_pattern(x -> min(x[1], x[2]), [2.0, 1.0], G) ≈ [0 1;] @test local_jacobian_pattern(x -> min(x[1], x[2]), [1.0, 1.0], G) ≈ [1 1;] + # Comparisons + @test local_jacobian_pattern( + x -> x[1] > x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0], G + ) ≈ [0 0 0 1;] + @test local_jacobian_pattern( + x -> x[1] > x[2] ? x[3] : x[4], [2.0, 1.0, 3.0, 4.0], G + ) ≈ [0 0 1 0;] + @test local_jacobian_pattern( + x -> x[1] < x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0], G + ) ≈ [0 0 1 0;] + @test local_jacobian_pattern( + x -> x[1] < x[2] ? x[3] : x[4], [2.0, 1.0, 3.0, 4.0], G + ) ≈ [0 0 0 1;] + # Linear algebra @test local_jacobian_pattern(logdet, [1.0 -1.0; 2.0 2.0], G) ≈ [1 1 1 1] # (#68) @test local_jacobian_pattern(x -> log(det(x)), [1.0 -1.0; 2.0 2.0], G) ≈ [1 1 1 1] diff --git a/test/second_order.jl b/test/second_order.jl index 79f2e4c1..2d73518c 100644 --- a/test/second_order.jl +++ b/test/second_order.jl @@ -3,140 +3,168 @@ using SparseConnectivityTracer: HessianTracer, tracer, trace_input, empty using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using Test -@testset "Default hessian_pattern" begin - h = hessian_pattern(x -> x[1] / x[2] + x[3] / 1 + 1 / x[4], rand(4)) - @test h ≈ [ - 0 1 0 0 - 1 1 0 0 - 0 0 0 0 - 0 0 0 1 - ] -end - -@testset "Set type $G" for G in ( +const SECOND_ORDER_SET_TYPES = ( BitSet, Set{UInt64}, DuplicateVector{UInt64}, RecursiveSet{UInt64}, SortedVector{UInt64} ) - I = eltype(G) - H = Set{Tuple{I,I}} - HT = HessianTracer{G,H} - - @test hessian_pattern(identity, rand(), G, H) ≈ [0;;] - @test hessian_pattern(sqrt, rand(), G, H) ≈ [1;;] - - @test hessian_pattern(x -> 1 * x, rand(), G, H) ≈ [0;;] - @test hessian_pattern(x -> x * 1, rand(), G, H) ≈ [0;;] - - # Code coverage - @test hessian_pattern(typemax, 1, G, H) ≈ [0;;] - @test hessian_pattern(x -> x^(2im), 1, G, H) ≈ [1;;] - @test hessian_pattern(x -> (2im)^x, 1, G, H) ≈ [1;;] - @test hessian_pattern(x -> x^(2//3), 1, G, H) ≈ [1;;] - @test hessian_pattern(x -> (2//3)^x, 1, G, H) ≈ [1;;] - @test hessian_pattern(x -> x^ℯ, 1, G, H) ≈ [1;;] - @test hessian_pattern(x -> ℯ^x, 1, G, H) ≈ [1;;] - @test hessian_pattern(x -> round(x, RoundNearestTiesUp), 1, G, H) ≈ [0;;] - - h = hessian_pattern(x -> x[1] / x[2] + x[3] / 1 + 1 / x[4], rand(4), G, H) - @test h ≈ [ - 0 1 0 0 - 1 1 0 0 - 0 0 0 0 - 0 0 0 1 - ] - - h = hessian_pattern(x -> x[1] * x[2] + x[3] * 1 + 1 * x[4], rand(4), G, H) - @test h ≈ [ - 0 1 0 0 - 1 0 0 0 - 0 0 0 0 - 0 0 0 0 - ] - - h = hessian_pattern(x -> (x[1] * x[2]) * (x[3] * x[4]), rand(4), G, H) - @test h ≈ [ - 0 1 1 1 - 1 0 1 1 - 1 1 0 1 - 1 1 1 0 - ] - - h = hessian_pattern(x -> (x[1] + x[2]) * (x[3] + x[4]), rand(4), G, H) - @test h ≈ [ - 0 0 1 1 - 0 0 1 1 - 1 1 0 0 - 1 1 0 0 - ] - - h = hessian_pattern(x -> (x[1] + x[2] + x[3] + x[4])^2, rand(4), G, H) - @test h ≈ [ - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - ] - - h = hessian_pattern(x -> 1 / (x[1] + x[2] + x[3] + x[4]), rand(4), G, H) - @test h ≈ [ - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - ] - - h = hessian_pattern(x -> (x[1] - x[2]) + (x[3] - 1) + (1 - x[4]), rand(4), G, H) - @test h ≈ [ - 0 0 0 0 - 0 0 0 0 - 0 0 0 0 - 0 0 0 0 - ] - - h = hessian_pattern(x -> copysign(x[1] * x[2], x[3] * x[4]), rand(4), G, H) - @test h ≈ [ - 0 1 0 0 - 1 0 0 0 - 0 0 0 0 - 0 0 0 0 - ] - - h = hessian_pattern(x -> div(x[1] * x[2], x[3] * x[4]), rand(4), G, H) - @test h ≈ [ - 0 0 0 0 - 0 0 0 0 - 0 0 0 0 - 0 0 0 0 - ] - - h = hessian_pattern(x -> sum(sincosd(x)), 1.0, G, H) - @test h ≈ [1;;] - - h = hessian_pattern(x -> sum(diff(x) .^ 3), rand(4), G, H) - @test h ≈ [ - 1 1 0 0 - 1 1 1 0 - 0 1 1 1 - 0 0 1 1 - ] - - x = rand(5) - foo(x) = x[1] + x[2] * x[3] + 1 / x[4] + 1 * x[5] - h = hessian_pattern(foo, x, G, H) - @test h ≈ [ - 0 0 0 0 0 - 0 0 1 0 0 - 0 1 0 0 0 - 0 0 0 1 0 - 0 0 0 0 0 - ] - - bar(x) = foo(x) + x[2]^x[5] - h = hessian_pattern(bar, x, G, H) - @test h ≈ [ - 0 0 0 0 0 - 0 1 1 0 1 - 0 1 0 0 0 - 0 0 0 1 0 - 0 1 0 0 1 - ] + +@testset "Global" begin + @testset "Default hessian_pattern" begin + h = hessian_pattern(x -> x[1] / x[2] + x[3] / 1 + 1 / x[4], rand(4)) + @test h ≈ [ + 0 1 0 0 + 1 1 0 0 + 0 0 0 0 + 0 0 0 1 + ] + end + + @testset "Set type $G" for G in SECOND_ORDER_SET_TYPES + I = eltype(G) + H = Set{Tuple{I,I}} + HT = HessianTracer{G,H} + + @test hessian_pattern(identity, rand(), G, H) ≈ [0;;] + @test hessian_pattern(sqrt, rand(), G, H) ≈ [1;;] + + @test hessian_pattern(x -> 1 * x, rand(), G, H) ≈ [0;;] + @test hessian_pattern(x -> x * 1, rand(), G, H) ≈ [0;;] + + # Code coverage + @test hessian_pattern(typemax, 1, G, H) ≈ [0;;] + @test hessian_pattern(x -> x^(2im), 1, G, H) ≈ [1;;] + @test hessian_pattern(x -> (2im)^x, 1, G, H) ≈ [1;;] + @test hessian_pattern(x -> x^(2//3), 1, G, H) ≈ [1;;] + @test hessian_pattern(x -> (2//3)^x, 1, G, H) ≈ [1;;] + @test hessian_pattern(x -> x^ℯ, 1, G, H) ≈ [1;;] + @test hessian_pattern(x -> ℯ^x, 1, G, H) ≈ [1;;] + @test hessian_pattern(x -> round(x, RoundNearestTiesUp), 1, G, H) ≈ [0;;] + + h = hessian_pattern(x -> x[1] / x[2] + x[3] / 1 + 1 / x[4], rand(4), G, H) + @test h ≈ [ + 0 1 0 0 + 1 1 0 0 + 0 0 0 0 + 0 0 0 1 + ] + + h = hessian_pattern(x -> x[1] * x[2] + x[3] * 1 + 1 * x[4], rand(4), G, H) + @test h ≈ [ + 0 1 0 0 + 1 0 0 0 + 0 0 0 0 + 0 0 0 0 + ] + + h = hessian_pattern(x -> (x[1] * x[2]) * (x[3] * x[4]), rand(4), G, H) + @test h ≈ [ + 0 1 1 1 + 1 0 1 1 + 1 1 0 1 + 1 1 1 0 + ] + + h = hessian_pattern(x -> (x[1] + x[2]) * (x[3] + x[4]), rand(4), G, H) + @test h ≈ [ + 0 0 1 1 + 0 0 1 1 + 1 1 0 0 + 1 1 0 0 + ] + + h = hessian_pattern(x -> (x[1] + x[2] + x[3] + x[4])^2, rand(4), G, H) + @test h ≈ [ + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + ] + + h = hessian_pattern(x -> 1 / (x[1] + x[2] + x[3] + x[4]), rand(4), G, H) + @test h ≈ [ + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + ] + + h = hessian_pattern(x -> (x[1] - x[2]) + (x[3] - 1) + (1 - x[4]), rand(4), G, H) + @test h ≈ [ + 0 0 0 0 + 0 0 0 0 + 0 0 0 0 + 0 0 0 0 + ] + + h = hessian_pattern(x -> copysign(x[1] * x[2], x[3] * x[4]), rand(4), G, H) + @test h ≈ [ + 0 1 0 0 + 1 0 0 0 + 0 0 0 0 + 0 0 0 0 + ] + + h = hessian_pattern(x -> div(x[1] * x[2], x[3] * x[4]), rand(4), G, H) + @test h ≈ [ + 0 0 0 0 + 0 0 0 0 + 0 0 0 0 + 0 0 0 0 + ] + + h = hessian_pattern(x -> sum(sincosd(x)), 1.0, G, H) + @test h ≈ [1;;] + + h = hessian_pattern(x -> sum(diff(x) .^ 3), rand(4), G, H) + @test h ≈ [ + 1 1 0 0 + 1 1 1 0 + 0 1 1 1 + 0 0 1 1 + ] + + x = rand(5) + foo(x) = x[1] + x[2] * x[3] + 1 / x[4] + 1 * x[5] + h = hessian_pattern(foo, x, G, H) + @test h ≈ [ + 0 0 0 0 0 + 0 0 1 0 0 + 0 1 0 0 0 + 0 0 0 1 0 + 0 0 0 0 0 + ] + + bar(x) = foo(x) + x[2]^x[5] + h = hessian_pattern(bar, x, G, H) + @test h ≈ [ + 0 0 0 0 0 + 0 1 1 0 1 + 0 1 0 0 0 + 0 0 0 1 0 + 0 1 0 0 1 + ] + end +end + +@testset "Local" begin + @testset "Set type $G" for G in SECOND_ORDER_SET_TYPES + x = f1(x) = x[1] + x[2] * x[3] + 1 / x[4] + x[2] * max(x[1], x[5]) + + h = local_hessian_pattern(f1, [1.0 3.0 5.0 1.0 2.0], G) + @test h ≈ [ + 0 0 0 0 0 + 0 0 1 0 1 + 0 1 0 0 0 + 0 0 0 1 0 + 0 1 0 0 0 + ] + + h = local_hessian_pattern(f1, [4.0 3.0 5.0 1.0 2.0], G) + @test h ≈ [ + 0 1 0 0 0 + 1 0 1 0 0 + 0 1 0 0 0 + 0 0 0 1 0 + 0 0 0 0 0 + ] + end end From d7b9402443011f76b8f9af015c47d670a162f6a9 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 22:38:18 +0200 Subject: [PATCH 38/47] Fixes --- src/overload_gradient.jl | 10 +++++----- src/pattern.jl | 27 +++++++++++++++++++-------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index 183eee05..7a2e8762 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -23,17 +23,17 @@ end function gradient_tracer_2_to_1( tx::T, ty::T, - is_firstder_arg1_zero_or_number::Bool, - is_firstder_arg2_zero_or_number::Bool, + is_firstder_arg1_zero::Bool, + is_firstder_arg2_zero::Bool, ) where {T<:GradientTracer} - if is_firstder_arg1_zero_or_number - if is_firstder_arg2_zero_or_number + if is_firstder_arg1_zero + if is_firstder_arg2_zero return empty(T) else return ty end else # ∂f∂x ≠ 0 - if is_firstder_arg2_zero_or_number + if is_firstder_arg2_zero return tx else return T(gradient(tx) ∪ gradient(ty)) diff --git a/src/pattern.jl b/src/pattern.jl index bcbfff84..0b6c4e4b 100644 --- a/src/pattern.jl +++ b/src/pattern.jl @@ -182,14 +182,14 @@ function local_jacobian_pattern(f!, y, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where { end function jacobian_pattern_to_mat( - xt::AbstractArray{TT}, yt::AbstractArray{<:Number} -) where {P,T<:GradientTracer,D<:Dual{P,T},TT<:Union{T,D}} + xt::AbstractArray{T}, yt::AbstractArray{<:Number} +) where {T<:GradientTracer} n, m = length(xt), length(yt) I = Int[] # row indices J = Int[] # column indices V = Bool[] # values for (i, y) in enumerate(yt) - if y isa TT + if y isa T for j in gradient(y) push!(I, i) push!(J, j) @@ -200,6 +200,15 @@ function jacobian_pattern_to_mat( return sparse(I, J, V, m, n) end +_tracer_or_number(x::Number) = x +_tracer_or_number(d::Dual) = tracer(d) + +function jacobian_pattern_to_mat( + xt::AbstractArray{D}, yt::AbstractArray{<:Number} +) where {P,T<:GradientTracer,D<:Dual{P,T}} + return jacobian_pattern_to_mat(tracer.(xt), _tracer_or_number.(yt)) +end + """ hessian_pattern(f, x) hessian_pattern(f, x, T) @@ -283,11 +292,7 @@ function local_hessian_pattern( return hessian_pattern_to_mat(to_array(xt), yt) end -function hessian_pattern_to_mat( - xt::AbstractArray{TT}, yt::TT -) where {P,T<:HessianTracer,D<:Dual{P,T},TT<:Union{T,D}} - - # Allocate Hessian matrix +function hessian_pattern_to_mat(xt::AbstractArray{T}, yt::T) where {T<:HessianTracer} n = length(xt) I = Int[] # row indices J = Int[] # column indices @@ -301,3 +306,9 @@ function hessian_pattern_to_mat( h = sparse(I, J, V, n, n) return h end + +function hessian_pattern_to_mat( + xt::AbstractArray{D}, yt::D +) where {P,T<:HessianTracer,D<:Dual{P,T}} + return hessian_pattern_to_mat(tracer.(xt), tracer(yt)) +end From c5edf82b3db958f4972099f8b1b069de22a56c82 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 22:41:21 +0200 Subject: [PATCH 39/47] Remove `2_to_1_one_tracer` functions --- src/overload_gradient.jl | 29 +++++------------------------ src/overload_hessian.jl | 24 ++++-------------------- 2 files changed, 9 insertions(+), 44 deletions(-) diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index 7a2e8762..ff859910 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -21,10 +21,7 @@ end ## 2-to-1 function gradient_tracer_2_to_1( - tx::T, - ty::T, - is_firstder_arg1_zero::Bool, - is_firstder_arg2_zero::Bool, + tx::T, ty::T, is_firstder_arg1_zero::Bool, is_firstder_arg2_zero::Bool ) where {T<:GradientTracer} if is_firstder_arg1_zero if is_firstder_arg2_zero @@ -41,18 +38,6 @@ function gradient_tracer_2_to_1( end end -function gradient_tracer_2_to_1_one_tracer( - t::T, is_firstder_zero::Bool -) where {T<:GradientTracer} - # NOTE: this is identical to gradient_tracer_1_to_1 due to ignored second argument having empty set - # TODO: remove once gdalle agrees - if is_firstder_zero - return empty(T) - else - return t - end -end - for fn in ops_2_to_1 @eval function Base.$fn(tx::T, ty::T) where {T<:GradientTracer} return gradient_tracer_2_to_1( @@ -73,26 +58,22 @@ for fn in ops_2_to_1 end @eval function Base.$fn(tx::GradientTracer, ::Number) - return gradient_tracer_2_to_1_one_tracer(tx, is_firstder_arg1_zero_global($fn)) + return gradient_tracer_1_to_1(tx, is_firstder_arg1_zero_global($fn)) end @eval function Base.$fn(dx::D, y::Number) where {P,T<:GradientTracer,D<:Dual{P,T}} x = primal(dx) p_out = Base.$fn(x, y) - t_out = gradient_tracer_2_to_1_one_tracer( - tracer(dx), is_firstder_arg1_zero_local($fn, x, y) - ) + t_out = gradient_tracer_1_to_1(tracer(dx), is_firstder_arg1_zero_local($fn, x, y)) return Dual(p_out, t_out) end @eval function Base.$fn(::Number, ty::GradientTracer) - return gradient_tracer_2_to_1_one_tracer(ty, is_firstder_arg2_zero_global($fn)) + return gradient_tracer_1_to_1(ty, is_firstder_arg2_zero_global($fn)) end @eval function Base.$fn(x::Number, dy::D) where {P,T<:GradientTracer,D<:Dual{P,T}} y = primal(dy) p_out = Base.$fn(x, y) - t_out = gradient_tracer_2_to_1_one_tracer( - tracer(dy), is_firstder_arg2_zero_local($fn, x, y) - ) + t_out = gradient_tracer_1_to_1(tracer(dy), is_firstder_arg2_zero_local($fn, x, y)) return Dual(p_out, t_out) end end diff --git a/src/overload_hessian.jl b/src/overload_hessian.jl index bfbb29d6..94825af8 100644 --- a/src/overload_hessian.jl +++ b/src/overload_hessian.jl @@ -61,22 +61,6 @@ function hessian_tracer_2_to_1( return T(grad, hess) end -function hessian_tracer_2_to_1_one_tracer( - t::T, is_firstder_zero::Bool, is_seconder_zero::Bool -) where {T<:HessianTracer} - # NOTE: this is identical to hessian_tracer_1_to_1 due to ignored second argument having empty set - # TODO: remove once gdalle agrees - if is_seconder_zero - if is_firstder_zero - return empty(T) - else - return t - end - else - return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) - end -end - for fn in ops_2_to_1 @eval function Base.$fn(tx::T, ty::T) where {T<:HessianTracer} return hessian_tracer_2_to_1( @@ -106,12 +90,12 @@ for fn in ops_2_to_1 end @eval function Base.$fn(tx::HessianTracer, y::Number) - return hessian_tracer_2_to_1_one_tracer( + return hessian_tracer_1_to_1( tx, is_firstder_arg1_zero_global($fn), is_seconder_arg1_zero_global($fn) ) end @eval function Base.$fn(x::Number, ty::HessianTracer) - return hessian_tracer_2_to_1_one_tracer( + return hessian_tracer_1_to_1( ty, is_firstder_arg2_zero_global($fn), is_seconder_arg2_zero_global($fn) ) end @@ -119,7 +103,7 @@ for fn in ops_2_to_1 @eval function Base.$fn(dx::D, y::Number) where {P,T<:HessianTracer,D<:Dual{P,T}} x = primal(dx) p_out = Base.$fn(x, y) - t_out = hessian_tracer_2_to_1_one_tracer( + t_out = hessian_tracer_1_to_1( tracer(dx), is_firstder_arg1_zero_local($fn, x, y), is_seconder_arg1_zero_local($fn, x, y), @@ -129,7 +113,7 @@ for fn in ops_2_to_1 @eval function Base.$fn(x::Number, dy::D) where {P,T<:HessianTracer,D<:Dual{P,T}} y = primal(dy) p_out = Base.$fn(x, y) - t_out = hessian_tracer_2_to_1_one_tracer( + t_out = hessian_tracer_1_to_1( tracer(dy), is_firstder_arg2_zero_local($fn, x, y), is_seconder_arg2_zero_local($fn, x, y), From 01b6435387ea21582aa8da9a330af8c0ca54ccd9 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 22:42:12 +0200 Subject: [PATCH 40/47] Remove random number TODO --- src/overload_gradient.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index ff859910..3de5f526 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -126,5 +126,4 @@ function Base.round( end ## Random numbers -# TODO: support random numbers on Duals rand(::AbstractRNG, ::SamplerType{T}) where {T<:GradientTracer} = empty(T) From 609120e7b256b5456cdcbaac7a7dcffac87e0b07 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 22:49:51 +0200 Subject: [PATCH 41/47] Fix `hessian_tracer_1_to_1` according to review --- src/overload_hessian.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/overload_hessian.jl b/src/overload_hessian.jl index 94825af8..298542f5 100644 --- a/src/overload_hessian.jl +++ b/src/overload_hessian.jl @@ -1,7 +1,7 @@ ## 1-to-1 function hessian_tracer_1_to_1( t::T, is_firstder_zero::Bool, is_seconder_zero::Bool -) where {T<:HessianTracer} +) where {G,H,T<:HessianTracer{G,H}} if is_seconder_zero if is_firstder_zero return empty(T) @@ -9,7 +9,11 @@ function hessian_tracer_1_to_1( return t end else - return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) + if is_firstder_zero + return T(empty(G), gradient(t) × gradient(t)) + else + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) + end end end From f0217f4f575b1176c67fa46c45d11d181b475db8 Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 17 May 2024 14:01:19 +0200 Subject: [PATCH 42/47] Fix `similar` --- src/conversion.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/conversion.jl b/src/conversion.jl index 6d712e28..28b24383 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -20,6 +20,7 @@ for TT in (GradientTracer, ConnectivityTracer, HessianTracer) Base.typemin(::Type{T}) where {T<:TT} = empty(T) Base.typemax(::Type{T}) where {T<:TT} = empty(T) Base.eps(::Type{T}) where {T<:TT} = empty(T) + Base.float(::Type{T}) where {T<:TT} = empty(T) Base.floatmin(::Type{T}) where {T<:TT} = empty(T) Base.floatmax(::Type{T}) where {T<:TT} = empty(T) Base.maxintfloat(::Type{T}) where {T<:TT} = empty(T) @@ -29,6 +30,7 @@ for TT in (GradientTracer, ConnectivityTracer, HessianTracer) Base.typemin(::T) where {T<:TT} = empty(T) Base.typemax(::T) where {T<:TT} = empty(T) Base.eps(::T) where {T<:TT} = empty(T) + Base.float(::T) where {T<:TT} = empty(T) Base.floatmin(::T) where {T<:TT} = empty(T) Base.floatmax(::T) where {T<:TT} = empty(T) Base.maxintfloat(::T) where {T<:TT} = empty(T) @@ -40,9 +42,13 @@ for TT in (GradientTracer, ConnectivityTracer, HessianTracer) Base.similar(a::Array{A,2}, ::Type{T}) where {T<:TT,A} = zeros(T, size(a, 1), size(a, 2)) Base.similar(::Array{T}, m::Int) where {T<:TT} = zeros(T, m) Base.similar(::Array{T}, dims::Dims{N}) where {T<:TT,N} = zeros(T, dims) - Base.similar(::Array, ::Type{T}, dims::Dims{N}) where {T<:TT,N} = zeros(T, dims) end +Base.similar(::Array, ::Type{ConnectivityTracer{C}}, dims::Dims{N}) where {C,N} = zeros(T, dims) +Base.similar(::Array, ::Type{GradientTracer{G}}, dims::Dims{N}) where {G,N} = zeros(T, dims) +Base.similar(::Array, ::Type{HessianTracer{G,H}}, dims::Dims{N}) where {G,H,N} = zeros(T, dims) + + ## Duals function Base.promote_rule(::Type{D}, ::Type{N}) where {P,T,D<:Dual{P,T},N<:Number} PP = Base.promote_rule(P, N) # TODO: possible method call error? @@ -68,6 +74,7 @@ Base.one(::Type{D}) where {P,T,D<:Dual{P,T}} = D(one(P), empty(T Base.typemin(::Type{D}) where {P,T,D<:Dual{P,T}} = D(typemin(P), empty(T)) Base.typemax(::Type{D}) where {P,T,D<:Dual{P,T}} = D(typemax(P), empty(T)) Base.eps(::Type{D}) where {P,T,D<:Dual{P,T}} = D(eps(P), empty(T)) +Base.float(::Type{D}) where {P,T,D<:Dual{P,T}} = D(float(P), empty(T)) Base.floatmin(::Type{D}) where {P,T,D<:Dual{P,T}} = D(floatmin(P), empty(T)) Base.floatmax(::Type{D}) where {P,T,D<:Dual{P,T}} = D(floatmax(P), empty(T)) Base.maxintfloat(::Type{D}) where {P,T,D<:Dual{P,T}} = D(maxintfloat(P), empty(T)) @@ -77,6 +84,7 @@ Base.one(d::D) where {P,T,D<:Dual{P,T}} = D(one(primal(d)), empt Base.typemin(d::D) where {P,T,D<:Dual{P,T}} = D(typemin(primal(d)), empty(T)) Base.typemax(d::D) where {P,T,D<:Dual{P,T}} = D(typemax(primal(d)), empty(T)) Base.eps(d::D) where {P,T,D<:Dual{P,T}} = D(eps(primal(d)), empty(T)) +Base.float(d::D) where {P,T,D<:Dual{P,T}} = D(float(primal(d)), empty(T)) Base.floatmin(d::D) where {P,T,D<:Dual{P,T}} = D(floatmin(primal(d)), empty(T)) Base.floatmax(d::D) where {P,T,D<:Dual{P,T}} = D(floatmax(primal(d)), empty(T)) Base.maxintfloat(d::D) where {P,T,D<:Dual{P,T}} = D(maxintfloat(primal(d)), empty(T)) @@ -106,7 +114,7 @@ function Base.similar(a::Array{D}, dims::Dims{N}) where {P,T,D<:Dual{P,T}, N} p_out = similar(primal.(a), dims) return Dual.(p_out, empty(T)) end -function Base.similar(a::Array, ::Type{D}, dims::Dims{N}) where {P,T,D<:Dual{P,T},N} +function Base.similar(a::Array, ::Type{Dual{P,T}}, dims::Dims{N}) where {P,T,N} p_out = similar(primal.(a), P, dims) return Dual.(p_out, empty(T)) end From db6929fa9652bb2f3595a503378e14425d28e830 Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 17 May 2024 14:01:28 +0200 Subject: [PATCH 43/47] Update docs --- docs/src/api.md | 31 +++++++++++++++++++++++++++++++ src/adtypes.jl | 2 +- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 54e89b65..0a244e54 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -9,16 +9,40 @@ CollapsedDocStrings = true ``` ## Interface + +### Global sparsity + +The following functions can be used to compute global sparsity patterns of `f(x)` over the entire input domain `x`. + ```@docs connectivity_pattern jacobian_pattern hessian_pattern ``` + +Alternatively, [ADTypes.jl](https://github.com/SciML/ADTypes.jl)'s interface can be used: ```@docs TracerSparsityDetector ``` +### Local sparsity + +The following functions can be used to compute local sparsity patterns of `f(x)` at a specific input `x`. +Note that these patterns are sparser than global patterns but need to be recomputed when `x` changes. + +```@docs +local_connectivity_pattern +local_jacobian_pattern +local_hessian_pattern +``` + +Note that [ADTypes.jl](https://github.com/SciML/ADTypes.jl) doesn't provide an interface for local sparsity detection. + ## Internals + +!!! warning + Internals may change without warning in a future release of SparseConnectivityTracer. + SparseConnectivityTracer works by pushing `Number` types called tracers through generic functions. Currently, three tracer types are provided: @@ -28,6 +52,13 @@ SparseConnectivityTracer.GradientTracer SparseConnectivityTracer.HessianTracer ``` +These can be used alone or inside of the dual number type [`Dual`](@ref), +which keeps track of the primal computation and allows tracing through comparisons and control flow: + +```@docs +SparseConnectivityTracer.Dual +``` + We also define alternative pseudo-set types that can deliver faster `union`: ```@docs diff --git a/src/adtypes.jl b/src/adtypes.jl index 4f1a8e83..fa9c6f0d 100644 --- a/src/adtypes.jl +++ b/src/adtypes.jl @@ -1,7 +1,7 @@ """ TracerSparsityDetector <: ADTypes.AbstractSparsityDetector -Singleton struct for integration with the sparsity detection framework of ADTypes.jl. +Singleton struct for integration with the sparsity detection framework of [ADTypes.jl](https://github.com/SciML/ADTypes.jl). # Example From 3257e48dff65cc3bd114f591b3ab60635d3ebb65 Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 17 May 2024 14:22:22 +0200 Subject: [PATCH 44/47] Add `local_connectivity_pattern` --- src/SparseConnectivityTracer.jl | 2 +- src/pattern.jl | 115 ++++++++++++++++++++++++++------ src/tracers.jl | 21 +++--- 3 files changed, 108 insertions(+), 30 deletions(-) diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index b4bc31a5..4c7ea1c0 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -20,7 +20,7 @@ include("overload_dual.jl") include("pattern.jl") include("adtypes.jl") -export connectivity_pattern +export connectivity_pattern, local_connectivity_pattern export jacobian_pattern, local_jacobian_pattern export hessian_pattern, local_hessian_pattern diff --git a/src/pattern.jl b/src/pattern.jl index 0b6c4e4b..9638c4b5 100644 --- a/src/pattern.jl +++ b/src/pattern.jl @@ -1,7 +1,10 @@ const DEFAULT_VECTOR_TYPE = BitSet const DEFAULT_MATRIX_TYPE = Set{Tuple{Int,Int}} -## Enumerate inputs +#==================# +# Enumerate inputs # +#==================# + """ trace_input(T, x) trace_input(T, x) @@ -20,7 +23,10 @@ function trace_input(::Type{T}, xs::AbstractArray, i) where {T<:AbstractTracer} return create_tracer.(T, xs, indices) end -## Trace function +#=========================# +# Trace through functions # +#=========================# + function trace_function(::Type{T}, f, x) where {T<:AbstractTracer} xt = trace_input(T, x) yt = f(xt) @@ -37,7 +43,14 @@ end to_array(x::Number) = [x] to_array(x::AbstractArray) = x -## Construct sparsity pattern matrix +# Utilities +_tracer_or_number(x::Number) = x +_tracer_or_number(d::Dual) = tracer(d) + +#====================# +# ConnectivityTracer # +#====================# + """ connectivity_pattern(f, x) connectivity_pattern(f, x, T) @@ -80,6 +93,59 @@ function connectivity_pattern(f!, y, x, ::Type{C}=DEFAULT_VECTOR_TYPE) where {C} return connectivity_pattern_to_mat(to_array(xt), to_array(yt)) end +""" + local_connectivity_pattern(f, x) + local_connectivity_pattern(f, x, T) + +Enumerates inputs `x` and primal outputs `y = f(x)` and returns sparse matrix `C` of size `(m, n)` +where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. + +Unlike [`connectivity_pattern`](@ref), this function supports control flow and comparisons. + +The type of index set `S` can be specified as an optional argument and defaults to `BitSet`. + +## Example + +```jldoctest +julia> f(x) = ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]); + +julia> x = [1 2 3 4]; + +julia> local_connectivity_pattern(f, x) +1×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 2 stored entries: + 1 1 ⋅ ⋅ + +julia> x = [1 3 2 4]; + +julia> local_connectivity_pattern(f, x) +1×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 2 stored entries: + ⋅ ⋅ 1 1 +``` +""" +function local_connectivity_pattern(f, x, ::Type{C}=DEFAULT_VECTOR_TYPE) where {C} + D = Dual{eltype(x),ConnectivityTracer{C}} + xt, yt = trace_function(D, f, x) + return connectivity_pattern_to_mat(to_array(xt), to_array(yt)) +end + +""" + local_connectivity_pattern(f!, y, x) + local_connectivity_pattern(f!, y, x, T) + +Enumerates inputs `x` and primal outputs `y` after `f!(y, x)` and returns sparse matrix `C` of size `(m, n)` +where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. + +Unlike [`connectivity_pattern`](@ref), this function supports control flow and comparisons. + + +The type of index set `S` can be specified as an optional argument and defaults to `BitSet`. +""" +function local_connectivity_pattern(f!, y, x, ::Type{C}=DEFAULT_VECTOR_TYPE) where {C} + D = Dual{eltype(x),ConnectivityTracer{C}} + xt, yt = trace_function(D, f!, y, x) + return connectivity_pattern_to_mat(to_array(xt), to_array(yt)) +end + function connectivity_pattern_to_mat( xt::AbstractArray{T}, yt::AbstractArray{<:Number} ) where {T<:ConnectivityTracer} @@ -99,6 +165,16 @@ function connectivity_pattern_to_mat( return sparse(I, J, V, m, n) end +function connectivity_pattern_to_mat( + xt::AbstractArray{D}, yt::AbstractArray{<:Number} +) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} + return connectivity_pattern_to_mat(tracer.(xt), _tracer_or_number.(yt)) +end + +#================# +# GradientTracer # +#================# + """ jacobian_pattern(f, x) jacobian_pattern(f, x, T) @@ -126,6 +202,19 @@ function jacobian_pattern(f, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {G} return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) end +""" + jacobian_pattern(f!, y, x) + jacobian_pattern(f!, y, x, T) + +Compute the sparsity pattern of the Jacobian of `f!(y, x)`. + +The type of index set `S` can be specified as an optional argument and defaults to `BitSet`. +""" +function jacobian_pattern(f!, y, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {G} + xt, yt = trace_function(GradientTracer{G}, f!, y, x) + return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) +end + """ local_jacobian_pattern(f, x) local_jacobian_pattern(f, x, T) @@ -154,19 +243,6 @@ function local_jacobian_pattern(f, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {G} return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) end -""" - jacobian_pattern(f!, y, x) - jacobian_pattern(f!, y, x, T) - -Compute the sparsity pattern of the Jacobian of `f!(y, x)`. - -The type of index set `S` can be specified as an optional argument and defaults to `BitSet`. -""" -function jacobian_pattern(f!, y, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {G} - xt, yt = trace_function(GradientTracer{G}, f!, y, x) - return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) -end - """ local_jacobian_pattern(f!, y, x) local_jacobian_pattern(f!, y, x, T) @@ -200,15 +276,16 @@ function jacobian_pattern_to_mat( return sparse(I, J, V, m, n) end -_tracer_or_number(x::Number) = x -_tracer_or_number(d::Dual) = tracer(d) - function jacobian_pattern_to_mat( xt::AbstractArray{D}, yt::AbstractArray{<:Number} ) where {P,T<:GradientTracer,D<:Dual{P,T}} return jacobian_pattern_to_mat(tracer.(xt), _tracer_or_number.(yt)) end +#===============# +# HessianTracer # +#===============# + """ hessian_pattern(f, x) hessian_pattern(f, x, T) diff --git a/src/tracers.jl b/src/tracers.jl index 6ba9062c..503b94f6 100644 --- a/src/tracers.jl +++ b/src/tracers.jl @@ -12,9 +12,9 @@ sparse_vector(T, index) = T([index]) ×(a::G, b::G) where {G<:AbstractSet} = Set((i, j) for i in a, j in b) -#==============# -# Connectivity # -#==============# +#====================# +# ConnectivityTracer # +#====================# """ $(TYPEDEF) @@ -65,9 +65,9 @@ end ConnectivityTracer{C}(t::ConnectivityTracer{C}) where {C<:AbstractSet{<:Integer}} = t ConnectivityTracer(t::ConnectivityTracer) = t -#=================# -# Gradient Tracer # -#=================# +#================# +# GradientTracer # +#================# """ $(TYPEDEF) @@ -114,9 +114,9 @@ end GradientTracer{G}(t::GradientTracer{G}) where {G<:AbstractSet{<:Integer}} = t GradientTracer(t::GradientTracer) = t -#=========# -# Hessian # -#=========# +#===============# +# HessianTracer # +#===============# """ $(TYPEDEF) @@ -196,7 +196,8 @@ Dual number type keeping track of the results of a primal computation as well as ## Fields $(TYPEDFIELDS) """ -struct Dual{P<:Number,T<:Union{GradientTracer,HessianTracer}} <: AbstractTracer +struct Dual{P<:Number,T<:Union{ConnectivityTracer,GradientTracer,HessianTracer}} <: + AbstractTracer primal::P tracer::T end From 82f05148e6dafb4a0c26926a7e6348f652df3853 Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 17 May 2024 14:28:42 +0200 Subject: [PATCH 45/47] Fix similar v2 --- src/conversion.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/conversion.jl b/src/conversion.jl index 28b24383..4ef98db4 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -44,9 +44,9 @@ for TT in (GradientTracer, ConnectivityTracer, HessianTracer) Base.similar(::Array{T}, dims::Dims{N}) where {T<:TT,N} = zeros(T, dims) end -Base.similar(::Array, ::Type{ConnectivityTracer{C}}, dims::Dims{N}) where {C,N} = zeros(T, dims) -Base.similar(::Array, ::Type{GradientTracer{G}}, dims::Dims{N}) where {G,N} = zeros(T, dims) -Base.similar(::Array, ::Type{HessianTracer{G,H}}, dims::Dims{N}) where {G,H,N} = zeros(T, dims) +Base.similar(::Array, ::Type{ConnectivityTracer{C}}, dims::Dims{N}) where {C,N} = zeros(ConnectivityTracer{C}, dims) +Base.similar(::Array, ::Type{GradientTracer{G}}, dims::Dims{N}) where {G,N} = zeros(GradientTracer{G}, dims) +Base.similar(::Array, ::Type{HessianTracer{G,H}}, dims::Dims{N}) where {G,H,N} = zeros(HessianTracer{G,H}, dims) ## Duals From 923c5f0096a286434402304319687bc9afda7aa6 Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 17 May 2024 14:30:11 +0200 Subject: [PATCH 46/47] Fix docs --- docs/src/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 0a244e54..5383a82f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -52,7 +52,7 @@ SparseConnectivityTracer.GradientTracer SparseConnectivityTracer.HessianTracer ``` -These can be used alone or inside of the dual number type [`Dual`](@ref), +These can be used alone or inside of the dual number type `Dual`, which keeps track of the primal computation and allows tracing through comparisons and control flow: ```@docs From 5c14422100cbb69718072b1efdf80325c3f26b31 Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 17 May 2024 14:58:15 +0200 Subject: [PATCH 47/47] Add `MissingPrimalError` --- src/SparseConnectivityTracer.jl | 1 + src/exceptions.jl | 23 +++++++++++++++++++++++ src/overload_dual.jl | 7 +++++++ test/first_order.jl | 18 +++++++++++++++++- test/second_order.jl | 25 +++++++++++++++++++++---- 5 files changed, 69 insertions(+), 5 deletions(-) create mode 100644 src/exceptions.jl diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 4c7ea1c0..8cf0e2b9 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -11,6 +11,7 @@ include("settypes/recursiveset.jl") include("settypes/sortedvector.jl") include("tracers.jl") +include("exceptions.jl") include("conversion.jl") include("operators.jl") include("overload_connectivity.jl") diff --git a/src/exceptions.jl b/src/exceptions.jl new file mode 100644 index 00000000..378e94f2 --- /dev/null +++ b/src/exceptions.jl @@ -0,0 +1,23 @@ +struct MissingPrimalError <: Exception + fn::Function + tracer::AbstractTracer +end + +function Base.showerror(io::IO, e::MissingPrimalError) + println(io, "Function ", e.fn, " requires primal value(s).") + print( + io, + "A dual-number tracer for local sparsity detection can be used via `", + str_local_pattern_fn(e.tracer), + "`.", + ) + return nothing +end + +str_pattern_fn(::ConnectivityTracer) = "connectivity_pattern" +str_pattern_fn(::GradientTracer) = "jacobian_pattern" +str_pattern_fn(::HessianTracer) = "hessian_pattern" + +str_local_pattern_fn(::ConnectivityTracer) = "local_connectivity_pattern" +str_local_pattern_fn(::GradientTracer) = "local_jacobian_pattern" +str_local_pattern_fn(::HessianTracer) = "local_hessian_pattern" diff --git a/src/overload_dual.jl b/src/overload_dual.jl index 113a234b..f442c302 100644 --- a/src/overload_dual.jl +++ b/src/overload_dual.jl @@ -1,3 +1,4 @@ + # Special overloads for Dual numbers for fn in ( :iseven, @@ -14,8 +15,14 @@ for fn in ( :real, ) @eval Base.$fn(d::D) where {D<:Dual} = $fn(primal(d)) + @eval function Base.$fn(t::T) where {T<:AbstractTracer} + throw(MissingPrimalError($fn, t)) + end end for fn in (:isequal, :isapprox, :isless, :(==), :(<), :(>), :(<=), :(>=)) @eval Base.$fn(dx::D, dy::D) where {D<:Dual} = $fn(primal(dx), primal(dy)) + @eval function Base.$fn(t1::T, t2::T) where {T<:AbstractTracer} + throw(MissingPrimalError($fn, t1)) + end end diff --git a/test/first_order.jl b/test/first_order.jl index 9d4584df..82445354 100644 --- a/test/first_order.jl +++ b/test/first_order.jl @@ -1,6 +1,6 @@ using SparseConnectivityTracer using SparseConnectivityTracer: - ConnectivityTracer, GradientTracer, tracer, trace_input, empty + ConnectivityTracer, GradientTracer, MissingPrimalError, tracer, trace_input, empty using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using LinearAlgebra: det, logdet @@ -99,5 +99,21 @@ end # Linear algebra @test local_jacobian_pattern(logdet, [1.0 -1.0; 2.0 2.0], G) ≈ [1 1 1 1] # (#68) @test local_jacobian_pattern(x -> log(det(x)), [1.0 -1.0; 2.0 2.0], G) ≈ [1 1 1 1] + + ## ConnectivityTracer + @test local_connectivity_pattern( + x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 2 3 4], G + ) ≈ [1 1 0 0] + @test local_connectivity_pattern( + x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 3 2 4], G + ) ≈ [0 0 1 1] + + ## Error handling when applying non-dual tracers to "local" functions with control flow + @test_throws MissingPrimalError connectivity_pattern( + x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 2 3 4], G + ) ≈ [1 1 0 0] + @test_throws MissingPrimalError jacobian_pattern( + x -> x[1] > x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0], G + ) ≈ [0 0 0 1;] end end diff --git a/test/second_order.jl b/test/second_order.jl index 2d73518c..ba8ab903 100644 --- a/test/second_order.jl +++ b/test/second_order.jl @@ -1,5 +1,6 @@ using SparseConnectivityTracer -using SparseConnectivityTracer: HessianTracer, tracer, trace_input, empty +using SparseConnectivityTracer: + HessianTracer, MissingPrimalError, tracer, trace_input, empty using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using Test @@ -147,8 +148,7 @@ end @testset "Local" begin @testset "Set type $G" for G in SECOND_ORDER_SET_TYPES - x = f1(x) = x[1] + x[2] * x[3] + 1 / x[4] + x[2] * max(x[1], x[5]) - + f1(x) = x[1] + x[2] * x[3] + 1 / x[4] + x[2] * max(x[1], x[5]) h = local_hessian_pattern(f1, [1.0 3.0 5.0 1.0 2.0], G) @test h ≈ [ 0 0 0 0 0 @@ -157,7 +157,6 @@ end 0 0 0 1 0 0 1 0 0 0 ] - h = local_hessian_pattern(f1, [4.0 3.0 5.0 1.0 2.0], G) @test h ≈ [ 0 1 0 0 0 @@ -166,5 +165,23 @@ end 0 0 0 1 0 0 0 0 0 0 ] + + f2(x) = ifelse(x[2] < x[3], x[1] * x[2], x[3] * x[4]) + h = local_hessian_pattern(f2, [1 2 3 4], G) + @test h ≈ [ + 0 1 0 0 + 1 0 0 0 + 0 0 0 0 + 0 0 0 0 + ] + h = local_hessian_pattern(f2, [1 3 2 4], G) + @test h ≈ [ + 0 0 0 0 + 0 0 0 0 + 0 0 0 1 + 0 0 1 0 + ] + ## Error handling when applying non-dual tracers to "local" functions with control flow + @test_throws MissingPrimalError hessian_pattern(f2, [1 3 2 4], G) end end