Skip to content

Commit

Permalink
Merge pull request #160 from control-toolbox/AD
Browse files Browse the repository at this point in the history
Add change of default AD backend
  • Loading branch information
ocots authored Jun 20, 2024
2 parents b4d8e3d + 28ab346 commit 7d7ce56
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 18 deletions.
3 changes: 2 additions & 1 deletion src/CTBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ export Description, add, getFullDescription
export CTException, ParsingError, AmbiguousDescription, IncorrectMethod
export IncorrectArgument, IncorrectOutput, NotImplemented, UnauthorizedCall

# checking
# AD
export set_AD_backend

# functions
export Hamiltonian, HamiltonianVectorField, VectorField
Expand Down
10 changes: 9 additions & 1 deletion src/default.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
#
__default_AD_backend = AutoForwardDiff()

function set_AD_backend(AD)
global __default_AD_backend = AD
nothing
end

"""
$(TYPEDSIGNATURES)
Used to set the default value of Automatic Differentiation backend.
The default value is `AutoForwardDiff()`, that is the `ForwardDiff` package is used by default.
"""
__auto() = AutoForwardDiff() # default AD backend
__get_AD_backend() = __default_AD_backend # default AD backend

"""
$(TYPEDSIGNATURES)
Expand Down
19 changes: 10 additions & 9 deletions src/repl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ end
end

#
global ct_repl_is_set::Bool = false
global ct_repl_data::CTRepl
global ct_repl_ct_repl_history::HistoryRepl
ct_repl_is_set::Bool = false
ct_repl_data::CTRepl = CTRepl()
ct_repl_history::HistoryRepl = HistoryRepl(0, Vector{ModelRepl}())

"""
$(TYPEDSIGNATURES)
Expand Down Expand Up @@ -47,13 +47,14 @@ Create a ct REPL.
"""
function ct_repl(; debug=false, demo=false, verbose=false)

global ct_repl_is_set
global ct_repl_data
global ct_repl_history

if !ct_repl_is_set
#
global ct_repl_is_set = true

# init: ct_repl_data, ct_repl_history
global ct_repl_data = CTRepl()
global ct_repl_history = HistoryRepl(0, Vector{ModelRepl}())

#
ct_repl_is_set = true

#
ct_repl_data.debug = debug
Expand Down
8 changes: 4 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ $(TYPEDSIGNATURES)
Return the gradient of `f` at `x`.
"""
function ctgradient(f::Function, x::ctNumber; backend=__auto())
function ctgradient(f::Function, x::ctNumber; backend=__get_AD_backend())
extras = prepare_derivative(f, backend, x)
return derivative(f, backend, x, extras)
end
Expand All @@ -83,7 +83,7 @@ $(TYPEDSIGNATURES)
Return the gradient of `f` at `x`.
"""
function ctgradient(f::Function, x; backend=__auto())
function ctgradient(f::Function, x; backend=__get_AD_backend())
extras = prepare_gradient(f, backend, x)
return gradient(f, backend, x, extras)
end
Expand All @@ -100,7 +100,7 @@ $(TYPEDSIGNATURES)
Return the Jacobian of `f` at `x`.
"""
function ctjacobian(f::Function, x::ctNumber; backend=__auto())
function ctjacobian(f::Function, x::ctNumber; backend=__get_AD_backend())
f_number_to_number = only f only
extras = prepare_derivative(f_number_to_number, backend, x)
der = derivative(f_number_to_number, backend, x, extras)
Expand All @@ -112,7 +112,7 @@ $(TYPEDSIGNATURES)
Return the Jacobian of `f` at `x`.
"""
function ctjacobian(f::Function, x; backend=__auto())
function ctjacobian(f::Function, x; backend=__get_AD_backend())
extras = prepare_jacobian(f, backend, x)
return jacobian(f, backend, x, extras)
end
Expand Down
3 changes: 0 additions & 3 deletions test/test_description.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ descriptions = add(descriptions, (:b,))
@test descriptions[1] == (:a,)
@test descriptions[2] == (:b,)

# print a tuple of descriptions
@test display(descriptions) isa Nothing

# get the complete description of the chosen method
algorithmes = ()
algorithmes = add(algorithmes, (:descent, :bfgs, :bissection))
Expand Down

0 comments on commit 7d7ce56

Please sign in to comment.