-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Switch to ADTypes.jl v1.0 --------- Co-authored-by: adrhill <[email protected]>
- Loading branch information
Showing
12 changed files
with
65 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,15 +4,11 @@ authors = ["Adrian Hill <[email protected]>"] | |
version = "0.1.0" | ||
|
||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" | ||
|
||
[weakdeps] | ||
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" | ||
|
||
[extensions] | ||
SparseConnectivityTracerSparseDiffToolsExt = "SparseDiffTools" | ||
|
||
[compat] | ||
SparseDiffTools = "2.17" | ||
ADTypes = "1" | ||
SparseArrays = "1" | ||
julia = "1.6" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" | ||
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" | ||
|
||
[compat] | ||
ADTypes = "1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
""" | ||
TracerSparsityDetector <: ADTypes.AbstractSparsityDetector | ||
Singleton struct for integration with the sparsity detection framework of ADTypes.jl. | ||
# Example | ||
```jldoctest | ||
julia> using ADTypes, SparseConnectivityTracer | ||
julia> ADTypes.jacobian_sparsity(diff, rand(4), TracerSparsityDetector()) | ||
3×4 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 6 stored entries: | ||
1 1 ⋅ ⋅ | ||
⋅ 1 1 ⋅ | ||
⋅ ⋅ 1 1 | ||
``` | ||
""" | ||
struct TracerSparsityDetector <: ADTypes.AbstractSparsityDetector end | ||
|
||
function ADTypes.jacobian_sparsity(f, x, ::TracerSparsityDetector) | ||
return connectivity(f, x) | ||
end | ||
|
||
function ADTypes.jacobian_sparsity(f!, y, x, ::TracerSparsityDetector) | ||
return connectivity(f!, y, x) | ||
end | ||
|
||
function ADTypes.hessian_sparsity(f, x, ::TracerSparsityDetector) | ||
return error("Hessian sparsity is not yet implemented for `TracerSparsityDetector`.") | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
using ADTypes | ||
using SparseConnectivityTracer | ||
using SparseArrays | ||
using Test | ||
|
||
sd = TracerSparsityDetector() | ||
|
||
x = rand(10) | ||
y = zeros(9) | ||
J1 = ADTypes.jacobian_sparsity(diff, x, sd) | ||
J2 = ADTypes.jacobian_sparsity((y, x) -> y .= diff(x), y, x, sd) | ||
@test J1 == J2 | ||
@test J1 isa SparseMatrixCSC | ||
@test J2 isa SparseMatrixCSC | ||
@test nnz(J1) == nnz(J2) == 18 | ||
@test_throws ErrorException ADTypes.hessian_sparsity(sum, x, sd) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.