Skip to content

Commit

Permalink
added dimension checks for init
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreMartinon committed Jul 1, 2024
1 parent 9e112ec commit 87afaa9
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 45 deletions.
14 changes: 9 additions & 5 deletions ext/CTSolveExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,16 @@ function CommonSolve.solve(docp::DOCP;

# solve DOCP with NLP solver
print_level = display ? print_level : 0
if init == nothing
nlp = getNLP(docp)
if isnothing(init)
# use initial guess embedded in the DOCP
docp_solution = ipopt(getNLP(docp), print_level=print_level, mu_strategy=mu_strategy, sb="yes", linear_solver=linear_solver; kwargs...)
docp_solution = ipopt(nlp, print_level=print_level, mu_strategy=mu_strategy, sb="yes", linear_solver=linear_solver; kwargs...)
else
# use given initial guess
docp_solution = ipopt(getNLP(docp), x0=CTDirect.DOCP_initial_guess(docp, _OptimalControlInit(init)), print_level=print_level, mu_strategy=mu_strategy, sb="yes", linear_solver=linear_solver; kwargs...)
ocp = docp.ocp
x0 = CTDirect.DOCP_initial_guess(docp, _OptimalControlInit(init, state_dim=ocp.state_dimension, control_dim=ocp.control_dimension, variable_dim=ocp.variable_dimension))

docp_solution = ipopt(nlp, x0=x0, print_level=print_level, mu_strategy=mu_strategy, sb="yes", linear_solver=linear_solver; kwargs...)
end

# return DOCP solution
Expand All @@ -55,7 +59,7 @@ Solve an optimal control problem OCP by direct method
"""
function CommonSolve.solve(ocp::OptimalControlModel,
description...;
init=_OptimalControlInit(),
init=nothing,
grid_size::Integer=CTDirect.__grid_size_direct(),
time_grid=nothing,
display::Bool=CTDirect.__display(),
Expand All @@ -65,7 +69,7 @@ function CommonSolve.solve(ocp::OptimalControlModel,
kwargs...)

# build discretized OCP
docp = directTranscription(ocp, description, init=_OptimalControlInit(init), grid_size=grid_size, time_grid=time_grid)
docp = directTranscription(ocp, description, init=init, grid_size=grid_size, time_grid=time_grid)

# solve DOCP
docp_solution = solve(docp, display=display, print_level=print_level, mu_strategy=mu_strategy, linear_solver=linear_solver; kwargs...)
Expand Down
11 changes: 7 additions & 4 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ Discretize an optimal control problem into a nonlinear optimization problem (ie
"""
function directTranscription(ocp::OptimalControlModel,
description...;
init=_OptimalControlInit(),
init=nothing,
grid_size::Integer=__grid_size_direct(),
time_grid=nothing)

# build DOCP
docp = DOCP(ocp, grid_size, time_grid)

# set initial guess and bounds
x0 = DOCP_initial_guess(docp, _OptimalControlInit(init))
# set initial guess
x0 = DOCP_initial_guess(docp, _OptimalControlInit(init, state_dim=ocp.state_dimension, control_dim=ocp.control_dimension, variable_dim=ocp.variable_dimension))

# set bounds
docp.var_l, docp.var_u = variables_bounds(docp)
docp.con_l, docp.con_u = constraints_bounds(docp)

Expand Down Expand Up @@ -64,6 +66,7 @@ Extract the NLP problem from the DOCP
function setInitialGuess(docp::DOCP, init)

nlp = getNLP(docp)
nlp.meta.x0 .= DOCP_initial_guess(docp,_OptimalControlInit(init))
ocp = docp.ocp
nlp.meta.x0 .= DOCP_initial_guess(docp, _OptimalControlInit(init, state_dim=ocp.state_dimension, control_dim=ocp.control_dimension, variable_dim=ocp.variable_dimension))

end
96 changes: 60 additions & 36 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,107 +194,131 @@ function isaVectVect(data)
return (data isa Vector) && (data[1] isa ctVector)
end

function checkData(data)
function formatData(data)
if data isa Matrix
return matrix2vec(data,1)
else
return data
end
end

function setFunctionalInit(data, time)
function formatTimeGrid(time)
if isnothing(time)
return nothing
elseif time isa ctVector
return time
else
return vec(time)
end
end

# +++split in different methods vs data/time type ?!
function setFunctionalInit(data, time, dim)
if isnothing(data)
# fallback to method-dependent default initialization
return t-> nothing
elseif data isa Function
# functional initialization
return t -> data(t)
if !isnothing(dim) && length(data(0)) != dim
error("Init dimension mismatch: got ",length(data(0))," but expected ",dim )
else
return t -> data(t)
end
elseif (data isa ctVector)
if !isnothing(time) && (length(data) == length(time))
# interpolation vs time, dim 1 case
itp = ctinterpolate(time, data)
return t -> itp(t)
elseif !isnothing(dim) && length(data) != dim
error("Init dimension mismatch: got ",length(data)," but expected ",dim )
else
# constant initialization
return t -> data
end
elseif isaVectVect(data)
# interpolation vs time, general case
itp = ctinterpolate(time, data)
return t -> itp(t)
if !isnothing(dim) && length(itp(0)) != dim
error("Init dimension mismatch: got ",length(itp(0))," but expected ",dim )
else
return t -> itp(t)
end
else
error("Unrecognized initialization argument: ",typeof(data))
end

end

function checkTimeGrid(time)
if isnothing(time)
return nothing
elseif time isa ctVector
return time
else
return vec(time)
end
end

mutable struct _OptimalControlInit

state_dimension
control_dimension
variable_dimension
state_init::Function
control_init::Function
variable_init::Union{Nothing, ctVector}
costate_init::Function
multipliers_init::Union{Nothing, ctVector}

# base constructor with explicit arguments
function _OptimalControlInit(; state=nothing, control=nothing, variable=nothing, time=nothing)
function _OptimalControlInit(; state=nothing, control=nothing, variable=nothing, time=nothing, state_dim=nothing, control_dim=nothing, variable_dim=nothing)

init = new()
time = checkTimeGrid(time)
state = checkData(state)
control = checkData(control)
init.state_init = setFunctionalInit(state, time)
init.control_init = setFunctionalInit(control, time)
time = formatTimeGrid(time)
state = formatData(state)
control = formatData(control)
init.state_init = setFunctionalInit(state, time, state_dim)
init.control_init = setFunctionalInit(control, time, control_dim)
# check v dim
init.variable_init = variable
return init

end

# version with arguments as named tuple or dict
function _OptimalControlInit(init_data)
function _OptimalControlInit(init_data; state_dim=nothing, control_dim=nothing, variable_dim=nothing)

# trivial case: default init
x_init = nothing
u_init = nothing
v_init = nothing
t_init = nothing

for key in keys(init_data)
if key == :state
x_init = init_data[:state]
elseif key == :control
u_init = init_data[:control]
elseif key == :variable
v_init = init_data[:variable]
elseif key == :time
t_init = init_data[:time]
else
error("Unknown key in initialization data (allowed: state, control, variable, time): ", key)
x_dim = nothing
u_dim = nothing
v_dim = nothing

# parse arguments and call base constructor
if !isnothing(init_data)
for key in keys(init_data)
if key == :state
x_init = init_data[:state]
elseif key == :control
u_init = init_data[:control]
elseif key == :variable
v_init = init_data[:variable]
elseif key == :time
t_init = init_data[:time]
else
error("Unknown key in initialization data (allowed: state, control, variable, time, state_dim, control_dim, variable_dim): ", key)
end
end
end

return _OptimalControlInit(state=x_init, control=u_init, variable=v_init, time=t_init)
return _OptimalControlInit(state=x_init, control=u_init, variable=v_init, time=t_init, state_dim=state_dim, control_dim=control_dim, variable_dim=variable_dim)

end

# warm start from solution
function _OptimalControlInit(sol::OptimalControlSolution)
return _OptimalControlInit(state=sol.state, control=sol.control, variable=sol.variable)
function _OptimalControlInit(sol::OptimalControlSolution; unused_kwargs...)
return _OptimalControlInit(state=sol.state, control=sol.control, variable=sol.variable, state_dim=sol.state_dimension, control_dim=sol.control_dimension, variable_dim=sol.variable_dimension)
end

#=
# trivial version for unified syntax in caller functions
function _OptimalControlInit(init::_OptimalControlInit)
return init
end
=#

end

Expand Down

0 comments on commit 87afaa9

Please sign in to comment.