Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed May 16, 2024
1 parent 52cf2e5 commit 103cf56
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 25 deletions.
6 changes: 3 additions & 3 deletions src/conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,19 @@ function Base.similar(a::Array{D,2}) where {P,T,D<:Dual{P,T}}
p_out = similar(primal.(a))
return Dual.(p_out, empty(T))
end
function Base.similar(a::Array{A,1}, ::Type{D}) where {A,P,T,D<:Dual{P,T}}
function Base.similar(a::Array{A,1}, ::Type{D}) where {P,T,D<:Dual{P,T},A}
p_out = similar(a, P)
return Dual.(p_out, empty(T))
end
function Base.similar(a::Array{A,2}, ::Type{D}) where {A,P,T,D<:Dual{P,T}}
function Base.similar(a::Array{A,2}, ::Type{D}) where {P,T,D<:Dual{P,T},A}
p_out = similar(a, P)
return Dual.(p_out, empty(T))
end
function Base.similar(a::Array{D}, m::Int) where {P,T,D<:Dual{P,T}}
p_out = similar(primal.(a), m)
return Dual.(p_out, empty(T))
end
function Base.similar(a::Array{D}, dims::Dims{N}) where {N,D<:Dual}
function Base.similar(a::Array{D}, dims::Dims{N}) where {P,T,D<:Dual{P,T}, N}
p_out = similar(primal.(a), dims)
return Dual.(p_out, empty(T))
end
Expand Down
10 changes: 5 additions & 5 deletions src/overload_connectivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ for fn in ops_1_to_2
end

# Extra types required for exponent
for T in (Real, Integer, Rational, Irrational{:ℯ})
Base.:^(t::ConnectivityTracer, ::T) = t
function Base.:^(dx::D, y::Number) where {P,T<:ConnectivityTracer,D<:Dual{P,T}}
for S in (Real, Integer, Rational, Irrational{:ℯ})
Base.:^(t::ConnectivityTracer, ::S) = t
function Base.:^(dx::D, y::S) where {P,T<:ConnectivityTracer,D<:Dual{P,T}}
return Dual(primal(dx)^y, tracer(dx))
end

Base.:^(::T, t::ConnectivityTracer) = t
function Base.:^(x::Number, dy::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}}
Base.:^(::S, t::ConnectivityTracer) = t
function Base.:^(x::S, dy::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}}
return Dual(x^primal(dy), tracer(dy))
end
end
Expand Down
10 changes: 5 additions & 5 deletions src/overload_gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,14 @@ for fn in ops_1_to_2
end

# Extra types required for exponent
for T in (Real, Integer, Rational, Irrational{:ℯ})
Base.:^(t::GradientTracer, ::T) = t
Base.:^(::T, t::GradientTracer) = t
for S in (Real, Integer, Rational, Irrational{:ℯ})
Base.:^(t::GradientTracer, ::S) = t
Base.:^(::S, t::GradientTracer) = t

function Base.:^(dx::D, y::T) where {P,GT<:GradientTracer,D<:Dual{P,GT}}
function Base.:^(dx::D, y::S) where {P,T<:GradientTracer,D<:Dual{P,T}}
return Dual(primal(dx)^y, tracer(dx))
end
function Base.:^(x::T, dy::D) where {P,GT<:GradientTracer,D<:Dual{P,GT}}
function Base.:^(x::S, dy::D) where {P,T<:GradientTracer,D<:Dual{P,T}}
return Dual(x^primal(dy), tracer(dy))
end
end
Expand Down
14 changes: 7 additions & 7 deletions src/overload_hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function hessian_tracer_2_to_1(
is_firstder_arg2_zero::Bool,
is_seconder_arg2_zero::Bool,
is_crossder_zero::Bool,
) where {T<:HessianTracer}
) where {G,H,T<:HessianTracer{G,H}}
grad = empty(G)
hess = empty(H)
if !is_firstder_arg1_zero
Expand All @@ -63,7 +63,7 @@ end

function hessian_tracer_2_to_1_one_tracer(
t::T, is_firstder_zero::Bool, is_seconder_zero::Bool
) where {T<:GradientTracer}
) where {T<:HessianTracer}
# NOTE: this is identical to hessian_tracer_1_to_1 due to ignored second argument having empty set
# TODO: remove once gdalle agrees
if is_seconder_zero
Expand Down Expand Up @@ -177,12 +177,12 @@ end

# TODO: support Dual tracers for these.
# Extra types required for exponent
for T in (Real, Integer, Rational, Irrational{:ℯ})
function Base.:^(t::HT, ::T) where {HT<:HessianTracer}
return HT(gradient(t), hessian(t) (gradient(t) × gradient(t)))
for S in (Real, Integer, Rational, Irrational{:ℯ})
function Base.:^(t::T, ::S) where {T<:HessianTracer}
return T(gradient(t), hessian(t) (gradient(t) × gradient(t)))
end
function Base.:^(::T, t::HT) where {HT<:HessianTracer}
return HT(gradient(t), hessian(t) (gradient(t) × gradient(t)))
function Base.:^(::S, t::T) where {T<:HessianTracer}
return T(gradient(t), hessian(t) (gradient(t) × gradient(t)))
end
end

Expand Down
11 changes: 6 additions & 5 deletions src/pattern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,14 @@ function local_jacobian_pattern(f!, y, x, ::Type{G}=DEFAULT_VECTOR_TYPE) where {
end

function jacobian_pattern_to_mat(
xt::AbstractArray{T}, yt::AbstractArray{<:Number}
) where {P,G<:GradientTracer,T<:Union{G,Dual{P,G}}}
xt::AbstractArray{TT}, yt::AbstractArray{<:Number}
) where {P,T<:GradientTracer,D<:Dual{P,T},TT<:Union{T,D}}
n, m = length(xt), length(yt)
I = Int[] # row indices
J = Int[] # column indices
V = Bool[] # values
for (i, y) in enumerate(yt)
if y isa T
if y isa TT
for j in gradient(y)
push!(I, i)
push!(J, j)
Expand Down Expand Up @@ -285,8 +285,9 @@ function local_hessian_pattern(
end

function hessian_pattern_to_mat(
xt::AbstractArray{T}, yt::T
) where {P,G,H<:AbstractSet,HT<:HessianTracer{G,H},T<:Union{HT,Dual{P,HT}}}
xt::AbstractArray{TT}, yt::TT
) where {P,T<:HessianTracer,D<:Dual{P,T},TT<:Union{T,D}}

# Allocate Hessian matrix
n = length(xt)
I = Int[] # row indices
Expand Down

0 comments on commit 103cf56

Please sign in to comment.