Skip to content

Commit

Permalink
removed useless @_def; added nothing case in to_out_of_place
Browse files Browse the repository at this point in the history
  • Loading branch information
jbcaillau committed Sep 6, 2024
1 parent c1bed0a commit 57c8fa5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
30 changes: 11 additions & 19 deletions src/onepass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ p_constraint!(p, ocp, e1, e2, e3, label = gensym(); log = false) = begin
ee2 = replace_call(ee2, p.x, p.tf, xf)
args = [x0, xf]
__v_dep(p) && push!(args, p.v)
quote
quote # todo
function $gs($(args...))
$ee2
end
Expand All @@ -374,7 +374,7 @@ p_constraint!(p, ocp, e1, e2, e3, label = gensym(); log = false) = begin
__t_dep(p) && push!(args, p.t)
push!(args, ut)
__v_dep(p) && push!(args, p.v)
quote
quote # todo
function $gs($(args...))
$ee2
end
Expand All @@ -392,7 +392,7 @@ p_constraint!(p, ocp, e1, e2, e3, label = gensym(); log = false) = begin
__t_dep(p) && push!(args, p.t)
push!(args, xt)
__v_dep(p) && push!(args, p.v)
quote
quote # todo
function $gs($(args...))
$ee2
end
Expand All @@ -404,7 +404,7 @@ p_constraint!(p, ocp, e1, e2, e3, label = gensym(); log = false) = begin
:variable_fun => begin
gs = gensym()
args = [p.v]
quote
quote # todo
function $gs($(args...))
$e2
end
Expand All @@ -421,7 +421,7 @@ p_constraint!(p, ocp, e1, e2, e3, label = gensym(); log = false) = begin
__t_dep(p) && push!(args, p.t)
push!(args, xt, ut)
__v_dep(p) && push!(args, p.v)
quote
quote # todo
function $gs($(args...))
$ee2
end
Expand All @@ -448,7 +448,7 @@ p_dynamics!(p, ocp, x, t, e, label = nothing; log = false) = begin
p.t_dep = p.t_dep || has(e, t)
gs = gensym()
args = [ ]; __t_dep(p) && push!(args, p.t); push!(args, xt, ut); __v_dep(p) && push!(args, p.v)
code = quote
code = quote # todo
function $gs($(args...))
$e
end
Expand All @@ -469,7 +469,7 @@ p_lagrange!(p, ocp, e, type; log = false) = begin
ttype = QuoteNode(type)
gs = gensym()
args = [ ]; __t_dep(p) && push!(args, p.t); push!(args, xt, ut); __v_dep(p) && push!(args, p.v)
code = quote
code = quote # todo
function $gs($(args...))
$e
end
Expand All @@ -495,7 +495,7 @@ p_mayer!(p, ocp, e, type; log = false) = begin
e = replace_call(e, p.x, p.tf, xf)
ttype = QuoteNode(type)
args = [ x0, xf ]; __v_dep(p) && push!(args, p.v)
code = quote
code = quote # todo
function $gs($(args...))
$e
end
Expand Down Expand Up @@ -528,7 +528,7 @@ p_bolza!(p, ocp, e1, e2, type; log = false) = begin
push!(args2, xt, ut)
__v_dep(p) && push!(args2, p.v)
ttype = QuoteNode(type)
code = quote
code = quote # todo
function $gs1($(args1...))
$e1
end
Expand Down Expand Up @@ -608,23 +608,15 @@ macro def(e)
esc(code)
end

macro __def(e) # todo: remove after in place test
ocp = gensym()
code = quote
@def $ocp $e false true # Force in place
$ocp
end
esc(code)
end

macro def(ocp, e, log = false, in_place = false) # todo: default to in_place
macro def(ocp, e, log = false)
try
p0 = ParsingInfo()
parse!(p0, ocp, e; log = false) # initial pass to get the dependencies (time and variable)
p = ParsingInfo()
p.t_dep = p0.t_dep
p.v = p0.v
code = parse!(p, ocp, e; log = log)
in_place = False # todo: change to True for in place
init = @match (__t_dep(p), __v_dep(p)) begin
(false, false) => :($ocp = __OCPModel(; in_place = $in_place))
(true, false) => :($ocp = __OCPModel(autonomous = false; in_place = $in_place))
Expand Down
8 changes: 4 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,15 @@ __view(x::AbstractVector, rg::AbstractRange) = view(x, rg) # Allows StepRange
$(TYPEDSIGNATURES)
Tranform in place function to out of place. Pass the result size and type (default = `Float64`).
Return a scalar when the result has size one.
Return a scalar when the result has size one. If `f!` is `nothing`, return `nothing`.
"""
function to_out_of_place(f!, n; T = Float64)
function f(x...)
function f(args...; kwargs...)
r = zeros(T, n)
f!(r, x...)
f!(r, args...; kwargs...)
return n == 1 ? r[1] : r
end
return f
return isnothing(f!) ? nothing : f
end

# Adapt getters to test in place
Expand Down
8 changes: 8 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,13 @@ function test_utils()
end
@test CTBase.to_out_of_place(f2!, 1; T = Int32)(1, 2) == 3

function f3!(r, x; y = 1)
r[:] .= x + y
return nothing
end
@test CTBase.to_out_of_place(f3!, 1; T = Int32)(1; y = 2) == 3

@test isnothing( CTBase.to_out_of_place(nothing, 1) )

end
end

0 comments on commit 57c8fa5

Please sign in to comment.