From d0c6c15b595a62fbcbc57e321d4178bd45fe2af0 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Fri, 29 Nov 2024 23:45:05 -0600 Subject: [PATCH] fix tests again --- src/enzyme.jl | 56 +++++++++++++++++++++++++-------------------------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index 82591806..b03df03c 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -148,20 +148,18 @@ end @init begin @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin - import Enzyme: Const, Reverse, Forward, Duplicated, DuplicatedNoNeed - function ADNLPModels.gradient(::EnzymeReverseADGradient, f, x) g = similar(x) - Enzyme.gradient!(Reverse, g, Const(f), x) + Enzyme.gradient!(Enzyme.Reverse, g, Enzyme.Const(f), x) return g end function ADNLPModels.gradient!(::EnzymeReverseADGradient, g, f, x) - Enzyme.autodiff(Reverse, Const(f), Active, Duplicated(x, g)) + Enzyme.autodiff(Enzyme.Reverse, Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, g)) return g end - jacobian(::EnzymeReverseADJacobian, f, x) = Enzyme.jacobian(Reverse, f, x) + jacobian(::EnzymeReverseADJacobian, f, x) = Enzyme.jacobian(Enzyme.Reverse, f, x) function hessian(::EnzymeReverseADHessian, f, x) seed = similar(x) @@ -170,7 +168,7 @@ end tmp = similar(x) for i in 1:length(x) seed[i] = one(eltype(seed)) - Enzyme.hvp!(tmp, Const(f), x, seed) + Enzyme.hvp!(tmp, Enzyme.Const(f), x, seed) hess[:, i] .= tmp seed[i] = zero(eltype(seed)) end @@ -178,24 +176,24 @@ end end function Jprod!(b::EnzymeReverseADJprod, Jv, c!, x, v, ::Val) - Enzyme.autodiff(Enzyme.Forward, Const(c!), Duplicated(b.x, Jv), Duplicated(x, v)) + Enzyme.autodiff(Enzyme.Forward, Enzyme.Const(c!), Enzyme.Duplicated(b.x, Jv), Enzyme.Duplicated(x, v)) return Jv end function Jtprod!(b::EnzymeReverseADJtprod, Jtv, c!, x, v, ::Val) - Enzyme.autodiff(Reverse, Const(c!), Duplicated(b.x, Jtv), Duplicated(x, v)) + Enzyme.autodiff(Enzyme.Reverse, Enzyme.Const(c!), Enzyme.Duplicated(b.x, Jtv), Enzyme.Duplicated(x, v)) return Jtv end function Hvprod!(b::EnzymeReverseADHvprod, Hv, x, v, f, args...) # What to do with args? Enzyme.autodiff( - Forward, - Const(Enzyme.gradient!), - Const(Reverse), - DuplicatedNoNeed(b.grad, Hv), - Const(f), - Duplicated(x, v), + Enzyme.Forward, + Enzyme.Const(Enzyme.gradient!), + Enzyme.Const(Enzyme.Reverse), + Enzyme.DuplicatedNoNeed(b.grad, Hv), + Enzyme.Const(f), + Enzyme.Duplicated(x, v), ) return Hv end @@ -211,13 +209,13 @@ end obj_weight::Real = one(eltype(x)), ) Enzyme.autodiff( - Forward, - Const(Enzyme.gradient!), - Const(Reverse), - DuplicatedNoNeed(b.grad, Hv), - Const(ℓ), - Duplicated(x, v), - Const(y), + Enzyme.Forward, + Enzyme.Const(Enzyme.gradient!), + Enzyme.Const(Enzyme.Reverse), + Enzyme.DuplicatedNoNeed(b.grad, Hv), + Enzyme.Const(ℓ), + Enzyme.Duplicated(x, v), + Enzyme.Const(y), ) return Hv @@ -233,13 +231,13 @@ end obj_weight::Real = one(eltype(x)), ) Enzyme.autodiff( - Forward, - Const(Enzyme.gradient!), - Const(Reverse), - DuplicatedNoNeed(b.grad, Hv), - Const(f), - Duplicated(x, v), - Const(y), + Enzyme.Forward, + Enzyme.Const(Enzyme.gradient!), + Enzyme.Const(Enzyme.Reverse), + Enzyme.DuplicatedNoNeed(b.grad, Hv), + Enzyme.Const(f), + Enzyme.Duplicated(x, v), + Enzyme.Const(y), ) return Hv end @@ -264,7 +262,7 @@ end # b.compressed_jacobian is just a vector Jv here # We don't use the vector mode - Enzyme.autodiff(Enzyme.Forward, Const(c!), Duplicated(b.buffer, b.compressed_jacobian), Duplicated(x, b.v)) + Enzyme.autodiff(Enzyme.Forward, Enzyme.Const(c!), Enzyme.Duplicated(b.buffer, b.compressed_jacobian), Enzyme.Duplicated(x, b.v)) # Update the columns of the Jacobian that have the color `icol` decompress_single_color!(A, b.compressed_jacobian, icol, b.result_coloring)