-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement gradient-free optimization via Optimization.jl
- Loading branch information
Showing
6 changed files
with
440 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,21 @@ | ||
name = "ParameterizedQuantumControl" | ||
uuid = "409be4c9-afa4-4246-894e-472b92a1ed06" | ||
authors = ["Michael Goerz <[email protected]>"] | ||
version = "0.0.1" | ||
version = "0.1.0-dev" | ||
|
||
[deps] | ||
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" | ||
QuantumControlBase = "f10a33bc-5a64-497c-be7b-6f86b4f0c2aa" | ||
QuantumGradientGenerators = "a563f35e-61db-434d-8c01-8b9e3ccdfd85" | ||
|
||
[weakdeps] | ||
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" | ||
|
||
[extensions] | ||
ParameterizedQuantumControlOptimizationExt = "Optimization" | ||
|
||
[compat] | ||
LinearAlgebra = "1" | ||
QuantumControlBase = ">=0.9.0" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
module ParameterizedQuantumControlOptimizationExt | ||
using Dates: now | ||
|
||
using Optimization: Optimization, OptimizationProblem | ||
import ParameterizedQuantumControl: run_optimizer, update_result! | ||
|
||
function run_optimizer( | ||
backend::Val{:Optimization}, | ||
optimizer, | ||
wrk, | ||
f, | ||
info_hook, | ||
check_convergence! | ||
) | ||
u0 = copy(wrk.result.guess_parameters) | ||
#kwargs = ... # TODO: optimization_kwargs | ||
function callback(state, loss_val) | ||
wrk.optimizer_state = state | ||
iter = wrk.result.iter | ||
update_result!(wrk, iter) | ||
#update_hook!(...) # TODO | ||
info_tuple = info_hook(wrk, iter) | ||
copyto!(wrk.result.guess_parameters, wrk.result.optimized_parameters) | ||
wrk.fg_count .= 0 | ||
(info_tuple !== nothing) && push!(wrk.result.records, info_tuple) | ||
check_convergence!(wrk.result) | ||
wrk.result.iter += 1 # next iteration | ||
return wrk.result.converged | ||
end | ||
prob = OptimizationProblem((u, _) -> f(u), u0, nothing) | ||
try | ||
sol = Optimization.solve(prob, optimizer; callback) | ||
catch exc | ||
exc_msg = sprint(showerror, exc) | ||
if !contains(exc_msg, "Optimization halted by callback") | ||
rethrow() | ||
end | ||
end | ||
wrk.result.end_local_time = now() | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,7 @@ | ||
module ParameterizedQuantumControl | ||
|
||
"""Print "HELLO WORLD".""" | ||
function hello_world() | ||
println("HELLO WORLD") | ||
end | ||
|
||
#include("workspace.jl") | ||
#include("result.jl") | ||
#include("optimize.jl") | ||
include("workspace.jl") | ||
include("result.jl") | ||
include("optimize.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import QuantumControlBase: optimize | ||
|
||
using LinearAlgebra | ||
using QuantumControlBase: @threadsif | ||
using QuantumControlBase: set_atexit_save_optimization | ||
using QuantumControlBase.QuantumPropagators: reinit_prop!, propagate | ||
|
||
|
||
@doc raw""" | ||
```julia | ||
using ParameterizedQuantumControl | ||
result = optimize(problem; method=ParameterizedQuantumControl, kwargs...) | ||
``` | ||
optimizes the given control [`problem`](@ref QuantumControlBase.ControlProblem) | ||
by varying a set of control parameters in order to minimize the functional | ||
```math | ||
J(\{u_{n}\}) = J_T(\{|ϕ_k(T)⟩\}) | ||
``` | ||
where ``|ϕ_k(T)⟩`` is the result of propagating the initial state of the | ||
``k``'th trajectory under the parameters ``\{u_n\}`` | ||
Returns a [`ParameterizedOptResult`](@ref). | ||
Keyword arguments that control the optimization are taken from the keyword | ||
arguments used in the instantiation of `problem`; any of these can be overridden | ||
with explicit keyword arguments to `optimize`. | ||
# Required problem keyword arguments | ||
* `backend`: A package to perform the optimization, e.g., `Optimization` (for | ||
[Optimization.jl](https://github.com/SciML/Optimization.jl)) | ||
* `optimizer`: A backend-specific object to perform the optimizatino, e.g., | ||
`NLopt.LN_NELDERMEAD()` from `NLOpt`/`OptimizationNLOpt` | ||
* `J_T`: A function `J_T(ϕ, trajectories; τ=τ)` that evaluates the final time | ||
functional from a vector `ϕ` of forward-propagated states and | ||
`problem.trajectories`. For all `trajectories` that define a `target_state`, | ||
the element `τₖ` of the vector `τ` will contain the overlap of the state `ϕₖ` | ||
with the `target_state` of the `k`'th trajectory, or `NaN` otherwise. | ||
""" | ||
function optimize_parameters(problem) | ||
|
||
verbose = get(problem.kwargs, :verbose, false) | ||
wrk = ParameterizedOptWrk(problem; verbose) | ||
|
||
J_T_func = wrk.kwargs[:J_T] | ||
|
||
initial_states = [traj.initial_state for traj ∈ wrk.trajectories] | ||
Ψtgt = Union{eltype(initial_states),Nothing}[ | ||
(hasproperty(traj, :target_state) ? traj.target_state : nothing) for | ||
traj ∈ wrk.trajectories | ||
] | ||
τ = wrk.result.tau_vals | ||
J = wrk.J_parts | ||
|
||
N = length(wrk.trajectories) | ||
|
||
# loss function | ||
function f(u; count_call=true) | ||
copyto!(wrk.parameters, u) | ||
@threadsif wrk.use_threads for k = 1:N | ||
Ψ₀ = wrk.trajectories[k].initial_state | ||
Ψₖ = propagate(Ψ₀, wrk.propagators[k]) | ||
τ[k] = isnothing(Ψtgt[k]) ? NaN : (Ψtgt[k] ⋅ Ψₖ) | ||
end | ||
Ψ = [p.state for p ∈ wrk.propagators] | ||
J[1] = J_T_func(Ψ, wrk.trajectories; τ=τ) | ||
if count_call | ||
wrk.fg_count[2] += 1 | ||
end | ||
return sum(J) | ||
end | ||
|
||
backend = wrk.backend | ||
optimizer = wrk.optimizer | ||
info_hook = get(problem.kwargs, :info_hook, print_table) | ||
check_convergence! = get(problem.kwargs, :check_convergence, res -> res) | ||
|
||
atexit_filename = get(problem.kwargs, :atexit_filename, nothing) | ||
# atexit_filename is undocumented on purpose: this is considered a feature | ||
# of @optimize_or_load | ||
if !isnothing(atexit_filename) | ||
set_atexit_save_optimization(atexit_filename, wrk.result) | ||
if !isinteractive() | ||
@info "Set callback to store result in $(relpath(atexit_filename)) on unexpected exit." | ||
# In interactive mode, `atexit` is very unlikely, and | ||
# `InterruptException` is handles via try/catch instead. | ||
end | ||
end | ||
try | ||
run_optimizer(Val(backend), optimizer, wrk, f, info_hook, check_convergence!) | ||
catch exc | ||
# Primarily, this is intended to catch Ctrl-C in interactive | ||
# optimizations (InterruptException) | ||
exc_msg = sprint(showerror, exc) | ||
wrk.result.message = "Exception: $exc_msg" | ||
end | ||
if !isnothing(atexit_filename) | ||
popfirst!(Base.atexit_hooks) | ||
end | ||
|
||
return wrk.result | ||
|
||
end | ||
|
||
|
||
# backend code stub (see extensions) | ||
function run_optimizer end | ||
|
||
|
||
"""Print optimization progress as a table. | ||
This functions serves as the default `info_hook` for an optimization with | ||
`ParameterizedQuantumControl`. | ||
""" | ||
function print_table(wrk, iteration, args...) | ||
# TODO: make_print_table that precomputes headers and such, and maybe | ||
# allows for more options. | ||
# TODO: should we report ΔJ instead of ΔJ_T? | ||
|
||
J_T = wrk.result.J_T | ||
ΔJ_T = J_T - wrk.result.J_T_prev | ||
secs = wrk.result.secs | ||
|
||
headers = ["iter.", "J_T", "ΔJ_T", "FG(F)", "secs"] | ||
@assert length(wrk.J_parts) == 1 | ||
|
||
iter_stop = "$(get(wrk.kwargs, :iter_stop, 5000))" | ||
width = Dict( | ||
"iter." => max(length("$iter_stop"), 6), | ||
"J_T" => 11, | ||
"|∇J_T|" => 11, | ||
"|∇J_a|" => 11, | ||
"|∇J|" => 11, | ||
"ΔJ" => 11, | ||
"ΔJ_T" => 11, | ||
"FG(F)" => 8, | ||
"secs" => 8, | ||
) | ||
|
||
if iteration == 0 | ||
for header in headers | ||
w = width[header] | ||
print(lpad(header, w)) | ||
end | ||
print("\n") | ||
end | ||
|
||
strs = [ | ||
"$iteration", | ||
@sprintf("%.2e", J_T), | ||
(iteration > 0) ? @sprintf("%.2e", ΔJ_T) : "n/a", | ||
@sprintf("%d(%d)", wrk.fg_count[1], wrk.fg_count[2]), | ||
@sprintf("%.1f", secs), | ||
] | ||
for (str, header) in zip(strs, headers) | ||
w = width[header] | ||
print(lpad(str, w)) | ||
end | ||
print("\n") | ||
flush(stdout) | ||
end | ||
|
||
|
||
# Transfer information from `wrk` to `wrk.result` in each iteration (before the | ||
# `info_hook`) | ||
function update_result!(wrk::ParameterizedOptWrk, i::Int64) | ||
res = wrk.result | ||
for (k, propagator) in enumerate(wrk.propagators) | ||
copyto!(res.states[k], propagator.state) | ||
end | ||
copyto!(wrk.result.optimized_parameters, wrk.parameters) | ||
res.f_calls += wrk.fg_count[2] | ||
res.fg_calls += wrk.fg_count[1] | ||
res.J_T_prev = res.J_T | ||
res.J_T = wrk.J_parts[1] | ||
(i > 0) && (res.iter = i) | ||
if i >= res.iter_stop | ||
res.converged = true | ||
res.message = "Reached maximum number of iterations" | ||
# Note: other convergence checks are done in user-supplied | ||
# check_convergence routine | ||
end | ||
prev_time = res.end_local_time | ||
res.end_local_time = now() | ||
res.secs = Dates.toms(res.end_local_time - prev_time) / 1000.0 | ||
end | ||
|
||
|
||
optimize(problem, method::Val{:ParameterizedQuantumControl}) = optimize_parameters(problem) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
using QuantumControlBase: get_parameters | ||
using Printf | ||
using Dates | ||
|
||
|
||
"""Result object returned by [`optimize_parameters`](@ref).""" | ||
mutable struct ParameterizedOptResult{STST,PT} | ||
tlist::Vector{Float64} | ||
iter_start::Int64 # the starting iteration number | ||
iter_stop::Int64 # the maximum iteration number | ||
iter::Int64 # the current iteration number | ||
secs::Float64 # seconds that the last iteration took | ||
tau_vals::Vector{ComplexF64} | ||
J_T::Float64 # the current value of the final-time functional J_T | ||
J_T_prev::Float64 # previous value of J_T | ||
guess_parameters::PT | ||
optimized_parameters::PT | ||
states::Vector{STST} | ||
start_local_time::DateTime | ||
end_local_time::DateTime | ||
records::Vector{Tuple} # storage for info_hook to write data into at each iteration | ||
converged::Bool | ||
f_calls::Int64 | ||
fg_calls::Int64 | ||
message::String | ||
|
||
function ParameterizedOptResult(problem) | ||
tlist = problem.tlist | ||
iter_start = get(problem.kwargs, :iter_start, 0) | ||
iter_stop = get(problem.kwargs, :iter_stop, 5000) | ||
iter = iter_start | ||
secs = 0 | ||
tau_vals = zeros(ComplexF64, length(problem.trajectories)) | ||
J_T = 0.0 | ||
J_T_prev = 0.0 | ||
parameters = get(problem.kwargs, :parameters, get_parameters(problem)) | ||
optimized_parameters = copy(parameters) | ||
guess_parameters = copy(parameters) | ||
states = [similar(traj.initial_state) for traj in problem.trajectories] | ||
start_local_time = now() | ||
end_local_time = now() | ||
records = Vector{Tuple}() | ||
converged = false | ||
message = "in progress" | ||
f_calls = 0 | ||
fg_calls = 0 | ||
PT = typeof(optimized_parameters) | ||
STST = eltype(states) | ||
new{STST,PT}( | ||
tlist, | ||
iter_start, | ||
iter_stop, | ||
iter, | ||
secs, | ||
tau_vals, | ||
J_T, | ||
J_T_prev, | ||
guess_parameters, | ||
optimized_parameters, | ||
states, | ||
start_local_time, | ||
end_local_time, | ||
records, | ||
converged, | ||
f_calls, | ||
fg_calls, | ||
message | ||
) | ||
end | ||
end | ||
|
||
|
||
Base.show(io::IO, r::ParameterizedOptResult) = | ||
print(io, "ParameterizedOptResult<$(r.message)>") | ||
Base.show(io::IO, ::MIME"text/plain", r::ParameterizedOptResult) = print( | ||
io, | ||
""" | ||
Parameterized Optimization Result | ||
--------------------------------- | ||
- Started at $(r.start_local_time) | ||
- Number of trajectories: $(length(r.states)) | ||
- Number of iterations: $(max(r.iter - r.iter_start, 0)) | ||
- Number of pure func evals: $(r.f_calls) | ||
- Number of func/grad evals: $(r.fg_calls) | ||
- Value of functional: $(@sprintf("%.5e", r.J_T)) | ||
- Reason for termination: $(r.message) | ||
- Ended at $(r.end_local_time) ($(Dates.canonicalize(Dates.CompoundPeriod(r.end_local_time - r.start_local_time)))) | ||
""" | ||
) |
Oops, something went wrong.