Skip to content

Commit

Permalink
Add __auto
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored May 30, 2024
1 parent f18da04 commit 1314e97
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
4 changes: 3 additions & 1 deletion src/CTBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module CTBase
import Base
using DocStringExtensions
using DifferentiationInterface: AutoForwardDiff, derivative, gradient, jacobian, prepare_derivative, prepare_gradient, prepare_jacobian
using ForwardDiff: ForwardDiff # automatic differentiation
import ForwardDiff
using Interpolations: linear_interpolation, Line, Interpolations # for default interpolation
using MLStyle # pattern matching
using Parameters # @with_kw: to have default values in struct
Expand Down Expand Up @@ -95,6 +95,8 @@ Type alias for a tangent vector to the costate space.
"""
const DCostate = ctVector

__auto() = AutoForwardDiff()

#
include("exception.jl")
include("description.jl")
Expand Down
10 changes: 5 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ $(TYPEDSIGNATURES)
Return the gradient of `f` at `x`.
"""
function ctgradient(f::Function, x::ctNumber)
backend = AutoForwardDiff()
backend = __auto()
extras = prepare_derivative(f, backend, x)
return derivative(f, backend, x, extras)
end
Expand All @@ -85,7 +85,7 @@ $(TYPEDSIGNATURES)
Return the gradient of `f` at `x`.
"""
function ctgradient(f::Function, x)
backend = AutoForwardDiff()
backend = __auto()
extras = prepare_gradient(f, backend, x)
return gradient(f, backend, x, extras)
end
Expand All @@ -104,7 +104,7 @@ Return the Jacobian of `f` at `x`.
"""
function ctjacobian(f::Function, x::ctNumber)
f_number_to_number = only f only
backend = AutoForwardDiff()
backend = __auto()
extras = prepare_derivative(f_number_to_number, backend, x)
der = derivative(f_number_to_number, backend, x, extras)
return [der;;]
Expand All @@ -116,7 +116,7 @@ $(TYPEDSIGNATURES)
Return the Jacobian of `f` at `x`.
"""
function ctjacobian(f::Function, x)
backend = AutoForwardDiff()
backend = __auto()
extras = prepare_jacobian(f, backend, x)
return jacobian(f, backend, x, extras)
end
Expand Down Expand Up @@ -207,4 +207,4 @@ function matrix2vec(x::Matrix{<:ctNumber}, dim::Integer=__matrix_dimension_stock
end
end
return y
end
end

0 comments on commit 1314e97

Please sign in to comment.