Skip to content

Commit

Permalink
Switch to ADTypes.jl v1.0 (#16)
Browse files Browse the repository at this point in the history
* Switch to ADTypes.jl v1.0

---------

Co-authored-by: adrhill <[email protected]>
  • Loading branch information
gdalle and adrhill authored Apr 21, 2024
1 parent f6350a9 commit 3830ca0
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 98 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ jobs:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
version: '1.10'
- uses: julia-actions/cache@v1
- name: Configure doc environment
shell: julia --project=docs --color=yes {0}
run: |
using Pkg
Pkg.Registry.update()
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()
- uses: julia-actions/julia-buildpkg@v1
Expand Down
10 changes: 3 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@ authors = ["Adrian Hill <[email protected]>"]
version = "0.1.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[extensions]
SparseConnectivityTracerSparseDiffToolsExt = "SparseDiffTools"

[compat]
SparseDiffTools = "2.17"
ADTypes = "1"
SparseArrays = "1"
julia = "1.6"
4 changes: 4 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"

[compat]
ADTypes = "1"
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ CollapsedDocStrings = true
## Interface
```@docs
connectivity
TracerSparsityDetector
```

## Internals
Expand Down
29 changes: 0 additions & 29 deletions docs/src/index.md

This file was deleted.

35 changes: 0 additions & 35 deletions ext/SparseConnectivityTracerSparseDiffToolsExt.jl

This file was deleted.

4 changes: 4 additions & 0 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
module SparseConnectivityTracer

using ADTypes: ADTypes
import Random: rand, AbstractRNG, SamplerType
import SparseArrays: sparse

Expand All @@ -7,10 +9,12 @@ include("conversion.jl")
include("operators.jl")
include("overload_tracer.jl")
include("connectivity.jl")
include("adtypes.jl")

export Tracer
export tracer, trace_input
export inputs
export connectivity
export TracerSparsityDetector

end # module
30 changes: 30 additions & 0 deletions src/adtypes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
TracerSparsityDetector <: ADTypes.AbstractSparsityDetector
Singleton struct for integration with the sparsity detection framework of ADTypes.jl.
# Example
```jldoctest
julia> using ADTypes, SparseConnectivityTracer
julia> ADTypes.jacobian_sparsity(diff, rand(4), TracerSparsityDetector())
3×4 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 6 stored entries:
1 1 ⋅ ⋅
⋅ 1 1 ⋅
⋅ ⋅ 1 1
```
"""
struct TracerSparsityDetector <: ADTypes.AbstractSparsityDetector end

function ADTypes.jacobian_sparsity(f, x, ::TracerSparsityDetector)
return connectivity(f, x)
end

function ADTypes.jacobian_sparsity(f!, y, x, ::TracerSparsityDetector)
return connectivity(f!, y, x)
end

function ADTypes.hessian_sparsity(f, x, ::TracerSparsityDetector)
return error("Hessian sparsity is not yet implemented for `TracerSparsityDetector`.")
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down
16 changes: 16 additions & 0 deletions test/adtypes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using ADTypes
using SparseConnectivityTracer
using SparseArrays
using Test

sd = TracerSparsityDetector()

x = rand(10)
y = zeros(9)
J1 = ADTypes.jacobian_sparsity(diff, x, sd)
J2 = ADTypes.jacobian_sparsity((y, x) -> y .= diff(x), y, x, sd)
@test J1 == J2
@test J1 isa SparseMatrixCSC
@test J2 isa SparseMatrixCSC
@test nnz(J1) == nnz(J2) == 18
@test_throws ErrorException ADTypes.hessian_sparsity(sum, x, sd)
7 changes: 4 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ DocMeta.setdocmeta!(
Aqua.test_all(
SparseConnectivityTracer;
ambiguities=false,
deps_compat=(ignore=[:Random, :SparseArrays],),
deps_compat=(ignore=[:Random, :SparseArrays], check_extras=false),
persistent_tasks=false,
)
end
@testset "JET tests" begin
Expand Down Expand Up @@ -91,8 +92,8 @@ DocMeta.setdocmeta!(
@test C == C_ref
end
end
@testset "SparseDiffTools integration" begin
include("sparsedifftools.jl")
@testset "ADTypes integration" begin
include("adtypes.jl")
end
@testset "Doctests" begin
Documenter.doctest(SparseConnectivityTracer)
Expand Down
23 changes: 0 additions & 23 deletions test/sparsedifftools.jl

This file was deleted.

0 comments on commit 3830ca0

Please sign in to comment.