Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Flux Conv layers with ReLU activation functions #70

Merged
merged 6 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ julia> x = rand(3);
julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])];

julia> jacobian_pattern(f, x)
3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries:
3×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 4 stored entries:
1 ⋅ ⋅
1 1 ⋅
⋅ ⋅ 1
```

As a larger example, let's compute the sparsity pattern from a convolutional layer from [Flux.jl](https://github.com/FluxML/Flux.jl):

```julia-repl
julia> using SparseConnectivityTracer, Flux

Expand All @@ -47,7 +48,7 @@ julia> x = rand(28, 28, 3, 1);
julia> layer = Conv((3, 3), 3 => 2);

julia> jacobian_pattern(layer, x)
1352×2352 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 36504 stored entries:
1352×2352 SparseArrays.SparseMatrixCSC{Bool, Int64} with 36504 stored entries:
⎡⠙⢿⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠻⣷⣤⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣷⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎤
⎢⠀⠀⠙⢿⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠙⢿⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣷⣤⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥
⎢⠀⠀⠀⠀⠙⢿⣦⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠙⢿⣦⡀⠀⠀⠀⠀⠀⠀⠀⎥
Expand All @@ -64,7 +65,7 @@ julia> jacobian_pattern(layer, x)
```

The type of index set `S` that is internally used to keep track of connectivity can be specified via `jacobian_pattern(f, x, S)`, defaulting to `BitSet`.
For high-dimensional functions, `Set{UInt64}` can be more efficient .
For high-dimensional functions, `Set{Int64}` can be more efficient .

### Hessian

Expand All @@ -77,7 +78,7 @@ julia> x = rand(5);
julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + 1*x[5];

julia> hessian_pattern(f, x)
5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 3 stored entries:
5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ 1 ⋅ ⋅
⋅ 1 ⋅ ⋅ ⋅
Expand All @@ -87,7 +88,7 @@ julia> hessian_pattern(f, x)
julia> g(x) = f(x) + x[2]^x[5];

julia> hessian_pattern(g, x)
5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 7 stored entries:
5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 7 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ 1 1 ⋅ 1
⋅ 1 ⋅ ⋅ ⋅
Expand All @@ -97,6 +98,32 @@ julia> hessian_pattern(g, x)

For more detailled examples, take a look at the [documentation](https://adrianhill.de/SparseConnectivityTracer.jl/dev).

### Global function tracing

The functions `jacobian_pattern`, `hessian_pattern` and `connectivity_pattern` return conservative sparsity patterns over the entire input domain of `x`.
They are not compatible with functions that require information about the primal values of a computation (e.g. `iszero`, `>`, `==`).

To compute a less conservative sparsity pattern at an input point `x`, use `local_jacobian_pattern`, `local_hessian_pattern` and `local_connectivity_pattern` instead.
Note that these patterns depend on the input `x`:

```julia-repl
julia> f(x) = ifelse(x[2] < x[3], x[1] ^ x[2], x[3] * x[4]);

julia> local_hessian_pattern(f, [1 2 3 4])
4×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 4 stored entries:
1 1 ⋅ ⋅
1 1 ⋅ ⋅
⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅

julia> local_hessian_pattern(f, [1 3 2 4])
4×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 2 stored entries:
⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ 1
⋅ ⋅ 1 ⋅
```

## Related packages
* [SparseDiffTools.jl](https://github.com/JuliaDiff/SparseDiffTools.jl): automatic sparsity detection via Symbolics.jl and Cassette.jl
* [SparsityTracing.jl](https://github.com/PALEOtoolkit/SparsityTracing.jl): automatic Jacobian sparsity detection using an algorithm based on SparsLinC by Bischof et al. (1996)
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"

[compat]
ADTypes = "1"
ADTypes = "1"
31 changes: 21 additions & 10 deletions src/conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ for TT in (GradientTracer, ConnectivityTracer, HessianTracer)
## Constants
Base.zero(::Type{T}) where {T<:TT} = empty(T)
Base.one(::Type{T}) where {T<:TT} = empty(T)
Base.oneunit(::Type{T}) where {T<:TT} = empty(T)
Base.typemin(::Type{T}) where {T<:TT} = empty(T)
Base.typemax(::Type{T}) where {T<:TT} = empty(T)
Base.eps(::Type{T}) where {T<:TT} = empty(T)
Expand All @@ -27,6 +28,7 @@ for TT in (GradientTracer, ConnectivityTracer, HessianTracer)

Base.zero(::T) where {T<:TT} = empty(T)
Base.one(::T) where {T<:TT} = empty(T)
Base.oneunit(::T) where {T<:TT} = empty(T)
Base.typemin(::T) where {T<:TT} = empty(T)
Base.typemax(::T) where {T<:TT} = empty(T)
Base.eps(::T) where {T<:TT} = empty(T)
Expand All @@ -48,29 +50,37 @@ Base.similar(::Array, ::Type{ConnectivityTracer{C}}, dims::Dims{N}) where {C,N}
Base.similar(::Array, ::Type{GradientTracer{G}}, dims::Dims{N}) where {G,N} = zeros(GradientTracer{G}, dims)
Base.similar(::Array, ::Type{HessianTracer{G,H}}, dims::Dims{N}) where {G,H,N} = zeros(HessianTracer{G,H}, dims)


## Duals
function Base.promote_rule(::Type{D}, ::Type{N}) where {P,T,D<:Dual{P,T},N<:Number}
PP = Base.promote_rule(P, N) # TODO: possible method call error?
return D{PP,T}
function Base.promote_rule(::Type{Dual{P1, T}}, ::Type{Dual{P2, T}}) where {P1,P2,T}
PP = Base.promote_type(P1, P2) # TODO: possible method call error?
return Dual{PP,T}
end
function Base.promote_rule(::Type{Dual{P, T}}, ::Type{N}) where {P,T,N<:Number}
PP = Base.promote_type(P, N) # TODO: possible method call error?
return Dual{PP,T}
end
function Base.promote_rule(::Type{N}, ::Type{D}) where {P,T,D<:Dual{P,T},N<:Number}
PP = Base.promote_rule(P, N) # TODO: possible method call error?
return D{PP,T}
function Base.promote_rule(::Type{N}, ::Type{Dual{P, T}}) where {P,T,N<:Number}
PP = Base.promote_type(P, N) # TODO: possible method call error?
return Dual{PP,T}
end

Base.big(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{big(P),T}
Base.widen(::Type{D}) where {P,T,D<:Dual{P,T}} = Dual{widen(P),T}
Base.big(d::D) where {P,T,D<:Dual{P,T}} = Dual(big(primal(d)), tracer(d))
Base.widen(d::D) where {P,T,D<:Dual{P,T}} = Dual(widen(primal(d)), tracer(d))

Base.convert(::Type{D}, x::Number) where {P,T,D<:Dual{P,T}} = Dual(x, empty(T))
Base.convert(::Type{D}, d::D) where {D<:Dual} = d
Base.convert(::Type{T}, d::D) where {T<:Number,D<:Dual} = Dual(convert(T, primal(d)), tracer(d))
Base.convert(::Type{D}, x::Number) where {P,T,D<:Dual{P,T}} = Dual(x, empty(T))
Base.convert(::Type{D}, d::D) where {P,T,D<:Dual{P,T}} = d
Base.convert(::Type{N}, d::D) where {N<:Number,P,T,D<:Dual{P,T}} = Dual(convert(T, primal(d)), tracer(d))

function Base.convert(::Type{Dual{P1,T}}, d::Dual{P2,T}) where {P1,P2,T}
return Dual(convert(P1, primal(d)), tracer(d))
end

## Constants
Base.zero(::Type{D}) where {P,T,D<:Dual{P,T}} = D(zero(P), empty(T))
Base.one(::Type{D}) where {P,T,D<:Dual{P,T}} = D(one(P), empty(T))
Base.oneunit(::Type{D}) where {P,T,D<:Dual{P,T}} = D(oneunit(P), empty(T))
Base.typemin(::Type{D}) where {P,T,D<:Dual{P,T}} = D(typemin(P), empty(T))
Base.typemax(::Type{D}) where {P,T,D<:Dual{P,T}} = D(typemax(P), empty(T))
Base.eps(::Type{D}) where {P,T,D<:Dual{P,T}} = D(eps(P), empty(T))
Expand All @@ -81,6 +91,7 @@ Base.maxintfloat(::Type{D}) where {P,T,D<:Dual{P,T}} = D(maxintfloat(P), empty(T

Base.zero(d::D) where {P,T,D<:Dual{P,T}} = D(zero(primal(d)), empty(T))
Base.one(d::D) where {P,T,D<:Dual{P,T}} = D(one(primal(d)), empty(T))
Base.oneunit(d::D) where {P,T,D<:Dual{P,T}} = D(oneunit(primal(d)), empty(T))
Base.typemin(d::D) where {P,T,D<:Dual{P,T}} = D(typemin(primal(d)), empty(T))
Base.typemax(d::D) where {P,T,D<:Dual{P,T}} = D(typemax(primal(d)), empty(T))
Base.eps(d::D) where {P,T,D<:Dual{P,T}} = D(eps(primal(d)), empty(T))
Expand Down
14 changes: 12 additions & 2 deletions src/overload_dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,17 @@ end

for fn in (:isequal, :isapprox, :isless, :(==), :(<), :(>), :(<=), :(>=))
@eval Base.$fn(dx::D, dy::D) where {D<:Dual} = $fn(primal(dx), primal(dy))
@eval function Base.$fn(t1::T, t2::T) where {T<:AbstractTracer}
throw(MissingPrimalError($fn, t1))
@eval Base.$fn(dx::D, y::Number) where {D<:Dual} = $fn(primal(dx), y)
@eval Base.$fn(x::Number, dy::D) where {D<:Dual} = $fn(x, primal(dy))

# Error on non-dual tracers
@eval function Base.$fn(tx::T, ty::T) where {T<:AbstractTracer}
return throw(MissingPrimalError($fn, tx))
end
@eval function Base.$fn(tx::T, y::Number) where {T<:AbstractTracer}
return throw(MissingPrimalError($fn, tx))
end
@eval function Base.$fn(x::Number, ty::T) where {T<:AbstractTracer}
return throw(MissingPrimalError($fn, ty))
end
end
4 changes: 4 additions & 0 deletions src/tracers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ gradient(d::Dual{P,T}) where {P,T<:GradientTracer} = gradient(d.tracer)
gradient(d::Dual{P,T}) where {P,T<:HessianTracer} = gradient(d.tracer)
hessian(d::Dual{P,T}) where {P,T<:HessianTracer} = hessian(d.tracer)

function Dual{P,T}(x::Number) where {P<:Number,T<:AbstractTracer}
return Dual(convert(P, x), empty(T))
end

#===========#
# Utilities #
#===========#
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
67 changes: 67 additions & 0 deletions test/flux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using ADTypes
using ADTypes: AbstractSparsityDetector
using Flux: Conv, relu
using ReferenceTests
using SparseConnectivityTracer
using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector
using Test

function test_flux_conv(S::Type)
x = reshape(
[
0.2677768300138966
1.1934917429169245
-1.0496617141319355
0.456668782925957
0.09678342859916624
-0.7962039825333248
-0.6138709208787495
-0.6809396498148278
0.4938230574627916
0.7847107012511034
0.7423059724033608
-0.6914378396432983
1.2062310319178624
-0.19647670394840708
0.10708057449244994
-0.4787927739226245
0.045072020113458774
-1.219617669693635
],
3,
3,
2,
1,
) # WHCN
weights = reshape(
[
0.311843398150865
0.488663701947109
0.648497438559604
-0.41742794246238
0.174865988551499
1.061745573803265
-0.72434245370475
-0.05213963181095
],
2,
2,
2,
1,
)
bias = [0.1]

layer = Conv(weights, bias) # Conv((2, 2), 2 => 1)
J1 = jacobian_pattern(layer, x, S)
@test_reference "references/pattern/jacobian/NNlib/conv.txt" BitMatrix(J1)

layer = Conv(weights, bias, relu)
J2 = local_jacobian_pattern(layer, x, S)
@test_reference "references/pattern/jacobian/NNlib/conv_relu.txt" BitMatrix(J2)
end

@testset "$S" for S in (
BitSet, Set{UInt64}, DuplicateVector{UInt64}, RecursiveSet{UInt64}, SortedVector{UInt64}
)
test_flux_conv(S)
end
26 changes: 0 additions & 26 deletions test/nnlib.jl

This file was deleted.

1 change: 1 addition & 0 deletions test/references/pattern/jacobian/NNlib/conv_relu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Bool[1 1 0 1 1 0 0 0 0 1 1 0 1 1 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
5 changes: 2 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ using Documenter

using LinearAlgebra
using Random
using NNlib

DocMeta.setdocmeta!(
SparseConnectivityTracer,
Expand Down Expand Up @@ -70,8 +69,8 @@ DocMeta.setdocmeta!(
@testset "Brusselator" begin
include("brusselator.jl")
end
@testset "NNlib" begin
include("nnlib.jl")
@testset "Flux.jl" begin
include("flux.jl")
end
end

Expand Down
Loading