diff --git a/Project.toml b/Project.toml index 0cc7316e..005e6f49 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseConnectivityTracer" uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" authors = ["Adrian Hill "] -version = "0.1.0" +version = "0.2.0-DEV" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/README.md b/README.md index 00bb1d25..8a0117a1 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,7 @@ Fast sparsity detection via operator-overloading. -Will soon include Jacobian sparsity detection ([#19](https://github.com/adrhill/SparseConnectivityTracer.jl/issues/19)) -and Hessian sparsity detection ([#20](https://github.com/adrhill/SparseConnectivityTracer.jl/issues/20)). +Will soon include Hessian sparsity detection ([#20](https://github.com/adrhill/SparseConnectivityTracer.jl/issues/20)). ## Installation To install this package, open the Julia REPL and run @@ -26,7 +25,7 @@ julia> x = rand(3); julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]; -julia> connectivity(f, x) +julia> pattern(f, JacobianTracer, x) 3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries: 1 ⋅ ⋅ 1 1 ⋅ @@ -41,37 +40,32 @@ julia> x = rand(28, 28, 3, 1); julia> layer = Conv((3, 3), 3 => 8); -julia> connectivity(layer, x) +julia> pattern(layer, JacobianTracer, x) 5408×2352 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 146016 stored entries: -⎡⠙⢶⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠘⢷⣄⠀⠀⠀⠀⠀⎤ -⎢⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠙⢷⣄⠀⠀⠀⎥ -⎢⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠉⠳⣦⡀⎥ -⎢⠙⢷⣄⠀⠀⠀⠉⠻⣦⡀⠀⠀⠈⠙⢷⣄⠀⠀⠀⠈⠁⎥ -⎢⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠉⠳⣦⡀⠀⠀⎥ -⎢⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠈⠻⣦⡀⎥ -⎢⠙⢷⣄⠀⠀⠀⠉⠻⣦⡀⠀⠀⠈⠙⠷⣤⡀⠀⠀⠈⠁⎥ -⎢⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠈⠻⣦⡀⠀⠀⎥ -⎢⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⢦⣄⠀⠀⠀⠈⠻⣦⡀⎥ -⎢⠙⢷⣄⠀⠀⠀⠉⠻⣦⡀⠀⠀⠀⠉⠻⣦⡀⠀⠀⠈⠁⎥ -⎢⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⢦⣀⠀⠀⠀⠈⠻⣦⡀⠀⠀⎥ -⎢⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⎥ -⎢⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⎥ -⎢⠀⠀⠙⢷⣄⠀⠀⠀⠈⠙⢶⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⎥ -⎢⣀⠀⠀⠀⠙⢷⣄⡀⠀⠀⠀⠙⢷⣄⡀⠀⠀⠈⠻⣦⡀⎥ -⎢⠙⢷⣄⠀⠀⠀⠈⠛⢶⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⎥ -⎢⠀⠀⠙⢷⣄⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⎥ -⎢⣀⠀⠀⠀⠙⠳⣦⣀⠀⠀⠀⠙⢷⣄⡀⠀⠀⠈⠻⣦⡀⎥ -⎢⠙⢷⣄⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⎥ -⎢⠀⠀⠙⠷⣄⡀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⎥ -⎢⣀⠀⠀⠀⠈⠻⣦⣀⠀⠀⠀⠙⢷⣄⡀⠀⠀⠈⠻⣦⡀⎥ -⎢⠙⠷⣄⡀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⎥ -⎢⠀⠀⠈⠻⣦⡀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⎥ -⎣⠀⠀⠀⠀⠈⠻⣦⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⢦⡀⎦ +⎡⠙⢦⡀⠀⠀⠘⢷⣄⠀⠀⠈⠻⣦⡀⠀⠀⠀⎤ +⎢⠀⠀⠙⢷⣄⠀⠀⠙⠷⣄⠀⠀⠈⠻⣦⡀⠀⎥ +⎢⢶⣄⠀⠀⠙⠳⣦⡀⠀⠈⠳⢦⡀⠀⠈⠛⠂⎥ +⎢⠀⠙⢷⣄⠀⠀⠈⠻⣦⡀⠀⠀⠙⢦⣄⠀⠀⎥ +⎢⣀⡀⠀⠉⠳⣄⡀⠀⠈⠻⣦⣀⠀⠀⠙⢷⡄⎥ +⎢⠈⠻⣦⡀⠀⠈⠛⢦⡀⠀⠀⠙⢷⣄⠀⠀⠀⎥ +⎢⠀⠀⠈⠻⣦⡀⠀⠀⠙⢷⣄⠀⠀⠙⠷⣄⠀⎥ +⎢⠻⣦⡀⠀⠈⠙⢷⣄⠀⠀⠉⠻⣦⡀⠀⠈⠁⎥ +⎢⠀⠀⠙⢦⣀⠀⠀⠙⢷⣄⠀⠀⠈⠻⣦⡀⠀⎥ +⎢⢤⣄⠀⠀⠙⠳⣄⡀⠀⠉⠳⣤⡀⠀⠈⠛⠂⎥ +⎢⠀⠙⢷⣄⠀⠀⠈⠻⣦⡀⠀⠈⠙⢦⡀⠀⠀⎥ +⎢⣀⠀⠀⠙⢷⣄⡀⠀⠈⠻⣦⣀⠀⠀⠙⢷⡄⎥ +⎢⠈⠳⣦⡀⠀⠈⠻⣦⡀⠀⠀⠙⢷⣄⠀⠀⠀⎥ +⎢⠀⠀⠈⠻⣦⡀⠀⠀⠙⢦⣄⠀⠀⠙⢷⣄⠀⎥ +⎢⠻⣦⡀⠀⠈⠙⢷⣄⠀⠀⠉⠳⣄⡀⠀⠉⠁⎥ +⎢⠀⠈⠛⢦⡀⠀⠀⠙⢷⣄⠀⠀⠈⠻⣦⡀⠀⎥ +⎢⢤⣄⠀⠀⠙⠶⣄⠀⠀⠙⠷⣤⡀⠀⠈⠻⠆⎥ +⎢⠀⠙⢷⣄⠀⠀⠈⠳⣦⡀⠀⠈⠻⣦⡀⠀⠀⎥ +⎣⠀⠀⠀⠙⢷⣄⠀⠀⠈⠻⣦⠀⠀⠀⠙⢦⡀⎦ ``` -SparseConnectivityTracer enumerates inputs `x` and primal outputs `y=f(x)` and returns a sparse connectivity matrix `C` of size $m \times n$, where `C[i, j]` is `true` if the compute graph connects the $j$-th entry in `x` to the $i$-th entry in `y`. +SparseConnectivityTracer enumerates inputs `x` and primal outputs `y = f(x)` and returns a sparse matrix `C` of size $m \times n$, where `C[i, j]` is `true` if the compute graph connects the $j$-th entry in `x` to the $i$-th entry in `y`. -For more detailled examples, take a look at the [API reference](https://adrianhill.de/SparseConnectivityTracer.jl/dev/api). +For more detailled examples, take a look at the [documentation](https://adrianhill.de/SparseConnectivityTracer.jl/dev). ## Related packages * [SparseDiffTools.jl](https://github.com/JuliaDiff/SparseDiffTools.jl): automatic sparsity detection via Symbolics.jl and Cassette.jl diff --git a/docs/src/api.md b/docs/src/api.md index baac9133..d3e6c2a4 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -10,19 +10,26 @@ CollapsedDocStrings = true ## Interface ```@docs -connectivity +pattern TracerSparsityDetector ``` ## Internals -SparseConnectivityTracer works by pushing a `Number` type called [`Tracer`](@ref) through generic functions: +SparseConnectivityTracer works by pushing `Number` types called tracers through generic functions. +Currently, two tracer types are provided: + +```@docs +JacobianTracer +ConnectivityTracer +``` + +Utilities to create tracers: ```@docs -Tracer tracer trace_input ``` -The following utilities can be used to extract input indices from [`Tracer`](@ref)s: +Utility to extract input indices from tracers: ```@docs inputs ``` diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index e84fac90..039d9896 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -4,17 +4,20 @@ using ADTypes: ADTypes import Random: rand, AbstractRNG, SamplerType import SparseArrays: sparse -include("tracer.jl") +abstract type AbstractTracer <: Number end + +include("tracers.jl") include("conversion.jl") include("operators.jl") -include("overload_tracer.jl") -include("connectivity.jl") +include("overload_connectivity.jl") +include("overload_jacobian.jl") +include("pattern.jl") include("adtypes.jl") -export Tracer +export JacobianTracer, ConnectivityTracer export tracer, trace_input export inputs -export connectivity +export pattern export TracerSparsityDetector end # module diff --git a/src/adtypes.jl b/src/adtypes.jl index eabea79a..25a39ea4 100644 --- a/src/adtypes.jl +++ b/src/adtypes.jl @@ -18,13 +18,14 @@ julia> ADTypes.jacobian_sparsity(diff, rand(4), TracerSparsityDetector()) struct TracerSparsityDetector <: ADTypes.AbstractSparsityDetector end function ADTypes.jacobian_sparsity(f, x, ::TracerSparsityDetector) - return connectivity(f, x) + return pattern(f, JacobianTracer, x) end function ADTypes.jacobian_sparsity(f!, y, x, ::TracerSparsityDetector) - return connectivity(f!, y, x) + return pattern(f!, y, JacobianTracer, x) end function ADTypes.hessian_sparsity(f, x, ::TracerSparsityDetector) + # TODO: return pattern(f, HessianTracer, x) return error("Hessian sparsity is not yet implemented for `TracerSparsityDetector`.") end diff --git a/src/connectivity.jl b/src/connectivity.jl deleted file mode 100644 index fd736017..00000000 --- a/src/connectivity.jl +++ /dev/null @@ -1,98 +0,0 @@ -## Enumerate inputs - -""" - trace_input(x) - -Enumerates input indices and constructs [`Tracer`](@ref)s. - -## Example -```jldoctest -julia> x = rand(3); - -julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]; - -julia> xt = trace_input(x) -3-element Vector{Tracer}: - Tracer(1,) - Tracer(2,) - Tracer(3,) - -julia> yt = f(xt) -3-element Vector{Tracer}: - Tracer(1,) - Tracer(1, 2) - Tracer(3,) -``` -""" -trace_input(x) = trace_input(x, 1) -trace_input(::Number, i) = tracer(i) -function trace_input(x::AbstractArray, i) - indices = (i - 1) .+ reshape(1:length(x), size(x)) - return tracer.(indices) -end - -## Construct connectivity matrix -""" - connectivity(f, x) - -Enumerates inputs `x` and primal outputs `y=f(x)` and returns sparse connectivity matrix `C` of size `(m, n)` -where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. - -## Example -```jldoctest -julia> x = rand(3); - -julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]; - -julia> connectivity(f, x) -3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries: - 1 ⋅ ⋅ - 1 1 ⋅ - ⋅ ⋅ 1 -``` -""" -function connectivity(f, x) - xt = trace_input(x) - yt = f(xt) - return _connectivity(xt, yt) -end - -""" - connectivity(f!, y, x) - -Enumerates inputs `x` and primal outputs `y` after `f!(y, x)` and returns sparse connectivity matrix `C` of size `(m, n)` -where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. -""" -function connectivity(f!, y, x) - xt = trace_input(x) - yt = similar(y, Tracer) - f!(yt, xt) - return _connectivity(xt, yt) -end - -_connectivity(xt::Tracer, yt::Number) = _connectivity([xt], [yt]) -_connectivity(xt::Tracer, yt::AbstractArray{Number}) = _connectivity([xt], yt) -_connectivity(xt::AbstractArray{Tracer}, yt::Number) = _connectivity(xt, [yt]) -function _connectivity(xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number}) - return connectivity_sparsematrixcsc(xt, yt) -end - -function connectivity_sparsematrixcsc( - xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number} -) - # Construct connectivity matrix of size (ouput_dim, input_dim) - n, m = length(xt), length(yt) - I = UInt64[] - J = UInt64[] - V = Bool[] - for (i, y) in enumerate(yt) - if y isa Tracer - for j in inputs(y) - push!(I, i) - push!(J, j) - push!(V, true) - end - end - end - return sparse(I, J, V, m, n) -end diff --git a/src/conversion.jl b/src/conversion.jl index fad90eb4..fcdc5b0d 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -1,23 +1,25 @@ ## Type conversions -Base.promote_rule(::Type{Tracer}, ::Type{N}) where {N<:Number} = Tracer -Base.promote_rule(::Type{N}, ::Type{Tracer}) where {N<:Number} = Tracer +for T in (:JacobianTracer, :ConnectivityTracer) + @eval Base.promote_rule(::Type{$T}, ::Type{N}) where {N<:Number} = $T + @eval Base.promote_rule(::Type{N}, ::Type{$T}) where {N<:Number} = $T -Base.big(::Type{Tracer}) = Tracer -Base.widen(::Type{Tracer}) = Tracer -Base.widen(t::Tracer) = t + @eval Base.big(::Type{$T}) = $T + @eval Base.widen(::Type{$T}) = $T + @eval Base.widen(t::$T) = t -Base.convert(::Type{Tracer}, x::Number) = EMPTY_TRACER -Base.convert(::Type{Tracer}, t::Tracer) = t -Base.convert(::Type{<:Number}, t::Tracer) = t + @eval Base.convert(::Type{$T}, x::Number) = empty($T) + @eval Base.convert(::Type{$T}, t::$T) = t + @eval Base.convert(::Type{<:Number}, t::$T) = t -## Array constructors -Base.zero(::Type{Tracer}) = EMPTY_TRACER -Base.one(::Type{Tracer}) = EMPTY_TRACER + ## Array constructors + @eval Base.zero(::Type{$T}) = empty($T) + @eval Base.one(::Type{$T}) = empty($T) -Base.similar(a::Array{Tracer,1}) = zeros(Tracer, size(a, 1)) -Base.similar(a::Array{Tracer,2}) = zeros(Tracer, size(a, 1), size(a, 2)) -Base.similar(a::Array{T,1}, ::Type{Tracer}) where {T} = zeros(Tracer, size(a, 1)) -Base.similar(a::Array{T,2}, ::Type{Tracer}) where {T} = zeros(Tracer, size(a, 1), size(a, 2)) -Base.similar(::Array{Tracer}, m::Int) = zeros(Tracer, m) -Base.similar(::Array, ::Type{Tracer}, dims::Dims{N}) where {N} = zeros(Tracer, dims) -Base.similar(::Array{Tracer}, dims::Dims{N}) where {N} = zeros(Tracer, dims) + @eval Base.similar(a::Array{$T,1}) = zeros($T, size(a, 1)) + @eval Base.similar(a::Array{$T,2}) = zeros($T, size(a, 1), size(a, 2)) + @eval Base.similar(a::Array{A,1}, ::Type{$T}) where {A} = zeros($T, size(a, 1)) + @eval Base.similar(a::Array{A,2}, ::Type{$T}) where {A} = zeros($T, size(a, 1), size(a, 2)) + @eval Base.similar(::Array{$T}, m::Int) = zeros($T, m) + @eval Base.similar(::Array, ::Type{$T}, dims::Dims{N}) where {N} = zeros($T, dims) + @eval Base.similar(::Array{$T}, dims::Dims{N}) where {N} = zeros($T, dims) +end diff --git a/src/overload_connectivity.jl b/src/overload_connectivity.jl new file mode 100644 index 00000000..2bd667b3 --- /dev/null +++ b/src/overload_connectivity.jl @@ -0,0 +1,32 @@ +for fn in union(ops_1_to_1_s, ops_1_to_1_f, ops_1_to_1_z) + @eval Base.$fn(t::ConnectivityTracer) = t +end + +for fn in ops_1_to_1_const + @eval Base.$fn(::ConnectivityTracer) = EMPTY_CONNECTIVITY_TRACER +end + +for fn in ops_1_to_2 + @eval Base.$fn(t::ConnectivityTracer) = (t, t) +end + +for fn in ops_2_to_1 + @eval Base.$fn(a::ConnectivityTracer, b::ConnectivityTracer) = uniontracer(a, b) + @eval Base.$fn(t::ConnectivityTracer, ::Number) = t + @eval Base.$fn(::Number, t::ConnectivityTracer) = t +end + +# Extra types required for exponent +Base.:^(a::ConnectivityTracer, b::ConnectivityTracer) = uniontracer(a, b) +for T in (:Real, :Integer, :Rational) + @eval Base.:^(t::ConnectivityTracer, ::$T) = t + @eval Base.:^(::$T, t::ConnectivityTracer) = t +end +Base.:^(t::ConnectivityTracer, ::Irrational{:ℯ}) = t +Base.:^(::Irrational{:ℯ}, t::ConnectivityTracer) = t + +## Rounding +Base.round(t::ConnectivityTracer, ::RoundingMode; kwargs...) = t + +## Random numbers +rand(::AbstractRNG, ::SamplerType{ConnectivityTracer}) = EMPTY_CONNECTIVITY_TRACER diff --git a/src/overload_jacobian.jl b/src/overload_jacobian.jl new file mode 100644 index 00000000..6b4d9b15 --- /dev/null +++ b/src/overload_jacobian.jl @@ -0,0 +1,68 @@ +for fn in union(ops_1_to_1_s, ops_1_to_1_f) + @eval Base.$fn(t::JacobianTracer) = t +end + +for fn in union(ops_1_to_1_z, ops_1_to_1_const) + @eval Base.$fn(::JacobianTracer) = EMPTY_JACOBIAN_TRACER +end + +for fn in union( + ops_2_to_1_ssc, + ops_2_to_1_ssz, + ops_2_to_1_sfc, + ops_2_to_1_sfz, + ops_2_to_1_fsc, + ops_2_to_1_fsz, + ops_2_to_1_ffc, + ops_2_to_1_ffz, +) + @eval Base.$fn(a::JacobianTracer, b::JacobianTracer) = uniontracer(a, b) + @eval Base.$fn(t::JacobianTracer, ::Number) = t + @eval Base.$fn(::Number, t::JacobianTracer) = t +end + +for fn in union(ops_2_to_1_zsz, ops_2_to_1_zfz) + @eval Base.$fn(::JacobianTracer, t::JacobianTracer) = t + @eval Base.$fn(::JacobianTracer, ::Number) = EMPTY_JACOBIAN_TRACER + @eval Base.$fn(::Number, t::JacobianTracer) = t +end +for fn in union(ops_2_to_1_szz, ops_2_to_1_fzz) + @eval Base.$fn(t::JacobianTracer, ::JacobianTracer) = t + @eval Base.$fn(t::JacobianTracer, ::Number) = t + @eval Base.$fn(::Number, t::JacobianTracer) = EMPTY_JACOBIAN_TRACER +end +for fn in ops_2_to_1_zzz + @eval Base.$fn(::JacobianTracer, ::JacobianTracer) = EMPTY_JACOBIAN_TRACER + @eval Base.$fn(::JacobianTracer, ::Number) = EMPTY_JACOBIAN_TRACER + @eval Base.$fn(::Number, ::JacobianTracer) = EMPTY_JACOBIAN_TRACER +end + +for fn in union(ops_1_to_2_ss, ops_1_to_2_sf, ops_1_to_2_fs, ops_1_to_2_ff) + @eval Base.$fn(t::JacobianTracer) = (t, t) +end + +for fn in union(ops_1_to_2_sz, ops_1_to_2_fz) + @eval Base.$fn(t::JacobianTracer) = (t, EMPTY_JACOBIAN_TRACER) +end + +for fn in union(ops_1_to_2_zs, ops_1_to_2_zf) + @eval Base.$fn(t::JacobianTracer) = (EMPTY_JACOBIAN_TRACER, t) +end +for fn in ops_1_to_2_zz + @eval Base.$fn(::JacobianTracer) = (EMPTY_JACOBIAN_TRACER, EMPTY_JACOBIAN_TRACER) +end + +# Extra types required for exponent +Base.:^(a::JacobianTracer, b::JacobianTracer) = uniontracer(a, b) +for T in (:Real, :Integer, :Rational) + @eval Base.:^(t::JacobianTracer, ::$T) = t + @eval Base.:^(::$T, t::JacobianTracer) = t +end +Base.:^(t::JacobianTracer, ::Irrational{:ℯ}) = t +Base.:^(::Irrational{:ℯ}, t::JacobianTracer) = t + +## Rounding +Base.round(t::JacobianTracer, ::RoundingMode; kwargs...) = EMPTY_JACOBIAN_TRACER + +## Random numbers +rand(::AbstractRNG, ::SamplerType{JacobianTracer}) = EMPTY_JACOBIAN_TRACER diff --git a/src/overload_tracer.jl b/src/overload_tracer.jl deleted file mode 100644 index 9409f715..00000000 --- a/src/overload_tracer.jl +++ /dev/null @@ -1,32 +0,0 @@ -for fn in union(ops_1_to_1_s, ops_1_to_1_f, ops_1_to_1_z) - @eval Base.$fn(t::Tracer) = t -end - -for fn in ops_1_to_1_const - @eval Base.$fn(::Tracer) = EMPTY_TRACER -end - -for fn in ops_1_to_2 - @eval Base.$fn(t::Tracer) = (t, t) -end - -for fn in ops_2_to_1 - @eval Base.$fn(a::Tracer, b::Tracer) = uniontracer(a, b) - @eval Base.$fn(t::Tracer, ::Number) = t - @eval Base.$fn(::Number, t::Tracer) = t -end - -# Extra types required for exponent -Base.:^(a::Tracer, b::Tracer) = uniontracer(a, b) -for T in (:Real, :Integer, :Rational) - @eval Base.:^(t::Tracer, ::$T) = t - @eval Base.:^(::$T, t::Tracer) = t -end -Base.:^(t::Tracer, ::Irrational{:ℯ}) = t -Base.:^(::Irrational{:ℯ}, t::Tracer) = t - -## Rounding -Base.round(t::Tracer, ::RoundingMode; kwargs...) = t - -## Random numbers -rand(::AbstractRNG, ::SamplerType{Tracer}) = EMPTY_TRACER diff --git a/src/pattern.jl b/src/pattern.jl new file mode 100644 index 00000000..cde0afbe --- /dev/null +++ b/src/pattern.jl @@ -0,0 +1,109 @@ +## Enumerate inputs + +""" + trace_input(JacobianTracer, x) + trace_input(ConnectivityTracer, x) + + +Enumerates input indices and constructs the specified type of tracer. +Supports [`JacobianTracer`](@ref) and [`ConnectivityTracer`](@ref). + +## Example +```jldoctest +julia> x = rand(3); + +julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]; + +julia> xt = trace_input(ConnectivityTracer, x) +3-element Vector{ConnectivityTracer}: + ConnectivityTracer(1,) + ConnectivityTracer(2,) + ConnectivityTracer(3,) + +julia> yt = f(xt) +3-element Vector{ConnectivityTracer}: + ConnectivityTracer(1,) + ConnectivityTracer(1, 2) + ConnectivityTracer(3,) +``` +""" +trace_input(::Type{T}, x) where {T<:AbstractTracer} = trace_input(T, x, 1) +trace_input(::Type{T}, ::Number, i) where {T<:AbstractTracer} = tracer(T, i) +function trace_input(::Type{T}, x::AbstractArray, i) where {T<:AbstractTracer} + indices = (i - 1) .+ reshape(1:length(x), size(x)) + return tracer.(T, indices) +end + +## Construct sparsity pattern matrix +""" + pattern(f, JacobianTracer, x) + +Computes the sparsity pattern of the Jacobian of `y = f(x)`. + + pattern(f, ConnectivityTracer, x) + +Enumerates inputs `x` and primal outputs `y = f(x)` and returns sparse matrix `C` of size `(m, n)` +where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. + +## Example +```jldoctest +julia> x = rand(3); + +julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]; + +julia> pattern(f, ConnectivityTracer, x) +3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries: + 1 ⋅ ⋅ + 1 1 ⋅ + ⋅ ⋅ 1 +``` +""" +function pattern(f, ::Type{T}, x) where {T<:AbstractTracer} + xt = trace_input(T, x) + yt = f(xt) + return _pattern(xt, yt) +end + +""" + pattern(f!, y, JacobianTracer, x) + +Computes the sparsity pattern of the Jacobian of `f!(y, x)`. + + pattern(f!, y, ConnectivityTracer, x) + +Enumerates inputs `x` and primal outputs `y` after `f!(y, x)` and returns sparse matrix `C` of size `(m, n)` +where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. +""" +function pattern(f!, y, ::Type{T}, x) where {T<:AbstractTracer} + xt = trace_input(T, x) + yt = similar(y, T) + f!(yt, xt) + return _pattern(xt, yt) +end + +_pattern(xt::AbstractTracer, yt::Number) = _pattern([xt], [yt]) +_pattern(xt::AbstractTracer, yt::AbstractArray{Number}) = _pattern([xt], yt) +_pattern(xt::AbstractArray{<:AbstractTracer}, yt::Number) = _pattern(xt, [yt]) +function _pattern(xt::AbstractArray{<:AbstractTracer}, yt::AbstractArray{<:Number}) + return _pattern_to_sparsemat(xt, yt) +end + +function _pattern_to_sparsemat( + xt::AbstractArray{<:AbstractTracer}, yt::AbstractArray{<:Number} +) + # Construct matrix of size (ouput_dim, input_dim) + n, m = length(xt), length(yt) + I = UInt64[] + J = UInt64[] + V = Bool[] + for (i, y) in enumerate(yt) + if y isa AbstractTracer + for j in inputs(y) + push!(I, i) + push!(J, j) + push!(V, true) + end + end + end + return sparse(I, J, V, m, n) +end diff --git a/src/tracer.jl b/src/tracer.jl deleted file mode 100644 index cee0c7d7..00000000 --- a/src/tracer.jl +++ /dev/null @@ -1,123 +0,0 @@ -""" - Tracer(indexset) <: Number - -Number type keeping track of input indices of previous computations. - -See also the convenience constructor [`tracer`](@ref). -For a higher-level interface, refer to [`connectivity`](@ref). - -## Examples -By enumerating inputs with tracers, we can keep track of input connectivities: -```jldoctest -julia> xt = [tracer(1), tracer(2), tracer(3)] -3-element Vector{Tracer}: - Tracer(1,) - Tracer(2,) - Tracer(3,) - -julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]; - -julia> yt = f(xt) -3-element Vector{Tracer}: - Tracer(1,) - Tracer(1, 2) - Tracer(3,) -``` - -This works by overloading operators to either keep input connectivities constant, -compute unions or set connectivities to zero: -```jldoctest Tracer -julia> x = tracer(1, 2, 3) -Tracer(1, 2, 3) - -julia> sin(x) # Most operators don't modify input connectivities. -Tracer(1, 2, 3) - -julia> 2 * x^3 -Tracer(1, 2, 3) - -julia> zero(x) # Tracer is strictly operator overloading... -Tracer() - -julia> 0 * x # ...and doesn't look at input values. -Tracer(1, 2, 3) - -julia> y = tracer(3, 5) -Tracer(3, 5) - -julia> x + y # Operations on two Tracers construct union sets -Tracer(1, 2, 3, 5) - -julia> x ^ y -Tracer(1, 2, 3, 5) -``` - -[`Tracer`](@ref) also supports random number generation and pre-allocations: -```jldoctest Tracer -julia> M = rand(Tracer, 3, 2) -3×2 Matrix{Tracer}: - Tracer() Tracer() - Tracer() Tracer() - Tracer() Tracer() - -julia> similar(M) -3×2 Matrix{Tracer}: - Tracer() Tracer() - Tracer() Tracer() - Tracer() Tracer() - -julia> M * [x, y] -3-element Vector{Tracer}: - Tracer(1, 2, 3, 5) - Tracer(1, 2, 3, 5) - Tracer(1, 2, 3, 5) -``` -""" -struct Tracer <: Number - inputs::BitSet # indices of connected, enumerated inputs -end - -const EMPTY_TRACER = Tracer(BitSet()) - -# We have to be careful when defining constructors: -# Generic code expecting "regular" numbers `x` will sometimes convert them -# by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `Tracer`. -# When this happens, we create a new empty tracer with no input connectivity. -Tracer(::Number) = EMPTY_TRACER -Tracer(t::Tracer) = t - -uniontracer(a::Tracer, b::Tracer) = Tracer(union(a.inputs, b.inputs)) - -""" - tracer(index) - tracer(indices) - -Convenience constructor for [`Tracer`](@ref) from input indices. -""" -tracer(index::Integer) = Tracer(BitSet(index)) -tracer(inds::NTuple{N,<:Integer}) where {N} = Tracer(BitSet(inds)) -tracer(inds...) = tracer(inds) - -# Utilities for accessing input indices -""" - inputs(tracer) - -Return raw `UInt64` input indices of a [`Tracer`](@ref). - -## Example -```jldoctest -julia> t = tracer(1, 2, 4) -Tracer(1, 2, 4) - -julia> inputs(t) -3-element Vector{Int64}: - 1 - 2 - 4 -``` -""" -inputs(t::Tracer) = collect(t.inputs) - -function Base.show(io::IO, t::Tracer) - return Base.show_delim_array(io, inputs(t), "Tracer(", ',', ')', true) -end diff --git a/src/tracers.jl b/src/tracers.jl new file mode 100644 index 00000000..c932e492 --- /dev/null +++ b/src/tracers.jl @@ -0,0 +1,113 @@ +#==============# +# Connectivity # +#==============# + +""" + ConnectivityTracer(indexset) <: Number + +Number type keeping track of input indices of previous computations. + +See also the convenience constructor [`tracer`](@ref). +For a higher-level interface, refer to [`pattern`](@ref). +""" +struct ConnectivityTracer <: AbstractTracer + inputs::BitSet # indices of connected, enumerated inputs +end + +function Base.show(io::IO, t::ConnectivityTracer) + return Base.show_delim_array(io, inputs(t), "ConnectivityTracer(", ',', ')', true) +end + +const EMPTY_CONNECTIVITY_TRACER = ConnectivityTracer(BitSet()) + +# We have to be careful when defining constructors: +# Generic code expecting "regular" numbers `x` will sometimes convert them +# by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `ConnectivityTracer`. +# When this happens, we create a new empty tracer with no input pattern. +ConnectivityTracer(::Number) = EMPTY_CONNECTIVITY_TRACER +ConnectivityTracer(t::ConnectivityTracer) = t + +#==========# +# Jacobian # +#==========# + +""" + JacobianTracer(indexset) <: Number + +Number type keeping track of input indices of previous computations with non-zero derivatives. + +See also the convenience constructor [`tracer`](@ref). +For a higher-level interface, refer to [`pattern`](@ref). +""" +struct JacobianTracer <: AbstractTracer + inputs::BitSet # indices of connected, enumerated inputs +end + +function Base.show(io::IO, t::JacobianTracer) + return Base.show_delim_array(io, inputs(t), "JacobianTracer(", ',', ')', true) +end + +const EMPTY_JACOBIAN_TRACER = JacobianTracer(BitSet()) + +JacobianTracer(::Number) = EMPTY_JACOBIAN_TRACER +JacobianTracer(t::JacobianTracer) = t + +#===========# +# Utilities # +#===========# + +## Access inputs +""" + inputs(tracer) + +Return raw `UInt64` input indices of a [`ConnectivityTracer`](@ref) or [`JacobianTracer`](@ref) + +## Example +```jldoctest +julia> t = tracer(ConnectivityTracer, 1, 2, 4) +ConnectivityTracer(1, 2, 4) + +julia> inputs(t) +3-element Vector{Int64}: + 1 + 2 + 4 +``` +""" +inputs(t::ConnectivityTracer) = collect(t.inputs) +inputs(t::JacobianTracer) = collect(t.inputs) + +## Unions of tracers +function uniontracer(a::ConnectivityTracer, b::ConnectivityTracer) + return ConnectivityTracer(union(a.inputs, b.inputs)) +end + +function uniontracer(a::JacobianTracer, b::JacobianTracer) + return JacobianTracer(union(a.inputs, b.inputs)) +end + +## Get empty tracer +empty(::JacobianTracer) = EMPTY_JACOBIAN_TRACER +empty(::Type{JacobianTracer}) = EMPTY_JACOBIAN_TRACER +empty(::ConnectivityTracer) = EMPTY_CONNECTIVITY_TRACER +empty(::Type{ConnectivityTracer}) = EMPTY_CONNECTIVITY_TRACER + +""" + tracer(JacobianTracer, index) + tracer(JacobianTracer, indices) + tracer(ConnectivityTracer, index) + tracer(ConnectivityTracer, indices) + +Convenience constructor for [`JacobianTracer`](@ref) [`ConnectivityTracer`](@ref) from input indices. +""" +tracer(::Type{JacobianTracer}, index::Integer) = JacobianTracer(BitSet(index)) +tracer(::Type{ConnectivityTracer}, index::Integer) = ConnectivityTracer(BitSet(index)) + +function tracer(::Type{JacobianTracer}, inds::NTuple{N,<:Integer}) where {N} + return JacobianTracer(BitSet(inds)) +end +function tracer(::Type{ConnectivityTracer}, inds::NTuple{N,<:Integer}) where {N} + return ConnectivityTracer(BitSet(inds)) +end + +tracer(::Type{T}, inds...) where {T<:AbstractTracer} = tracer(T, inds) diff --git a/test/benchmark.jl b/test/benchmark.jl index 5b614e16..c98db653 100644 --- a/test/benchmark.jl +++ b/test/benchmark.jl @@ -19,7 +19,7 @@ function benchmark_brusselator(N::Integer, method=:tracer) f!(du, u) = brusselator_2d_loop(du, u, p, nothing) if method == :tracer - return @benchmark connectivity($f!, $du, $u) + return @benchmark pattern($f!, $du, $u) elseif method == :symbolics return @benchmark Symbolics.jacobian_sparsity($f!, $du, $u) end @@ -31,7 +31,7 @@ function benchmark_conv(method=:tracer) f(x) = conv(x, w) if method == :tracer - return @benchmark connectivity($f, $x) + return @benchmark pattern($f, $x) elseif method == :symbolics return @benchmark Symbolics.jacobian_sparsity($f, $x) end diff --git a/test/references/connectivity/Brusselator.txt b/test/references/pattern/connectivity/Brusselator.txt similarity index 100% rename from test/references/connectivity/Brusselator.txt rename to test/references/pattern/connectivity/Brusselator.txt diff --git a/test/references/connectivity/NNlib/conv.txt b/test/references/pattern/connectivity/NNlib/conv.txt similarity index 100% rename from test/references/connectivity/NNlib/conv.txt rename to test/references/pattern/connectivity/NNlib/conv.txt diff --git a/test/references/pattern/jacobian/Brusselator.txt b/test/references/pattern/jacobian/Brusselator.txt new file mode 100644 index 00000000..24a82b43 --- /dev/null +++ b/test/references/pattern/jacobian/Brusselator.txt @@ -0,0 +1 @@ +Bool[1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 1 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 1 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0; 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0; 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0; 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0; 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0; 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0; 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1; 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0; 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0; 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0; 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0; 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0; 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1; 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0 0 1 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 1; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 1 0 0 0 1; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1 0; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 1 1; 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 1 1] \ No newline at end of file diff --git a/test/references/pattern/jacobian/NNlib/conv.txt b/test/references/pattern/jacobian/NNlib/conv.txt new file mode 100644 index 00000000..aaf436e1 --- /dev/null +++ b/test/references/pattern/jacobian/NNlib/conv.txt @@ -0,0 +1 @@ +Bool[1 1 0 1 1 0 0 0 0 1 1 0 1 1 0 0 0 0; 0 1 1 0 1 1 0 0 0 0 1 1 0 1 1 0 0 0; 0 0 0 1 1 0 1 1 0 0 0 0 1 1 0 1 1 0; 0 0 0 0 1 1 0 1 1 0 0 0 0 1 1 0 1 1] \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 7853d8e8..3217f9d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,19 +37,22 @@ DocMeta.setdocmeta!( @testset "JET tests" begin JET.test_package(SparseConnectivityTracer; target_defined_modules=true) end - @testset verbose = true "Classification of operators by diff'ability" begin + @testset "Doctests" begin + Documenter.doctest(SparseConnectivityTracer) + end + @testset "Classification of operators by diff'ability" begin include("test_differentiability.jl") end @testset "Connectivity" begin x = rand(3) - xt = trace_input(x) + xt = trace_input(ConnectivityTracer, x) # Matrix multiplication A = rand(1, 3) yt = only(A * xt) @test inputs(yt) == [1, 2, 3] - @test connectivity(x -> only(A * x), x) ≈ [1 1 1] + @test pattern(x -> only(A * x), ConnectivityTracer, x) ≈ [1 1 1] # Custom functions f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])] @@ -58,17 +61,29 @@ DocMeta.setdocmeta!( @test inputs(yt[2]) == [1, 2] @test inputs(yt[3]) == [3] - @test connectivity(f, x) ≈ [1 0 0; 1 1 0; 0 0 1] + @test pattern(f, ConnectivityTracer, x) ≈ [1 0 0; 1 1 0; 0 0 1] + @test pattern(f, JacobianTracer, x) ≈ [1 0 0; 1 1 0; 0 0 1] + + @test pattern(identity, ConnectivityTracer, rand()) ≈ [1;;] + @test pattern(identity, JacobianTracer, rand()) ≈ [1;;] + @test pattern(Returns(1), ConnectivityTracer, 1) ≈ [0;;] + @test pattern(Returns(1), JacobianTracer, 1) ≈ [0;;] - @test connectivity(identity, rand()) ≈ [1;;] - @test connectivity(Returns(1), 1) ≈ [0;;] + # Test JacobianTracer on functions with zero derivatives + x = rand(2) + g(x) = [x[1] * x[2], ceil(x[1] * x[2]), x[1] * round(x[2])] + @test pattern(g, ConnectivityTracer, x) ≈ [1 1; 1 1; 1 1] + @test pattern(g, JacobianTracer, x) ≈ [1 1; 0 0; 1 0] end @testset "Real-world tests" begin @testset "NNlib" begin x = rand(3, 3, 2, 1) # WHCN w = rand(2, 2, 2, 1) # Conv((2, 2), 2 => 1) - C = connectivity(x -> NNlib.conv(x, w), x) - @test_reference "references/connectivity/NNlib/conv.txt" BitMatrix(C) + C = pattern(x -> NNlib.conv(x, w), ConnectivityTracer, x) + @test_reference "references/pattern/connectivity/NNlib/conv.txt" BitMatrix(C) + J = pattern(x -> NNlib.conv(x, w), JacobianTracer, x) + @test_reference "references/pattern/jacobian/NNlib/conv.txt" BitMatrix(J) + @test C == J end @testset "Brusselator" begin include("brusselator.jl") @@ -85,8 +100,11 @@ DocMeta.setdocmeta!( du = similar(u) f!(du, u) = brusselator_2d_loop(du, u, p, nothing) - C = connectivity(f!, du, u) - @test_reference "references/connectivity/Brusselator.txt" BitMatrix(C) + C = pattern(f!, du, ConnectivityTracer, u) + @test_reference "references/pattern/connectivity/Brusselator.txt" BitMatrix(C) + J = pattern(f!, du, JacobianTracer, u) + @test_reference "references/pattern/jacobian/Brusselator.txt" BitMatrix(J) + @test C == J C_ref = Symbolics.jacobian_sparsity(f!, du, u) @test C == C_ref @@ -95,7 +113,4 @@ DocMeta.setdocmeta!( @testset "ADTypes integration" begin include("adtypes.jl") end - @testset "Doctests" begin - Documenter.doctest(SparseConnectivityTracer) - end end