Skip to content

Commit

Permalink
Fix the tests with Zygote
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Nov 26, 2024
1 parent 9c10a5c commit ae65c66
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 26 deletions.
27 changes: 1 addition & 26 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,32 +51,7 @@ ReverseDiffAD(nvar, f) = ADNLPModels.ADModelBackend(
hessian_backend = ADNLPModels.ReverseDiffADHessian,
)

function test_getter_setter(nlp)
@test get_adbackend(nlp) == nlp.adbackend
if typeof(nlp) <: ADNLPModel
set_adbackend!(nlp, ReverseDiffAD(nlp.meta.nvar, nlp.f))
elseif typeof(nlp) <: ADNLSModel
function F(x; nequ = nlp.nls_meta.nequ)
Fx = similar(x, nequ)
nlp.F!(Fx, x)
return Fx
end
set_adbackend!(nlp, ReverseDiffAD(nlp.meta.nvar, x -> sum(F(x) .^ 2)))
end
@test typeof(get_adbackend(nlp).gradient_backend) <: ADNLPModels.ReverseDiffADGradient
@test typeof(get_adbackend(nlp).hprod_backend) <: ADNLPModels.ReverseDiffADHvprod
@test typeof(get_adbackend(nlp).hessian_backend) <: ADNLPModels.ReverseDiffADHessian
set_adbackend!(
nlp,
gradient_backend = ADNLPModels.ForwardDiffADGradient,
jtprod_backend = ADNLPModels.GenericForwardDiffADJtprod(),
)
@test typeof(get_adbackend(nlp).gradient_backend) <: ADNLPModels.ForwardDiffADGradient
@test typeof(get_adbackend(nlp).hprod_backend) <: ADNLPModels.ReverseDiffADHvprod
@test typeof(get_adbackend(nlp).jtprod_backend) <: ADNLPModels.GenericForwardDiffADJtprod
@test typeof(get_adbackend(nlp).hessian_backend) <: ADNLPModels.ReverseDiffADHessian
end

include("utils.jl")
include("nlp/basic.jl")
include("nls/basic.jl")
include("nlp/nlpmodelstest.jl")
Expand Down
25 changes: 25 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
function test_getter_setter(nlp)
@test get_adbackend(nlp) == nlp.adbackend
if typeof(nlp) <: ADNLPModel
set_adbackend!(nlp, ReverseDiffAD(nlp.meta.nvar, nlp.f))
elseif typeof(nlp) <: ADNLSModel
function F(x; nequ = nlp.nls_meta.nequ)
Fx = similar(x, nequ)
nlp.F!(Fx, x)
return Fx
end
set_adbackend!(nlp, ReverseDiffAD(nlp.meta.nvar, x -> sum(F(x) .^ 2)))
end
@test typeof(get_adbackend(nlp).gradient_backend) <: ADNLPModels.ReverseDiffADGradient
@test typeof(get_adbackend(nlp).hprod_backend) <: ADNLPModels.ReverseDiffADHvprod
@test typeof(get_adbackend(nlp).hessian_backend) <: ADNLPModels.ReverseDiffADHessian
set_adbackend!(
nlp,
gradient_backend = ADNLPModels.ForwardDiffADGradient,
jtprod_backend = ADNLPModels.GenericForwardDiffADJtprod(),
)
@test typeof(get_adbackend(nlp).gradient_backend) <: ADNLPModels.ForwardDiffADGradient
@test typeof(get_adbackend(nlp).hprod_backend) <: ADNLPModels.ReverseDiffADHvprod
@test typeof(get_adbackend(nlp).jtprod_backend) <: ADNLPModels.GenericForwardDiffADJtprod
@test typeof(get_adbackend(nlp).hessian_backend) <: ADNLPModels.ReverseDiffADHessian
end
1 change: 1 addition & 0 deletions test/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ ADNLPModels.predefined_backend = Dict(
# Automatically loads the code for Zygote with Requires
import Zygote

include("utils.jl")
include("nlp/basic.jl")
include("nls/basic.jl")
include("nlp/nlpmodelstest.jl")
Expand Down

0 comments on commit ae65c66

Please sign in to comment.