Skip to content

Commit

Permalink
Use ADTypes interface in real world tests (#52)
Browse files Browse the repository at this point in the history
* Move benchmarks to benchmark folder

* Update benchmarks to use ADTypes interface in real world tests
  • Loading branch information
adrhill authored May 6, 2024
1 parent f20d411 commit fdbc36c
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 76 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
/docs/Manifest.toml
/docs/build/
/docs/src/index.md
/benchmark/Manifest.toml
5 changes: 5 additions & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
56 changes: 56 additions & 0 deletions benchmark/benchmark.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using ADTypes
using ADTypes: AbstractSparsityDetector
using BenchmarkTools
using SparseConnectivityTracer
using SparseConnectivityTracer: SortedVector
using NNlib: conv

include("../test/brusselator_definition.jl")

const METHODS = (
TracerSparsityDetector(BitSet),
TracerSparsityDetector(Set{UInt64}),
TracerSparsityDetector(SortedVector{UInt64}),
)

function benchmark_brusselator(N::Integer, method::AbstractSparsityDetector)
dims = (N, N, 2)
A = 1.0
B = 1.0
alpha = 1.0
xyd = fill(1.0, N)
dx = 1.0
p = (A, B, alpha, xyd, dx, N)

u = rand(dims...)
du = similar(u)
f!(du, u) = brusselator_2d_loop(du, u, p, nothing)

return @benchmark ADTypes.jacobian_sparsity($f!, $du, $u, $method)
end

function benchmark_conv(N, method::AbstractSparsityDetector)
x = rand(N, N, 3, 1) # WHCN image
w = rand(5, 5, 3, 2) # corresponds to Conv((5, 5), 3 => 2)
f(x) = conv(x, w)

return @benchmark ADTypes.jacobian_sparsity($f, $x, $method)
end

## Run Brusselator benchmarks
for N in (6, 24, 100)
for method in METHODS
@info "Benchmarking Brusselator of size $N with $method..."
b = benchmark_brusselator(N, method)
display(b)
end
end

## Run conv benchmarks
for N in (28, 128)
for method in METHODS # Symbolics fails on this example
@info "Benchmarking NNlib.conv on image of size ($N, $N, 3) with with $method..."
b = benchmark_conv(N, method)
display(b)
end
end
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
61 changes: 0 additions & 61 deletions test/benchmark.jl

This file was deleted.

19 changes: 12 additions & 7 deletions test/brusselator.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
using ADTypes
using ADTypes: AbstractSparsityDetector
using ReferenceTests
using SparseConnectivityTracer
using SparseConnectivityTracer: SortedVector
using Test

include("brusselator_definition.jl")

@testset "Set type $S" for S in (BitSet, Set{UInt64}, SortedVector{UInt64})
function test_brusselator(method::AbstractSparsityDetector)
N = 6
dims = (N, N, 2)
A = 1.0
Expand All @@ -18,12 +21,14 @@ include("brusselator_definition.jl")
du = similar(u)
f!(du, u) = brusselator_2d_loop(du, u, p, nothing)

C = connectivity_pattern(f!, du, u, S)
@test_reference "references/pattern/connectivity/Brusselator.txt" BitMatrix(C)
J = jacobian_pattern(f!, du, u, S)
J = ADTypes.jacobian_sparsity(f!, du, u, method)
@test_reference "references/pattern/jacobian/Brusselator.txt" BitMatrix(J)
@test C == J
end

C_ref = Symbolics.jacobian_sparsity(f!, du, u)
@test C == C_ref
@testset "$method" for method in (
TracerSparsityDetector(BitSet),
TracerSparsityDetector(Set{UInt64}),
TracerSparsityDetector(SortedVector{UInt64}),
)
test_brusselator(method)
end
22 changes: 16 additions & 6 deletions test/nnlib.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
using NNlib
using ADTypes
using ADTypes: AbstractSparsityDetector
using ReferenceTests
using SparseConnectivityTracer
using SparseConnectivityTracer: SortedVector
using NNlib
using Test

@testset "Set type $S" for S in (BitSet, Set{UInt64}, SortedVector{UInt64})
function test_nnlib_conv(method::AbstractSparsityDetector)
x = rand(3, 3, 2, 1) # WHCN
w = rand(2, 2, 2, 1) # Conv((2, 2), 2 => 1)
C = jacobian_pattern(x -> NNlib.conv(x, w), x, S)
@test_reference "references/pattern/connectivity/NNlib/conv.txt" BitMatrix(C)
J = jacobian_pattern(x -> NNlib.conv(x, w), x, S)
f(x) = NNlib.conv(x, w)

J = ADTypes.jacobian_sparsity(f, x, method)
@test_reference "references/pattern/jacobian/NNlib/conv.txt" BitMatrix(J)
@test C == J
end

@testset "$method" for method in (
TracerSparsityDetector(BitSet),
TracerSparsityDetector(Set{UInt64}),
TracerSparsityDetector(SortedVector{UInt64}),
)
test_nnlib_conv(method)
end
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ using Documenter

using LinearAlgebra
using Random
using Symbolics: Symbolics
using NNlib

DocMeta.setdocmeta!(
Expand Down

0 comments on commit fdbc36c

Please sign in to comment.