diff --git a/CHANGELOG.md b/CHANGELOG.md index d512f5e..8b4559a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # SparseConnectivityTracer.jl +## Version `v0.6.8` + +* ![Feature][badge-feature] Support `clamp` and `clamp!` ([#208]) + ## Version `v0.6.7` * ![Enhancement][badge-enhancement] Drop compatibility with Julia <1.10 to improve tracer performance ([#204], [#205]) @@ -80,6 +84,7 @@ [badge-maintenance]: https://img.shields.io/badge/maintenance-gray.svg [badge-docs]: https://img.shields.io/badge/docs-orange.svg +[#208]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/208 [#205]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/205 [#204]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/204 [#202]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/202 diff --git a/Project.toml b/Project.toml index 00d02c1..88aa774 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseConnectivityTracer" uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" authors = ["Adrian Hill "] -version = "0.6.7" +version = "0.6.8-DEV" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index dfa0d7c..8ecd652 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -23,11 +23,12 @@ include("operators.jl") include("overloads/conversion.jl") include("overloads/gradient_tracer.jl") include("overloads/hessian_tracer.jl") +include("overloads/utils.jl") include("overloads/special_cases.jl") +include("overloads/three_arg.jl") include("overloads/ifelse_global.jl") include("overloads/dual.jl") include("overloads/arrays.jl") -include("overloads/utils.jl") include("overloads/ambiguities.jl") include("trace_functions.jl") diff --git a/src/overloads/arrays.jl b/src/overloads/arrays.jl index 903f044..5901f91 100644 --- a/src/overloads/arrays.jl +++ b/src/overloads/arrays.jl @@ -1,48 +1,3 @@ -""" - second_order_or(tracers) - -Compute the most conservative elementwise OR of tracer sparsity patterns, -including second-order interactions to update the `hessian` field of `HessianTracer`. - -This is functionally equivalent to: -```julia -reduce(^, tracers) -``` -""" -function second_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer} - # TODO: improve performance - return reduce(second_order_or, ts; init=myempty(T)) -end - -function second_order_or(a::T, b::T) where {T<:GradientTracer} - return gradient_tracer_2_to_1(a, b, false, false) -end -function second_order_or(a::T, b::T) where {T<:HessianTracer} - return hessian_tracer_2_to_1(a, b, false, false, false, false, false) -end - -""" - first_order_or(tracers) - -Compute the most conservative elementwise OR of tracer sparsity patterns, -excluding second-order interactions of `HessianTracer`. - -This is functionally equivalent to: -```julia -reduce(+, tracers) -``` -""" -function first_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer} - # TODO: improve performance - return reduce(first_order_or, ts; init=myempty(T)) -end -function first_order_or(a::T, b::T) where {T<:GradientTracer} - return gradient_tracer_2_to_1(a, b, false, false) -end -function first_order_or(a::T, b::T) where {T<:HessianTracer} - return hessian_tracer_2_to_1(a, b, false, true, false, true, true) -end - #===========# # Utilities # #===========# @@ -168,6 +123,18 @@ function Base.literal_pow(::typeof(^), D::Diagonal{T}, ::Val{0}) where {T<:Abstr return Diagonal(ts) end +## clamp! +Base.clamp!(A::AbstractArray{T}, lo, hi) where {T<:AbstractTracer} = A +function Base.clamp!(A::AbstractArray{T}, lo::T, hi) where {T<:AbstractTracer} + return first_order_or.(A, lo) +end +function Base.clamp!(A::AbstractArray{T}, lo, hi::T) where {T<:AbstractTracer} + return first_order_or.(A, hi) +end +function Base.clamp!(A::AbstractArray{T}, lo::T, hi::T) where {T<:AbstractTracer} + return first_order_or.(A, first_order_or(lo, hi)) +end + #==========================# # LinearAlgebra.jl on Dual # #==========================# diff --git a/src/overloads/three_arg.jl b/src/overloads/three_arg.jl new file mode 100644 index 0000000..e28c10b --- /dev/null +++ b/src/overloads/three_arg.jl @@ -0,0 +1,12 @@ +#= For now, three-argument functions are overloaded individually. +If this file grows too large: + 1. 3-arg operators should be classified in src/operators.jl + 2. the classification should be tested in test/classification.jl + 3. code generation utilities should be added to the src/overloads/*_tracer.jl files +=# +Base.clamp(t::T, lo, hi) where {T<:AbstractTracer} = t +Base.clamp(t::T, lo::T, hi) where {T<:AbstractTracer} = first_order_or(t, lo) +Base.clamp(t::T, lo, hi::T) where {T<:AbstractTracer} = first_order_or(t, hi) +function Base.clamp(t::T, lo::T, hi::T) where {T<:AbstractTracer} + return first_order_or(t, first_order_or(lo, hi)) +end diff --git a/src/overloads/utils.jl b/src/overloads/utils.jl index 585b712..c70f9cf 100644 --- a/src/overloads/utils.jl +++ b/src/overloads/utils.jl @@ -1,3 +1,56 @@ +#===============# +# Tracer unions # +#===============# + +""" + first_order_or(tracers) + +Compute the most conservative elementwise OR of tracer sparsity patterns, +excluding second-order interactions of `HessianTracer`. + +This is functionally equivalent to: +```julia +reduce(+, tracers) +``` +""" +function first_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer} + # TODO: improve performance + return reduce(first_order_or, ts; init=myempty(T)) +end +function first_order_or(a::T, b::T) where {T<:GradientTracer} + return gradient_tracer_2_to_1(a, b, false, false) +end +function first_order_or(a::T, b::T) where {T<:HessianTracer} + return hessian_tracer_2_to_1(a, b, false, true, false, true, true) +end + +""" + second_order_or(tracers) + +Compute the most conservative elementwise OR of tracer sparsity patterns, +including second-order interactions to update the `hessian` field of `HessianTracer`. + +This is functionally equivalent to: +```julia +reduce(^, tracers) +``` +""" +function second_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer} + # TODO: improve performance + return reduce(second_order_or, ts; init=myempty(T)) +end + +function second_order_or(a::T, b::T) where {T<:GradientTracer} + return gradient_tracer_2_to_1(a, b, false, false) +end +function second_order_or(a::T, b::T) where {T<:HessianTracer} + return hessian_tracer_2_to_1(a, b, false, false, false, false, false) +end + +#=================# +# Code generation # +#=================# + dims = (Symbol("1_to_1"), Symbol("2_to_1"), Symbol("1_to_2")) # Generate both Gradient and Hessian code with one call to `generate_code_X_to_Y` diff --git a/test/test_arrays.jl b/test/test_arrays.jl index 0525a16..1e3ff74 100644 --- a/test/test_arrays.jl +++ b/test/test_arrays.jl @@ -330,6 +330,41 @@ S = BitSet P = IndexSetGradientPattern{Int,S} TG = GradientTracer{P} +@testset "clamp!" begin + t1 = TG(P(S(1))) + t2 = TG(P(S(2))) + t3 = TG(P(S(3))) + t4 = TG(P(S(4))) + A = [t1 t2; t3 t4] + + t_lo = TG(P(S(5))) + t_hi = TG(P(S(6))) + + out = clamp!(A, 0.0, 1.0) + @test SCT.gradient(out[1, 1]) == S(1) + @test SCT.gradient(out[1, 2]) == S(2) + @test SCT.gradient(out[2, 1]) == S(3) + @test SCT.gradient(out[2, 2]) == S(4) + + out = clamp!(A, t_lo, 1.0) + @test SCT.gradient(out[1, 1]) == S([1, 5]) + @test SCT.gradient(out[1, 2]) == S([2, 5]) + @test SCT.gradient(out[2, 1]) == S([3, 5]) + @test SCT.gradient(out[2, 2]) == S([4, 5]) + + out = clamp!(A, 0.0, t_hi) + @test SCT.gradient(out[1, 1]) == S([1, 6]) + @test SCT.gradient(out[1, 2]) == S([2, 6]) + @test SCT.gradient(out[2, 1]) == S([3, 6]) + @test SCT.gradient(out[2, 2]) == S([4, 6]) + + out = clamp!(A, t_lo, t_hi) + @test SCT.gradient(out[1, 1]) == S([1, 5, 6]) + @test SCT.gradient(out[1, 2]) == S([2, 5, 6]) + @test SCT.gradient(out[2, 1]) == S([3, 5, 6]) + @test SCT.gradient(out[2, 2]) == S([4, 5, 6]) +end + @testset "Matrix division" begin t1 = TG(P(S([1, 3, 4]))) t2 = TG(P(S([2, 4]))) diff --git a/test/test_gradient.jl b/test/test_gradient.jl index 360774d..2c515e1 100644 --- a/test/test_gradient.jl +++ b/test/test_gradient.jl @@ -69,6 +69,13 @@ J(f, x) = jacobian_sparsity(f, x, detector) @test J(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;] end + @testset "Three-argument operators" begin + @test J(x -> clamp(x, 0.0, 1.0), rand()) == [1;;] + @test J(x -> clamp(x[1], x[2], 1.0), rand(2)) == [1 1] + @test J(x -> clamp(x[1], 0.0, x[2]), rand(2)) == [1 1] + @test J(x -> clamp(x[1], x[2], x[3]), rand(3)) == [1 1 1] + end + @testset "Random" begin @test J(x -> rand(typeof(x)), 1) ≈ [0;;] @test J(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;] @@ -219,6 +226,18 @@ end @test J(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;] end + @testset "Three-argument operators" begin + @test J(x -> clamp(x, 0.0, 1.0), 0.5) == [1;;] + @test J(x -> clamp(x, 0.0, 1.0), -0.5) == [0;;] + @test J(x -> clamp(x[1], x[2], 1.0), [0.5, 0.0]) == [1 0] + @test J(x -> clamp(x[1], x[2], 1.0), [0.5, 0.6]) == [0 1] + @test J(x -> clamp(x[1], 0.0, x[2]), [0.5, 1.0]) == [1 0] + @test J(x -> clamp(x[1], 0.0, x[2]), [0.5, 0.4]) == [0 1] + @test J(x -> clamp(x[1], x[2], x[3]), [0.5, 0.0, 1.0]) == [1 0 0] + @test J(x -> clamp(x[1], x[2], x[3]), [0.5, 0.6, 1.0]) == [0 1 0] + @test J(x -> clamp(x[1], x[2], x[3]), [0.5, 0.0, 0.4]) == [0 0 1] + end + @testset "Random" begin @test J(x -> rand(typeof(x)), 1) ≈ [0;;] @test J(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;] diff --git a/test/test_hessian.jl b/test/test_hessian.jl index f47014f..515f5c3 100644 --- a/test/test_hessian.jl +++ b/test/test_hessian.jl @@ -69,6 +69,16 @@ D = Dual{Int,T} @test H(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;] end + @testset "Three-argument operators" begin + @test H(x -> clamp(x, 0.1, 0.9), rand()) == [0;;] + @test H(x -> clamp(x[1], x[2], 0.9), rand(2)) == [0 0; 0 0] + @test H(x -> clamp(x[1], 0.1, x[2]), rand(2)) == [0 0; 0 0] + @test H(x -> clamp(x[1], x[2], x[3]), rand(3)) == [0 0 0; 0 0 0; 0 0 0] + @test H(x -> x[1] * clamp(x[1], x[2], x[3]), rand(3)) == [1 1 1; 1 0 0; 1 0 0] + @test H(x -> x[2] * clamp(x[1], x[2], x[3]), rand(3)) == [0 1 0; 1 1 1; 0 1 0] + @test H(x -> x[3] * clamp(x[1], x[2], x[3]), rand(3)) == [0 0 1; 0 0 1; 1 1 1] + end + @testset "Random" begin @test H(x -> rand(typeof(x)), 1) ≈ [0;;] @test H(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;] @@ -377,6 +387,21 @@ end @test H(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;] end + @testset "Three-argument operators" begin + @test H(x -> x * clamp(x, 0.0, 1.0), 0.5) == [1;;] + @test H(x -> x * clamp(x, 0.0, 1.0), -0.5) == [0;;] + @test H(x -> sum(x) * clamp(x[1], x[2], 1.0), [0.5, 0.0]) == [1 1; 1 0] + @test H(x -> sum(x) * clamp(x[1], x[2], 1.0), [0.5, 0.6]) == [0 1; 1 1] + @test H(x -> sum(x) * clamp(x[1], 0.0, x[2]), [0.5, 1.0]) == [1 1; 1 0] + @test H(x -> sum(x) * clamp(x[1], 0.0, x[2]), [0.5, 0.4]) == [0 1; 1 1] + @test H(x -> sum(x) * clamp(x[1], x[2], x[3]), [0.5, 0.0, 1.0]) == + [1 1 1; 1 0 0; 1 0 0] + @test H(x -> sum(x) * clamp(x[1], x[2], x[3]), [0.5, 0.6, 1.0]) == + [0 1 0; 1 1 1; 0 1 0] + @test H(x -> sum(x) * clamp(x[1], x[2], x[3]), [0.5, 0.0, 0.4]) == + [0 0 1; 0 0 1; 1 1 1] + end + @testset "Random" begin @test H(x -> rand(typeof(x)), 1) ≈ [0;;] @test H(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;]