From 85f81e60a77612cbdc7044ad36a8491fa5023c5d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 6 Aug 2024 15:08:59 +0200 Subject: [PATCH 1/4] Support ComponentArrays --- Project.toml | 2 +- src/interface.jl | 3 ++- test/Project.toml | 1 + test/componentarrays.jl | 30 ++++++++++++++++++++++++++++++ test/runtests.jl | 3 +++ 5 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 test/componentarrays.jl diff --git a/Project.toml b/Project.toml index 0b0f54fa..8a00512a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseConnectivityTracer" uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" authors = ["Adrian Hill "] -version = "0.6.0" +version = "0.6.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/interface.jl b/src/interface.jl index 325e0b86..89182127 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -17,7 +17,8 @@ Supports [`GradientTracer`](@ref), [`HessianTracer`](@ref) and [`Dual`](@ref). trace_input(::Type{T}, xs) where {T<:Union{AbstractTracer,Dual}} = trace_input(T, xs, 1) function trace_input(::Type{T}, xs::AbstractArray, i) where {T<:Union{AbstractTracer,Dual}} - is = reshape(1:length(xs), size(xs)) .+ (i - 1) + is = similar(xs, Int) # same array type as xs + is .= reshape(1:length(xs), size(xs)) .+ (i - 1) return create_tracers(T, xs, is) end function trace_input(::Type{T}, x::Real, i::Integer) where {T<:Union{AbstractTracer,Dual}} diff --git a/test/Project.toml b/test/Project.toml index 81e54b10..fca60bc6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/test/componentarrays.jl b/test/componentarrays.jl new file mode 100644 index 00000000..a341bb90 --- /dev/null +++ b/test/componentarrays.jl @@ -0,0 +1,30 @@ +using ComponentArrays +using SparseConnectivityTracer +using Test + +f(x::AbstractVector) = abs2.(x) +f_comp(x::ComponentVector) = ComponentVector(; a=abs2.(x.a), b=abs2.(x.b)) + +function f!(y::AbstractVector, x::AbstractVector) + y .= abs2.(x) + return y +end + +function f_comp!(y::ComponentVector, x::ComponentVector) + y.a .= abs2.(x.a) + y.b .= abs2.(x.b) + return y +end + +x_comp = ComponentVector(; a=rand(2), b=rand(3)) +y_comp = ComponentVector(; a=rand(2), b=rand(3)) +x = Vector(x_comp) +y = Vector(y_comp) + +detector = TracerSparsityDetector() + +@test jacobian_sparsity(f_comp, x_comp, detector) == jacobian_sparsity(f, x, detector) +@test jacobian_sparsity(f_comp!, similar(y_comp), x_comp, detector) == + jacobian_sparsity(f!, similar(y), x, detector) +@test hessian_sparsity(sum ∘ f_comp, x_comp, detector) == + hessian_sparsity(sum ∘ f, x, detector) diff --git a/test/runtests.jl b/test/runtests.jl index 0186e6a1..4b19a515 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -90,6 +90,9 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core") @testset "Array overloads" begin include("test_arrays.jl") end + @testset "ComponentArrays" begin + include("componentarrays.jl") + end end end From a4527b13728a5bd7dea141f4772abc5a634ccc63 Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 7 Aug 2024 16:02:54 +0200 Subject: [PATCH 2/4] Fix: `similar` isn't always viable --- src/SparseConnectivityTracer.jl | 2 +- src/interface.jl | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 47d421e4..5f91d426 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -6,7 +6,7 @@ using SparseArrays: SparseArrays using SparseArrays: sparse using Random: AbstractRNG, SamplerType -using LinearAlgebra: LinearAlgebra +using LinearAlgebra: LinearAlgebra, Symmetric, Diagonal using FillArrays: Fill using DocStringExtensions diff --git a/src/interface.jl b/src/interface.jl index 89182127..52dab76c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -16,11 +16,19 @@ Supports [`GradientTracer`](@ref), [`HessianTracer`](@ref) and [`Dual`](@ref). """ trace_input(::Type{T}, xs) where {T<:Union{AbstractTracer,Dual}} = trace_input(T, xs, 1) +# If possible, this should call `similar` and have a function signature `A -> A`. +# For some array types like `Symmetric`, this function signature isn't possible, +# e.g. because symmetry doesn't hold for the index matrix. +allocate_index_matrix(A::AbstractArray) = similar(A, Int) +allocate_index_matrix(A::Symmetric) = Matrix{Int}(undef, size(A)...) +allocate_index_matrix(A::Diagonal) = Matrix{Int}(undef, size(A)...) + function trace_input(::Type{T}, xs::AbstractArray, i) where {T<:Union{AbstractTracer,Dual}} - is = similar(xs, Int) # same array type as xs + is = allocate_index_matrix(xs) is .= reshape(1:length(xs), size(xs)) .+ (i - 1) return create_tracers(T, xs, is) end + function trace_input(::Type{T}, x::Real, i::Integer) where {T<:Union{AbstractTracer,Dual}} return only(create_tracers(T, [x], [i])) end From d3939b91b17f8bfc69cd79f0c0962b5411eae2c2 Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 7 Aug 2024 16:03:29 +0200 Subject: [PATCH 3/4] Update version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8a00512a..eae1d812 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseConnectivityTracer" uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" authors = ["Adrian Hill "] -version = "0.6.1" +version = "0.6.1-DEV" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From f48d3ffac30e851948cbea45760bdce67194f504 Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 7 Aug 2024 16:06:57 +0200 Subject: [PATCH 4/4] Improve comment --- src/interface.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 52dab76c..11765957 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -16,9 +16,9 @@ Supports [`GradientTracer`](@ref), [`HessianTracer`](@ref) and [`Dual`](@ref). """ trace_input(::Type{T}, xs) where {T<:Union{AbstractTracer,Dual}} = trace_input(T, xs, 1) -# If possible, this should call `similar` and have a function signature `A -> A`. -# For some array types like `Symmetric`, this function signature isn't possible, -# e.g. because symmetry doesn't hold for the index matrix. +# If possible, this should call `similar` and have the function signature `A{T} -> A{Int}`. +# For some array types, this function signature isn't possible, +# e.g. on `Symmetric`, where symmetry doesn't hold for the index matrix. allocate_index_matrix(A::AbstractArray) = similar(A, Int) allocate_index_matrix(A::Symmetric) = Matrix{Int}(undef, size(A)...) allocate_index_matrix(A::Diagonal) = Matrix{Int}(undef, size(A)...)