diff --git a/src/adtypes_interface.jl b/src/adtypes_interface.jl index b2b685b..c9da8f3 100644 --- a/src/adtypes_interface.jl +++ b/src/adtypes_interface.jl @@ -1,9 +1,10 @@ #= This file implements the ADTypes interface for `AbstractSparsityDetector`s =# -const DEFAULT_GRADIENT_TRACER = GradientTracer{IndexSetGradientPattern{Int,BitSet}} -const DEFAULT_HESSIAN_TRACER = HessianTracer{ - DictHessianPattern{Int,BitSet,Dict{Int,BitSet},NotShared} -} +const DEFAULT_GRADIENT_PATTERN = IndexSetGradientPattern{Int,BitSet} +const DEFAULT_GRADIENT_TRACER = GradientTracer{DEFAULT_GRADIENT_PATTERN} + +const DEFAULT_HESSIAN_PATTERN = DictHessianPattern{Int,BitSet,Dict{Int,BitSet},NotShared} +const DEFAULT_HESSIAN_TRACER = HessianTracer{DEFAULT_HESSIAN_PATTERN} """ TracerSparsityDetector <: ADTypes.AbstractSparsityDetector @@ -18,17 +19,18 @@ For local sparsity patterns at a specific input point, use [`TracerLocalSparsity ```jldoctest julia> using SparseConnectivityTracer -julia> jacobian_sparsity(diff, rand(4), TracerSparsityDetector()) +julia> detector = TracerSparsityDetector() +TracerSparsityDetector() + +julia> jacobian_sparsity(diff, rand(4), detector) 3×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 6 stored entries: 1 1 ⋅ ⋅ ⋅ 1 1 ⋅ ⋅ ⋅ 1 1 -``` -```jldoctest julia> f(x) = x[1] + x[2]*x[3] + 1/x[4]; -julia> hessian_sparsity(f, rand(4), TracerSparsityDetector()) +julia> hessian_sparsity(f, rand(4), detector) 4×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ @@ -77,23 +79,24 @@ Local sparsity patterns are less convervative than global patterns and need to b ```jldoctest julia> using SparseConnectivityTracer -julia> method = TracerLocalSparsityDetector(); +julia> detector = TracerLocalSparsityDetector() +TracerLocalSparsityDetector() julia> f(x) = x[1] * x[2]; # J_f = [x[2], x[1]] -julia> jacobian_sparsity(f, [1, 0], method) +julia> jacobian_sparsity(f, [1, 0], detector) 1×2 SparseArrays.SparseMatrixCSC{Bool, Int64} with 1 stored entry: ⋅ 1 -julia> jacobian_sparsity(f, [0, 1], method) +julia> jacobian_sparsity(f, [0, 1], detector) 1×2 SparseArrays.SparseMatrixCSC{Bool, Int64} with 1 stored entry: 1 ⋅ -julia> jacobian_sparsity(f, [0, 0], method) +julia> jacobian_sparsity(f, [0, 0], detector) 1×2 SparseArrays.SparseMatrixCSC{Bool, Int64} with 0 stored entries: ⋅ ⋅ -julia> jacobian_sparsity(f, [1, 1], method) +julia> jacobian_sparsity(f, [1, 1], detector) 1×2 SparseArrays.SparseMatrixCSC{Bool, Int64} with 2 stored entries: 1 1 ``` @@ -155,3 +158,15 @@ end function ADTypes.hessian_sparsity(f, x, ::TracerLocalSparsityDetector{TG,TH}) where {TG,TH} return _local_hessian_sparsity(f, x, TH) end + +## Pretty printing +for detector in (:TracerSparsityDetector, :TracerLocalSparsityDetector) + @eval function Base.show(io::IO, d::$detector{TG,TH}) where {TG,TH} + if TG == DEFAULT_GRADIENT_TRACER && TH == DEFAULT_HESSIAN_TRACER + println(io, $detector, "()") + else + println(io, $detector, "{", TG, ",", TH, "}()") + end + return nothing + end +end diff --git a/src/patterns.jl b/src/patterns.jl index c8d4a70..3222685 100644 --- a/src/patterns.jl +++ b/src/patterns.jl @@ -176,8 +176,6 @@ struct IndexSetGradientPattern{I<:Integer,S<:AbstractSet{I}} <: AbstractGradient gradient::S end -Base.show(io::IO, p::IndexSetGradientPattern) = Base.show(io, gradient(p)) - function myempty(::Type{IndexSetGradientPattern{I,S}}) where {I,S} return IndexSetGradientPattern{I,S}(myempty(S)) end diff --git a/src/tracers.jl b/src/tracers.jl index 4bf041b..5e33d28 100644 --- a/src/tracers.jl +++ b/src/tracers.jl @@ -45,17 +45,6 @@ isemptytracer(t::GradientTracer) = t.isempty pattern(t::GradientTracer) = t.pattern gradient(t::GradientTracer) = gradient(pattern(t)) -function Base.show(io::IO, t::GradientTracer) - print(io, typeof(t)) - if isemptytracer(t) - print(io, "()") - else - printsorted(io, gradient(t)) - end - println(io) - return nothing -end - #===============# # HessianTracer # #===============# @@ -88,20 +77,6 @@ pattern(t::HessianTracer) = t.pattern gradient(t::HessianTracer) = gradient(pattern(t)) hessian(t::HessianTracer) = hessian(pattern(t)) -function Base.show(io::IO, t::HessianTracer) - print(io, typeof(t)) - if isemptytracer(t) - print(io, "()") - else - print(io, "(\n", " Gradient:") - printlnsorted(io, gradient(t)) - print(io, " Hessian: ") - printlnsorted(io, hessian(t)) - println(io, ")") - end - return nothing -end - #================================# # Dual numbers for local tracing # #================================# @@ -178,12 +153,3 @@ name(::Type{T}) where {T<:HessianTracer} = "HessianTracer" name(::Type{D}) where {P,T,D<:Dual{P,T}} = "Dual-$(name(T))" name(::T) where {T<:AbstractTracer} = name(T) name(::D) where {D<:Dual} = name(D) - -# Utilities for printing sets -printsorted(io::IO, x) = Base.show_delim_array(io, sort(x), "(", ',', ')', true) -printsorted(io::IO, s::AbstractSet) = printsorted(io, collect(s)) -function printlnsorted(io::IO, x) - printsorted(io, x) - println(io) - return nothing -end