Skip to content

Commit

Permalink
Dont infer what is hard to infer
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 20, 2024
1 parent 01a789b commit 0b4eeed
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 79 deletions.
35 changes: 4 additions & 31 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ every solve call.
- `nlls`: Specify that the BVP is a nonlinear least squares problem. Use `Val(true)` or
`Val(false)` for type stability. By default this is automatically inferred based on the
size of the input and outputs, however this is type unstable for any array type that
doesn't store array size as part of type information. Note that if problem is inplace
and `bcresid_prototype` in BVPFunction is not specified, then `nlls` is assumed to be
`false`.
doesn't store array size as part of type information. If we can't reliably infer this,
we set it to `Nothing`. Downstreams solvers must be setup to deal with this case.
"""
struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
AbstractBVProblem{uType, tType, isinplace, nlls}
Expand All @@ -124,7 +123,7 @@ struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
kwargs::K

@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, u0, tspan,
p = NullParameters(); problem_type=nothing, nlls=nothing,
p = NullParameters(); problem_type = nothing, nlls = nothing,
kwargs...) where {iip, TP}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
Expand Down Expand Up @@ -156,11 +155,7 @@ struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
if f.bcresid_prototype !== nothing
_nlls = length(f.bcresid_prototype) != length(_u0)
else
if iip
_nlls = false # Should we assume `true` instead?
else
_nlls = length(f.bc(FakeSolutionObject(u0), p, tspan)) != length(_u0)
end
_nlls = Nothing # Cannot reliably infer
end
end
else
Expand All @@ -176,28 +171,6 @@ struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
end
end

"""
isnonlinearleastsquares(prob::BVProblem)
Returns `true` if the underlying problem is a nonlinear least squares problem.
"""
@inline function isnonlinearleastsquares(::BVProblem{uType,
tType, iip, nlls}) where {uType, tType, iip, nlls}
return nlls
end

struct FakeSolutionObject{U}
u::U
end

(sol::FakeSolutionObject)(t) = sol.u
Base.length(::FakeSolutionObject) = 1
Base.firstindex(::FakeSolutionObject) = 1
Base.lastindex(::FakeSolutionObject) = 1
Base.getindex(sol::FakeSolutionObject, i::Int) = sol.u

TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2

function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
iip = isinplace(f, 4)
return BVProblem{iip}(BVPFunction{iip}(f, bc), u0, tspan, p; kwargs...)
Expand Down
25 changes: 14 additions & 11 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ function isrecompile(prob::ODEProblem{iip}) where {iip}
(prob.f isa ODEFunction) ? !isfunctionwrapper(prob.f.f) : true
end

function remake(prob::AbstractSciMLProblem; u0 = missing, p = missing, interpret_symbolicmap = true, kwargs...)
function remake(prob::AbstractSciMLProblem; u0 = missing,

Check warning on line 48 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L48

Added line #L48 was not covered by tests
p = missing, interpret_symbolicmap = true, kwargs...)
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
_remake_internal(prob; kwargs..., u0, p)
end
Expand All @@ -54,7 +55,8 @@ function remake(prob::AbstractNoiseProblem; kwargs...)
_remake_internal(prob; kwargs...)
end

function remake(prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, kwargs...)
function remake(

Check warning on line 58 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L58

Added line #L58 was not covered by tests
prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, kwargs...)
p = updated_p(prob, p; interpret_symbolicmap)
_remake_internal(prob; kwargs..., p)
end
Expand Down Expand Up @@ -129,8 +131,8 @@ end
Remake the given `BVProblem`.
"""
function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = missing,
u0 = missing, tspan = missing, p = missing, kwargs = missing, problem_type = missing,
_kwargs...) where {uType, tType, iip, nlls}
u0 = missing, tspan = missing, p = missing, kwargs = missing, problem_type = missing,
interpret_symbolicmap = true, _kwargs...) where {uType, tType, iip, nlls}
if tspan === missing
tspan = prob.tspan
end
Expand Down Expand Up @@ -169,10 +171,11 @@ function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = miss
end

if kwargs === missing
BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, nlls=Val(nlls), prob.kwargs...,
BVProblem{iip}(
_f, bc, u0, tspan, p; problem_type, nlls = Val(nlls), prob.kwargs...,
_kwargs...)
else
BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, nlls=Val(nlls), kwargs...)
BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, nlls = Val(nlls), kwargs...)

Check warning on line 178 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L178

Added line #L178 was not covered by tests
end
end

Expand Down Expand Up @@ -254,7 +257,6 @@ function remake(prob::OptimizationProblem;
kwargs = missing,
interpret_symbolicmap = true,
_kwargs...)

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap)
if f === missing
f = prob.f
Expand Down Expand Up @@ -393,10 +395,11 @@ function updated_p(prob, p; interpret_symbolicmap = true)
end
if eltype(p) <: Pair
if interpret_symbolicmap
has_sys(prob.f) || throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
has_sys(prob.f) ||
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
"with the `p` keyword argument as a vector of values (paying attention to" *
"parameter order) or pass `interpret_symbolicmap = false` as a keyword argument"))
else
return p
end
Expand Down
16 changes: 5 additions & 11 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2032,8 +2032,8 @@ For more details on this argument, see the ODEFunction documentation.
The fields of the BVPFunction type directly match the names of the inputs.
"""
struct BVPFunction{iip, specialize, twopoint, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP,
JP, BCJP, BCRP, SP, TW, TWt, TPJ, O, TCV, BCTCV,
SYS} <: AbstractBVPFunction{iip, twopoint}
JP, BCJP, BCRP, SP, TW, TWt, TPJ, O, TCV, BCTCV,
SYS} <: AbstractBVPFunction{iip, twopoint}
f::F
bc::BF
mass_matrix::TMM
Expand Down Expand Up @@ -2337,11 +2337,7 @@ function ODEFunction{iip, specialize}(f;
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
<<<<<<< HEAD
observed, _colorvec, sys, initializeprob, initializeprobmap)
=======
observed, _colorvec, sys)
>>>>>>> 65d8f530 (Add a nlls trait to BVProblem)
else
ODEFunction{iip, specialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2354,11 +2350,7 @@ function ODEFunction{iip, specialize}(f;
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
<<<<<<< HEAD
observed, _colorvec, sys, initializeprob, initializeprobmap)
=======
observed, _colorvec, sys)
>>>>>>> 65d8f530 (Add a nlls trait to BVProblem)
end
end

Expand Down Expand Up @@ -3921,7 +3913,9 @@ end
function sys_or_symbolcache(sys, syms, paramsyms, indepsym = nothing)
if sys === nothing &&
(syms !== nothing || paramsyms !== nothing || indepsym !== nothing)
Base.depwarn("The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead.", :syms)
Base.depwarn(

Check warning on line 3916 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L3916

Added line #L3916 was not covered by tests
"The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead.",
:syms)
sys = SymbolCache(syms, paramsyms, indepsym)
end
return sys
Expand Down
4 changes: 2 additions & 2 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ end

# FIXME: Remove the defaults for resid and original on a breaking release
function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense,
tslocation, stats, alg_choice, retcode, resid = nothing,
original = nothing) where {T, N}
tslocation, stats, alg_choice, retcode, resid = nothing,
original = nothing) where {T, N}
return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t),
typeof(k), typeof(prob), typeof(alg), typeof(interp),
typeof(stats), typeof(alg_choice), typeof(resid),
Expand Down
22 changes: 11 additions & 11 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ sprob3 = remake(sprob; u0 = [x => 3.0], p = [σ => 30.0]) # partial update

# NonlinearProblem
@named ns = NonlinearSystem(
[0 ~ σ*(y-x),
0 ~ x*-z)-y,
0 ~ x*y - β*z],
[x,y,z],
[σ,ρ,β]
[0 ~ σ * (y - x),
0 ~ x *- z) - y,
0 ~ x * y - β * z],
[x, y, z],
[σ, ρ, β]
)
ns = complete(ns)
nlprob = NonlinearProblem(ns, u0, p)
Expand All @@ -131,14 +131,14 @@ nlprob3 = remake(nlprob; u0 = [x => 3.0], p = [σ => 30.0]) # partial update

@parameters β γ
@variables S(t) I(t) R(t)
rate₁ = β*S*I
rate₁ = β * S * I
affect₁ = [S ~ S - 1, I ~ I + 1]
rate₂ = γ*I
rate₂ = γ * I
affect₂ = [I ~ I - 1, R ~ R + 1]
j₁ = ConstantRateJump(rate₁,affect₁)
j₂ = ConstantRateJump(rate₂,affect₂)
j₃ = MassActionJump(2*β+γ, [R => 1], [S => 1, R => -1])
@named js = JumpSystem([j₁,j₂,j₃], t, [S,I,R], [β,γ])
j₁ = ConstantRateJump(rate₁, affect₁)
j₂ = ConstantRateJump(rate₂, affect₂)
j₃ = MassActionJump(2 * β + γ, [R => 1], [S => 1, R => -1])
@named js = JumpSystem([j₁, j₂, j₃], t, [S, I, R], [β, γ])
js = complete(js)
u₀map = [S => 999, I => 1, R => 0.0]
parammap ==> 0.1 / 1000, γ => 0.01]
Expand Down
23 changes: 11 additions & 12 deletions test/remake_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ using SciMLBase
using SymbolicIndexingInterface

# ODE
function lorenz!(du,u,p,t)
du[1] = p[1] * (u[2]-u[1])
du[2] = u[1]*(p[2]-u[3]) - u[2]
du[3] = u[1]*u[2] - p[3]*u[3]
function lorenz!(du, u, p, t)
du[1] = p[1] * (u[2] - u[1])
du[2] = u[1] * (p[2] - u[3]) - u[2]
du[3] = u[1] * u[2] - p[3] * u[3]
end
u0 = [1.0;0.0;0.0]
tspan = (0.0,100.0)
p = [10.0, 28.0, 8/3]
u0 = [1.0; 0.0; 0.0]
tspan = (0.0, 100.0)
p = [10.0, 28.0, 8 / 3]
fn = ODEFunction(lorenz!; sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t))
prob = ODEProblem(fn, u0, tspan, p)

Expand All @@ -21,7 +21,7 @@ prob = ODEProblem(fn, u0, tspan, p)
@test remake(prob; p = [:a => 11.0, :c => 13.0, :b => 12.0]).p == [11.0, 12.0, 13.0]
@test remake(prob; p = (11.0, 12.0, 13)).p == (11.0, 12.0, 13)
@test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, 0.0, 0.0]
@test remake(prob; p = [:b => 11.0]).p == [10.0, 11.0, 8/3]
@test remake(prob; p = [:b => 11.0]).p == [10.0, 11.0, 8 / 3]

# BVP
g = 9.81
Expand All @@ -39,7 +39,7 @@ function bc1!(residual, u, p, t)
end
u0 = [pi / 2, pi / 2]
p = [g, L]
fn = BVPFunction(simplependulum!, bc1!; sys = SymbolCache([:x, :y], [:a, :b], :t) )
fn = BVPFunction(simplependulum!, bc1!; sys = SymbolCache([:x, :y], [:a, :b], :t))
prob = BVProblem(fn, u0, tspan, p)

@test remake(prob).u0 == u0
Expand Down Expand Up @@ -80,11 +80,11 @@ prob = SDEProblem(fn, u0, tspan, p)

# OptimizationProblem
function loss(u, p)
return (p[1] - u[1]) ^ 2 + p[2] * (u[2] - u[1] ^ 2) ^ 2
return (p[1] - u[1])^2 + p[2] * (u[2] - u[1]^2)^2
end
u0 = [1.0, 2.0]
p = [1.0, 100.0]
fn = OptimizationFunction(loss; sys = SymbolCache([:x, :y], [:a, :b], :t) )
fn = OptimizationFunction(loss; sys = SymbolCache([:x, :y], [:a, :b], :t))
prob = OptimizationProblem(fn, u0, p)
@test remake(prob).u0 == u0
@test remake(prob).p == p
Expand Down Expand Up @@ -128,4 +128,3 @@ prob = NonlinearLeastSquaresProblem(fn, u0, p)
@test remake(prob; p = (11.0, 12.0, 13)).p == (11.0, 12.0, 13)
@test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, 0.0, 0.0]
@test remake(prob; p = [:b => 11.0]).p == [10.0, 11.0, 26.0]

2 changes: 1 addition & 1 deletion test/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
@test !SciMLBase.Tables.isrowtable(SciMLBase.QuadratureSolution)
@test !SciMLBase.Tables.isrowtable(SciMLBase.OptimizationSolution)

@variables x(t)=1
@variables x(t) = 1
eqs = [D(x) ~ -x]
@named sys = ODESystem(eqs, t)
sys = complete(sys)
Expand Down

0 comments on commit 0b4eeed

Please sign in to comment.