Skip to content

Commit

Permalink
fix tests again
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Nov 30, 2024
1 parent f23edf1 commit d0c6c15
Showing 1 changed file with 27 additions and 29 deletions.
56 changes: 27 additions & 29 deletions src/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,20 +148,18 @@ end
@init begin
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin

import Enzyme: Const, Reverse, Forward, Duplicated, DuplicatedNoNeed

function ADNLPModels.gradient(::EnzymeReverseADGradient, f, x)
g = similar(x)
Enzyme.gradient!(Reverse, g, Const(f), x)
Enzyme.gradient!(Enzyme.Reverse, g, Enzyme.Const(f), x)
return g
end

function ADNLPModels.gradient!(::EnzymeReverseADGradient, g, f, x)
Enzyme.autodiff(Reverse, Const(f), Active, Duplicated(x, g))
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Const(f), Enzyme.Active, Enzyme.Duplicated(x, g))
return g
end

jacobian(::EnzymeReverseADJacobian, f, x) = Enzyme.jacobian(Reverse, f, x)
jacobian(::EnzymeReverseADJacobian, f, x) = Enzyme.jacobian(Enzyme.Reverse, f, x)

function hessian(::EnzymeReverseADHessian, f, x)
seed = similar(x)
Expand All @@ -170,32 +168,32 @@ end
tmp = similar(x)
for i in 1:length(x)
seed[i] = one(eltype(seed))
Enzyme.hvp!(tmp, Const(f), x, seed)
Enzyme.hvp!(tmp, Enzyme.Const(f), x, seed)
hess[:, i] .= tmp
seed[i] = zero(eltype(seed))
end
return hess
end

function Jprod!(b::EnzymeReverseADJprod, Jv, c!, x, v, ::Val)
Enzyme.autodiff(Enzyme.Forward, Const(c!), Duplicated(b.x, Jv), Duplicated(x, v))
Enzyme.autodiff(Enzyme.Forward, Enzyme.Const(c!), Enzyme.Duplicated(b.x, Jv), Enzyme.Duplicated(x, v))
return Jv
end

function Jtprod!(b::EnzymeReverseADJtprod, Jtv, c!, x, v, ::Val)
Enzyme.autodiff(Reverse, Const(c!), Duplicated(b.x, Jtv), Duplicated(x, v))
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Const(c!), Enzyme.Duplicated(b.x, Jtv), Enzyme.Duplicated(x, v))
return Jtv
end

function Hvprod!(b::EnzymeReverseADHvprod, Hv, x, v, f, args...)
# What to do with args?
Enzyme.autodiff(
Forward,
Const(Enzyme.gradient!),
Const(Reverse),
DuplicatedNoNeed(b.grad, Hv),
Const(f),
Duplicated(x, v),
Enzyme.Forward,
Enzyme.Const(Enzyme.gradient!),
Enzyme.Const(Enzyme.Reverse),
Enzyme.DuplicatedNoNeed(b.grad, Hv),
Enzyme.Const(f),
Enzyme.Duplicated(x, v),
)
return Hv
end
Expand All @@ -211,13 +209,13 @@ end
obj_weight::Real = one(eltype(x)),
)
Enzyme.autodiff(
Forward,
Const(Enzyme.gradient!),
Const(Reverse),
DuplicatedNoNeed(b.grad, Hv),
Const(ℓ),
Duplicated(x, v),
Const(y),
Enzyme.Forward,
Enzyme.Const(Enzyme.gradient!),
Enzyme.Const(Enzyme.Reverse),
Enzyme.DuplicatedNoNeed(b.grad, Hv),
Enzyme.Const(ℓ),
Enzyme.Duplicated(x, v),
Enzyme.Const(y),
)

return Hv
Expand All @@ -233,13 +231,13 @@ end
obj_weight::Real = one(eltype(x)),
)
Enzyme.autodiff(
Forward,
Const(Enzyme.gradient!),
Const(Reverse),
DuplicatedNoNeed(b.grad, Hv),
Const(f),
Duplicated(x, v),
Const(y),
Enzyme.Forward,
Enzyme.Const(Enzyme.gradient!),
Enzyme.Const(Enzyme.Reverse),
Enzyme.DuplicatedNoNeed(b.grad, Hv),
Enzyme.Const(f),
Enzyme.Duplicated(x, v),
Enzyme.Const(y),
)
return Hv
end
Expand All @@ -264,7 +262,7 @@ end

# b.compressed_jacobian is just a vector Jv here
# We don't use the vector mode
Enzyme.autodiff(Enzyme.Forward, Const(c!), Duplicated(b.buffer, b.compressed_jacobian), Duplicated(x, b.v))
Enzyme.autodiff(Enzyme.Forward, Enzyme.Const(c!), Enzyme.Duplicated(b.buffer, b.compressed_jacobian), Enzyme.Duplicated(x, b.v))

# Update the columns of the Jacobian that have the color `icol`
decompress_single_color!(A, b.compressed_jacobian, icol, b.result_coloring)
Expand Down

0 comments on commit d0c6c15

Please sign in to comment.