Skip to content

Commit

Permalink
Refactor tracer types (#53)
Browse files Browse the repository at this point in the history
Update tracers to parametric signatures

* `ConnectivityTracer{I,S}`
* `JacobianTracer{I,S}`
* `HessianTracer{I,S,D}`

where `I` is the index type, `S` is a (pseudo-)set type and `D` is an `AbstractDict{I,S}`.
  • Loading branch information
adrhill authored May 9, 2024
1 parent 2fcd730 commit 0dd1bbd
Show file tree
Hide file tree
Showing 25 changed files with 217 additions and 138 deletions.
1 change: 1 addition & 0 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ include("overload_hessian.jl")
include("pattern.jl")
include("adtypes.jl")

include("settypes/base_sets.jl")
include("settypes/duplicatevector.jl")
include("settypes/recursiveset.jl")
include("settypes/sortedvector.jl")
Expand Down
11 changes: 9 additions & 2 deletions src/conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@ for TT in (:JacobianTracer, :ConnectivityTracer, :HessianTracer)
@eval Base.similar(a::Array{A,2}, ::Type{T}) where {A,T<:$TT} = zeros(T, size(a, 1), size(a, 2))
@eval Base.similar(::Array{T}, m::Int) where {T<:$TT} = zeros(T, m)
@eval Base.similar(::Array{T}, dims::Dims{N}) where {N,T<:$TT} = zeros(T, dims)
end

@eval Base.similar(::Array, ::Type{$TT{S}}, dims::Dims{N}) where {N,S} =
zeros($TT{S}, dims)
function Base.similar(::Array, ::Type{ConnectivityTracer{I,S}}, dims::Dims{N}) where {I,S,N}
return zeros(ConnectivityTracer{I,S}, dims)
end
function Base.similar(::Array, ::Type{JacobianTracer{I,S}}, dims::Dims{N}) where {I,S,N}
return zeros(JacobianTracer{I,S}, dims)
end
function Base.similar(::Array, ::Type{HessianTracer{I,S,D}}, dims::Dims{N}) where {I,S,D,N}
return zeros(HessianTracer{I,S,D}, dims)
end
20 changes: 13 additions & 7 deletions src/pattern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ julia> connectivity_pattern(f, x)
```
"""
function connectivity_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S}
xt, yt = trace_function(ConnectivityTracer{S}, f, x)
I = eltype(S)
xt, yt = trace_function(ConnectivityTracer{I,S}, f, x)
return connectivity_pattern_to_mat(to_array(xt), to_array(yt))
end

Expand All @@ -72,7 +73,8 @@ where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to
The type of index set `S` can be specified as an optional argument and defaults to `BitSet`.
"""
function connectivity_pattern(f!, y, x, ::Type{S}=DEFAULT_SET_TYPE) where {S}
xt, yt = trace_function(ConnectivityTracer{S}, f!, y, x)
I = eltype(S)
xt, yt = trace_function(ConnectivityTracer{I,S}, f!, y, x)
return connectivity_pattern_to_mat(to_array(xt), to_array(yt))
end

Expand Down Expand Up @@ -118,7 +120,8 @@ julia> jacobian_pattern(f, x)
```
"""
function jacobian_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S}
xt, yt = trace_function(JacobianTracer{S}, f, x)
I = eltype(S)
xt, yt = trace_function(JacobianTracer{I,S}, f, x)
return jacobian_pattern_to_mat(to_array(xt), to_array(yt))
end

Expand All @@ -131,7 +134,8 @@ Compute the sparsity pattern of the Jacobian of `f!(y, x)`.
The type of index set `S` can be specified as an optional argument and defaults to `BitSet`.
"""
function jacobian_pattern(f!, y, x, ::Type{S}=DEFAULT_SET_TYPE) where {S}
xt, yt = trace_function(JacobianTracer{S}, f!, y, x)
I = eltype(S)
xt, yt = trace_function(JacobianTracer{I,S}, f!, y, x)
return jacobian_pattern_to_mat(to_array(xt), to_array(yt))
end

Expand Down Expand Up @@ -189,13 +193,15 @@ julia> hessian_pattern(g, x)
```
"""
function hessian_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S}
xt, yt = trace_function(HessianTracer{S}, f, x)
I = eltype(S)
T = HessianTracer{I,S,Dict{I,S}}
xt, yt = trace_function(T, f, x)
return hessian_pattern_to_mat(to_array(xt), yt)
end

function hessian_pattern_to_mat(
xt::AbstractArray{HessianTracer{S,T}}, yt::HessianTracer{S,T}
) where {S,T}
xt::AbstractArray{HessianTracer{TI,S,D}}, yt::HessianTracer{TI,S,D}
) where {TI,S,D}
# Allocate Hessian matrix
n = length(xt)
I = UInt64[] # row indices
Expand Down
44 changes: 44 additions & 0 deletions src/settypes/base_sets.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
function keys2set(::Type{S}, d::Dict{I}) where {I<:Integer,S<:AbstractSet{<:I}}
return S(keys(d))
end

# Performance can be gained by not re-allocating empty tracers
## BitSet
const EMPTY_CONNECTIVITY_TRACER_BITSET = ConnectivityTracer(BitSet())
const EMPTY_JACOBIAN_TRACER_BITSET = JacobianTracer(BitSet())
const EMPTY_HESSIAN_TRACER_BITSET = HessianTracer(Dict{Int,BitSet}())

empty(::Type{ConnectivityTracer{Int,BitSet}}) = EMPTY_CONNECTIVITY_TRACER_BITSET
empty(::Type{JacobianTracer{Int,BitSet}}) = EMPTY_JACOBIAN_TRACER_BITSET
empty(::Type{HessianTracer{Int,BitSet,Dict{Int,BitSet}}}) = EMPTY_HESSIAN_TRACER_BITSET

## Set
const EMPTY_CONNECTIVITY_TRACER_SET_U8 = ConnectivityTracer(Set{UInt8}())
const EMPTY_CONNECTIVITY_TRACER_SET_U16 = ConnectivityTracer(Set{UInt16}())
const EMPTY_CONNECTIVITY_TRACER_SET_U32 = ConnectivityTracer(Set{UInt32}())
const EMPTY_CONNECTIVITY_TRACER_SET_U64 = ConnectivityTracer(Set{UInt64}())

const EMPTY_JACOBIAN_TRACER_SET_U8 = JacobianTracer(Set{UInt8}())
const EMPTY_JACOBIAN_TRACER_SET_U16 = JacobianTracer(Set{UInt16}())
const EMPTY_JACOBIAN_TRACER_SET_U32 = JacobianTracer(Set{UInt32}())
const EMPTY_JACOBIAN_TRACER_SET_U64 = JacobianTracer(Set{UInt64}())

const EMPTY_HESSIAN_TRACER_SET_U8 = HessianTracer(Dict{UInt8,Set{UInt8}}())
const EMPTY_HESSIAN_TRACER_SET_U16 = HessianTracer(Dict{UInt16,Set{UInt16}}())
const EMPTY_HESSIAN_TRACER_SET_U32 = HessianTracer(Dict{UInt32,Set{UInt32}}())
const EMPTY_HESSIAN_TRACER_SET_U64 = HessianTracer(Dict{UInt64,Set{UInt64}}())

empty(::Type{ConnectivityTracer{UInt8,Set{UInt8}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U8
empty(::Type{ConnectivityTracer{UInt16,Set{UInt16}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U16
empty(::Type{ConnectivityTracer{UInt32,Set{UInt32}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U32
empty(::Type{ConnectivityTracer{UInt64,Set{UInt64}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U64

empty(::Type{JacobianTracer{UInt8,Set{UInt8}}}) = EMPTY_JACOBIAN_TRACER_SET_U8
empty(::Type{JacobianTracer{UInt16,Set{UInt16}}}) = EMPTY_JACOBIAN_TRACER_SET_U16
empty(::Type{JacobianTracer{UInt32,Set{UInt32}}}) = EMPTY_JACOBIAN_TRACER_SET_U32
empty(::Type{JacobianTracer{UInt64,Set{UInt64}}}) = EMPTY_JACOBIAN_TRACER_SET_U64

empty(::Type{HessianTracer{UInt8,Set{UInt8},Dict{UInt8,Set{UInt8}}}}) = EMPTY_HESSIAN_TRACER_SET_U8
empty(::Type{HessianTracer{UInt16,Set{UInt16},Dict{UInt16,Set{UInt16}}}}) = EMPTY_HESSIAN_TRACER_SET_U16
empty(::Type{HessianTracer{UInt32,Set{UInt32},Dict{UInt32,Set{UInt32}}}}) = EMPTY_HESSIAN_TRACER_SET_U32
empty(::Type{HessianTracer{UInt64,Set{UInt64},Dict{UInt64,Set{UInt64}}}}) = EMPTY_HESSIAN_TRACER_SET_U64
36 changes: 27 additions & 9 deletions src/settypes/duplicatevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,38 @@ const EMPTY_HESSIAN_TRACER_DV_U16 = HessianTracer(Dict{UInt16,DuplicateVector{UI
const EMPTY_HESSIAN_TRACER_DV_U32 = HessianTracer(Dict{UInt32,DuplicateVector{UInt32}}())
const EMPTY_HESSIAN_TRACER_DV_U64 = HessianTracer(Dict{UInt64,DuplicateVector{UInt64}}())

function empty(::Type{ConnectivityTracer{DuplicateVector{UInt16}}})
function empty(::Type{ConnectivityTracer{UInt16,DuplicateVector{UInt16}}})
return EMPTY_CONNECTIVITY_TRACER_DV_U16
end
function empty(::Type{ConnectivityTracer{DuplicateVector{UInt32}}})
function empty(::Type{ConnectivityTracer{UInt32,DuplicateVector{UInt32}}})
return EMPTY_CONNECTIVITY_TRACER_DV_U32
end
function empty(::Type{ConnectivityTracer{DuplicateVector{UInt64}}})
function empty(::Type{ConnectivityTracer{UInt64,DuplicateVector{UInt64}}})
return EMPTY_CONNECTIVITY_TRACER_DV_U64
end

empty(::Type{JacobianTracer{DuplicateVector{UInt16}}}) = EMPTY_JACOBIAN_TRACER_DV_U16
empty(::Type{JacobianTracer{DuplicateVector{UInt32}}}) = EMPTY_JACOBIAN_TRACER_DV_U32
empty(::Type{JacobianTracer{DuplicateVector{UInt64}}}) = EMPTY_JACOBIAN_TRACER_DV_U64
empty(::Type{JacobianTracer{UInt16,DuplicateVector{UInt16}}}) = EMPTY_JACOBIAN_TRACER_DV_U16
empty(::Type{JacobianTracer{UInt32,DuplicateVector{UInt32}}}) = EMPTY_JACOBIAN_TRACER_DV_U32
empty(::Type{JacobianTracer{UInt64,DuplicateVector{UInt64}}}) = EMPTY_JACOBIAN_TRACER_DV_U64

empty(::Type{HessianTracer{DuplicateVector{UInt16},UInt16}}) = EMPTY_HESSIAN_TRACER_DV_U16
empty(::Type{HessianTracer{DuplicateVector{UInt32},UInt32}}) = EMPTY_HESSIAN_TRACER_DV_U32
empty(::Type{HessianTracer{DuplicateVector{UInt64},UInt64}}) = EMPTY_HESSIAN_TRACER_DV_U64
function empty(
::Type{
HessianTracer{UInt16,DuplicateVector{UInt16},Dict{UInt16,DuplicateVector{UInt16}}}
},
)
return EMPTY_HESSIAN_TRACER_DV_U16
end
function empty(
::Type{
HessianTracer{UInt32,DuplicateVector{UInt32},Dict{UInt32,DuplicateVector{UInt32}}}
},
)
return EMPTY_HESSIAN_TRACER_DV_U32
end
function empty(
::Type{
HessianTracer{UInt64,DuplicateVector{UInt64},Dict{UInt64,DuplicateVector{UInt64}}}
},
)
return EMPTY_HESSIAN_TRACER_DV_U64
end
36 changes: 27 additions & 9 deletions src/settypes/recursiveset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,32 @@ const EMPTY_HESSIAN_TRACER_RS_U16 = HessianTracer(Dict{UInt16,RecursiveSet{UInt1
const EMPTY_HESSIAN_TRACER_RS_U32 = HessianTracer(Dict{UInt32,RecursiveSet{UInt32}}())
const EMPTY_HESSIAN_TRACER_RS_U64 = HessianTracer(Dict{UInt64,RecursiveSet{UInt64}}())

empty(::Type{ConnectivityTracer{RecursiveSet{UInt16}}}) = EMPTY_CONNECTIVITY_TRACER_RS_U16
empty(::Type{ConnectivityTracer{RecursiveSet{UInt32}}}) = EMPTY_CONNECTIVITY_TRACER_RS_U32
empty(::Type{ConnectivityTracer{RecursiveSet{UInt64}}}) = EMPTY_CONNECTIVITY_TRACER_RS_U64
function empty(::Type{ConnectivityTracer{UInt16,RecursiveSet{UInt16}}})
return EMPTY_CONNECTIVITY_TRACER_RS_U16
end
function empty(::Type{ConnectivityTracer{UInt32,RecursiveSet{UInt32}}})
return EMPTY_CONNECTIVITY_TRACER_RS_U32
end
function empty(::Type{ConnectivityTracer{UInt64,RecursiveSet{UInt64}}})
return EMPTY_CONNECTIVITY_TRACER_RS_U64
end

empty(::Type{JacobianTracer{RecursiveSet{UInt16}}}) = EMPTY_JACOBIAN_TRACER_RS_U16
empty(::Type{JacobianTracer{RecursiveSet{UInt32}}}) = EMPTY_JACOBIAN_TRACER_RS_U32
empty(::Type{JacobianTracer{RecursiveSet{UInt64}}}) = EMPTY_JACOBIAN_TRACER_RS_U64
empty(::Type{JacobianTracer{UInt16,RecursiveSet{UInt16}}}) = EMPTY_JACOBIAN_TRACER_RS_U16
empty(::Type{JacobianTracer{UInt32,RecursiveSet{UInt32}}}) = EMPTY_JACOBIAN_TRACER_RS_U32
empty(::Type{JacobianTracer{UInt64,RecursiveSet{UInt64}}}) = EMPTY_JACOBIAN_TRACER_RS_U64

empty(::Type{HessianTracer{RecursiveSet{UInt16},UInt16}}) = EMPTY_HESSIAN_TRACER_RS_U16
empty(::Type{HessianTracer{RecursiveSet{UInt32},UInt32}}) = EMPTY_HESSIAN_TRACER_RS_U32
empty(::Type{HessianTracer{RecursiveSet{UInt64},UInt64}}) = EMPTY_HESSIAN_TRACER_RS_U64
function empty(
::Type{HessianTracer{UInt16,RecursiveSet{UInt16},Dict{UInt16,RecursiveSet{UInt16}}}}
)
return EMPTY_HESSIAN_TRACER_RS_U16
end
function empty(
::Type{HessianTracer{UInt32,RecursiveSet{UInt32},Dict{UInt32,RecursiveSet{UInt32}}}}
)
return EMPTY_HESSIAN_TRACER_RS_U32
end
function empty(
::Type{HessianTracer{UInt64,RecursiveSet{UInt64},Dict{UInt64,RecursiveSet{UInt64}}}}
)
return EMPTY_HESSIAN_TRACER_RS_U64
end
36 changes: 27 additions & 9 deletions src/settypes/sortedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,32 @@ const EMPTY_HESSIAN_TRACER_SV_U16 = HessianTracer(Dict{UInt16,SortedVector{UInt1
const EMPTY_HESSIAN_TRACER_SV_U32 = HessianTracer(Dict{UInt32,SortedVector{UInt32}}())
const EMPTY_HESSIAN_TRACER_SV_U64 = HessianTracer(Dict{UInt64,SortedVector{UInt64}}())

empty(::Type{ConnectivityTracer{SortedVector{UInt16}}}) = EMPTY_CONNECTIVITY_TRACER_SV_U16
empty(::Type{ConnectivityTracer{SortedVector{UInt32}}}) = EMPTY_CONNECTIVITY_TRACER_SV_U32
empty(::Type{ConnectivityTracer{SortedVector{UInt64}}}) = EMPTY_CONNECTIVITY_TRACER_SV_U64
function empty(::Type{ConnectivityTracer{UInt16,SortedVector{UInt16}}})
return EMPTY_CONNECTIVITY_TRACER_SV_U16
end
function empty(::Type{ConnectivityTracer{UInt32,SortedVector{UInt32}}})
return EMPTY_CONNECTIVITY_TRACER_SV_U32
end
function empty(::Type{ConnectivityTracer{UInt64,SortedVector{UInt64}}})
return EMPTY_CONNECTIVITY_TRACER_SV_U64
end

empty(::Type{JacobianTracer{SortedVector{UInt16}}}) = EMPTY_JACOBIAN_TRACER_SV_U16
empty(::Type{JacobianTracer{SortedVector{UInt32}}}) = EMPTY_JACOBIAN_TRACER_SV_U32
empty(::Type{JacobianTracer{SortedVector{UInt64}}}) = EMPTY_JACOBIAN_TRACER_SV_U64
empty(::Type{JacobianTracer{UInt16,SortedVector{UInt16}}}) = EMPTY_JACOBIAN_TRACER_SV_U16
empty(::Type{JacobianTracer{UInt32,SortedVector{UInt32}}}) = EMPTY_JACOBIAN_TRACER_SV_U32
empty(::Type{JacobianTracer{UInt64,SortedVector{UInt64}}}) = EMPTY_JACOBIAN_TRACER_SV_U64

empty(::Type{HessianTracer{SortedVector{UInt16},UInt16}}) = EMPTY_HESSIAN_TRACER_SV_U16
empty(::Type{HessianTracer{SortedVector{UInt32},UInt32}}) = EMPTY_HESSIAN_TRACER_SV_U32
empty(::Type{HessianTracer{SortedVector{UInt64},UInt64}}) = EMPTY_HESSIAN_TRACER_SV_U64
function empty(
::Type{HessianTracer{UInt16,SortedVector{UInt16},Dict{UInt16,SortedVector{UInt16}}}}
)
return EMPTY_HESSIAN_TRACER_SV_U16
end
function empty(
::Type{HessianTracer{UInt32,SortedVector{UInt32},Dict{UInt32,SortedVector{UInt32}}}}
)
return EMPTY_HESSIAN_TRACER_SV_U32
end
function empty(
::Type{HessianTracer{UInt64,SortedVector{UInt64},Dict{UInt64,SortedVector{UInt64}}}}
)
return EMPTY_HESSIAN_TRACER_SV_U64
end
Loading

0 comments on commit 0dd1bbd

Please sign in to comment.