Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme #311

Closed
wants to merge 4 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Enzyme WIP
michel2323 committed Nov 26, 2024
commit b4d1765cb7e759705e2739b7a00d784fbdc68399
167 changes: 158 additions & 9 deletions src/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,170 @@
struct EnzymeADGradient <: ADNLPModels.ADBackend end
struct EnzymeReverseADJacobian <: ADBackend end
struct EnzymeReverseADHessian <: ADBackend end

function EnzymeADGradient(
struct EnzymeReverseADGradient <: ADNLPModels.ADBackend end

function EnzymeReverseADGradient(
nvar::Integer,
f,
ncon::Integer = 0,
c::Function = (args...) -> [];
x0::AbstractVector = rand(nvar),
kwargs...,
)
return EnzymeADGradient()
return EnzymeReverseADGradient()
end

function ADNLPModels.gradient!(::EnzymeReverseADGradient, g, f, x)
Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Duplicated(x, g)) # gradient!(Reverse, g, f, x)
return g
end

function EnzymeReverseADJacobian(
nvar::Integer,
f,
ncon::Integer = 0,
c::Function = (args...) -> [];
kwargs...,
)
return EnzymeReverseADJacobian()
end

jacobian(::EnzymeReverseADJacobian, f, x) = Enzyme.jacobian(Enzyme.Reverse, f, x)

function EnzymeReverseADHessian(
nvar::Integer,

f,
ncon::Integer = 0,
c::Function = (args...) -> [];
kwargs...,
)
@assert nvar > 0
nnzh = nvar * (nvar + 1) / 2
return EnzymeReverseADHessian()
end

@init begin
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
function ADNLPModels.gradient!(::EnzymeADGradient, g, f, x)
Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Duplicated(x, g)) # gradient!(Reverse, g, f, x)
return g
end
function hessian(::EnzymeReverseADHessian, f, x)
seed = similar(x)
hess = zeros(eltype(x), length(x), length(x))
fill!(seed, zero(x))
for i in 1:length(x)
seed[i] = one(x)
Enzyme.hvp!(view(hess, i, :), f, x, seed)
seed[i] = zero(x)
end
return hess
end

struct EnzymeReverseADJprod <: InPlaceADBackend
x::Vector{Float64}
end

function EnzymeReverseADJprod(
nvar::Integer,
f,
ncon::Integer = 0,
c::Function = (args...) -> [];
kwargs...,
)
x = zeros(nvar)
return EnzymeReverseADJprod(x)
end

function Jprod!(b::EnzymeReverseADJprod, Jv, c!, x, v, ::Val)
Enzyme.autodiff(Enzyme.Forward, c!, Duplicated(b.x, Jv), Enzyme.Duplicated(x, v))
return Jv
end

struct EnzymeReverseADJtprod <: InPlaceADBackend
x::Vector{Float64}
end

function EnzymeReverseADJtprod(
nvar::Integer,
f,
ncon::Integer = 0,
c::Function = (args...) -> [];
kwargs...,
)
x = zeros(nvar)
return EnzymeReverseADJtprod(x)
end

function Jtvprod!(b::EnzymeReverseADJtprod, Jtv, c!, x, v, ::Val)
Enzyme.autodiff(Enzyme.Reverse, c!, Duplicated(b.x, Jtv), Enzyme.Duplicated(x, v))
return Jtv
end

struct EnzymeReverseADHprod <: InPlaceADBackend
grad::Vector{Float64}
end

function EnzymeReverseADHvprod(
nvar::Integer,
f,
ncon::Integer = 0,
c!::Function = (args...) -> [];
x0::AbstractVector{T} = rand(nvar),
kwargs...,
) where {T}
grad = zeros(nvar)
return EnzymeReverseADHprod(grad)
end

function Hvprod!(b::EnzymeReverseADHvprod, Hv, x, v, f, args...)
# What to do with args?
Enzyme.autodiff(
Forward,
gradient!,
Const(Reverse),
DuplicatedNoNeed(b.grad, Hv),
Const(f),
Duplicated(x, v),
)
return Hv
end

function Hvprod!(
b::EnzymeReverseADHvprod,
Hv,
x::AbstractVector{T},
v,
ℓ,
::Val{:lag},
y,
obj_weight::Real = one(T),
)
Enzyme.autodiff(
Forward,
gradient!,
Const(Reverse),
DuplicatedNoNeed(b.grad, Hv),
Const(ℓ),
Duplicated(x, v),
Const(y),
)

return Hv
end

function Hvprod!(
b::EnzymeReverseADHvprod{T, S, Tagf},
Hv,
x,
v,
f,
::Val{:obj},
obj_weight::Real = one(T),
)
Enzyme.autodiff(
Forward,
gradient!,
Const(Reverse),
DuplicatedNoNeed(b.grad, Hv),
Const(f),
Duplicated(x, v),
Const(y),
)
return Hv
end