Skip to content

Commit

Permalink
Implement gradient-free optimization via Optimization.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
goerz committed Apr 1, 2024
1 parent 32e9d23 commit 0716b27
Show file tree
Hide file tree
Showing 6 changed files with 440 additions and 9 deletions.
10 changes: 9 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
name = "ParameterizedQuantumControl"
uuid = "409be4c9-afa4-4246-894e-472b92a1ed06"
authors = ["Michael Goerz <[email protected]>"]
version = "0.0.1"
version = "0.1.0-dev"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
QuantumControlBase = "f10a33bc-5a64-497c-be7b-6f86b4f0c2aa"
QuantumGradientGenerators = "a563f35e-61db-434d-8c01-8b9e3ccdfd85"

[weakdeps]
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"

[extensions]
ParameterizedQuantumControlOptimizationExt = "Optimization"

[compat]
LinearAlgebra = "1"
QuantumControlBase = ">=0.9.0"
Expand Down
42 changes: 42 additions & 0 deletions ext/ParameterizedQuantumControlOptimizationExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
module ParameterizedQuantumControlOptimizationExt
using Dates: now

using Optimization: Optimization, OptimizationProblem
import ParameterizedQuantumControl: run_optimizer, update_result!

function run_optimizer(
backend::Val{:Optimization},
optimizer,
wrk,
f,
info_hook,
check_convergence!
)
u0 = copy(wrk.result.guess_parameters)
#kwargs = ... # TODO: optimization_kwargs
function callback(state, loss_val)
wrk.optimizer_state = state
iter = wrk.result.iter
update_result!(wrk, iter)
#update_hook!(...) # TODO
info_tuple = info_hook(wrk, iter)
copyto!(wrk.result.guess_parameters, wrk.result.optimized_parameters)
wrk.fg_count .= 0
(info_tuple !== nothing) && push!(wrk.result.records, info_tuple)
check_convergence!(wrk.result)
wrk.result.iter += 1 # next iteration
return wrk.result.converged
end
prob = OptimizationProblem((u, _) -> f(u), u0, nothing)
try
sol = Optimization.solve(prob, optimizer; callback)
catch exc
exc_msg = sprint(showerror, exc)
if !contains(exc_msg, "Optimization halted by callback")
rethrow()
end
end
wrk.result.end_local_time = now()
end

end
11 changes: 3 additions & 8 deletions src/ParameterizedQuantumControl.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
module ParameterizedQuantumControl

"""Print "HELLO WORLD"."""
function hello_world()
println("HELLO WORLD")
end

#include("workspace.jl")
#include("result.jl")
#include("optimize.jl")
include("workspace.jl")
include("result.jl")
include("optimize.jl")

end
192 changes: 192 additions & 0 deletions src/optimize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import QuantumControlBase: optimize

using LinearAlgebra
using QuantumControlBase: @threadsif
using QuantumControlBase: set_atexit_save_optimization
using QuantumControlBase.QuantumPropagators: reinit_prop!, propagate


@doc raw"""
```julia
using ParameterizedQuantumControl
result = optimize(problem; method=ParameterizedQuantumControl, kwargs...)
```
optimizes the given control [`problem`](@ref QuantumControlBase.ControlProblem)
by varying a set of control parameters in order to minimize the functional
```math
J(\{u_{n}\}) = J_T(\{|ϕ_k(T)⟩\})
```
where ``|ϕ_k(T)⟩`` is the result of propagating the initial state of the
``k``'th trajectory under the parameters ``\{u_n\}``
Returns a [`ParameterizedOptResult`](@ref).
Keyword arguments that control the optimization are taken from the keyword
arguments used in the instantiation of `problem`; any of these can be overridden
with explicit keyword arguments to `optimize`.
# Required problem keyword arguments
* `backend`: A package to perform the optimization, e.g., `Optimization` (for
[Optimization.jl](https://github.com/SciML/Optimization.jl))
* `optimizer`: A backend-specific object to perform the optimizatino, e.g.,
`NLopt.LN_NELDERMEAD()` from `NLOpt`/`OptimizationNLOpt`
* `J_T`: A function `J_T(ϕ, trajectories; τ=τ)` that evaluates the final time
functional from a vector `ϕ` of forward-propagated states and
`problem.trajectories`. For all `trajectories` that define a `target_state`,
the element `τₖ` of the vector `τ` will contain the overlap of the state `ϕₖ`
with the `target_state` of the `k`'th trajectory, or `NaN` otherwise.
"""
function optimize_parameters(problem)

verbose = get(problem.kwargs, :verbose, false)
wrk = ParameterizedOptWrk(problem; verbose)

J_T_func = wrk.kwargs[:J_T]

initial_states = [traj.initial_state for traj wrk.trajectories]
Ψtgt = Union{eltype(initial_states),Nothing}[
(hasproperty(traj, :target_state) ? traj.target_state : nothing) for
traj wrk.trajectories
]
τ = wrk.result.tau_vals
J = wrk.J_parts

N = length(wrk.trajectories)

# loss function
function f(u; count_call=true)
copyto!(wrk.parameters, u)
@threadsif wrk.use_threads for k = 1:N
Ψ₀ = wrk.trajectories[k].initial_state
Ψₖ = propagate(Ψ₀, wrk.propagators[k])
τ[k] = isnothing(Ψtgt[k]) ? NaN : (Ψtgt[k] Ψₖ)
end
Ψ = [p.state for p wrk.propagators]
J[1] = J_T_func(Ψ, wrk.trajectories; τ=τ)
if count_call
wrk.fg_count[2] += 1
end
return sum(J)
end

backend = wrk.backend
optimizer = wrk.optimizer
info_hook = get(problem.kwargs, :info_hook, print_table)
check_convergence! = get(problem.kwargs, :check_convergence, res -> res)

atexit_filename = get(problem.kwargs, :atexit_filename, nothing)
# atexit_filename is undocumented on purpose: this is considered a feature
# of @optimize_or_load
if !isnothing(atexit_filename)
set_atexit_save_optimization(atexit_filename, wrk.result)
if !isinteractive()
@info "Set callback to store result in $(relpath(atexit_filename)) on unexpected exit."
# In interactive mode, `atexit` is very unlikely, and
# `InterruptException` is handles via try/catch instead.
end
end
try
run_optimizer(Val(backend), optimizer, wrk, f, info_hook, check_convergence!)
catch exc
# Primarily, this is intended to catch Ctrl-C in interactive
# optimizations (InterruptException)
exc_msg = sprint(showerror, exc)
wrk.result.message = "Exception: $exc_msg"
end
if !isnothing(atexit_filename)
popfirst!(Base.atexit_hooks)
end

return wrk.result

end


# backend code stub (see extensions)
function run_optimizer end


"""Print optimization progress as a table.
This functions serves as the default `info_hook` for an optimization with
`ParameterizedQuantumControl`.
"""
function print_table(wrk, iteration, args...)
# TODO: make_print_table that precomputes headers and such, and maybe
# allows for more options.
# TODO: should we report ΔJ instead of ΔJ_T?

J_T = wrk.result.J_T
ΔJ_T = J_T - wrk.result.J_T_prev
secs = wrk.result.secs

headers = ["iter.", "J_T", "ΔJ_T", "FG(F)", "secs"]
@assert length(wrk.J_parts) == 1

iter_stop = "$(get(wrk.kwargs, :iter_stop, 5000))"
width = Dict(
"iter." => max(length("$iter_stop"), 6),
"J_T" => 11,
"|∇J_T|" => 11,
"|∇J_a|" => 11,
"|∇J|" => 11,
"ΔJ" => 11,
"ΔJ_T" => 11,
"FG(F)" => 8,
"secs" => 8,
)

if iteration == 0
for header in headers
w = width[header]
print(lpad(header, w))
end
print("\n")
end

strs = [
"$iteration",
@sprintf("%.2e", J_T),
(iteration > 0) ? @sprintf("%.2e", ΔJ_T) : "n/a",
@sprintf("%d(%d)", wrk.fg_count[1], wrk.fg_count[2]),
@sprintf("%.1f", secs),
]
for (str, header) in zip(strs, headers)
w = width[header]
print(lpad(str, w))
end
print("\n")
flush(stdout)
end


# Transfer information from `wrk` to `wrk.result` in each iteration (before the
# `info_hook`)
function update_result!(wrk::ParameterizedOptWrk, i::Int64)
res = wrk.result
for (k, propagator) in enumerate(wrk.propagators)
copyto!(res.states[k], propagator.state)
end
copyto!(wrk.result.optimized_parameters, wrk.parameters)
res.f_calls += wrk.fg_count[2]
res.fg_calls += wrk.fg_count[1]
res.J_T_prev = res.J_T
res.J_T = wrk.J_parts[1]
(i > 0) && (res.iter = i)
if i >= res.iter_stop
res.converged = true
res.message = "Reached maximum number of iterations"
# Note: other convergence checks are done in user-supplied
# check_convergence routine
end
prev_time = res.end_local_time
res.end_local_time = now()
res.secs = Dates.toms(res.end_local_time - prev_time) / 1000.0
end


optimize(problem, method::Val{:ParameterizedQuantumControl}) = optimize_parameters(problem)
89 changes: 89 additions & 0 deletions src/result.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
using QuantumControlBase: get_parameters
using Printf
using Dates


"""Result object returned by [`optimize_parameters`](@ref)."""
mutable struct ParameterizedOptResult{STST,PT}
tlist::Vector{Float64}
iter_start::Int64 # the starting iteration number
iter_stop::Int64 # the maximum iteration number
iter::Int64 # the current iteration number
secs::Float64 # seconds that the last iteration took
tau_vals::Vector{ComplexF64}
J_T::Float64 # the current value of the final-time functional J_T
J_T_prev::Float64 # previous value of J_T
guess_parameters::PT
optimized_parameters::PT
states::Vector{STST}
start_local_time::DateTime
end_local_time::DateTime
records::Vector{Tuple} # storage for info_hook to write data into at each iteration
converged::Bool
f_calls::Int64
fg_calls::Int64
message::String

function ParameterizedOptResult(problem)
tlist = problem.tlist
iter_start = get(problem.kwargs, :iter_start, 0)
iter_stop = get(problem.kwargs, :iter_stop, 5000)
iter = iter_start
secs = 0
tau_vals = zeros(ComplexF64, length(problem.trajectories))
J_T = 0.0
J_T_prev = 0.0
parameters = get(problem.kwargs, :parameters, get_parameters(problem))
optimized_parameters = copy(parameters)
guess_parameters = copy(parameters)
states = [similar(traj.initial_state) for traj in problem.trajectories]
start_local_time = now()
end_local_time = now()
records = Vector{Tuple}()
converged = false
message = "in progress"
f_calls = 0
fg_calls = 0
PT = typeof(optimized_parameters)
STST = eltype(states)
new{STST,PT}(
tlist,
iter_start,
iter_stop,
iter,
secs,
tau_vals,
J_T,
J_T_prev,
guess_parameters,
optimized_parameters,
states,
start_local_time,
end_local_time,
records,
converged,
f_calls,
fg_calls,
message
)
end
end


Base.show(io::IO, r::ParameterizedOptResult) =
print(io, "ParameterizedOptResult<$(r.message)>")
Base.show(io::IO, ::MIME"text/plain", r::ParameterizedOptResult) = print(
io,
"""
Parameterized Optimization Result
---------------------------------
- Started at $(r.start_local_time)
- Number of trajectories: $(length(r.states))
- Number of iterations: $(max(r.iter - r.iter_start, 0))
- Number of pure func evals: $(r.f_calls)
- Number of func/grad evals: $(r.fg_calls)
- Value of functional: $(@sprintf("%.5e", r.J_T))
- Reason for termination: $(r.message)
- Ended at $(r.end_local_time) ($(Dates.canonicalize(Dates.CompoundPeriod(r.end_local_time - r.start_local_time))))
"""
)
Loading

0 comments on commit 0716b27

Please sign in to comment.