Skip to content

Commit

Permalink
Update the extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Nov 27, 2024
1 parent d1dd9a0 commit b97ff8d
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions ext/ADNLPModelsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@ module ADNLPModelsZygoteExt

using Zygote, ADNLPModels

function gradient(::ZygoteADGradient, f, x)
function gradient(::ADNLPModels.ZygoteADGradient, f, x)
g = Zygote.gradient(f, x)[1]
return g === nothing ? zero(x) : g
end
function gradient!(::ZygoteADGradient, g, f, x)
function gradient!(::ADNLPModels.ZygoteADGradient, g, f, x)
_g = Zygote.gradient(f, x)[1]
g .= _g === nothing ? 0 : _g
end

function Jprod!(::ZygoteADJprod, Jv, f, x, v, ::Val)
function Jprod!(::ADNLPModels.ZygoteADJprod, Jv, f, x, v, ::Val)
Jv .= vec(Zygote.jacobian(t -> f(x + t * v), 0)[1])
return Jv
end

function Jtprod!(::ZygoteADJtprod, Jtv, f, x, v, ::Val)
function Jtprod!(::ADNLPModels.ZygoteADJtprod, Jtv, f, x, v, ::Val)
g = Zygote.gradient(x -> dot(f(x), v), x)[1]
if g === nothing
Jtv .= zero(x)
Expand All @@ -26,14 +26,14 @@ function Jtprod!(::ZygoteADJtprod, Jtv, f, x, v, ::Val)
return Jtv
end

function jacobian(::ZygoteADJacobian, f, x)
function jacobian(::ADNLPModels.ZygoteADJacobian, f, x)
return Zygote.jacobian(f, x)[1]
end

function hessian(b::ZygoteADHessian, f, x)
function hessian(b::ADNLPModels.ZygoteADHessian, f, x)
return jacobian(
ForwardDiffADJacobian(length(x), f, x0 = x),
x -> gradient(ZygoteADGradient(), f, x),
ADNLPModels.ForwardDiffADJacobian(length(x), f, x0 = x),
x -> gradient(ADNLPModels.ZygoteADGradient(), f, x),
x,
)
end
Expand Down

0 comments on commit b97ff8d

Please sign in to comment.