Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shared Hessian tracer à la Walther #135

Merged
merged 21 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions benchmark/bench_jogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ suite["OptimizationProblems"] = optbench([:britgas])
for S1 in SET_TYPES
S2 = Set{Tuple{Int,Int}}

# Non-shared tracers
shared = false
PG = IndexSetGradientPattern{Int,S1}
PH = IndexSetHessianPattern{Int,S1,S2}

PH = IndexSetHessianPattern{Int,S1,S2,shared}
G = GradientTracer{PG}
H = HessianTracer{PH}

Expand All @@ -34,4 +35,15 @@ for S1 in SET_TYPES
suite["Hessian"]["Local"][(nameof(S1), nameof(S2))] = hessbench(
TracerLocalSparsityDetector(G, H)
)

# Shared tracers
shared = true
PG = IndexSetGradientPattern{Int,S1}
PH = IndexSetHessianPattern{Int,S1,S2,shared}
G = GradientTracer{PG}
H = HessianTracer{PH}

suite["Hessian"]["Global (shared)"][(nameof(S1), nameof(S2))] = hessbench(
TracerSparsityDetector(G, H)
)
end
17 changes: 8 additions & 9 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
const DEFAULT_GRADIENT_TRACER = GradientTracer{IndexSetGradientPattern{Int,BitSet}}
const DEFAULT_HESSIAN_TRACER = HessianTracer{
IndexSetHessianPattern{Int,BitSet,Set{Tuple{Int,Int}}}
IndexSetHessianPattern{Int,BitSet,Set{Tuple{Int,Int}},false}
}

#==================#
Expand All @@ -9,20 +9,19 @@ const DEFAULT_HESSIAN_TRACER = HessianTracer{

"""
trace_input(T, x)
trace_input(T, x)

trace_input(T, xs)

Enumerates input indices and constructs the specified type `T` of tracer.
Supports [`GradientTracer`](@ref), [`HessianTracer`](@ref) and [`Dual`](@ref).
"""
trace_input(::Type{T}, x) where {T<:Union{AbstractTracer,Dual}} = trace_input(T, x, 1)
trace_input(::Type{T}, xs) where {T<:Union{AbstractTracer,Dual}} = trace_input(T, xs, 1)

function trace_input(::Type{T}, x::Real, i::Integer) where {T<:Union{AbstractTracer,Dual}}
return create_tracer(T, x, i)
end
function trace_input(::Type{T}, xs::AbstractArray, i) where {T<:Union{AbstractTracer,Dual}}
indices = reshape(1:length(xs), size(xs)) .+ (i - 1)
return create_tracer.(T, xs, indices)
is = reshape(1:length(xs), size(xs)) .+ (i - 1)
return create_tracers(T, xs, is)
end
function trace_input(::Type{T}, x::Real, i::Integer) where {T<:Union{AbstractTracer,Dual}}
return only(create_tracers(T, [x], [i]))
end

#=========================#
Expand Down
49 changes: 46 additions & 3 deletions src/overloads/hessian_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ end

function hessian_tracer_1_to_1_inner(
p::P, is_der1_zero::Bool, is_der2_zero::Bool
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH}}
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,false}}
sg = gradient(p)
sh = hessian(p)
sg_out = gradient_tracer_1_to_1_inner(sg, is_der1_zero)
Expand All @@ -22,13 +22,32 @@ function hessian_tracer_1_to_1_inner(
elseif !is_der1_zero && is_der2_zero
sh
elseif is_der1_zero && !is_der2_zero
# TODO: this branch of the code currently isn't tested.
# Covering it would require a scalar 1-to-1 function with local overloads,
# such that ∂f/∂x == 0 and ∂²f/∂x² != 0.
union_product!(myempty(SH), sg, sg)
else
else # !is_der1_zero && !is_der2_zero
union_product!(copy(sh), sg, sg)
end
return P(sg_out, sh_out) # return pattern
end

# NOTE: mutates argument p and should arguably be called `hessian_tracer_1_to_1_inner!`
function hessian_tracer_1_to_1_inner(
p::P, is_der1_zero::Bool, is_der2_zero::Bool
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,true}}
sg = gradient(p)
sh = hessian(p)
sg_out = gradient_tracer_1_to_1_inner(sg, is_der1_zero)
# shared Hessian patterns can't remove second-order information, only add to it.
sh_out = if is_der2_zero
sh
else
union_product!(sh, sg, sg)
end
return P(sg_out, sh_out) # return pattern
end

function overload_hessian_1_to_1(M, op)
SCT = SparseConnectivityTracer
return quote
Expand Down Expand Up @@ -96,7 +115,7 @@ function hessian_tracer_2_to_1_inner(
is_der1_arg2_zero::Bool,
is_der2_arg2_zero::Bool,
is_der_cross_zero::Bool,
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH}}
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,false}}
sgx, shx = gradient(px), hessian(px)
sgy, shy = gradient(py), hessian(py)
sg_out = gradient_tracer_2_to_1_inner(sgx, sgy, is_der1_arg1_zero, is_der1_arg2_zero)
Expand All @@ -110,6 +129,30 @@ function hessian_tracer_2_to_1_inner(
return P(sg_out, sh_out) # return pattern
end

# NOTE: mutates arguments px and py and should arguably be called `hessian_tracer_1_to_1_inner!`
function hessian_tracer_2_to_1_inner(
px::P,
py::P,
is_der1_arg1_zero::Bool,
is_der2_arg1_zero::Bool,
is_der1_arg2_zero::Bool,
is_der2_arg2_zero::Bool,
is_der_cross_zero::Bool,
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,true}}
sgx, shx = gradient(px), hessian(px)
sgy, shy = gradient(py), hessian(py)

shx !== shy && error("Expected shared Hessians, got $shx, $shy.")
sh_out = shx # union of shx and shy can be skipped since they are the same object
sg_out = gradient_tracer_2_to_1_inner(sgx, sgy, is_der1_arg1_zero, is_der1_arg2_zero)

!is_der2_arg1_zero && union_product!(sh_out, sgx, sgx) # product alpha
!is_der2_arg2_zero && union_product!(sh_out, sgy, sgy) # product beta
!is_der_cross_zero && union_product!(sh_out, sgx, sgy) # cross product 1
!is_der_cross_zero && union_product!(sh_out, sgy, sgx) # cross product 2
return P(sg_out, sh_out) # return pattern
end

function overload_hessian_2_to_1(M, op)
SCT = SparseConnectivityTracer
return quote
Expand Down
11 changes: 10 additions & 1 deletion src/overloads/ifelse_global.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,20 @@
function output_union(px::P, py::P) where {P<:IndexSetGradientPattern}
return P(union(set(px), set(py))) # return pattern
end
function output_union(px::P, py::P) where {P<:IndexSetHessianPattern}
function output_union(
px::P, py::P
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,false}} # non-mutating
g_out = union(gradient(px), gradient(py))
h_out = union(hessian(px), hessian(py))
return P(g_out, h_out) # return pattern
end
function output_union(
px::P, py::P
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,true}} # mutating
g_out = union(gradient(px), gradient(py))
h_out = union!(hessian(px), hessian(py))
return P(g_out, h_out) # return pattern
end

output_union(tx::AbstractTracer, y) = tx
output_union(x, ty::AbstractTracer) = ty
Expand Down
83 changes: 56 additions & 27 deletions src/patterns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@ AbstractPattern
"""
abstract type AbstractPattern end

"""
isshared(pattern)

Indicates whether patterns **always** share memory and whether operators are **allowed** to mutate their `AbstractTracer` arguments.

If `false`, patterns **can** share memory and operators are **prohibited** from mutating `AbstractTracer` arguments.

## Note
In practice, memory sharing is limited to second-order information in `AbstractHessianPattern`.
"""
isshared(::P) where {P<:AbstractPattern} = isshared(P)
isshared(::Type{P}) where {P<:AbstractPattern} = false

"""
myempty(T)
myempty(tracer)
Expand All @@ -25,13 +38,11 @@ Constructor for an empty tracer or pattern of type `T` representing a new number
myempty

"""
seed(T, i)
seed(tracer, i)
seed(pattern, i)
create_patterns(P, xs, is)

Constructor for a tracer or pattern of type `T` that only contains the given index `i`.
Convenience constructor for patterns of type `P` for multiple inputs `xs` and their indices `is`.
"""
seed
create_patterns

#==========================#
# Utilities on AbstractSet #
Expand All @@ -49,8 +60,8 @@ product(a::AbstractSet{I}, b::AbstractSet{I}) where {I<:Integer} =
Set((i, j) for i in a, j in b)

function union_product!(
hessian::SH, gradient_x::SG, gradient_y::SG
) where {I<:Integer,SG<:AbstractSet{I},SH<:AbstractSet{Tuple{I,I}}}
hessian::H, gradient_x::G, gradient_y::G
) where {I<:Integer,G<:AbstractSet{I},H<:AbstractSet{Tuple{I,I}}}
hxy = product(gradient_x, gradient_y)
return union!(hessian, hxy)
end
Expand All @@ -69,18 +80,17 @@ For use with [`GradientTracer`](@ref).

## Expected interface

* `myempty(::Type{MyPattern})`: return a pattern representing a new number (usually an empty pattern)
* `seed(::Type{MyPattern}, i::Integer)`: return an pattern that only contains the given index `i`
* `gradient(p::MyPattern)`: return non-zero indices `i` for use with `GradientTracer`

Note that besides their names, the last two functions are usually identical.
* [`myempty`](@ref)
* [`create_patterns`](@ref)
* `gradient(p::MyPattern)`: return non-zero indices `i` in the gradient representation
* [`isshared`](@ref) in case the pattern is shared (mutates). Defaults to false.
"""
abstract type AbstractGradientPattern <: AbstractPattern end

"""
$(TYPEDEF)

Vector sparsity pattern represented by an `AbstractSet` of indices ``{i}`` of non-zero values.
Gradient sparsity pattern represented by an `AbstractSet` of indices ``{i}`` of non-zero values.

## Fields
$(TYPEDFIELDS)
Expand All @@ -97,8 +107,9 @@ Base.show(io::IO, p::IndexSetGradientPattern) = Base.show(io, set(p))
function myempty(::Type{IndexSetGradientPattern{I,S}}) where {I,S}
return IndexSetGradientPattern{I,S}(myempty(S))
end
function seed(::Type{IndexSetGradientPattern{I,S}}, i) where {I,S}
return IndexSetGradientPattern{I,S}(seed(S, i))
function create_patterns(::Type{P}, xs, is) where {I,S,P<:IndexSetGradientPattern{I,S}}
sets = seed.(S, is)
return P.(sets)
end

# Tracer compatibility
Expand All @@ -118,29 +129,47 @@ For use with [`HessianTracer`](@ref).

## Expected interface

* `myempty(::Type{MyPattern})`: return a pattern representing a new number (usually an empty pattern)
* `seed(::Type{MyPattern}, i::Integer)`: return an pattern that only contains the given index `i` in the first-order representation
* [`myempty`](@ref)
* [`create_patterns`](@ref)
* `gradient(p::MyPattern)`: return non-zero indices `i` in the first-order representation
* `hessian(p::MyPattern)`: return non-zero indices `(i, j)` in the second-order representation
* [`isshared`](@ref) in case the pattern is shared (mutates). Defaults to false.
"""
abstract type AbstractHessianPattern <: AbstractPattern end

"""
IndexSetHessianPattern(vector::AbstractGradientPattern, mat::AbstractMatrixPattern)
$(TYPEDEF)

Hessian sparsity pattern represented by:
* an `AbstractSet` of indices ``i`` of non-zero values representing first-order sparsity
* an `AbstractSet` of index tuples ``(i,j)`` of non-zero values representing second-order sparsity

## Fields
$(TYPEDFIELDS)

## Internals

Gradient and Hessian sparsity patterns constructed by combining two AbstractSets.
The last type parameter `shared` is a `Bool` indicating whether the `hessian` field of this object should be shared among all intermediate scalar quantities involved in a function.
"""
struct IndexSetHessianPattern{I<:Integer,SG<:AbstractSet{I},SH<:AbstractSet{Tuple{I,I}}} <:
AbstractHessianPattern
gradient::SG
hessian::SH
struct IndexSetHessianPattern{
I<:Integer,G<:AbstractSet{I},H<:AbstractSet{Tuple{I,I}},shared
} <: AbstractHessianPattern
gradient::G
hessian::H
end
isshared(::Type{IndexSetHessianPattern{I,G,H,true}}) where {I,G,H} = true
adrhill marked this conversation as resolved.
Show resolved Hide resolved

function myempty(::Type{IndexSetHessianPattern{I,SG,SH}}) where {I,SG,SH}
return IndexSetHessianPattern{I,SG,SH}(myempty(SG), myempty(SH))
function myempty(::Type{P}) where {I,G,H,S,P<:IndexSetHessianPattern{I,G,H,S}}
return P(myempty(G), myempty(H))
end
function seed(::Type{IndexSetHessianPattern{I,SG,SH}}, index) where {I,SG,SH}
return IndexSetHessianPattern{I,SG,SH}(seed(SG, index), myempty(SH))
function create_patterns(
::Type{P}, xs, is
) where {I,G,H,S,P<:IndexSetHessianPattern{I,G,H,S}}
gradients = seed.(G, is)
hessian = myempty(H)
# Even if `shared=false`, sharing a single reference to `hessian` is allowed upon initialization,
# since mutation is prohibited when `isshared` is false.
return P.(gradients, Ref(hessian))
end

# Tracer compatibility
Expand Down
31 changes: 17 additions & 14 deletions src/tracers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,29 +131,32 @@ end
# Utilities #
#===========#

myempty(::T) where {T<:AbstractTracer} = myempty(T)
# isshared(::Type{T}) where {P,T<:GradientTracer{P}} = isshared(P) # no shared AbstractGradientPattern yet
isshared(::Type{T}) where {P,T<:HessianTracer{P}} = isshared(P)

# myempty(::Type{T}) where {P,T<:AbstractTracer{P}} = T(myempty(P), true) # JET complains about this
myempty(::T) where {T<:AbstractTracer} = myempty(T)
# myempty(::Type{T}) where {P,T<:AbstractTracer{P}} = T(myempty(P), true) # JET complains about this
myempty(::Type{T}) where {P,T<:GradientTracer{P}} = T(myempty(P), true)
myempty(::Type{T}) where {P,T<:HessianTracer{P}} = T(myempty(P), true)

seed(::T, i) where {T<:AbstractTracer} = seed(T, i)

# seed(::Type{T}, i) where {P,T<:AbstractTracer{P}} = T(seed(P, i)) # JET complains about this
seed(::Type{T}, i) where {P,T<:GradientTracer{P}} = T(seed(P, i))
seed(::Type{T}, i) where {P,T<:HessianTracer{P}} = T(seed(P, i))

"""
create_tracer(T, index) where {T<:AbstractTracer}
create_tracers(T, xs, indices)

Convenience constructor for [`GradientTracer`](@ref) and [`HessianTracer`](@ref) from input indices.
Convenience constructor for [`GradientTracer`](@ref), [`HessianTracer`](@ref) and [`Dual`](@ref)
from multiple inputs `xs` and their indices `is`.
"""
function create_tracer(::Type{T}, ::Real, index::Integer) where {P,T<:AbstractTracer{P}}
return T(seed(P, index))
function create_tracers(
::Type{T}, xs::AbstractArray{<:Real,N}, indices::AbstractArray{<:Integer,N}
) where {P<:AbstractPattern,T<:AbstractTracer{P},N}
patterns = create_patterns(P, xs, indices)
return T.(patterns)
end

function create_tracer(::Type{Dual{P,T}}, primal::Real, index::Integer) where {P,T}
return Dual(primal, create_tracer(T, primal, index))
function create_tracers(
::Type{D}, xs::AbstractArray{<:Real,N}, indices::AbstractArray{<:Integer,N}
) where {P,T,D<:Dual{P,T},N}
tracers = create_tracers(T, xs, indices)
return D.(xs, tracers)
end

# Pretty-printing of Dual tracers
Expand Down
2 changes: 1 addition & 1 deletion test/brusselator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector
using SparseConnectivityTracerBenchmarks.ODE: Brusselator!
using Test

# Load definitions of GRADIENT_TRACERS and HESSIAN_TRACERS
# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("tracers_definitions.jl")

function test_brusselator(method::AbstractSparsityDetector)
Expand Down
2 changes: 1 addition & 1 deletion test/flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using SparseConnectivityTracer
using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector
using Test

# Load definitions of GRADIENT_TRACERS and HESSIAN_TRACERS
# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("tracers_definitions.jl")

const INPUT_FLUX = reshape(
Expand Down
2 changes: 1 addition & 1 deletion test/test_constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using SparseConnectivityTracer: primal, tracer, isemptytracer
using SparseConnectivityTracer: myempty, name
using Test

# Load definitions of GRADIENT_TRACERS and HESSIAN_TRACERS
# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("tracers_definitions.jl")

function test_nested_duals(::Type{T}) where {T<:AbstractTracer}
Expand Down
Loading
Loading