Skip to content

Commit

Permalink
Get local GradientTracer to work
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed May 15, 2024
1 parent 2c19273 commit cd5c3ab
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ include("pattern.jl")
include("adtypes.jl")

export connectivity_pattern
export jacobian_pattern
export hessian_pattern
export jacobian_pattern, local_jacobian_pattern
export hessian_pattern, local_hessian_pattern

# ADTypes interface
export TracerSparsityDetector
Expand Down
1 change: 0 additions & 1 deletion src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,6 @@ end
# gradient of x*y: [y x]
SparseConnectivityTracer.is_firstder_arg1_zero_local(::typeof(Base.:*), x, y) = iszero(y)
SparseConnectivityTracer.is_firstder_arg2_zero_local(::typeof(Base.:*), x, y) = iszero(x)
SparseConnectivityTracer.is_crossder_zero_local(::typeof(Base.:*), x, y) = iszero(x) || iszero(y)

# ops_2_to_1_ffz:
# ∂f/∂x != 0
Expand Down
2 changes: 2 additions & 0 deletions src/overload_connectivity.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# TODO: support Duals

for fn in union(ops_1_to_1_s, ops_1_to_1_f, ops_1_to_1_z)
@eval Base.$fn(t::ConnectivityTracer) = t
end
Expand Down
14 changes: 7 additions & 7 deletions src/overload_gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ for fn in ops_1_to_1
if is_firstder_zero_local($fn, x)
return Dual(out, empty(T))
else
return Dual(out, t)
return Dual(out, tracer(t))
end
end
end
Expand Down Expand Up @@ -48,11 +48,11 @@ for fn in ops_2_to_1
if ∂f∂y0
return Dual(out, empty(T))
else # ∂f∂y ≠ 0
return Dual(out, ty)
return Dual(out, tracer(ty))
end
else # ∂f∂x ≠ 0
if ∂f∂y0
return Dual(out, tx)
return Dual(out, tracer(tx))
else # ∂f∂y ≠ 0
return Dual(out, T(gradient(tx) gradient(ty)))
end
Expand All @@ -72,7 +72,7 @@ for fn in ops_2_to_1
if is_firstder_arg1_zero_local($fn, x, y)
return Dual(out, empty(T))
else
return Dual(out, t)
return Dual(out, tracer(tx))
end
end

Expand All @@ -89,7 +89,7 @@ for fn in ops_2_to_1
if is_firstder_arg2_zero_local($fn, x, y)
return Dual(out, empty(T))
else
return Dual(out, t)
return Dual(out, tracer(ty))
end
end
end
Expand Down Expand Up @@ -117,12 +117,12 @@ for fn in ops_1_to_2
tracer1 = if is_firstder_out1_zero_global($fn)
empty(T)
else
t
tracer(tx)
end
tracer2 = if is_firstder_out2_zero_global($fn)
empty(T)
else
t
tracer(tx)
end
return (Dual(out1, tracer1), Dual(out2, tracer2))
end
Expand Down
8 changes: 4 additions & 4 deletions src/overload_hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ for fn in ops_2_to_1
if is_firstder_arg1_zero_local($fn, x, y)
return Dual(out, empty(T))
else
return Dual(out, t)
return Dual(out, tracer(t))
end
else
return Dual(out, T(gradient(t), hessian(t) (gradient(t) × gradient(t))))
Expand All @@ -124,7 +124,7 @@ for fn in ops_2_to_1
if is_firstder_arg2_zero_local($fn, x, y)
return Dual(out, empty(T))
else
return Dual(out, t)
return Dual(out, tracer(t))
end
else
return Dual(out, T(gradient(t), hessian(t) (gradient(t) × gradient(t))))
Expand Down Expand Up @@ -164,7 +164,7 @@ for fn in ops_1_to_2
if is_firstder_out1_zero_local($fn, x)
return empty(T)
else
return t
return tracer(t)
end
else
return T(gradient(t), hessian(t) (gradient(t) × gradient(t)))
Expand All @@ -173,7 +173,7 @@ for fn in ops_1_to_2
if is_firstder_out2_zero_local($fn, x)
return empty(T)
else
return t
return tracer(t)
end
else
return T(gradient(t), hessian(t) (gradient(t) × gradient(t)))
Expand Down
55 changes: 50 additions & 5 deletions src/pattern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ Enumerates input indices and constructs the specified type `T` of tracer.
Supports [`ConnectivityTracer`](@ref), [`GradientTracer`](@ref) and [`HessianTracer`](@ref).
"""
trace_input(::Type{T}, x) where {T<:AbstractTracer} = trace_input(T, x, 1)
trace_input(::Type{T}, ::Number, i) where {T<:AbstractTracer} = tracer(T, i)
function trace_input(::Type{T}, x::AbstractArray, i) where {T<:AbstractTracer}
indices = (i - 1) .+ reshape(1:length(x), size(x))
return tracer.(T, indices)

function trace_input(::Type{T}, x::Number, i::Integer) where {T<:AbstractTracer}
return create_tracer(T, x, i)
end
function trace_input(::Type{T}, xs::AbstractArray, i) where {T<:AbstractTracer}
indices = reshape(1:length(xs), size(xs)) .+ (i - 1)
return create_tracer.(T, xs, indices)
end

## Trace function
Expand Down Expand Up @@ -123,6 +126,34 @@ function jacobian_pattern(f, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {G}
return jacobian_pattern_to_mat(to_array(xt), to_array(yt))
end

"""
local_jacobian_pattern(f, x)
local_jacobian_pattern(f, x, T)
Compute the local sparsity pattern of the Jacobian of `y = f(x)` at `x`.
The type of index set `S` can be specified as an optional argument and defaults to `BitSet`.
## Example
```jldoctest

Check failure on line 139 in src/pattern.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/pattern.jl:139-149 ```jldoctest julia> x = rand(3); julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sign(x[3])]; julia> local_jacobian_pattern(f, x) 3×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: 1 ⋅ ⋅ 1 1 ⋅ ⋅ ⋅ ⋅ ``` Subexpression: local_jacobian_pattern(f, x) Evaluated output: ERROR: MethodError: ^(::SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.GradientTracer{BitSet}}, ::Int64) is ambiguous. Candidates: ^(tx::D, y::Number) where {P, T<:SparseConnectivityTracer.GradientTracer, D<:SparseConnectivityTracer.Dual{P, T}} @ SparseConnectivityTracer ~/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overload_gradient.jl:69 ^(x::Number, p::Integer) @ Base intfuncs.jl:311 Possible fix, define ^(::D, ::Integer) where {P, T<:SparseConnectivityTracer.GradientTracer, D<:SparseConnectivityTracer.Dual{P, T}} Stacktrace: [1] literal_pow @ ./intfuncs.jl:351 [inlined] [2] f(x::Vector{SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.GradientTracer{BitSet}}}) @ Main ./none:1 [3] trace_function(::Type{SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.GradientTracer{BitSet}}}, f::typeof(f), x::Vector{Float64}) @ SparseConnectivityTracer ~/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/pattern.jl:26 [4] local_jacobian_pattern(f::Function, x::Vector{Float64}, ::Type{BitSet}) @ SparseConnectivityTracer ~/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/pattern.jl:153 [5] local_jacobian_pattern(f::Function, x::Vector{Float64}) @ SparseConnectivityTracer ~/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/pattern.jl:152 [6] top-level scope @ none:1 Expected output: 3×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: 1 ⋅ ⋅ 1 1 ⋅ ⋅ ⋅ ⋅ diff = Warning: Diff output requires color. 3×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: 1 ⋅ ⋅ 1 1 ⋅ ⋅ ⋅ ⋅ERROR: MethodError: ^(::SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.GradientTracer{BitSet}}, ::Int64) is ambiguous. Candidates: ^(tx::D, y::Number) where {P, T<:SparseConnectivityTracer.GradientTracer, D<:SparseConnectivityTracer.Dual{P, T}} @ SparseConnectivityTracer ~/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/overload_gradient.jl:69 ^(x::Number, p::Integer) @ Base intfuncs.jl:311 Possible fix, define ^(::D, ::Integer) where {P, T<:SparseConnectivityTracer.GradientTracer, D<:SparseConnectivityTracer.Dual{P, T}} Stacktrace: [1] literal_pow @ ./intfuncs.jl:351 [inlined] [2] f(x::Vector{SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.GradientTracer{BitSet}}}) @ Main ./none:1 [3] trace_function(::Type{SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.GradientTracer{BitSet}}}, f::typeof(f), x::Vector{Float64}) @ SparseConnectivityTracer ~/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/pattern.jl:26 [4] local_jacobian_pattern(f::Function, x::Vector{Float64}, ::Type{BitSet}) @ SparseConnectivityTracer ~/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/pattern.jl:153 [5] local_jacobian_pattern(f::Function, x::Vector{Float64}) @ SparseConnectivityTracer ~/work/SparseConnectivityTracer.jl/SparseConnectivityTracer.jl/src/pattern.jl:152 [6] top-level scope @ none:1
julia> x = rand(3);
julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sign(x[3])];
julia> local_jacobian_pattern(f, x)
3×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries:
1 ⋅ ⋅
1 1 ⋅
⋅ ⋅ ⋅
```
"""
function local_jacobian_pattern(f, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {G}
D = Dual{eltype(x),GradientTracer{G}}
xt, yt = trace_function(D, f, x)
return jacobian_pattern_to_mat(to_array(xt), to_array(yt))
end

"""
jacobian_pattern(f!, y, x)
jacobian_pattern(f!, y, x, T)
Expand All @@ -136,9 +167,23 @@ function jacobian_pattern(f!, y, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {G}
return jacobian_pattern_to_mat(to_array(xt), to_array(yt))
end

"""
local_jacobian_pattern(f!, y, x)
local_jacobian_pattern(f!, y, x, T)
Compute the local sparsity pattern of the Jacobian of `f!(y, x)` at `x`.
The type of index set `S` can be specified as an optional argument and defaults to `BitSet`.
"""
function local_jacobian_pattern(f!, y, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {G}
D = Dual{eltype(x),GradientTracer{G}}
xt, yt = trace_function(D, f!, y, x)
return jacobian_pattern_to_mat(to_array(xt), to_array(yt))
end

function jacobian_pattern_to_mat(
xt::AbstractArray{T}, yt::AbstractArray{<:Number}
) where {T<:GradientTracer}
) where {P,G<:GradientTracer,T<:Union{G,Dual{P,G}}}
n, m = length(xt), length(yt)
I = Int[] # row indices
J = Int[] # column indices
Expand Down
15 changes: 10 additions & 5 deletions src/tracers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,11 @@ Dual number type keeping track of the results of a primal computation as well as
## Fields
$(TYPEDFIELDS)
"""
struct Dual{P<:Number,T<:AbstractTracer} <: AbstractTracer
struct Dual{P<:Number,T<:Union{GradientTracer,HessianTracer}} <: AbstractTracer
primal::P
tracer::T
end
# TODO: support ConnectivityTracer

primal(d::Dual) = d.primal
tracer(d::Dual) = d.tracer
Expand All @@ -214,16 +215,20 @@ hessian(d::Dual{P,T}) where {P,T<:HessianTracer} = hessian(d.tracer)
#===========#

"""
tracer(T, index) where {T<:AbstractTracer}
create_tracer(T, index) where {T<:AbstractTracer}
Convenience constructor for [`ConnectivityTracer`](@ref), [`GradientTracer`](@ref) and [`HessianTracer`](@ref) from input indices.
"""
function tracer(::Type{GradientTracer{G}}, index::Integer) where {G}
function create_tracer(::Type{Dual{P,T}}, primal::Number, index::Integer) where {P,T}
return Dual(primal, create_tracer(T, primal, index))
end

function create_tracer(::Type{GradientTracer{G}}, ::Number, index::Integer) where {G}
return GradientTracer{G}(sparse_vector(G, index))
end
function tracer(::Type{ConnectivityTracer{C}}, index::Integer) where {C}
function create_tracer(::Type{ConnectivityTracer{C}}, ::Number, index::Integer) where {C}
return ConnectivityTracer{C}(sparse_vector(C, index))
end
function tracer(::Type{HessianTracer{G,H}}, index::Integer) where {G,H}
function create_tracer(::Type{HessianTracer{G,H}}, ::Number, index::Integer) where {G,H}
return HessianTracer{G,H}(sparse_vector(G, index), empty(H))
end

0 comments on commit cd5c3ab

Please sign in to comment.