Skip to content

Commit

Permalink
Shorter printing of default detectors (#190)
Browse files Browse the repository at this point in the history
* Shorter printing of default detectors

* Add short print to doctests

* Remove unused `Base.show` methods on internals
  • Loading branch information
adrhill authored Sep 5, 2024
1 parent 39fbdbd commit 87f026a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 49 deletions.
41 changes: 28 additions & 13 deletions src/adtypes_interface.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#= This file implements the ADTypes interface for `AbstractSparsityDetector`s =#

const DEFAULT_GRADIENT_TRACER = GradientTracer{IndexSetGradientPattern{Int,BitSet}}
const DEFAULT_HESSIAN_TRACER = HessianTracer{
DictHessianPattern{Int,BitSet,Dict{Int,BitSet},NotShared}
}
const DEFAULT_GRADIENT_PATTERN = IndexSetGradientPattern{Int,BitSet}
const DEFAULT_GRADIENT_TRACER = GradientTracer{DEFAULT_GRADIENT_PATTERN}

const DEFAULT_HESSIAN_PATTERN = DictHessianPattern{Int,BitSet,Dict{Int,BitSet},NotShared}
const DEFAULT_HESSIAN_TRACER = HessianTracer{DEFAULT_HESSIAN_PATTERN}

"""
TracerSparsityDetector <: ADTypes.AbstractSparsityDetector
Expand All @@ -18,17 +19,18 @@ For local sparsity patterns at a specific input point, use [`TracerLocalSparsity
```jldoctest
julia> using SparseConnectivityTracer
julia> jacobian_sparsity(diff, rand(4), TracerSparsityDetector())
julia> detector = TracerSparsityDetector()
TracerSparsityDetector()
julia> jacobian_sparsity(diff, rand(4), detector)
3×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 6 stored entries:
1 1 ⋅ ⋅
⋅ 1 1 ⋅
⋅ ⋅ 1 1
```
```jldoctest
julia> f(x) = x[1] + x[2]*x[3] + 1/x[4];
julia> hessian_sparsity(f, rand(4), TracerSparsityDetector())
julia> hessian_sparsity(f, rand(4), detector)
4×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries:
⋅ ⋅ ⋅ ⋅
⋅ ⋅ 1 ⋅
Expand Down Expand Up @@ -77,23 +79,24 @@ Local sparsity patterns are less convervative than global patterns and need to b
```jldoctest
julia> using SparseConnectivityTracer
julia> method = TracerLocalSparsityDetector();
julia> detector = TracerLocalSparsityDetector()
TracerLocalSparsityDetector()
julia> f(x) = x[1] * x[2]; # J_f = [x[2], x[1]]
julia> jacobian_sparsity(f, [1, 0], method)
julia> jacobian_sparsity(f, [1, 0], detector)
1×2 SparseArrays.SparseMatrixCSC{Bool, Int64} with 1 stored entry:
⋅ 1
julia> jacobian_sparsity(f, [0, 1], method)
julia> jacobian_sparsity(f, [0, 1], detector)
1×2 SparseArrays.SparseMatrixCSC{Bool, Int64} with 1 stored entry:
1 ⋅
julia> jacobian_sparsity(f, [0, 0], method)
julia> jacobian_sparsity(f, [0, 0], detector)
1×2 SparseArrays.SparseMatrixCSC{Bool, Int64} with 0 stored entries:
⋅ ⋅
julia> jacobian_sparsity(f, [1, 1], method)
julia> jacobian_sparsity(f, [1, 1], detector)
1×2 SparseArrays.SparseMatrixCSC{Bool, Int64} with 2 stored entries:
1 1
```
Expand Down Expand Up @@ -155,3 +158,15 @@ end
function ADTypes.hessian_sparsity(f, x, ::TracerLocalSparsityDetector{TG,TH}) where {TG,TH}
return _local_hessian_sparsity(f, x, TH)
end

## Pretty printing
for detector in (:TracerSparsityDetector, :TracerLocalSparsityDetector)
@eval function Base.show(io::IO, d::$detector{TG,TH}) where {TG,TH}
if TG == DEFAULT_GRADIENT_TRACER && TH == DEFAULT_HESSIAN_TRACER
println(io, $detector, "()")
else
println(io, $detector, "{", TG, ",", TH, "}()")
end
return nothing
end
end
2 changes: 0 additions & 2 deletions src/patterns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,6 @@ struct IndexSetGradientPattern{I<:Integer,S<:AbstractSet{I}} <: AbstractGradient
gradient::S
end

Base.show(io::IO, p::IndexSetGradientPattern) = Base.show(io, gradient(p))

function myempty(::Type{IndexSetGradientPattern{I,S}}) where {I,S}
return IndexSetGradientPattern{I,S}(myempty(S))
end
Expand Down
34 changes: 0 additions & 34 deletions src/tracers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,6 @@ isemptytracer(t::GradientTracer) = t.isempty
pattern(t::GradientTracer) = t.pattern
gradient(t::GradientTracer) = gradient(pattern(t))

function Base.show(io::IO, t::GradientTracer)
print(io, typeof(t))
if isemptytracer(t)
print(io, "()")
else
printsorted(io, gradient(t))
end
println(io)
return nothing
end

#===============#
# HessianTracer #
#===============#
Expand Down Expand Up @@ -88,20 +77,6 @@ pattern(t::HessianTracer) = t.pattern
gradient(t::HessianTracer) = gradient(pattern(t))
hessian(t::HessianTracer) = hessian(pattern(t))

function Base.show(io::IO, t::HessianTracer)
print(io, typeof(t))
if isemptytracer(t)
print(io, "()")
else
print(io, "(\n", " Gradient:")
printlnsorted(io, gradient(t))
print(io, " Hessian: ")
printlnsorted(io, hessian(t))
println(io, ")")
end
return nothing
end

#================================#
# Dual numbers for local tracing #
#================================#
Expand Down Expand Up @@ -178,12 +153,3 @@ name(::Type{T}) where {T<:HessianTracer} = "HessianTracer"
name(::Type{D}) where {P,T,D<:Dual{P,T}} = "Dual-$(name(T))"
name(::T) where {T<:AbstractTracer} = name(T)
name(::D) where {D<:Dual} = name(D)

# Utilities for printing sets
printsorted(io::IO, x) = Base.show_delim_array(io, sort(x), "(", ',', ')', true)
printsorted(io::IO, s::AbstractSet) = printsorted(io, collect(s))
function printlnsorted(io::IO, x)
printsorted(io, x)
println(io)
return nothing
end

0 comments on commit 87f026a

Please sign in to comment.