diff --git a/src/onepass.jl b/src/onepass.jl index 324a32e..85eaf50 100644 --- a/src/onepass.jl +++ b/src/onepass.jl @@ -356,7 +356,7 @@ p_constraint!(p, ocp, e1, e2, e3, label = gensym(); log = false) = begin ee2 = replace_call(ee2, p.x, p.tf, xf) args = [x0, xf] __v_dep(p) && push!(args, p.v) - quote + quote # todo function $gs($(args...)) $ee2 end @@ -374,7 +374,7 @@ p_constraint!(p, ocp, e1, e2, e3, label = gensym(); log = false) = begin __t_dep(p) && push!(args, p.t) push!(args, ut) __v_dep(p) && push!(args, p.v) - quote + quote # todo function $gs($(args...)) $ee2 end @@ -392,7 +392,7 @@ p_constraint!(p, ocp, e1, e2, e3, label = gensym(); log = false) = begin __t_dep(p) && push!(args, p.t) push!(args, xt) __v_dep(p) && push!(args, p.v) - quote + quote # todo function $gs($(args...)) $ee2 end @@ -404,7 +404,7 @@ p_constraint!(p, ocp, e1, e2, e3, label = gensym(); log = false) = begin :variable_fun => begin gs = gensym() args = [p.v] - quote + quote # todo function $gs($(args...)) $e2 end @@ -421,7 +421,7 @@ p_constraint!(p, ocp, e1, e2, e3, label = gensym(); log = false) = begin __t_dep(p) && push!(args, p.t) push!(args, xt, ut) __v_dep(p) && push!(args, p.v) - quote + quote # todo function $gs($(args...)) $ee2 end @@ -448,7 +448,7 @@ p_dynamics!(p, ocp, x, t, e, label = nothing; log = false) = begin p.t_dep = p.t_dep || has(e, t) gs = gensym() args = [ ]; __t_dep(p) && push!(args, p.t); push!(args, xt, ut); __v_dep(p) && push!(args, p.v) - code = quote + code = quote # todo function $gs($(args...)) $e end @@ -469,7 +469,7 @@ p_lagrange!(p, ocp, e, type; log = false) = begin ttype = QuoteNode(type) gs = gensym() args = [ ]; __t_dep(p) && push!(args, p.t); push!(args, xt, ut); __v_dep(p) && push!(args, p.v) - code = quote + code = quote # todo function $gs($(args...)) $e end @@ -495,7 +495,7 @@ p_mayer!(p, ocp, e, type; log = false) = begin e = replace_call(e, p.x, p.tf, xf) ttype = QuoteNode(type) args = [ x0, xf ]; __v_dep(p) && push!(args, p.v) - code = quote + code = quote # todo function $gs($(args...)) $e end @@ -528,7 +528,7 @@ p_bolza!(p, ocp, e1, e2, type; log = false) = begin push!(args2, xt, ut) __v_dep(p) && push!(args2, p.v) ttype = QuoteNode(type) - code = quote + code = quote # todo function $gs1($(args1...)) $e1 end @@ -608,16 +608,7 @@ macro def(e) esc(code) end -macro __def(e) # todo: remove after in place test - ocp = gensym() - code = quote - @def $ocp $e false true # Force in place - $ocp - end - esc(code) -end - -macro def(ocp, e, log = false, in_place = false) # todo: default to in_place +macro def(ocp, e, log = false) try p0 = ParsingInfo() parse!(p0, ocp, e; log = false) # initial pass to get the dependencies (time and variable) @@ -625,6 +616,7 @@ macro def(ocp, e, log = false, in_place = false) # todo: default to in_place p.t_dep = p0.t_dep p.v = p0.v code = parse!(p, ocp, e; log = log) + in_place = False # todo: change to True for in place init = @match (__t_dep(p), __v_dep(p)) begin (false, false) => :($ocp = __OCPModel(; in_place = $in_place)) (true, false) => :($ocp = __OCPModel(autonomous = false; in_place = $in_place)) diff --git a/src/utils.jl b/src/utils.jl index 8487e24..bdf4096 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -252,15 +252,15 @@ __view(x::AbstractVector, rg::AbstractRange) = view(x, rg) # Allows StepRange $(TYPEDSIGNATURES) Tranform in place function to out of place. Pass the result size and type (default = `Float64`). -Return a scalar when the result has size one. +Return a scalar when the result has size one. If `f!` is `nothing`, return `nothing`. """ function to_out_of_place(f!, n; T = Float64) - function f(x...) + function f(args...; kwargs...) r = zeros(T, n) - f!(r, x...) + f!(r, args...; kwargs...) return n == 1 ? r[1] : r end - return f + return isnothing(f!) ? nothing : f end # Adapt getters to test in place diff --git a/test/test_utils.jl b/test/test_utils.jl index cc917e7..2d34ffa 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -94,5 +94,13 @@ function test_utils() end @test CTBase.to_out_of_place(f2!, 1; T = Int32)(1, 2) == 3 + function f3!(r, x; y = 1) + r[:] .= x + y + return nothing + end + @test CTBase.to_out_of_place(f3!, 1; T = Int32)(1; y = 2) == 3 + + @test isnothing( CTBase.to_out_of_place(nothing, 1) ) + end end \ No newline at end of file