From 0dd1bbde83366eab41e1b4c42b7e4fa56eb583e8 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Thu, 9 May 2024 22:31:44 +0200 Subject: [PATCH] Refactor tracer types (#53) Update tracers to parametric signatures * `ConnectivityTracer{I,S}` * `JacobianTracer{I,S}` * `HessianTracer{I,S,D}` where `I` is the index type, `S` is a (pseudo-)set type and `D` is an `AbstractDict{I,S}`. --- src/SparseConnectivityTracer.jl | 1 + src/conversion.jl | 11 +- src/pattern.jl | 20 ++- src/settypes/base_sets.jl | 44 ++++++ src/settypes/duplicatevector.jl | 36 +++-- src/settypes/recursiveset.jl | 36 +++-- src/settypes/sortedvector.jl | 36 +++-- src/tracers.jl | 133 +++++++----------- test/first_order.jl | 5 +- .../show/ConnectivityTracer_BitSet.txt | 2 +- ...ectivityTracer_DuplicateVector{UInt64}.txt | 2 +- ...onnectivityTracer_RecursiveSet{UInt64}.txt | 2 +- .../show/ConnectivityTracer_Set{UInt64}.txt | 2 +- ...onnectivityTracer_SortedVector{UInt64}.txt | 2 +- test/references/show/HessianTracer_BitSet.txt | 2 +- .../HessianTracer_DuplicateVector{UInt64}.txt | 2 +- .../HessianTracer_RecursiveSet{UInt64}.txt | 2 +- .../show/HessianTracer_Set{UInt64}.txt | 2 +- .../HessianTracer_SortedVector{UInt64}.txt | 2 +- .../references/show/JacobianTracer_BitSet.txt | 2 +- ...JacobianTracer_DuplicateVector{UInt64}.txt | 2 +- .../JacobianTracer_RecursiveSet{UInt64}.txt | 2 +- .../show/JacobianTracer_Set{UInt64}.txt | 2 +- .../JacobianTracer_SortedVector{UInt64}.txt | 2 +- test/second_order.jl | 3 +- 25 files changed, 217 insertions(+), 138 deletions(-) create mode 100644 src/settypes/base_sets.jl diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 9673c0c0..8ab92231 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -13,6 +13,7 @@ include("overload_hessian.jl") include("pattern.jl") include("adtypes.jl") +include("settypes/base_sets.jl") include("settypes/duplicatevector.jl") include("settypes/recursiveset.jl") include("settypes/sortedvector.jl") diff --git a/src/conversion.jl b/src/conversion.jl index c93022d9..0df697dd 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -24,7 +24,14 @@ for TT in (:JacobianTracer, :ConnectivityTracer, :HessianTracer) @eval Base.similar(a::Array{A,2}, ::Type{T}) where {A,T<:$TT} = zeros(T, size(a, 1), size(a, 2)) @eval Base.similar(::Array{T}, m::Int) where {T<:$TT} = zeros(T, m) @eval Base.similar(::Array{T}, dims::Dims{N}) where {N,T<:$TT} = zeros(T, dims) +end - @eval Base.similar(::Array, ::Type{$TT{S}}, dims::Dims{N}) where {N,S} = - zeros($TT{S}, dims) +function Base.similar(::Array, ::Type{ConnectivityTracer{I,S}}, dims::Dims{N}) where {I,S,N} + return zeros(ConnectivityTracer{I,S}, dims) +end +function Base.similar(::Array, ::Type{JacobianTracer{I,S}}, dims::Dims{N}) where {I,S,N} + return zeros(JacobianTracer{I,S}, dims) +end +function Base.similar(::Array, ::Type{HessianTracer{I,S,D}}, dims::Dims{N}) where {I,S,D,N} + return zeros(HessianTracer{I,S,D}, dims) end diff --git a/src/pattern.jl b/src/pattern.jl index 863e401d..40e8ac5a 100644 --- a/src/pattern.jl +++ b/src/pattern.jl @@ -58,7 +58,8 @@ julia> connectivity_pattern(f, x) ``` """ function connectivity_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S} - xt, yt = trace_function(ConnectivityTracer{S}, f, x) + I = eltype(S) + xt, yt = trace_function(ConnectivityTracer{I,S}, f, x) return connectivity_pattern_to_mat(to_array(xt), to_array(yt)) end @@ -72,7 +73,8 @@ where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to The type of index set `S` can be specified as an optional argument and defaults to `BitSet`. """ function connectivity_pattern(f!, y, x, ::Type{S}=DEFAULT_SET_TYPE) where {S} - xt, yt = trace_function(ConnectivityTracer{S}, f!, y, x) + I = eltype(S) + xt, yt = trace_function(ConnectivityTracer{I,S}, f!, y, x) return connectivity_pattern_to_mat(to_array(xt), to_array(yt)) end @@ -118,7 +120,8 @@ julia> jacobian_pattern(f, x) ``` """ function jacobian_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S} - xt, yt = trace_function(JacobianTracer{S}, f, x) + I = eltype(S) + xt, yt = trace_function(JacobianTracer{I,S}, f, x) return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) end @@ -131,7 +134,8 @@ Compute the sparsity pattern of the Jacobian of `f!(y, x)`. The type of index set `S` can be specified as an optional argument and defaults to `BitSet`. """ function jacobian_pattern(f!, y, x, ::Type{S}=DEFAULT_SET_TYPE) where {S} - xt, yt = trace_function(JacobianTracer{S}, f!, y, x) + I = eltype(S) + xt, yt = trace_function(JacobianTracer{I,S}, f!, y, x) return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) end @@ -189,13 +193,15 @@ julia> hessian_pattern(g, x) ``` """ function hessian_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S} - xt, yt = trace_function(HessianTracer{S}, f, x) + I = eltype(S) + T = HessianTracer{I,S,Dict{I,S}} + xt, yt = trace_function(T, f, x) return hessian_pattern_to_mat(to_array(xt), yt) end function hessian_pattern_to_mat( - xt::AbstractArray{HessianTracer{S,T}}, yt::HessianTracer{S,T} -) where {S,T} + xt::AbstractArray{HessianTracer{TI,S,D}}, yt::HessianTracer{TI,S,D} +) where {TI,S,D} # Allocate Hessian matrix n = length(xt) I = UInt64[] # row indices diff --git a/src/settypes/base_sets.jl b/src/settypes/base_sets.jl new file mode 100644 index 00000000..4f2ae786 --- /dev/null +++ b/src/settypes/base_sets.jl @@ -0,0 +1,44 @@ +function keys2set(::Type{S}, d::Dict{I}) where {I<:Integer,S<:AbstractSet{<:I}} + return S(keys(d)) +end + +# Performance can be gained by not re-allocating empty tracers +## BitSet +const EMPTY_CONNECTIVITY_TRACER_BITSET = ConnectivityTracer(BitSet()) +const EMPTY_JACOBIAN_TRACER_BITSET = JacobianTracer(BitSet()) +const EMPTY_HESSIAN_TRACER_BITSET = HessianTracer(Dict{Int,BitSet}()) + +empty(::Type{ConnectivityTracer{Int,BitSet}}) = EMPTY_CONNECTIVITY_TRACER_BITSET +empty(::Type{JacobianTracer{Int,BitSet}}) = EMPTY_JACOBIAN_TRACER_BITSET +empty(::Type{HessianTracer{Int,BitSet,Dict{Int,BitSet}}}) = EMPTY_HESSIAN_TRACER_BITSET + +## Set +const EMPTY_CONNECTIVITY_TRACER_SET_U8 = ConnectivityTracer(Set{UInt8}()) +const EMPTY_CONNECTIVITY_TRACER_SET_U16 = ConnectivityTracer(Set{UInt16}()) +const EMPTY_CONNECTIVITY_TRACER_SET_U32 = ConnectivityTracer(Set{UInt32}()) +const EMPTY_CONNECTIVITY_TRACER_SET_U64 = ConnectivityTracer(Set{UInt64}()) + +const EMPTY_JACOBIAN_TRACER_SET_U8 = JacobianTracer(Set{UInt8}()) +const EMPTY_JACOBIAN_TRACER_SET_U16 = JacobianTracer(Set{UInt16}()) +const EMPTY_JACOBIAN_TRACER_SET_U32 = JacobianTracer(Set{UInt32}()) +const EMPTY_JACOBIAN_TRACER_SET_U64 = JacobianTracer(Set{UInt64}()) + +const EMPTY_HESSIAN_TRACER_SET_U8 = HessianTracer(Dict{UInt8,Set{UInt8}}()) +const EMPTY_HESSIAN_TRACER_SET_U16 = HessianTracer(Dict{UInt16,Set{UInt16}}()) +const EMPTY_HESSIAN_TRACER_SET_U32 = HessianTracer(Dict{UInt32,Set{UInt32}}()) +const EMPTY_HESSIAN_TRACER_SET_U64 = HessianTracer(Dict{UInt64,Set{UInt64}}()) + +empty(::Type{ConnectivityTracer{UInt8,Set{UInt8}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U8 +empty(::Type{ConnectivityTracer{UInt16,Set{UInt16}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U16 +empty(::Type{ConnectivityTracer{UInt32,Set{UInt32}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U32 +empty(::Type{ConnectivityTracer{UInt64,Set{UInt64}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U64 + +empty(::Type{JacobianTracer{UInt8,Set{UInt8}}}) = EMPTY_JACOBIAN_TRACER_SET_U8 +empty(::Type{JacobianTracer{UInt16,Set{UInt16}}}) = EMPTY_JACOBIAN_TRACER_SET_U16 +empty(::Type{JacobianTracer{UInt32,Set{UInt32}}}) = EMPTY_JACOBIAN_TRACER_SET_U32 +empty(::Type{JacobianTracer{UInt64,Set{UInt64}}}) = EMPTY_JACOBIAN_TRACER_SET_U64 + +empty(::Type{HessianTracer{UInt8,Set{UInt8},Dict{UInt8,Set{UInt8}}}}) = EMPTY_HESSIAN_TRACER_SET_U8 +empty(::Type{HessianTracer{UInt16,Set{UInt16},Dict{UInt16,Set{UInt16}}}}) = EMPTY_HESSIAN_TRACER_SET_U16 +empty(::Type{HessianTracer{UInt32,Set{UInt32},Dict{UInt32,Set{UInt32}}}}) = EMPTY_HESSIAN_TRACER_SET_U32 +empty(::Type{HessianTracer{UInt64,Set{UInt64},Dict{UInt64,Set{UInt64}}}}) = EMPTY_HESSIAN_TRACER_SET_U64 diff --git a/src/settypes/duplicatevector.jl b/src/settypes/duplicatevector.jl index 7203d318..5c266c2e 100644 --- a/src/settypes/duplicatevector.jl +++ b/src/settypes/duplicatevector.jl @@ -37,20 +37,38 @@ const EMPTY_HESSIAN_TRACER_DV_U16 = HessianTracer(Dict{UInt16,DuplicateVector{UI const EMPTY_HESSIAN_TRACER_DV_U32 = HessianTracer(Dict{UInt32,DuplicateVector{UInt32}}()) const EMPTY_HESSIAN_TRACER_DV_U64 = HessianTracer(Dict{UInt64,DuplicateVector{UInt64}}()) -function empty(::Type{ConnectivityTracer{DuplicateVector{UInt16}}}) +function empty(::Type{ConnectivityTracer{UInt16,DuplicateVector{UInt16}}}) return EMPTY_CONNECTIVITY_TRACER_DV_U16 end -function empty(::Type{ConnectivityTracer{DuplicateVector{UInt32}}}) +function empty(::Type{ConnectivityTracer{UInt32,DuplicateVector{UInt32}}}) return EMPTY_CONNECTIVITY_TRACER_DV_U32 end -function empty(::Type{ConnectivityTracer{DuplicateVector{UInt64}}}) +function empty(::Type{ConnectivityTracer{UInt64,DuplicateVector{UInt64}}}) return EMPTY_CONNECTIVITY_TRACER_DV_U64 end -empty(::Type{JacobianTracer{DuplicateVector{UInt16}}}) = EMPTY_JACOBIAN_TRACER_DV_U16 -empty(::Type{JacobianTracer{DuplicateVector{UInt32}}}) = EMPTY_JACOBIAN_TRACER_DV_U32 -empty(::Type{JacobianTracer{DuplicateVector{UInt64}}}) = EMPTY_JACOBIAN_TRACER_DV_U64 +empty(::Type{JacobianTracer{UInt16,DuplicateVector{UInt16}}}) = EMPTY_JACOBIAN_TRACER_DV_U16 +empty(::Type{JacobianTracer{UInt32,DuplicateVector{UInt32}}}) = EMPTY_JACOBIAN_TRACER_DV_U32 +empty(::Type{JacobianTracer{UInt64,DuplicateVector{UInt64}}}) = EMPTY_JACOBIAN_TRACER_DV_U64 -empty(::Type{HessianTracer{DuplicateVector{UInt16},UInt16}}) = EMPTY_HESSIAN_TRACER_DV_U16 -empty(::Type{HessianTracer{DuplicateVector{UInt32},UInt32}}) = EMPTY_HESSIAN_TRACER_DV_U32 -empty(::Type{HessianTracer{DuplicateVector{UInt64},UInt64}}) = EMPTY_HESSIAN_TRACER_DV_U64 +function empty( + ::Type{ + HessianTracer{UInt16,DuplicateVector{UInt16},Dict{UInt16,DuplicateVector{UInt16}}} + }, +) + return EMPTY_HESSIAN_TRACER_DV_U16 +end +function empty( + ::Type{ + HessianTracer{UInt32,DuplicateVector{UInt32},Dict{UInt32,DuplicateVector{UInt32}}} + }, +) + return EMPTY_HESSIAN_TRACER_DV_U32 +end +function empty( + ::Type{ + HessianTracer{UInt64,DuplicateVector{UInt64},Dict{UInt64,DuplicateVector{UInt64}}} + }, +) + return EMPTY_HESSIAN_TRACER_DV_U64 +end diff --git a/src/settypes/recursiveset.jl b/src/settypes/recursiveset.jl index 4e7a0bd6..f45c8a35 100644 --- a/src/settypes/recursiveset.jl +++ b/src/settypes/recursiveset.jl @@ -81,14 +81,32 @@ const EMPTY_HESSIAN_TRACER_RS_U16 = HessianTracer(Dict{UInt16,RecursiveSet{UInt1 const EMPTY_HESSIAN_TRACER_RS_U32 = HessianTracer(Dict{UInt32,RecursiveSet{UInt32}}()) const EMPTY_HESSIAN_TRACER_RS_U64 = HessianTracer(Dict{UInt64,RecursiveSet{UInt64}}()) -empty(::Type{ConnectivityTracer{RecursiveSet{UInt16}}}) = EMPTY_CONNECTIVITY_TRACER_RS_U16 -empty(::Type{ConnectivityTracer{RecursiveSet{UInt32}}}) = EMPTY_CONNECTIVITY_TRACER_RS_U32 -empty(::Type{ConnectivityTracer{RecursiveSet{UInt64}}}) = EMPTY_CONNECTIVITY_TRACER_RS_U64 +function empty(::Type{ConnectivityTracer{UInt16,RecursiveSet{UInt16}}}) + return EMPTY_CONNECTIVITY_TRACER_RS_U16 +end +function empty(::Type{ConnectivityTracer{UInt32,RecursiveSet{UInt32}}}) + return EMPTY_CONNECTIVITY_TRACER_RS_U32 +end +function empty(::Type{ConnectivityTracer{UInt64,RecursiveSet{UInt64}}}) + return EMPTY_CONNECTIVITY_TRACER_RS_U64 +end -empty(::Type{JacobianTracer{RecursiveSet{UInt16}}}) = EMPTY_JACOBIAN_TRACER_RS_U16 -empty(::Type{JacobianTracer{RecursiveSet{UInt32}}}) = EMPTY_JACOBIAN_TRACER_RS_U32 -empty(::Type{JacobianTracer{RecursiveSet{UInt64}}}) = EMPTY_JACOBIAN_TRACER_RS_U64 +empty(::Type{JacobianTracer{UInt16,RecursiveSet{UInt16}}}) = EMPTY_JACOBIAN_TRACER_RS_U16 +empty(::Type{JacobianTracer{UInt32,RecursiveSet{UInt32}}}) = EMPTY_JACOBIAN_TRACER_RS_U32 +empty(::Type{JacobianTracer{UInt64,RecursiveSet{UInt64}}}) = EMPTY_JACOBIAN_TRACER_RS_U64 -empty(::Type{HessianTracer{RecursiveSet{UInt16},UInt16}}) = EMPTY_HESSIAN_TRACER_RS_U16 -empty(::Type{HessianTracer{RecursiveSet{UInt32},UInt32}}) = EMPTY_HESSIAN_TRACER_RS_U32 -empty(::Type{HessianTracer{RecursiveSet{UInt64},UInt64}}) = EMPTY_HESSIAN_TRACER_RS_U64 +function empty( + ::Type{HessianTracer{UInt16,RecursiveSet{UInt16},Dict{UInt16,RecursiveSet{UInt16}}}} +) + return EMPTY_HESSIAN_TRACER_RS_U16 +end +function empty( + ::Type{HessianTracer{UInt32,RecursiveSet{UInt32},Dict{UInt32,RecursiveSet{UInt32}}}} +) + return EMPTY_HESSIAN_TRACER_RS_U32 +end +function empty( + ::Type{HessianTracer{UInt64,RecursiveSet{UInt64},Dict{UInt64,RecursiveSet{UInt64}}}} +) + return EMPTY_HESSIAN_TRACER_RS_U64 +end diff --git a/src/settypes/sortedvector.jl b/src/settypes/sortedvector.jl index 364928ec..814fe247 100644 --- a/src/settypes/sortedvector.jl +++ b/src/settypes/sortedvector.jl @@ -82,14 +82,32 @@ const EMPTY_HESSIAN_TRACER_SV_U16 = HessianTracer(Dict{UInt16,SortedVector{UInt1 const EMPTY_HESSIAN_TRACER_SV_U32 = HessianTracer(Dict{UInt32,SortedVector{UInt32}}()) const EMPTY_HESSIAN_TRACER_SV_U64 = HessianTracer(Dict{UInt64,SortedVector{UInt64}}()) -empty(::Type{ConnectivityTracer{SortedVector{UInt16}}}) = EMPTY_CONNECTIVITY_TRACER_SV_U16 -empty(::Type{ConnectivityTracer{SortedVector{UInt32}}}) = EMPTY_CONNECTIVITY_TRACER_SV_U32 -empty(::Type{ConnectivityTracer{SortedVector{UInt64}}}) = EMPTY_CONNECTIVITY_TRACER_SV_U64 +function empty(::Type{ConnectivityTracer{UInt16,SortedVector{UInt16}}}) + return EMPTY_CONNECTIVITY_TRACER_SV_U16 +end +function empty(::Type{ConnectivityTracer{UInt32,SortedVector{UInt32}}}) + return EMPTY_CONNECTIVITY_TRACER_SV_U32 +end +function empty(::Type{ConnectivityTracer{UInt64,SortedVector{UInt64}}}) + return EMPTY_CONNECTIVITY_TRACER_SV_U64 +end -empty(::Type{JacobianTracer{SortedVector{UInt16}}}) = EMPTY_JACOBIAN_TRACER_SV_U16 -empty(::Type{JacobianTracer{SortedVector{UInt32}}}) = EMPTY_JACOBIAN_TRACER_SV_U32 -empty(::Type{JacobianTracer{SortedVector{UInt64}}}) = EMPTY_JACOBIAN_TRACER_SV_U64 +empty(::Type{JacobianTracer{UInt16,SortedVector{UInt16}}}) = EMPTY_JACOBIAN_TRACER_SV_U16 +empty(::Type{JacobianTracer{UInt32,SortedVector{UInt32}}}) = EMPTY_JACOBIAN_TRACER_SV_U32 +empty(::Type{JacobianTracer{UInt64,SortedVector{UInt64}}}) = EMPTY_JACOBIAN_TRACER_SV_U64 -empty(::Type{HessianTracer{SortedVector{UInt16},UInt16}}) = EMPTY_HESSIAN_TRACER_SV_U16 -empty(::Type{HessianTracer{SortedVector{UInt32},UInt32}}) = EMPTY_HESSIAN_TRACER_SV_U32 -empty(::Type{HessianTracer{SortedVector{UInt64},UInt64}}) = EMPTY_HESSIAN_TRACER_SV_U64 +function empty( + ::Type{HessianTracer{UInt16,SortedVector{UInt16},Dict{UInt16,SortedVector{UInt16}}}} +) + return EMPTY_HESSIAN_TRACER_SV_U16 +end +function empty( + ::Type{HessianTracer{UInt32,SortedVector{UInt32},Dict{UInt32,SortedVector{UInt32}}}} +) + return EMPTY_HESSIAN_TRACER_SV_U32 +end +function empty( + ::Type{HessianTracer{UInt64,SortedVector{UInt64},Dict{UInt64,SortedVector{UInt64}}}} +) + return EMPTY_HESSIAN_TRACER_SV_U64 +end diff --git a/src/tracers.jl b/src/tracers.jl index a0fcb364..64d6000c 100644 --- a/src/tracers.jl +++ b/src/tracers.jl @@ -3,10 +3,6 @@ abstract type AbstractTracer <: Number end # Convenience constructor for empty tracers empty(tracer::T) where {T<:AbstractTracer} = empty(T) -#==============# -# Connectivity # -#==============# - const SET_TYPE_MESSAGE = """ The provided index set type `S` has to satisfy the following conditions: @@ -16,8 +12,12 @@ The provided index set type `S` has to satisfy the following conditions: Subtypes of `AbstractSet{<:Integer}` are a natural choice, like `BitSet` or `Set{UInt64}`. """ +#==============# +# Connectivity # +#==============# + """ - ConnectivityTracer{S}(indexset) <: Number + ConnectivityTracer{I,S}(indexset) <: Number Number type keeping track of input indices of previous computations. @@ -25,40 +25,31 @@ $SET_TYPE_MESSAGE For a higher-level interface, refer to [`connectivity_pattern`](@ref). """ -struct ConnectivityTracer{S} <: AbstractTracer +struct ConnectivityTracer{I<:Integer,S} <: AbstractTracer inputs::S # indices of connected, enumerated inputs end +function ConnectivityTracer(inputs::S) where {S} + I = eltype(S) + return ConnectivityTracer{I,S}(inputs) +end -function Base.show(io::IO, t::ConnectivityTracer{S}) where {S} +function Base.show(io::IO, t::ConnectivityTracer{I,S}) where {I,S} return Base.show_delim_array( - io, convert.(Int, inputs(t)), "ConnectivityTracer{$S}(", ',', ')', true + io, convert.(Int, inputs(t)), "ConnectivityTracer{$I,$S}(", ',', ')', true ) end -empty(::Type{ConnectivityTracer{S}}) where {S} = ConnectivityTracer(S()) - -# Performance can be gained by not re-allocating empty tracers -const EMPTY_CONNECTIVITY_TRACER_BITSET = ConnectivityTracer(BitSet()) -const EMPTY_CONNECTIVITY_TRACER_SET_U8 = ConnectivityTracer(Set{UInt8}()) -const EMPTY_CONNECTIVITY_TRACER_SET_U16 = ConnectivityTracer(Set{UInt16}()) -const EMPTY_CONNECTIVITY_TRACER_SET_U32 = ConnectivityTracer(Set{UInt32}()) -const EMPTY_CONNECTIVITY_TRACER_SET_U64 = ConnectivityTracer(Set{UInt64}()) - -empty(::Type{ConnectivityTracer{BitSet}}) = EMPTY_CONNECTIVITY_TRACER_BITSET -empty(::Type{ConnectivityTracer{Set{UInt8}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U8 -empty(::Type{ConnectivityTracer{Set{UInt16}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U16 -empty(::Type{ConnectivityTracer{Set{UInt32}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U32 -empty(::Type{ConnectivityTracer{Set{UInt64}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U64 +empty(::Type{ConnectivityTracer{I,S}}) where {I<:Integer,S} = ConnectivityTracer{I,S}(S()) # We have to be careful when defining constructors: # Generic code expecting "regular" numbers `x` will sometimes convert them # by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `ConnectivityTracer`. # When this happens, we create a new empty tracer with no input pattern. -ConnectivityTracer{S}(::Number) where {S} = empty(ConnectivityTracer{S}) +ConnectivityTracer{I,S}(::Number) where {I<:Integer,S} = empty(ConnectivityTracer{I,S}) ConnectivityTracer(t::ConnectivityTracer) = t ## Unions of tracers -function uniontracer(a::ConnectivityTracer{S}, b::ConnectivityTracer{S}) where {S} +function uniontracer(a::ConnectivityTracer{I,S}, b::ConnectivityTracer{I,S}) where {I,S} return ConnectivityTracer(union(a.inputs, b.inputs)) end @@ -67,7 +58,7 @@ end #==========# """ - JacobianTracer{S}(indexset) <: Number + JacobianTracer{I,S}(indexset) <: Number Number type keeping track of input indices of previous computations with non-zero derivatives. @@ -75,36 +66,27 @@ $SET_TYPE_MESSAGE For a higher-level interface, refer to [`jacobian_pattern`](@ref). """ -struct JacobianTracer{S} <: AbstractTracer +struct JacobianTracer{I<:Integer,S} <: AbstractTracer inputs::S end +function JacobianTracer(inputs::S) where {S} + I = eltype(S) + return JacobianTracer{I,S}(inputs) +end -function Base.show(io::IO, t::JacobianTracer{S}) where {S} +function Base.show(io::IO, t::JacobianTracer{I,S}) where {I,S} return Base.show_delim_array( - io, convert.(Int, inputs(t)), "JacobianTracer{$S}(", ',', ')', true + io, convert.(Int, inputs(t)), "JacobianTracer{$I,$S}(", ',', ')', true ) end -empty(::Type{JacobianTracer{S}}) where {S} = JacobianTracer(S()) - -# Performance can be gained by not re-allocating empty tracers -const EMPTY_JACOBIAN_TRACER_BITSET = JacobianTracer(BitSet()) -const EMPTY_JACOBIAN_TRACER_SET_U8 = JacobianTracer(Set{UInt8}()) -const EMPTY_JACOBIAN_TRACER_SET_U16 = JacobianTracer(Set{UInt16}()) -const EMPTY_JACOBIAN_TRACER_SET_U32 = JacobianTracer(Set{UInt32}()) -const EMPTY_JACOBIAN_TRACER_SET_U64 = JacobianTracer(Set{UInt64}()) +empty(::Type{JacobianTracer{I,S}}) where {I<:Integer,S} = JacobianTracer{I,S}(S()) -empty(::Type{JacobianTracer{BitSet}}) = EMPTY_JACOBIAN_TRACER_BITSET -empty(::Type{JacobianTracer{Set{UInt8}}}) = EMPTY_JACOBIAN_TRACER_SET_U8 -empty(::Type{JacobianTracer{Set{UInt16}}}) = EMPTY_JACOBIAN_TRACER_SET_U16 -empty(::Type{JacobianTracer{Set{UInt32}}}) = EMPTY_JACOBIAN_TRACER_SET_U32 -empty(::Type{JacobianTracer{Set{UInt64}}}) = EMPTY_JACOBIAN_TRACER_SET_U64 - -JacobianTracer{S}(::Number) where {S} = empty(JacobianTracer{S}) +JacobianTracer{I,S}(::Number) where {I<:Integer,S} = empty(JacobianTracer{I,S}) JacobianTracer(t::JacobianTracer) = t ## Unions of tracers -function uniontracer(a::JacobianTracer{S}, b::JacobianTracer{S}) where {S} +function uniontracer(a::JacobianTracer{I,S}, b::JacobianTracer{I,S}) where {I,S} return JacobianTracer(union(a.inputs, b.inputs)) end @@ -112,7 +94,7 @@ end # Hessian # #=========# """ - HessianTracer{S}(indexset) <: Number + HessianTracer{I,S,D}(indexset) <: Number Number type keeping track of input indices of previous computations with non-zero first and second derivatives. @@ -120,11 +102,11 @@ $SET_TYPE_MESSAGE For a higher-level interface, refer to [`hessian_pattern`](@ref). """ -struct HessianTracer{S,I<:Integer} <: AbstractTracer - inputs::Dict{I,S} +struct HessianTracer{I<:Integer,S,D<:AbstractDict{I,S}} <: AbstractTracer + inputs::D end -function Base.show(io::IO, t::HessianTracer{S}) where {S} - println(io, "HessianTracer{", S, "}(") +function Base.show(io::IO, t::HessianTracer{I,S,D}) where {I,S,D} + println(io, "$(eltype(t))(") for key in keys(t.inputs) print(io, " ", Int(key), " => ") Base.show_delim_array(io, convert.(Int, t.inputs[key]), "(", ',', ')', true) @@ -133,32 +115,15 @@ function Base.show(io::IO, t::HessianTracer{S}) where {S} return print(io, ")") end -function empty(::Type{HessianTracer{S,I}}) where {S,I} - return HessianTracer(Dict{I,S}()) +function empty(::Type{HessianTracer{I,S,D}}) where {I<:Integer,S,D<:AbstractDict{I,S}} + return HessianTracer{I,S,D}(D()) end -# Performance can be gained by not re-allocating empty tracers -const EMPTY_HESSIAN_TRACER_BITSET = HessianTracer(Dict{Int,BitSet}()) -const EMPTY_HESSIAN_TRACER_SET_U8 = HessianTracer(Dict{UInt8,Set{UInt8}}()) -const EMPTY_HESSIAN_TRACER_SET_U16 = HessianTracer(Dict{UInt16,Set{UInt16}}()) -const EMPTY_HESSIAN_TRACER_SET_U32 = HessianTracer(Dict{UInt32,Set{UInt32}}()) -const EMPTY_HESSIAN_TRACER_SET_U64 = HessianTracer(Dict{UInt64,Set{UInt64}}()) - -empty(::Type{HessianTracer{BitSet,Int}}) = EMPTY_HESSIAN_TRACER_BITSET -empty(::Type{HessianTracer{Set{UInt8},UInt8}}) = EMPTY_HESSIAN_TRACER_SET_U8 -empty(::Type{HessianTracer{Set{UInt16},UInt16}}) = EMPTY_HESSIAN_TRACER_SET_U16 -empty(::Type{HessianTracer{Set{UInt32},UInt32}}) = EMPTY_HESSIAN_TRACER_SET_U32 -empty(::Type{HessianTracer{Set{UInt64},UInt64}}) = EMPTY_HESSIAN_TRACER_SET_U64 - -HessianTracer{S,I}(::Number) where {S,I} = empty(HessianTracer{S,I}) +HessianTracer{I,S,D}(::Number) where {I<:Integer,S,D} = empty(HessianTracer{I,S,D}) HessianTracer(t::HessianTracer) = t -function keys2set(::Type{S}, d::Dict{I}) where {I<:Integer,S<:AbstractSet{<:I}} - return S(keys(d)) -end - # Turn first-order interactions into second-order interactions -function promote_order(t::HessianTracer{S}) where {S} +function promote_order(t::HessianTracer{I,S}) where {I,S} d = deepcopy(t.inputs) s = keys2set(S, d) for (k, v) in pairs(d) @@ -173,7 +138,7 @@ function additive_merge(a::HessianTracer, b::HessianTracer) end # Merge first- and second-order terms in a "distributive" fashion -function distributive_merge(a::HessianTracer{S}, b::HessianTracer{S}) where {S} +function distributive_merge(a::HessianTracer{I,S,D}, b::HessianTracer{I,S,D}) where {I,S,D} da = deepcopy(a.inputs) db = deepcopy(b.inputs) sa = keys2set(S, da) @@ -208,22 +173,22 @@ inputs(t::HessianTracer, i::Integer) = collect(t.inputs[i]) Convenience constructor for [`ConnectivityTracer`](@ref), [`JacobianTracer`](@ref) and [`HessianTracer`](@ref) from input indices. """ -tracer(::Type{JacobianTracer{S}}, index::Integer) where {S} = JacobianTracer(S(index)) -function tracer(::Type{ConnectivityTracer{S}}, index::Integer) where {S} - return ConnectivityTracer(S(index)) +function tracer(::Type{JacobianTracer{I,S}}, index::Integer) where {I,S} + return JacobianTracer{I,S}(S(index)) end -function tracer(::Type{HessianTracer{S}}, index::Integer) where {S} - I = eltype(S) - return HessianTracer{S,I}(Dict{I,S}(index => S())) +function tracer(::Type{ConnectivityTracer{I,S}}, index::Integer) where {I,S} + return ConnectivityTracer{I,S}(S(index)) +end +function tracer(::Type{HessianTracer{I,S,D}}, index::Integer) where {I,S,D} + return HessianTracer{I,S,D}(D(index => S())) end -function tracer(::Type{JacobianTracer{S}}, inds::NTuple{N,<:Integer}) where {N,S} - return JacobianTracer{S}(S(inds)) +function tracer(::Type{JacobianTracer{I,S}}, inds::NTuple{N,<:Integer}) where {I,S,N} + return JacobianTracer{I,S}(S(inds)) end -function tracer(::Type{ConnectivityTracer{S}}, inds::NTuple{N,<:Integer}) where {N,S} - return ConnectivityTracer{S}(S(inds)) +function tracer(::Type{ConnectivityTracer{I,S}}, inds::NTuple{N,<:Integer}) where {I,S,N} + return ConnectivityTracer{I,S}(S(inds)) end -function tracer(::Type{HessianTracer{S}}, inds::NTuple{N,<:Integer}) where {N,S} - I = eltype(S) - return HessianTracer{S,I}(Dict{I,S}(i => S() for i in inds)) +function tracer(::Type{HessianTracer{I,S,D}}, inds::NTuple{N,<:Integer}) where {I,S,D,N} + return HessianTracer{I,S,D}(D(i => S() for i in inds)) end diff --git a/test/first_order.jl b/test/first_order.jl index b6f13be7..a6cd9b23 100644 --- a/test/first_order.jl +++ b/test/first_order.jl @@ -7,8 +7,9 @@ using Test @testset "Set type $S" for S in ( BitSet, Set{UInt64}, DuplicateVector{UInt64}, RecursiveSet{UInt64}, SortedVector{UInt64} ) - CT = ConnectivityTracer{S} - JT = JacobianTracer{S} + I = eltype(S) + CT = ConnectivityTracer{I,S} + JT = JacobianTracer{I,S} x = rand(3) xt = trace_input(CT, x) diff --git a/test/references/show/ConnectivityTracer_BitSet.txt b/test/references/show/ConnectivityTracer_BitSet.txt index 5bf162ba..245b6f0e 100644 --- a/test/references/show/ConnectivityTracer_BitSet.txt +++ b/test/references/show/ConnectivityTracer_BitSet.txt @@ -1 +1 @@ -ConnectivityTracer{BitSet}(2,) \ No newline at end of file +ConnectivityTracer{Int64,BitSet}(2,) \ No newline at end of file diff --git a/test/references/show/ConnectivityTracer_DuplicateVector{UInt64}.txt b/test/references/show/ConnectivityTracer_DuplicateVector{UInt64}.txt index e84637a4..7e04670a 100644 --- a/test/references/show/ConnectivityTracer_DuplicateVector{UInt64}.txt +++ b/test/references/show/ConnectivityTracer_DuplicateVector{UInt64}.txt @@ -1 +1 @@ -ConnectivityTracer{DuplicateVector{UInt64}}(2,) \ No newline at end of file +ConnectivityTracer{UInt64,DuplicateVector{UInt64}}(2,) \ No newline at end of file diff --git a/test/references/show/ConnectivityTracer_RecursiveSet{UInt64}.txt b/test/references/show/ConnectivityTracer_RecursiveSet{UInt64}.txt index e90ca8f8..19545d48 100644 --- a/test/references/show/ConnectivityTracer_RecursiveSet{UInt64}.txt +++ b/test/references/show/ConnectivityTracer_RecursiveSet{UInt64}.txt @@ -1 +1 @@ -ConnectivityTracer{RecursiveSet{UInt64}}(2,) \ No newline at end of file +ConnectivityTracer{UInt64,RecursiveSet{UInt64}}(2,) \ No newline at end of file diff --git a/test/references/show/ConnectivityTracer_Set{UInt64}.txt b/test/references/show/ConnectivityTracer_Set{UInt64}.txt index 60799161..52a9e29e 100644 --- a/test/references/show/ConnectivityTracer_Set{UInt64}.txt +++ b/test/references/show/ConnectivityTracer_Set{UInt64}.txt @@ -1 +1 @@ -ConnectivityTracer{Set{UInt64}}(2,) \ No newline at end of file +ConnectivityTracer{UInt64,Set{UInt64}}(2,) \ No newline at end of file diff --git a/test/references/show/ConnectivityTracer_SortedVector{UInt64}.txt b/test/references/show/ConnectivityTracer_SortedVector{UInt64}.txt index c53d707c..7b47c900 100644 --- a/test/references/show/ConnectivityTracer_SortedVector{UInt64}.txt +++ b/test/references/show/ConnectivityTracer_SortedVector{UInt64}.txt @@ -1 +1 @@ -ConnectivityTracer{SortedVector{UInt64}}(2,) \ No newline at end of file +ConnectivityTracer{UInt64,SortedVector{UInt64}}(2,) \ No newline at end of file diff --git a/test/references/show/HessianTracer_BitSet.txt b/test/references/show/HessianTracer_BitSet.txt index 0aeee5ed..2b812225 100644 --- a/test/references/show/HessianTracer_BitSet.txt +++ b/test/references/show/HessianTracer_BitSet.txt @@ -1,3 +1,3 @@ -HessianTracer{BitSet}( +HessianTracer{Int64, BitSet, Dict{Int64, BitSet}}( 2 => (), ) \ No newline at end of file diff --git a/test/references/show/HessianTracer_DuplicateVector{UInt64}.txt b/test/references/show/HessianTracer_DuplicateVector{UInt64}.txt index baf1bcea..1cc4dd72 100644 --- a/test/references/show/HessianTracer_DuplicateVector{UInt64}.txt +++ b/test/references/show/HessianTracer_DuplicateVector{UInt64}.txt @@ -1,3 +1,3 @@ -HessianTracer{DuplicateVector{UInt64}}( +HessianTracer{UInt64, DuplicateVector{UInt64}, Dict{UInt64, DuplicateVector{UInt64}}}( 2 => (), ) \ No newline at end of file diff --git a/test/references/show/HessianTracer_RecursiveSet{UInt64}.txt b/test/references/show/HessianTracer_RecursiveSet{UInt64}.txt index fa015a48..011fbe4f 100644 --- a/test/references/show/HessianTracer_RecursiveSet{UInt64}.txt +++ b/test/references/show/HessianTracer_RecursiveSet{UInt64}.txt @@ -1,3 +1,3 @@ -HessianTracer{RecursiveSet{UInt64}}( +HessianTracer{UInt64, RecursiveSet{UInt64}, Dict{UInt64, RecursiveSet{UInt64}}}( 2 => (), ) \ No newline at end of file diff --git a/test/references/show/HessianTracer_Set{UInt64}.txt b/test/references/show/HessianTracer_Set{UInt64}.txt index 54572703..18a245fb 100644 --- a/test/references/show/HessianTracer_Set{UInt64}.txt +++ b/test/references/show/HessianTracer_Set{UInt64}.txt @@ -1,3 +1,3 @@ -HessianTracer{Set{UInt64}}( +HessianTracer{UInt64, Set{UInt64}, Dict{UInt64, Set{UInt64}}}( 2 => (), ) \ No newline at end of file diff --git a/test/references/show/HessianTracer_SortedVector{UInt64}.txt b/test/references/show/HessianTracer_SortedVector{UInt64}.txt index fb4fbb7a..03239fd4 100644 --- a/test/references/show/HessianTracer_SortedVector{UInt64}.txt +++ b/test/references/show/HessianTracer_SortedVector{UInt64}.txt @@ -1,3 +1,3 @@ -HessianTracer{SortedVector{UInt64}}( +HessianTracer{UInt64, SortedVector{UInt64}, Dict{UInt64, SortedVector{UInt64}}}( 2 => (), ) \ No newline at end of file diff --git a/test/references/show/JacobianTracer_BitSet.txt b/test/references/show/JacobianTracer_BitSet.txt index 706e87bf..0eea7cb3 100644 --- a/test/references/show/JacobianTracer_BitSet.txt +++ b/test/references/show/JacobianTracer_BitSet.txt @@ -1 +1 @@ -JacobianTracer{BitSet}(2,) \ No newline at end of file +JacobianTracer{Int64,BitSet}(2,) \ No newline at end of file diff --git a/test/references/show/JacobianTracer_DuplicateVector{UInt64}.txt b/test/references/show/JacobianTracer_DuplicateVector{UInt64}.txt index 8a564f63..a2251a68 100644 --- a/test/references/show/JacobianTracer_DuplicateVector{UInt64}.txt +++ b/test/references/show/JacobianTracer_DuplicateVector{UInt64}.txt @@ -1 +1 @@ -JacobianTracer{DuplicateVector{UInt64}}(2,) \ No newline at end of file +JacobianTracer{UInt64,DuplicateVector{UInt64}}(2,) \ No newline at end of file diff --git a/test/references/show/JacobianTracer_RecursiveSet{UInt64}.txt b/test/references/show/JacobianTracer_RecursiveSet{UInt64}.txt index 80e416e4..8dc8b4ac 100644 --- a/test/references/show/JacobianTracer_RecursiveSet{UInt64}.txt +++ b/test/references/show/JacobianTracer_RecursiveSet{UInt64}.txt @@ -1 +1 @@ -JacobianTracer{RecursiveSet{UInt64}}(2,) \ No newline at end of file +JacobianTracer{UInt64,RecursiveSet{UInt64}}(2,) \ No newline at end of file diff --git a/test/references/show/JacobianTracer_Set{UInt64}.txt b/test/references/show/JacobianTracer_Set{UInt64}.txt index 3dfd6c7c..2a7f9754 100644 --- a/test/references/show/JacobianTracer_Set{UInt64}.txt +++ b/test/references/show/JacobianTracer_Set{UInt64}.txt @@ -1 +1 @@ -JacobianTracer{Set{UInt64}}(2,) \ No newline at end of file +JacobianTracer{UInt64,Set{UInt64}}(2,) \ No newline at end of file diff --git a/test/references/show/JacobianTracer_SortedVector{UInt64}.txt b/test/references/show/JacobianTracer_SortedVector{UInt64}.txt index 4199cd90..c7b7b673 100644 --- a/test/references/show/JacobianTracer_SortedVector{UInt64}.txt +++ b/test/references/show/JacobianTracer_SortedVector{UInt64}.txt @@ -1 +1 @@ -JacobianTracer{SortedVector{UInt64}}(2,) \ No newline at end of file +JacobianTracer{UInt64,SortedVector{UInt64}}(2,) \ No newline at end of file diff --git a/test/second_order.jl b/test/second_order.jl index e9dc16dc..a51a04a2 100644 --- a/test/second_order.jl +++ b/test/second_order.jl @@ -7,7 +7,8 @@ using Test @testset "Set type $S" for S in ( BitSet, Set{UInt64}, DuplicateVector{UInt64}, RecursiveSet{UInt64}, SortedVector{UInt64} ) - HT = HessianTracer{S} + I = eltype(S) + HT = HessianTracer{I,S,Dict{I,S}} @test hessian_pattern(identity, rand(), S) ≈ [0;;] @test hessian_pattern(sqrt, rand(), S) ≈ [1;;]