From 931cb4f01bc6576d5a707920face447c580b0b13 Mon Sep 17 00:00:00 2001 From: Pierre Martinon Date: Mon, 26 Aug 2024 16:59:10 +0200 Subject: [PATCH] updated json format export/import (#220) --- Project.toml | 2 +- ext/CTDirectExt.jl | 35 +++++++++++++++++++++++++++-------- src/solution.jl | 12 ++++++------ test/suite/test_misc.jl | 10 +++++----- 4 files changed, 39 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index e4dd5268..492e9be2 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,7 @@ CTSolveExtMadNLP = ["MadNLP"] [compat] ADNLPModels = "0.8" -CTBase = "0.12, 0.13" +CTBase = "0.13" DocStringExtensions = "0.9" HSL = "0.4" JLD2 = "0.4" diff --git a/ext/CTDirectExt.jl b/ext/CTDirectExt.jl index b2989498..07b6e701 100644 --- a/ext/CTDirectExt.jl +++ b/ext/CTDirectExt.jl @@ -33,10 +33,18 @@ $(TYPEDSIGNATURES) Export OCP solution in JSON format """ function CTDirect.export_ocp_solution(sol::OptimalControlSolution; filename_prefix="solution") - # +++ redo this, start with basics, fuse into save - #open(filename_prefix * ".json", "w") do io - # JSON3.pretty(io, CTDirect.OCPDiscreteSolution(sol)) - #end + # fuse into save ? + blob = Dict( + "objective" => sol.objective, + "time_grid" => sol.time_grid, + "state" => state_discretized(sol), + "control" => control_discretized(sol), + "costate" => costate_discretized(sol)[1:end-1,:], + "variable" => sol.variable + ) + open(filename_prefix * ".json", "w") do io + JSON3.pretty(io, blob) + end return nothing end @@ -45,10 +53,21 @@ $(TYPEDSIGNATURES) Read OCP solution in JSON format """ -function CTDirect.import_ocp_solution(filename_prefix="solution") - # +++ add constructor from json blob, fuse into load - #json_string = read(filename_prefix * ".json", String) - #return OptimalControlSolution(JSON3.read(json_string)) +function CTDirect.import_ocp_solution(ocp::OptimalControlModel; filename_prefix="solution") + # fuse into load ? + json_string = read(filename_prefix * ".json", String) + blob = JSON3.read(json_string) + + # NB. convert vect{vect} to matrix + return OptimalControlSolution( + ocp, + blob.time_grid, + stack(blob.state, dims=1), + stack(blob.control, dims=1), + blob.variable, + stack(blob.costate, dims=1); + objective = blob.objective + ) end diff --git a/src/solution.jl b/src/solution.jl index f5f6559e..5343187e 100644 --- a/src/solution.jl +++ b/src/solution.jl @@ -79,7 +79,7 @@ function CTBase.OptimalControlSolution( # call lowest level constructor return OptimalControlSolution( - docp, + docp.ocp, T, X, U, @@ -263,7 +263,7 @@ $(TYPEDSIGNATURES) Build OCP functional solution from DOCP vector solution (given as raw variables and multipliers plus some optional infos) """ function CTBase.OptimalControlSolution( - docp, + ocp::OptimalControlModel, T, X, U, @@ -280,7 +280,6 @@ function CTBase.OptimalControlSolution( box_multipliers = ((nothing, nothing), (nothing, nothing), (nothing, nothing)), ) - ocp = docp.ocp dim_x = state_dimension(ocp) dim_u = control_dimension(ocp) dim_v = variable_dimension(ocp) @@ -292,13 +291,14 @@ function CTBase.OptimalControlSolution( "WARNING: time grid at solution is not strictly increasing, replacing with list of indices...", ) println(T) - T = LinRange(0, docp.dim_NLP_steps, docp.dim_NLP_steps + 1) + dim_NLP_steps = length(T) - 1 + T = LinRange(0, dim_NLP_steps, dim_NLP_steps + 1) end # variables: remove additional state for lagrange cost x = ctinterpolate(T, matrix2vec(X[:, 1:dim_x], 1)) p = ctinterpolate(T[1:end-1], matrix2vec(P[:, 1:dim_x], 1)) - u = ctinterpolate(T, matrix2vec(U, 1)) + u = ctinterpolate(T, matrix2vec(U[:, 1:dim_u], 1)) # force scalar output when dimension is 1 fx = (dim_x == 1) ? deepcopy(t -> x(t)[1]) : deepcopy(t -> x(t)) @@ -335,7 +335,7 @@ function CTBase.OptimalControlSolution( ) = set_box_multipliers(T, box_multipliers, dim_x, dim_u) # build and return solution - if docp.has_variable + if is_variable_dependent(ocp) return OptimalControlSolution( ocp; state = fx, diff --git a/test/suite/test_misc.jl b/test/suite/test_misc.jl index b430b639..8bb39bfc 100644 --- a/test/suite/test_misc.jl +++ b/test/suite/test_misc.jl @@ -24,16 +24,16 @@ sol0 = direct_solve(ocp, display = false) # test save / load solution in JLD2 format @testset verbose = true showtiming = true ":save_load :JLD2" begin - save(sol0, filename_prefix = "solution_test") + save(sol0; filename_prefix = "solution_test") sol_reloaded = load("solution_test") @test sol0.objective == sol_reloaded.objective end -#= + # test export / read solution in JSON format @testset verbose = true showtiming = true ":export_read :JSON" begin - export_ocp_solution(sol0, filename_prefix = "solution_test") - sol_reloaded = import_ocp_solution("solution_test") + export_ocp_solution(sol0; filename_prefix = "solution_test") + sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test") @test sol0.objective == sol_reloaded.objective end -=# +