From 103cf56469b964ee31c4f93abce52ec4b72ba8ef Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 16 May 2024 17:38:30 +0200 Subject: [PATCH] Fixes --- src/conversion.jl | 6 +++--- src/overload_connectivity.jl | 10 +++++----- src/overload_gradient.jl | 10 +++++----- src/overload_hessian.jl | 14 +++++++------- src/pattern.jl | 11 ++++++----- 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/conversion.jl b/src/conversion.jl index 21616758..6d712e28 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -90,11 +90,11 @@ function Base.similar(a::Array{D,2}) where {P,T,D<:Dual{P,T}} p_out = similar(primal.(a)) return Dual.(p_out, empty(T)) end -function Base.similar(a::Array{A,1}, ::Type{D}) where {A,P,T,D<:Dual{P,T}} +function Base.similar(a::Array{A,1}, ::Type{D}) where {P,T,D<:Dual{P,T},A} p_out = similar(a, P) return Dual.(p_out, empty(T)) end -function Base.similar(a::Array{A,2}, ::Type{D}) where {A,P,T,D<:Dual{P,T}} +function Base.similar(a::Array{A,2}, ::Type{D}) where {P,T,D<:Dual{P,T},A} p_out = similar(a, P) return Dual.(p_out, empty(T)) end @@ -102,7 +102,7 @@ function Base.similar(a::Array{D}, m::Int) where {P,T,D<:Dual{P,T}} p_out = similar(primal.(a), m) return Dual.(p_out, empty(T)) end -function Base.similar(a::Array{D}, dims::Dims{N}) where {N,D<:Dual} +function Base.similar(a::Array{D}, dims::Dims{N}) where {P,T,D<:Dual{P,T}, N} p_out = similar(primal.(a), dims) return Dual.(p_out, empty(T)) end diff --git a/src/overload_connectivity.jl b/src/overload_connectivity.jl index 376169b3..78fdbdc3 100644 --- a/src/overload_connectivity.jl +++ b/src/overload_connectivity.jl @@ -32,14 +32,14 @@ for fn in ops_1_to_2 end # Extra types required for exponent -for T in (Real, Integer, Rational, Irrational{:ℯ}) - Base.:^(t::ConnectivityTracer, ::T) = t - function Base.:^(dx::D, y::Number) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} +for S in (Real, Integer, Rational, Irrational{:ℯ}) + Base.:^(t::ConnectivityTracer, ::S) = t + function Base.:^(dx::D, y::S) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} return Dual(primal(dx)^y, tracer(dx)) end - Base.:^(::T, t::ConnectivityTracer) = t - function Base.:^(x::Number, dy::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} + Base.:^(::S, t::ConnectivityTracer) = t + function Base.:^(x::S, dy::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} return Dual(x^primal(dy), tracer(dy)) end end diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index 3a01915c..183eee05 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -124,14 +124,14 @@ for fn in ops_1_to_2 end # Extra types required for exponent -for T in (Real, Integer, Rational, Irrational{:ℯ}) - Base.:^(t::GradientTracer, ::T) = t - Base.:^(::T, t::GradientTracer) = t +for S in (Real, Integer, Rational, Irrational{:ℯ}) + Base.:^(t::GradientTracer, ::S) = t + Base.:^(::S, t::GradientTracer) = t - function Base.:^(dx::D, y::T) where {P,GT<:GradientTracer,D<:Dual{P,GT}} + function Base.:^(dx::D, y::S) where {P,T<:GradientTracer,D<:Dual{P,T}} return Dual(primal(dx)^y, tracer(dx)) end - function Base.:^(x::T, dy::D) where {P,GT<:GradientTracer,D<:Dual{P,GT}} + function Base.:^(x::S, dy::D) where {P,T<:GradientTracer,D<:Dual{P,T}} return Dual(x^primal(dy), tracer(dy)) end end diff --git a/src/overload_hessian.jl b/src/overload_hessian.jl index e70206a5..bfbb29d6 100644 --- a/src/overload_hessian.jl +++ b/src/overload_hessian.jl @@ -38,7 +38,7 @@ function hessian_tracer_2_to_1( is_firstder_arg2_zero::Bool, is_seconder_arg2_zero::Bool, is_crossder_zero::Bool, -) where {T<:HessianTracer} +) where {G,H,T<:HessianTracer{G,H}} grad = empty(G) hess = empty(H) if !is_firstder_arg1_zero @@ -63,7 +63,7 @@ end function hessian_tracer_2_to_1_one_tracer( t::T, is_firstder_zero::Bool, is_seconder_zero::Bool -) where {T<:GradientTracer} +) where {T<:HessianTracer} # NOTE: this is identical to hessian_tracer_1_to_1 due to ignored second argument having empty set # TODO: remove once gdalle agrees if is_seconder_zero @@ -177,12 +177,12 @@ end # TODO: support Dual tracers for these. # Extra types required for exponent -for T in (Real, Integer, Rational, Irrational{:ℯ}) - function Base.:^(t::HT, ::T) where {HT<:HessianTracer} - return HT(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) +for S in (Real, Integer, Rational, Irrational{:ℯ}) + function Base.:^(t::T, ::S) where {T<:HessianTracer} + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) end - function Base.:^(::T, t::HT) where {HT<:HessianTracer} - return HT(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) + function Base.:^(::S, t::T) where {T<:HessianTracer} + return T(gradient(t), hessian(t) ∪ (gradient(t) × gradient(t))) end end diff --git a/src/pattern.jl b/src/pattern.jl index 4f3ad688..6c413696 100644 --- a/src/pattern.jl +++ b/src/pattern.jl @@ -182,14 +182,14 @@ function local_jacobian_pattern(f!, y, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where { end function jacobian_pattern_to_mat( - xt::AbstractArray{T}, yt::AbstractArray{<:Number} -) where {P,G<:GradientTracer,T<:Union{G,Dual{P,G}}} + xt::AbstractArray{TT}, yt::AbstractArray{<:Number} +) where {P,T<:GradientTracer,D<:Dual{P,T},TT<:Union{T,D}} n, m = length(xt), length(yt) I = Int[] # row indices J = Int[] # column indices V = Bool[] # values for (i, y) in enumerate(yt) - if y isa T + if y isa TT for j in gradient(y) push!(I, i) push!(J, j) @@ -285,8 +285,9 @@ function local_hessian_pattern( end function hessian_pattern_to_mat( - xt::AbstractArray{T}, yt::T -) where {P,G,H<:AbstractSet,HT<:HessianTracer{G,H},T<:Union{HT,Dual{P,HT}}} + xt::AbstractArray{TT}, yt::TT +) where {P,T<:HessianTracer,D<:Dual{P,T},TT<:Union{T,D}} + # Allocate Hessian matrix n = length(xt) I = Int[] # row indices