Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support clamp and clamp! #208

Merged
merged 4 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# SparseConnectivityTracer.jl

## Version `v0.6.8`

* ![Feature][badge-feature] Support `clamp` and `clamp!` ([#208])

## Version `v0.6.7`

* ![Enhancement][badge-enhancement] Drop compatibility with Julia <1.10 to improve tracer performance ([#204], [#205])
Expand Down Expand Up @@ -80,6 +84,7 @@
[badge-maintenance]: https://img.shields.io/badge/maintenance-gray.svg
[badge-docs]: https://img.shields.io/badge/docs-orange.svg

[#208]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/208
[#205]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/205
[#204]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/204
[#202]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/202
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseConnectivityTracer"
uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
authors = ["Adrian Hill <[email protected]>"]
version = "0.6.7"
version = "0.6.8-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ include("operators.jl")
include("overloads/conversion.jl")
include("overloads/gradient_tracer.jl")
include("overloads/hessian_tracer.jl")
include("overloads/utils.jl")
include("overloads/special_cases.jl")
include("overloads/three_arg.jl")
include("overloads/ifelse_global.jl")
include("overloads/dual.jl")
include("overloads/arrays.jl")
include("overloads/utils.jl")
include("overloads/ambiguities.jl")

include("trace_functions.jl")
Expand Down
57 changes: 12 additions & 45 deletions src/overloads/arrays.jl
Original file line number Diff line number Diff line change
@@ -1,48 +1,3 @@
"""
second_order_or(tracers)

Compute the most conservative elementwise OR of tracer sparsity patterns,
including second-order interactions to update the `hessian` field of `HessianTracer`.

This is functionally equivalent to:
```julia
reduce(^, tracers)
```
"""
function second_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer}
# TODO: improve performance
return reduce(second_order_or, ts; init=myempty(T))
end

function second_order_or(a::T, b::T) where {T<:GradientTracer}
return gradient_tracer_2_to_1(a, b, false, false)
end
function second_order_or(a::T, b::T) where {T<:HessianTracer}
return hessian_tracer_2_to_1(a, b, false, false, false, false, false)
end

"""
first_order_or(tracers)

Compute the most conservative elementwise OR of tracer sparsity patterns,
excluding second-order interactions of `HessianTracer`.

This is functionally equivalent to:
```julia
reduce(+, tracers)
```
"""
function first_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer}
# TODO: improve performance
return reduce(first_order_or, ts; init=myempty(T))
end
function first_order_or(a::T, b::T) where {T<:GradientTracer}
return gradient_tracer_2_to_1(a, b, false, false)
end
function first_order_or(a::T, b::T) where {T<:HessianTracer}
return hessian_tracer_2_to_1(a, b, false, true, false, true, true)
end

#===========#
# Utilities #
#===========#
Expand Down Expand Up @@ -168,6 +123,18 @@ function Base.literal_pow(::typeof(^), D::Diagonal{T}, ::Val{0}) where {T<:Abstr
return Diagonal(ts)
end

## clamp!
Base.clamp!(A::AbstractArray{T}, lo, hi) where {T<:AbstractTracer} = A
function Base.clamp!(A::AbstractArray{T}, lo::T, hi) where {T<:AbstractTracer}
return first_order_or.(A, lo)
end
function Base.clamp!(A::AbstractArray{T}, lo, hi::T) where {T<:AbstractTracer}
return first_order_or.(A, hi)
end
function Base.clamp!(A::AbstractArray{T}, lo::T, hi::T) where {T<:AbstractTracer}
return first_order_or.(A, first_order_or(lo, hi))
end

#==========================#
# LinearAlgebra.jl on Dual #
#==========================#
Expand Down
12 changes: 12 additions & 0 deletions src/overloads/three_arg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#= For now, three-argument functions are overloaded individually.
If this file grows too large:
1. 3-arg operators should be classified in src/operators.jl
2. the classification should be tested in test/classification.jl
3. code generation utilities should be added to the src/overloads/*_tracer.jl files
=#
Base.clamp(t::T, lo, hi) where {T<:AbstractTracer} = t
Base.clamp(t::T, lo::T, hi) where {T<:AbstractTracer} = first_order_or(t, lo)
Base.clamp(t::T, lo, hi::T) where {T<:AbstractTracer} = first_order_or(t, hi)
function Base.clamp(t::T, lo::T, hi::T) where {T<:AbstractTracer}
return first_order_or(t, first_order_or(lo, hi))
end
53 changes: 53 additions & 0 deletions src/overloads/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,56 @@
#===============#
# Tracer unions #
#===============#

"""
first_order_or(tracers)

Compute the most conservative elementwise OR of tracer sparsity patterns,
excluding second-order interactions of `HessianTracer`.

This is functionally equivalent to:
```julia
reduce(+, tracers)
```
"""
function first_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer}
# TODO: improve performance
return reduce(first_order_or, ts; init=myempty(T))
end
function first_order_or(a::T, b::T) where {T<:GradientTracer}
return gradient_tracer_2_to_1(a, b, false, false)
end
function first_order_or(a::T, b::T) where {T<:HessianTracer}
return hessian_tracer_2_to_1(a, b, false, true, false, true, true)
end

"""
second_order_or(tracers)

Compute the most conservative elementwise OR of tracer sparsity patterns,
including second-order interactions to update the `hessian` field of `HessianTracer`.

This is functionally equivalent to:
```julia
reduce(^, tracers)
```
"""
function second_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer}
# TODO: improve performance
return reduce(second_order_or, ts; init=myempty(T))
end

function second_order_or(a::T, b::T) where {T<:GradientTracer}
return gradient_tracer_2_to_1(a, b, false, false)
end
function second_order_or(a::T, b::T) where {T<:HessianTracer}
return hessian_tracer_2_to_1(a, b, false, false, false, false, false)
end

#=================#
# Code generation #
#=================#

dims = (Symbol("1_to_1"), Symbol("2_to_1"), Symbol("1_to_2"))

# Generate both Gradient and Hessian code with one call to `generate_code_X_to_Y`
Expand Down
35 changes: 35 additions & 0 deletions test/test_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,41 @@ S = BitSet
P = IndexSetGradientPattern{Int,S}
TG = GradientTracer{P}

@testset "clamp!" begin
t1 = TG(P(S(1)))
t2 = TG(P(S(2)))
t3 = TG(P(S(3)))
t4 = TG(P(S(4)))
A = [t1 t2; t3 t4]

t_lo = TG(P(S(5)))
t_hi = TG(P(S(6)))

out = clamp!(A, 0.0, 1.0)
@test SCT.gradient(out[1, 1]) == S(1)
@test SCT.gradient(out[1, 2]) == S(2)
@test SCT.gradient(out[2, 1]) == S(3)
@test SCT.gradient(out[2, 2]) == S(4)

out = clamp!(A, t_lo, 1.0)
@test SCT.gradient(out[1, 1]) == S([1, 5])
@test SCT.gradient(out[1, 2]) == S([2, 5])
@test SCT.gradient(out[2, 1]) == S([3, 5])
@test SCT.gradient(out[2, 2]) == S([4, 5])

out = clamp!(A, 0.0, t_hi)
@test SCT.gradient(out[1, 1]) == S([1, 6])
@test SCT.gradient(out[1, 2]) == S([2, 6])
@test SCT.gradient(out[2, 1]) == S([3, 6])
@test SCT.gradient(out[2, 2]) == S([4, 6])

out = clamp!(A, t_lo, t_hi)
@test SCT.gradient(out[1, 1]) == S([1, 5, 6])
@test SCT.gradient(out[1, 2]) == S([2, 5, 6])
@test SCT.gradient(out[2, 1]) == S([3, 5, 6])
@test SCT.gradient(out[2, 2]) == S([4, 5, 6])
end

@testset "Matrix division" begin
t1 = TG(P(S([1, 3, 4])))
t2 = TG(P(S([2, 4])))
Expand Down
19 changes: 19 additions & 0 deletions test/test_gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ J(f, x) = jacobian_sparsity(f, x, detector)
@test J(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;]
end

@testset "Three-argument operators" begin
@test J(x -> clamp(x, 0.0, 1.0), rand()) == [1;;]
@test J(x -> clamp(x[1], x[2], 1.0), rand(2)) == [1 1]
@test J(x -> clamp(x[1], 0.0, x[2]), rand(2)) == [1 1]
@test J(x -> clamp(x[1], x[2], x[3]), rand(3)) == [1 1 1]
end

@testset "Random" begin
@test J(x -> rand(typeof(x)), 1) ≈ [0;;]
@test J(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;]
Expand Down Expand Up @@ -219,6 +226,18 @@ end
@test J(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;]
end

@testset "Three-argument operators" begin
@test J(x -> clamp(x, 0.0, 1.0), 0.5) == [1;;]
@test J(x -> clamp(x, 0.0, 1.0), -0.5) == [0;;]
@test J(x -> clamp(x[1], x[2], 1.0), [0.5, 0.0]) == [1 0]
@test J(x -> clamp(x[1], x[2], 1.0), [0.5, 0.6]) == [0 1]
@test J(x -> clamp(x[1], 0.0, x[2]), [0.5, 1.0]) == [1 0]
@test J(x -> clamp(x[1], 0.0, x[2]), [0.5, 0.4]) == [0 1]
@test J(x -> clamp(x[1], x[2], x[3]), [0.5, 0.0, 1.0]) == [1 0 0]
@test J(x -> clamp(x[1], x[2], x[3]), [0.5, 0.6, 1.0]) == [0 1 0]
@test J(x -> clamp(x[1], x[2], x[3]), [0.5, 0.0, 0.4]) == [0 0 1]
end

@testset "Random" begin
@test J(x -> rand(typeof(x)), 1) ≈ [0;;]
@test J(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;]
Expand Down
25 changes: 25 additions & 0 deletions test/test_hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ D = Dual{Int,T}
@test H(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;]
end

@testset "Three-argument operators" begin
@test H(x -> clamp(x, 0.1, 0.9), rand()) == [0;;]
@test H(x -> clamp(x[1], x[2], 0.9), rand(2)) == [0 0; 0 0]
@test H(x -> clamp(x[1], 0.1, x[2]), rand(2)) == [0 0; 0 0]
@test H(x -> clamp(x[1], x[2], x[3]), rand(3)) == [0 0 0; 0 0 0; 0 0 0]
@test H(x -> x[1] * clamp(x[1], x[2], x[3]), rand(3)) == [1 1 1; 1 0 0; 1 0 0]
@test H(x -> x[2] * clamp(x[1], x[2], x[3]), rand(3)) == [0 1 0; 1 1 1; 0 1 0]
@test H(x -> x[3] * clamp(x[1], x[2], x[3]), rand(3)) == [0 0 1; 0 0 1; 1 1 1]
end

@testset "Random" begin
@test H(x -> rand(typeof(x)), 1) ≈ [0;;]
@test H(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;]
Expand Down Expand Up @@ -377,6 +387,21 @@ end
@test H(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;]
end

@testset "Three-argument operators" begin
@test H(x -> x * clamp(x, 0.0, 1.0), 0.5) == [1;;]
@test H(x -> x * clamp(x, 0.0, 1.0), -0.5) == [0;;]
@test H(x -> sum(x) * clamp(x[1], x[2], 1.0), [0.5, 0.0]) == [1 1; 1 0]
@test H(x -> sum(x) * clamp(x[1], x[2], 1.0), [0.5, 0.6]) == [0 1; 1 1]
@test H(x -> sum(x) * clamp(x[1], 0.0, x[2]), [0.5, 1.0]) == [1 1; 1 0]
@test H(x -> sum(x) * clamp(x[1], 0.0, x[2]), [0.5, 0.4]) == [0 1; 1 1]
@test H(x -> sum(x) * clamp(x[1], x[2], x[3]), [0.5, 0.0, 1.0]) ==
[1 1 1; 1 0 0; 1 0 0]
@test H(x -> sum(x) * clamp(x[1], x[2], x[3]), [0.5, 0.6, 1.0]) ==
[0 1 0; 1 1 1; 0 1 0]
@test H(x -> sum(x) * clamp(x[1], x[2], x[3]), [0.5, 0.0, 0.4]) ==
[0 0 1; 0 0 1; 1 1 1]
end

@testset "Random" begin
@test H(x -> rand(typeof(x)), 1) ≈ [0;;]
@test H(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;]
Expand Down
Loading