Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nls solvers #54

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/PreProcess/PreProcessLSKARC.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
function preprocess(PData::PDataLSKARC, Jx, Fx, gNorm2, calls, max_calls, α)
ζ, ξ, maxtol, mintol = PData.ζ, PData.ξ, PData.maxtol, PData.mintol
nshifts = PData.nshifts
shifts = PData.shifts

m, n = size(Jx)
# Tolerance used in Assumption 2.6b in the paper ( ξ > 0, 0 < ζ ≤ 1 )
atol = PData.cgatol(ζ, ξ, maxtol, mintol, gNorm2)
rtol = PData.cgrtol(ζ, ξ, maxtol, mintol, gNorm2)

nshifts = length(shifts)
cb = (slv) -> begin
ind = setdiff(1:length(shifts), findall(.!slv.converged))
if length(ind) > 1
for i in ind
if (norm(slv.x[i]) / shifts[i] - α > 0) # for lsqr: sqrt(slv.xNorm²[i])
return true
end
end
end
return false
end
solver = PData.solver
Krylov.solve!(
solver,
Jx,
Fx,
shifts,
itmax = min(max_calls - sum(calls), 2 * (m + n)),
atol = atol,
rtol = rtol,
verbose = 0,
callback = cb,
)

PData.indmin = 0
PData.positives .= solver.converged
for i = 1:nshifts
@. PData.xShift[i] = -solver.x[i]
PData.norm_dirs[i] = norm(solver.x[i])
end
PData.shifts .= shifts
PData.nshifts = nshifts
PData.OK = sum(solver.converged) != 0 # at least one system was solved

return PData
end
23 changes: 23 additions & 0 deletions src/SolveModel/SolveModelLSKARC.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
function solve_modelLSKARC(X::PDataLSKARC, Jx, Fx, gNorm2, calls, max_calls, α::T) where {T}
# target value should be close to satisfy αλ=||d||
start = findfirst(X.positives)
if isnothing(start)
start = 1
end
if VERSION < v"1.7.0"
positives = collect(start:length(X.positives))
target = [(abs(α * X.shifts[i] - X.norm_dirs[i])) for i in positives]
else
positives = start:length(X.positives)
target = ((abs(α * X.shifts[i] - X.norm_dirs[i])) for i in positives)
end

# pick the closest shift to the target within positive definite H+λI
indmin = argmin(target)
X.indmin = start + indmin - 1
p_imin = X.indmin
X.d .= X.xShift[p_imin]
X.λ = X.shifts[p_imin]

return X.d, X.λ
end
45 changes: 30 additions & 15 deletions src/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,21 @@ function preprocess(stp::NLPStopping, PData::TPData, workspace::TRARCWorkspace,
return PData
end

function preprocess(
stp::NLPStopping,
PData::PDataIterLS,
workspace::TRARCWorkspace{T,S,Hess},
∇f,
norm_∇f,
α,
) where {T,S,Hess<:HessGaussNewtonOp}
max_hprod = stp.meta.max_cntrs[:neval_jprod_residual]
Fx = workspace.Fx
Jx = jac_op_residual!(stp.pb, workspace.xt, workspace.Hstruct.Jv, workspace.Hstruct.Jtv)
PData = preprocess(PData, Jx, Fx, norm_∇f, neval_jprod_residual(stp.pb), max_hprod, α)
return PData
end

function compute_direction(
stp::NLPStopping,
PData::TPData,
Expand Down Expand Up @@ -132,21 +147,21 @@ function hessian!(workspace::TRARCWorkspace, nlp, x)
end

function TRARC(
nlp_stop::NLPStopping{Pb, M, SRC, NLPAtX{Score, T, S}, MStp, LoS};
TR::TrustRegion = TrustRegion(T(10.0)),
hess_type::Type{Hess} = HessOp,
pdata_type::Type{ParamData} = PDataKARC,
kwargs...,
) where {Pb, M, SRC, MStp, LoS, Score, S, T, Hess, ParamData}
nlp = nlp_stop.pb

if ParamData == PDataNLSST
PData = PDataNLSST(S, T, nlp.meta.nvar, nlp.nls_meta.nequ; kwargs...)
else
PData = ParamData(S, T, nlp.meta.nvar; kwargs...)
end
workspace = TRARCWorkspace(nlp, Hess)
return TRARC(nlp_stop, PData, workspace, TR; kwargs...)
nlp_stop::NLPStopping{Pb,M,SRC,NLPAtX{Score,T,S},MStp,LoS};
TR::TrustRegion = TrustRegion(T(10.0)),
hess_type::Type{Hess} = HessOp,
pdata_type::Type{ParamData} = PDataKARC,
kwargs...,
) where {Pb,M,SRC,MStp,LoS,Score,S,T,Hess,ParamData}
nlp = nlp_stop.pb

if ParamData in (PDataNLSST, PDataLSKARC)
PData = ParamData(S, T, nlp.meta.nvar, nlp.nls_meta.nequ; kwargs...)
else
PData = ParamData(S, T, nlp.meta.nvar; kwargs...)
end
workspace = TRARCWorkspace(nlp, Hess)
return TRARC(nlp_stop, PData, workspace, TR; kwargs...)
end

"""
Expand Down
26 changes: 14 additions & 12 deletions src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ const solvers_const = Dict(
)

const solvers_nls_const = Dict(
:ARCqKOpGN => (
HessGaussNewtonOp,
PDataKARC,
solve_modelKARC,
[:shifts => 10.0 .^ (collect(-10.0:0.5:20.0))],
),
:ST_TROpGN => (HessGaussNewtonOp, PDataST, solve_modelST_TR, ()),
:ST_TROpGNLSCgls =>
(HessGaussNewtonOp, PDataNLSST, solve_modelNLSST_TR, [:solver_method => :cgls]),
:ST_TROpGNLSLsqr =>
(HessGaussNewtonOp, PDataNLSST, solve_modelNLSST_TR, [:solver_method => :lsqr]),
:ST_TROpLS => (HessOp, PDataNLSST, solve_modelNLSST_TR, ()),
:ARCqKOpGN => (
HessGaussNewtonOp,
PDataKARC,
solve_modelKARC,
[:shifts => 10.0 .^ (collect(-10.0:0.5:20.0))],
),
:ST_TROpGN => (HessGaussNewtonOp, PDataST, solve_modelST_TR, ()),
:ST_TROpGNLSCgls =>
(HessGaussNewtonOp, PDataNLSST, solve_modelNLSST_TR, [:solver_method => :cgls]),
:ST_TROpGNLSLsqr =>
(HessGaussNewtonOp, PDataNLSST, solve_modelNLSST_TR, [:solver_method => :lsqr]),
:ST_TROpLS => (HessOp, PDataNLSST, solve_modelNLSST_TR, ()),
:LSARCqKOpCgls => (HessGaussNewtonOp, PDataLSKARC, solve_modelLSKARC, [:shifts => 10.0 .^ (collect(-10.0:0.5:20.0)), :solver_method => :cgls]),
:LSARCqKOpLsqr => (HessGaussNewtonOp, PDataLSKARC, solve_modelLSKARC, [:shifts => 10.0 .^ (collect(-10.0:0.5:20.0)), :solver_method => :lsqr]),
)
77 changes: 77 additions & 0 deletions src/utils/pdata_struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,83 @@ function PDataKARC(
)
end

"""
PDataLSKARC(::Type{S}, ::Type{T}, n)
Return a structure used for the preprocessing of ARCqK methods.
"""
mutable struct PDataLSKARC{T} <: PDataIterLS{T}
d::Array{T,1} # (H+λI)\g ; on first call = g
λ::T # "active" value of λ; on first call = 0
ζ::T # Inexact Newton order parameter: stop when ||∇q|| < ξ * ||g||^(1+ζ)
ξ::T # Inexact Newton order parameter: stop when ||∇q|| < ξ * ||g||^(1+ζ)
maxtol::T # Largest tolerance for Inexact Newton
mintol::T # Smallest tolerance for Inexact Newton
cgatol::Any
cgrtol::Any

indmin::Int # index of best shift value within "positive". On first call = 0

positives::Array{Bool,1} # indices of the shift values yielding (H+λI)⪰0
xShift::Array{Array{T,1},1} # solutions for each shifted system
shifts::Array{T,1} # values of the shifts
nshifts::Int # number of shifts
norm_dirs::Array{T,1} # norms of xShifts
OK::Bool # preprocess success

solver::Union{CglsLanczosShiftSolver,LsqrShiftSolver}
end

function PDataLSKARC(
::Type{S},
::Type{T},
n,
m;
ζ = T(0.5),
ξ = T(0.01),
maxtol = T(0.01),
mintol = sqrt(eps(T)),
cgatol = (ζ, ξ, maxtol, mintol, gNorm2) -> max(mintol, min(maxtol, ξ * gNorm2^(1 + ζ))),
cgrtol = (ζ, ξ, maxtol, mintol, gNorm2) -> max(mintol, min(maxtol, ξ * gNorm2^ζ)),
shifts = 10.0 .^ collect(-20.0:1.0:20.0),
solver_method = :cgls,
kwargs...,
) where {S,T}
d = S(undef, n)
λ = zero(T)
indmin = 1
nshifts = length(shifts)
positives = Array{Bool,1}(undef, nshifts)
xShift = Array{S,1}(undef, nshifts)
for i = 1:nshifts
xShift[i] = S(undef, n)
end
norm_dirs = S(undef, nshifts)
OK = true
solver = if solver_method == :cgls
CglsLanczosShiftSolver(m, n, nshifts, S)
else
LsqrShiftSolver(m, n, nshifts, S)
end
return PDataLSKARC(
d,
λ,
ζ,
ξ,
maxtol,
mintol,
cgatol,
cgrtol,
indmin,
positives,
xShift,
T.(shifts),
nshifts,
norm_dirs,
OK,
solver,
)
end

"""
PDataTRK(::Type{S}, ::Type{T}, n)
Return a structure used for the preprocessing of TRK methods.
Expand Down