From b97ff8dca39414c78bd09fe24ea39641f745cebf Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Wed, 27 Nov 2024 11:53:49 -0600 Subject: [PATCH] Update the extensions --- ext/ADNLPModelsZygoteExt.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ext/ADNLPModelsZygoteExt.jl b/ext/ADNLPModelsZygoteExt.jl index 8001ffdf..86f5101e 100644 --- a/ext/ADNLPModelsZygoteExt.jl +++ b/ext/ADNLPModelsZygoteExt.jl @@ -2,21 +2,21 @@ module ADNLPModelsZygoteExt using Zygote, ADNLPModels -function gradient(::ZygoteADGradient, f, x) +function gradient(::ADNLPModels.ZygoteADGradient, f, x) g = Zygote.gradient(f, x)[1] return g === nothing ? zero(x) : g end -function gradient!(::ZygoteADGradient, g, f, x) +function gradient!(::ADNLPModels.ZygoteADGradient, g, f, x) _g = Zygote.gradient(f, x)[1] g .= _g === nothing ? 0 : _g end -function Jprod!(::ZygoteADJprod, Jv, f, x, v, ::Val) +function Jprod!(::ADNLPModels.ZygoteADJprod, Jv, f, x, v, ::Val) Jv .= vec(Zygote.jacobian(t -> f(x + t * v), 0)[1]) return Jv end -function Jtprod!(::ZygoteADJtprod, Jtv, f, x, v, ::Val) +function Jtprod!(::ADNLPModels.ZygoteADJtprod, Jtv, f, x, v, ::Val) g = Zygote.gradient(x -> dot(f(x), v), x)[1] if g === nothing Jtv .= zero(x) @@ -26,14 +26,14 @@ function Jtprod!(::ZygoteADJtprod, Jtv, f, x, v, ::Val) return Jtv end -function jacobian(::ZygoteADJacobian, f, x) +function jacobian(::ADNLPModels.ZygoteADJacobian, f, x) return Zygote.jacobian(f, x)[1] end -function hessian(b::ZygoteADHessian, f, x) +function hessian(b::ADNLPModels.ZygoteADHessian, f, x) return jacobian( - ForwardDiffADJacobian(length(x), f, x0 = x), - x -> gradient(ZygoteADGradient(), f, x), + ADNLPModels.ForwardDiffADJacobian(length(x), f, x0 = x), + x -> gradient(ADNLPModels.ZygoteADGradient(), f, x), x, ) end