Skip to content

Commit

Permalink
add solution
Browse files Browse the repository at this point in the history
  • Loading branch information
ocots committed Dec 13, 2024
1 parent 68da876 commit fe0e6d2
Show file tree
Hide file tree
Showing 11 changed files with 348 additions and 99 deletions.
2 changes: 2 additions & 0 deletions src/CTModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ const Time = ctNumber
const ctVector = AbstractVector{<:ctNumber}
const Variable = ctVector
const ConstraintsDictType = Dict{Symbol, Tuple{Symbol, Union{Function, OrdinalRange{<:Int}}, ctVector, ctVector}}
const Times = AbstractVector{<:Time}
const TimesDisc = Union{Times, StepRangeLen}

#
include("types.jl")
Expand Down
20 changes: 6 additions & 14 deletions src/control.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,11 @@ end
# ------------------------------------------------------------------------------ #

# from ControlModel
name(ocp::ControlModel)::String = ocp.name
components(ocp::ControlModel)::Vector{String} = ocp.components
(dimension(ocp::ControlModel)::Dimension) = length(components(ocp))
name(model::ControlModel)::String = model.name
components(model::ControlModel)::Vector{String} = model.components
(dimension(model::ControlModel)::Dimension) = length(components(model))

# from Model
(control(ocp::Model{T, S, C, V, D, O, B})::C) where {
T<:AbstractTimesModel,
S<:AbstractStateModel,
C<:AbstractControlModel,
V<:AbstractVariableModel,
D<:Function,
O<:AbstractObjectiveModel,
B<:ConstraintsDictType} = ocp.control
control_name(ocp::Model)::String = name(control(ocp))
control_components(ocp::Model)::Vector{String} = components(control(ocp))
control_dimension(ocp::Model)::Dimension = dimension(control(ocp))
control_name(ocp::Model)::String = name(ocp.control)
control_components(ocp::Model)::Vector{String} = components(ocp.control)
control_dimension(ocp::Model)::Dimension = dimension(ocp.control)
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function build(pre_ocp::PreModel)::Model
function build_model(pre_ocp::PreModel)::Model

# checkings: times must be set
__is_times_set(pre_ocp) || throw(CTBase.UnauthorizedCall("the times must be set before building the model."))
Expand Down
30 changes: 15 additions & 15 deletions src/objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,25 +72,25 @@ end
# ------------------------------------------------------------------------------ #

# From MayerObjectiveModel
criterion(ocp::MayerObjectiveModel)::Symbol = ocp.criterion
(mayer(ocp::MayerObjectiveModel{M})::M) where {M <: Function} = ocp.mayer
lagrange(ocp::MayerObjectiveModel) = throw(CTBase.UnauthorizedCall("a Mayer objective ocp does not have a Lagrange function."))
has_mayer_cost(ocp::MayerObjectiveModel)::Bool = true
has_lagrange_cost(ocp::MayerObjectiveModel)::Bool = false
criterion(model::MayerObjectiveModel)::Symbol = model.criterion
(mayer(model::MayerObjectiveModel{M})::M) where {M <: Function} = model.mayer
lagrange(model::MayerObjectiveModel) = throw(CTBase.UnauthorizedCall("a Mayer objective model does not have a Lagrange function."))
has_mayer_cost(model::MayerObjectiveModel)::Bool = true
has_lagrange_cost(model::MayerObjectiveModel)::Bool = false

# From LagrangeObjectiveModel
criterion(ocp::LagrangeObjectiveModel)::Symbol = ocp.criterion
mayer(ocp::LagrangeObjectiveModel) = throw(CTBase.UnauthorizedCall("a Lagrange objective ocp does not have a Mayer function."))
(lagrange(ocp::LagrangeObjectiveModel{L})::L) where {L <: Function} = ocp.lagrange
has_mayer_cost(ocp::LagrangeObjectiveModel)::Bool = false
has_lagrange_cost(ocp::LagrangeObjectiveModel)::Bool = true
criterion(model::LagrangeObjectiveModel)::Symbol = model.criterion
mayer(model::LagrangeObjectiveModel) = throw(CTBase.UnauthorizedCall("a Lagrange objective model does not have a Mayer function."))
(lagrange(model::LagrangeObjectiveModel{L})::L) where {L <: Function} = model.lagrange
has_mayer_cost(model::LagrangeObjectiveModel)::Bool = false
has_lagrange_cost(model::LagrangeObjectiveModel)::Bool = true

# From BolzaObjectiveModel
criterion(ocp::BolzaObjectiveModel)::Symbol = ocp.criterion
(mayer(ocp::BolzaObjectiveModel{M, L})::M) where {M <: Function, L <: Function} = ocp.mayer
(lagrange(ocp::BolzaObjectiveModel{M, L})::L) where {M <: Function, L <: Function} = ocp.lagrange
has_mayer_cost(ocp::BolzaObjectiveModel)::Bool = true
has_lagrange_cost(ocp::BolzaObjectiveModel)::Bool = true
criterion(model::BolzaObjectiveModel)::Symbol = model.criterion
(mayer(model::BolzaObjectiveModel{M, L})::M) where {M <: Function, L <: Function} = model.mayer
(lagrange(model::BolzaObjectiveModel{M, L})::L) where {M <: Function, L <: Function} = model.lagrange
has_mayer_cost(model::BolzaObjectiveModel)::Bool = true
has_lagrange_cost(model::BolzaObjectiveModel)::Bool = true

# From Model
(objective(ocp::Model{T, S, C, V, D, O, B})::O) where {
Expand Down
108 changes: 108 additions & 0 deletions src/solution.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
function build_solution(
ocp::Model,
T::Vector{Float64},
X::Matrix{Float64},
U::Matrix{Float64},
v::Vector{Float64},
P::Vector{Float64};
cost::Float64,
iterations::Int,
constraints_violation::Float64,
message::String,
stopping::Symbol,
success::Bool,
state_constraints_lb_dual::Matrix{Float64},
state_constraints_ub_dual::Matrix{Float64},
control_constraints_lb_dual::Matrix{Float64},
control_constraints_ub_dual::Matrix{Float64},
variable_constraints_lb_dual::Vector{Float64},
variable_constraints_ub_dual::Vector{Float64},
boundary_constraints::Vector{Float64},
boundary_constraints_dual::Vector{Float64},
path_constraints::Matrix{Float64},
path_constraints_dual::Matrix{Float64},
variable_constraints::Vector{Float64},
variable_constraints_dual::Vector{Float64}
)

# get dimensions
dim_x = state_dimension(ocp)
dim_u = control_dimension(ocp)
dim_v = variable_dimension(ocp)

# check that time grid is strictly increasing
# if not proceed with list of indexes as time grid
if !issorted(T, lt = <=)
println(
"WARNING: time grid at solution is not strictly increasing, replacing with list of indices...",
)
println(T)
dim_NLP_steps = length(T) - 1
T = LinRange(0, dim_NLP_steps, dim_NLP_steps + 1)
end

# variables: remove additional state for lagrange cost
x = CTBase.ctinterpolate(T, CTBase.matrix2vec(X[:, 1:dim_x], 1))
p = CTBase.ctinterpolate(T[1:(end - 1)], CTBase.matrix2vec(P[:, 1:dim_x], 1))
u = CTBase.ctinterpolate(T, CTBase.matrix2vec(U[:, 1:dim_u], 1))

# force scalar output when dimension is 1
fx = (dim_x == 1) ? deepcopy(t -> x(t)[1]) : deepcopy(t -> x(t))
fu = (dim_u == 1) ? deepcopy(t -> u(t)[1]) : deepcopy(t -> u(t))
fp = (dim_x == 1) ? deepcopy(t -> p(t)[1]) : deepcopy(t -> p(t))
var = (dim_v == 1) ? v[1] : v

# misc infos
infos = Dict{Symbol, Any}()

# nonlinear constraints and dual variables
path_constraints_fun = t -> CTBase.ctinterpolate(T, CTBase.matrix2vec(path_constraints, 1))(t)
path_constraints_dual_fun = t -> CTBase.ctinterpolate(T, CTBase.matrix2vec(path_constraints_dual, 1))(t)

# box constraints multipliers
state_constraints_lb_dual_fun = t -> CTBase.ctinterpolate(T, CTBase.matrix2vec(state_constraints_lb_dual[:, 1:dim_x], 1))(t)
state_constraints_ub_dual_fun = t -> CTBase.ctinterpolate(T, CTBase.matrix2vec(state_constraints_ub_dual[:, 1:dim_x], 1))(t)
control_constraints_lb_dual_fun = t -> CTBase.ctinterpolate(T, CTBase.matrix2vec(control_constraints_lb_dual[:, 1:dim_u], 1))(t)
control_constraints_ub_dual_fun = t -> CTBase.ctinterpolate(T, CTBase.matrix2vec(control_constraints_ub_dual[:, 1:dim_u], 1))(t)

# build Models
time_grid = TimeGridModel(T)
state = StateModelSolution(state_name(ocp), state_components(ocp), fx)
control = ControlModelSolution(control_name(ocp), control_components(ocp), fu)
variable = VariableModelSolution(variable_name(ocp), variable_cpùomponents(ocp), var)
dual = DualModel(
state_constraints_lb_dual_fun,
state_constraints_ub_dual_fun,
control_constraints_lb_dual_fun,
control_constraints_ub_dual_fun,
variable_constraints_lb_dual,
variable_constraints_ub_dual,
boundary_constraints,
boundary_constraints_dual,
path_constraints_fun,
path_constraints_dual_fun,
variable_constraints,
variable_constraints_dual,
)
solver_infos = SolverInfos(
iterations,
stopping,
message,
success,
constraints_violation,
infos,
)

return Solution(
time_grid,
times(ocp),
state,
control,
variable,
fp,
cost,
dual,
solver_infos
)

end
20 changes: 6 additions & 14 deletions src/state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,11 @@ end
# ------------------------------------------------------------------------------ #

# from StateModel
name(ocp::StateModel)::String = ocp.name
components(ocp::StateModel)::Vector{String} = ocp.components
(dimension(ocp::StateModel)::Dimension) = length(components(ocp))
name(model::StateModel)::String = model.name
components(model::StateModel)::Vector{String} = model.components
(dimension(model::StateModel)::Dimension) = length(components(model))

# from Model
(state(ocp::Model{T, S, C, V, D, O, B})::S) where {
T<:AbstractTimesModel,
S<:AbstractStateModel,
C<:AbstractControlModel,
V<:AbstractVariableModel,
D<:Function,
O<:AbstractObjectiveModel,
B<:ConstraintsDictType} = ocp.state
state_name(ocp::Model)::String = name(state(ocp))
state_components(ocp::Model)::Vector{String} = components(state(ocp))
state_dimension(ocp::Model)::Dimension = dimension(state(ocp))
state_name(ocp::Model)::String = name(ocp.state)
state_components(ocp::Model)::Vector{String} = components(ocp.state)
state_dimension(ocp::Model)::Dimension = dimension(ocp.state)
54 changes: 27 additions & 27 deletions src/times.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,38 +129,38 @@ end
# ------------------------------------------------------------------------------ #

# From FixedTimeModel
time(ocp::FixedTimeModel)::Time = ocp.time
name(ocp::FixedTimeModel)::String = ocp.name
time(model::FixedTimeModel)::Time = model.time
name(model::FixedTimeModel)::String = model.name

# From FreeTimeModel
index(ocp::FreeTimeModel)::Int = ocp.index
name(ocp::FreeTimeModel)::String = ocp.name
function time(ocp::FreeTimeModel, variable::AbstractVector{T})::T where T<:ctNumber
# check if ocp.index in [1, length(variable)]
!(1 ocp.index length(variable)) && throw(CTBase.IncorrectArgument("the index of the time variable must be contained in 1:$(length(variable))"))
return variable[ocp.index]
index(model::FreeTimeModel)::Int = model.index
name(model::FreeTimeModel)::String = model.name
function time(model::FreeTimeModel, variable::AbstractVector{T})::T where T<:ctNumber
# check if model.index in [1, length(variable)]
!(1 model.index length(variable)) && throw(CTBase.IncorrectArgument("the index of the time variable must be contained in 1:$(length(variable))"))
return variable[model.index]
end

# From TimesModel
(initial(ocp::TimesModel{TI, TF})::TI) where {TI <: AbstractTimeModel, TF <: AbstractTimeModel} = ocp.initial
(final(ocp::TimesModel{TI, TF})::TF) where {TI <: AbstractTimeModel, TF <: AbstractTimeModel} = ocp.final
time_name(ocp::TimesModel)::String = ocp.time_name
initial_time(ocp::TimesModel{FixedTimeModel, <:AbstractTimeModel})::Time = time(initial(ocp))
final_time(ocp::TimesModel{<:AbstractTimeModel, FixedTimeModel})::Time = time(final(ocp))
initial_time(ocp::TimesModel{FreeTimeModel, <:AbstractTimeModel}, variable::Variable)::Time = time(initial(ocp), variable)
final_time(ocp::TimesModel{<:AbstractTimeModel, FreeTimeModel}, variable::Variable)::Time = time(final(ocp), variable)
(initial(model::TimesModel{TI, TF})::TI) where {TI <: AbstractTimeModel, TF <: AbstractTimeModel} = model.initial
(final(model::TimesModel{TI, TF})::TF) where {TI <: AbstractTimeModel, TF <: AbstractTimeModel} = model.final
time_name(model::TimesModel)::String = model.time_name
initial_time(model::TimesModel{FixedTimeModel, <:AbstractTimeModel})::Time = time(initial(model))
final_time(model::TimesModel{<:AbstractTimeModel, FixedTimeModel})::Time = time(final(model))
initial_time(model::TimesModel{FreeTimeModel, <:AbstractTimeModel}, variable::Variable)::Time = time(initial(model), variable)
final_time(model::TimesModel{<:AbstractTimeModel, FreeTimeModel}, variable::Variable)::Time = time(final(model), variable)

# From Model
(times(ocp::Model{T, S, C, V, D, O, B})::T) where {
T<:AbstractTimesModel,
S<:AbstractStateModel,
C<:AbstractControlModel,
V<:AbstractVariableModel,
T<:TimesModel,
S<:AbstractStateModel,
C<:AbstractControlModel,
V<:AbstractVariableModel,
D<:Function,
O<:AbstractObjectiveModel,
B<:ConstraintsDictType} = ocp.times

time_name(ocp::Model)::String = time_name(times(ocp))
time_name(ocp::Model)::String = time_name(ocp.times)

(initial_time(ocp::Model{T, S, C, V, D, O, B})::Time) where {
T<:TimesModel{FixedTimeModel, <:AbstractTimeModel},
Expand All @@ -169,7 +169,7 @@ time_name(ocp::Model)::String = time_name(times(ocp))
V<:AbstractVariableModel,
D<:Function,
O<:AbstractObjectiveModel,
B<:ConstraintsDictType} = initial_time(times(ocp))
B<:ConstraintsDictType} = initial_time(ocp.times)

(final_time(ocp::Model{T, S, C, V, D, O, B})::Time) where {
T<:TimesModel{<:AbstractTimeModel, FixedTimeModel},
Expand All @@ -178,7 +178,7 @@ time_name(ocp::Model)::String = time_name(times(ocp))
V<:AbstractVariableModel,
D<:Function,
O<:AbstractObjectiveModel,
B<:ConstraintsDictType} = final_time(times(ocp))
B<:ConstraintsDictType} = final_time(ocp.times)

(initial_time(ocp::Model{T, S, C, V, D, O, B}, variable::Variable)::Time) where {
T<:TimesModel{FreeTimeModel, <:AbstractTimeModel},
Expand All @@ -187,7 +187,7 @@ time_name(ocp::Model)::String = time_name(times(ocp))
V<:AbstractVariableModel,
D<:Function,
O<:AbstractObjectiveModel,
B<:ConstraintsDictType} = initial_time(times(ocp), variable)
B<:ConstraintsDictType} = initial_time(ocp.times, variable)

(final_time(ocp::Model{T, S, C, V, D, O, B}, variable::Variable)::Time) where {
T<:TimesModel{<:AbstractTimeModel, FreeTimeModel},
Expand All @@ -196,11 +196,11 @@ time_name(ocp::Model)::String = time_name(times(ocp))
V<:AbstractVariableModel,
D<:Function,
O<:AbstractObjectiveModel,
B<:ConstraintsDictType} = final_time(times(ocp), variable)
B<:ConstraintsDictType} = final_time(ocp.times, variable)

initial_time_name(ocp::Model)::String = name(initial(times(ocp)))
initial_time_name(ocp::Model)::String = name(initial(ocp.times))

final_time_name(ocp::Model)::String = name(final(times(ocp)))
final_time_name(ocp::Model)::String = name(final(ocp.times))

(has_fixed_initial_time(ocp::Model{T, S, C, V, D, O, B})::Bool) where {
T<:TimesModel{FixedTimeModel, <:AbstractTimeModel},
Expand Down
Loading

0 comments on commit fe0e6d2

Please sign in to comment.