From 4b7b0b1bc3c22e1fa03ee63dfd001fb07565406d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 20 Dec 2023 19:44:49 -0500 Subject: [PATCH 1/7] Add a nlls trait to BVProblem --- src/SciMLBase.jl | 2 +- src/problems/bvp_problems.jl | 70 ++++++++++++++++++++++++++++++----- src/scimlfunctions.jl | 72 ++++++++++++++++++++++++------------ 3 files changed, 109 insertions(+), 35 deletions(-) diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 0736503d1..122d6208d 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -180,7 +180,7 @@ $(TYPEDEF) Base for types which define BVP problems. """ -abstract type AbstractBVProblem{uType, tType, isinplace} <: +abstract type AbstractBVProblem{uType, tType, isinplace, nlls} <: AbstractODEProblem{uType, tType, isinplace} end """ diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index 97bb0f9ba..351cf5a44 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -11,7 +11,7 @@ struct TwoPointBVProblem{iip} end # The iip is needed to make type stable constr @doc doc""" Defines an BVP problem. -Documentation Page: https://docs.sciml.ai/DiffEqDocs/stable/types/bvp_types/ +Documentation Page: [https://docs.sciml.ai/DiffEqDocs/stable/types/bvp_types/](https://docs.sciml.ai/DiffEqDocs/stable/types/bvp_types/) ## Mathematical Specification of a BVP Problem @@ -41,16 +41,16 @@ u(t_f) = b ### Constructors ```julia -TwoPointBVProblem{isinplace}(f,bc,u0,tspan,p=NullParameters();kwargs...) -BVProblem{isinplace}(f,bc,u0,tspan,p=NullParameters();kwargs...) +TwoPointBVProblem{isinplace}(f, bc, u0, tspan, p=NullParameters(); kwargs...) +BVProblem{isinplace}(f, bc, u0, tspan, p=NullParameters(); kwargs...) ``` -or if we have an initial guess function `initialGuess(t)` for the given BVP, +or if we have an initial guess function `initialGuess(p, t)` for the given BVP, we can pass the initial guess to the problem constructors: ```julia -TwoPointBVProblem{isinplace}(f,bc,initialGuess,tspan,p=NullParameters();kwargs...) -BVProblem{isinplace}(f,bc,initialGuess,tspan,p=NullParameters();kwargs...) +TwoPointBVProblem{isinplace}(f, bc, initialGuess, tspan, p=NullParameters(); kwargs...) +BVProblem{isinplace}(f, bc, initialGuess, tspan, p=NullParameters(); kwargs...) ``` For any BVP problem type, `bc` must be inplace if `f` is inplace. Otherwise it must be @@ -104,9 +104,18 @@ every solve call. * `tspan`: The timespan for the problem. * `p`: The parameters for the problem. Defaults to `NullParameters` * `kwargs`: The keyword arguments passed onto the solves. + +### Special Keyword Arguments + +- `nlls`: Specify that the BVP is a nonlinear least squares problem. Use `Val(true)` or + `Val(false)` for type stability. By default this is automatically inferred based on the + size of the input and outputs, however this is type unstable for any array type that + doesn't store array size as part of type information. Note that if problem is inplace + and `bcresid_prototype` in BVPFunction is not specified, then `nlls` is assumed to be + `false`. """ -struct BVProblem{uType, tType, isinplace, P, F, PT, K} <: - AbstractBVProblem{uType, tType, isinplace} +struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <: + AbstractBVProblem{uType, tType, isinplace, nlls} f::F u0::uType tspan::tType @@ -115,18 +124,50 @@ struct BVProblem{uType, tType, isinplace, P, F, PT, K} <: kwargs::K @add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, u0, tspan, - p = NullParameters(); problem_type = nothing, kwargs...) where {iip, TP} + p = NullParameters(); problem_type=nothing, nlls=nothing, + kwargs...) where {iip, TP} _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) warn_paramtype(p) prob_type = TP ? TwoPointBVProblem{iip}() : StandardBVProblem() + # Needed to ensure that `problem_type` doesn't get passed in kwargs if problem_type === nothing problem_type = prob_type else @assert prob_type===problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use." end - return new{typeof(_u0), typeof(_tspan), iip, typeof(p), typeof(f), + + if nlls === nothing + # Try to infer it + if problem_type isa TwoPointBVProblem + if f.bcresid_prototype !== nothing + l1, l2 = length(f.bcresid_prototype[1]), length(f.bcresid_prototype[2]) + _nlls = l1 + l2 != length(_u0) + else + # iip without bcresid_prototype is not possible + if !iip + l1 = length(f.bc[1](u0, p)) + l2 = length(f.bc[2](u0, p)) + _nlls = l1 + l2 != length(_u0) + end + end + else + if f.bcresid_prototype !== nothing + _nlls = length(f.bcresid_prototype) != length(_u0) + else + if iip + _nlls = false # Should we assume `true` instead? + else + _nlls = length(f.bc(FFakeSolutionObject(u0), p, tspan)) != length(_u0) + end + end + end + else + _nlls = _unwrap_val(nlls) + end + + return new{typeof(_u0), typeof(_tspan), iip, _nlls, typeof(p), typeof(f), typeof(problem_type), typeof(kwargs)}(f, _u0, _tspan, p, problem_type, kwargs) end @@ -135,6 +176,15 @@ struct BVProblem{uType, tType, isinplace, P, F, PT, K} <: end end +struct FakeSolutionObject{U} + u::U +end + +(sol::FakeSolutionObject)(t) = sol.u +Base.getindex(sol::FakeSolutionObject, i::Int) = sol.u + +TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2 + function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...) iip = isinplace(f, 4) return BVProblem{iip}(BVPFunction{iip}(f, bc), u0, tspan, p; kwargs...) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 081fb6668..567f53abf 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -1929,13 +1929,20 @@ $(TYPEDEF) A representation of a BVP function `f`, defined by: ```math -\frac{du}{dt}=f(u,p,t) +\frac{du}{dt} = f(u, p, t) ``` and the constraints: ```math -\frac{du}{dt}=g(u,p,t) +g(u, p, t) = 0 +``` + +If the size of `g(u, p, t)` is different from the size of `u`, then the constraints are +interpreted as a least squares problem, i.e. the objective function is: + +```math +\min_{u} \| g_i(u, p, t) \|^2 ``` and all of its related functions, such as the Jacobian of `f`, its gradient @@ -1943,21 +1950,25 @@ with respect to time, and more. For all cases, `u0` is the initial condition, `p` are the parameters, and `t` is the independent variable. ```julia -BVPFunction{iip,specialize}(f, bc; - mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I, - analytic = __has_analytic(f) ? f.analytic : nothing, - tgrad= __has_tgrad(f) ? f.tgrad : nothing, - jac = __has_jac(f) ? f.jac : nothing, - bcjac = __has_jac(bc) ? bc.jac : nothing, - jvp = __has_jvp(f) ? f.jvp : nothing, - vjp = __has_vjp(f) ? f.vjp : nothing, - jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, - bcjac_prototype = __has_jac_prototype(bc) ? bc.jac_prototype : nothing, - sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype, - paramjac = __has_paramjac(f) ? f.paramjac : nothing, - colorvec = __has_colorvec(f) ? f.colorvec : nothing, - bccolorvec = __has_colorvec(f) ? bc.colorvec : nothing, - sys = __has_sys(f) ? f.sys : nothing) +BVPFunction{iip, specialize}(f, bc; + mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I, + analytic = __has_analytic(f) ? f.analytic : nothing, + tgrad= __has_tgrad(f) ? f.tgrad : nothing, + jac = __has_jac(f) ? f.jac : nothing, + bcjac = __has_jac(bc) ? bc.jac : nothing, + jvp = __has_jvp(f) ? f.jvp : nothing, + vjp = __has_vjp(f) ? f.vjp : nothing, + jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, + bcjac_prototype = __has_jac_prototype(bc) ? bc.jac_prototype : nothing, + sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype, + paramjac = __has_paramjac(f) ? f.paramjac : nothing, + syms = nothing, + indepsym= nothing, + paramsyms = nothing, + colorvec = __has_colorvec(f) ? f.colorvec : nothing, + bccolorvec = __has_colorvec(f) ? bc.colorvec : nothing, + sys = __has_sys(f) ? f.sys : nothing, + twopoint::Union{Val, Bool} = Val(false) ``` Note that both the function `f` and boundary condition `bc` are required. `f` should @@ -1985,7 +1996,7 @@ the usage of `f` and `bc`. These include: sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian. The default is `nothing`, which means a dense Jacobian. - `bcjac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example, - if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used + if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used as the prototype and integrators will specialize on this structure where possible. Non-structured sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian. The default is `nothing`, which means a dense Jacobian. @@ -2003,6 +2014,11 @@ the usage of `f` and `bc`. These include: internally computed on demand when required. The cost of this operation is highly dependent on the sparsity pattern. +Additional Options: + +- `twopoint`: Specify that the BVP is a two-point boundary value problem. Use `Val(true)` or + `Val(false)` for type stability. + ## iip: In-Place vs Out-Of-Place For more details on this argument, see the ODEFunction documentation. @@ -2016,8 +2032,8 @@ For more details on this argument, see the ODEFunction documentation. The fields of the BVPFunction type directly match the names of the inputs. """ struct BVPFunction{iip, specialize, twopoint, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, - JP, BCJP, BCRP, SP, TW, TWt, TPJ, O, TCV, BCTCV, - SYS} <: AbstractBVPFunction{iip, twopoint} + JP, BCJP, BCRP, SP, TW, TWt, TPJ, O, TCV, BCTCV, + SYS} <: AbstractBVPFunction{iip, twopoint} f::F bc::BF mass_matrix::TMM @@ -2321,7 +2337,11 @@ function ODEFunction{iip, specialize}(f; typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, +<<<<<<< HEAD observed, _colorvec, sys, initializeprob, initializeprobmap) +======= + observed, _colorvec, sys) +>>>>>>> 65d8f530 (Add a nlls trait to BVProblem) else ODEFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2334,7 +2354,11 @@ function ODEFunction{iip, specialize}(f; typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, +<<<<<<< HEAD observed, _colorvec, sys, initializeprob, initializeprobmap) +======= + observed, _colorvec, sys) +>>>>>>> 65d8f530 (Add a nlls trait to BVProblem) end end @@ -3801,7 +3825,7 @@ function BVPFunction{iip, specialize, twopoint}(f, bc; _f = prepare_function(f) - sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym) + sys = something(sys, SymbolCache(syms, paramsyms, indepsym)) if specialize === NoSpecialize BVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any, @@ -3813,9 +3837,9 @@ function BVPFunction{iip, specialize, twopoint}(f, bc; sparsity, Wfact, Wfact_t, paramjac, observed, _colorvec, _bccolorvec, sys) else - BVPFunction{iip, specialize, twopoint, typeof(_f), typeof(bc), typeof(mass_matrix), - typeof(analytic), typeof(tgrad), typeof(jac), typeof(bcjac), typeof(jvp), - typeof(vjp), typeof(jac_prototype), + BVPFunction{iip, specialize, twopoint, typeof(_f), typeof(bc), + typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), + typeof(bcjac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(bcjac_prototype), typeof(bcresid_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(_colorvec), typeof(_bccolorvec), typeof(sys)}( From d55339bde66fa81cc751af660da42c7dad377cfd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 21 Dec 2023 10:09:39 -0500 Subject: [PATCH 2/7] Update src/problems/bvp_problems.jl Co-authored-by: Qingyu Qu <52615090+ErikQQY@users.noreply.github.com> --- src/problems/bvp_problems.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index 351cf5a44..0ab418630 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -159,7 +159,7 @@ struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <: if iip _nlls = false # Should we assume `true` instead? else - _nlls = length(f.bc(FFakeSolutionObject(u0), p, tspan)) != length(_u0) + _nlls = length(f.bc(FakeSolutionObject(u0), p, tspan)) != length(_u0) end end end From 94f607fb5f91deba3418084f940e0a7c41424699 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 21 Dec 2023 10:12:41 -0500 Subject: [PATCH 3/7] remake bvp correctly for nlls --- src/remake.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index c0f05dc8b..66b5c3ec8 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -128,16 +128,15 @@ end Remake the given `BVProblem`. """ -function remake(prob::BVProblem; f = missing, bc = missing, u0 = missing, tspan = missing, - p = missing, kwargs = missing, problem_type = missing, interpret_symbolicmap = true, _kwargs...) +function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = missing, + u0 = missing, tspan = missing, p = missing, kwargs = missing, problem_type = missing, + _kwargs...) where {uType, tType, iip, nlls} if tspan === missing tspan = prob.tspan end u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) - iip = isinplace(prob) - if problem_type === missing problem_type = prob.problem_type end @@ -170,9 +169,10 @@ function remake(prob::BVProblem; f = missing, bc = missing, u0 = missing, tspan end if kwargs === missing - BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, prob.kwargs..., _kwargs...) + BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, nlls=Val(nlls), prob.kwargs..., + _kwargs...) else - BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, kwargs...) + BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, nlls=Val(nlls), kwargs...) end end From ed752b19044b1055cadd7481bab6bce2f58a1748 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 21 Dec 2023 11:41:03 -0500 Subject: [PATCH 4/7] Add original and resid to ODEsolution (used for BVPs) --- src/problems/bvp_problems.jl | 3 +++ src/solutions/ode_solutions.jl | 41 ++++++++++++++++++++++++---------- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index 0ab418630..f470d7070 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -181,6 +181,9 @@ struct FakeSolutionObject{U} end (sol::FakeSolutionObject)(t) = sol.u +Base.length(::FakeSolutionObject) = 1 +Base.firstindex(::FakeSolutionObject) = 1 +Base.lastindex(::FakeSolutionObject) = 1 Base.getindex(sol::FakeSolutionObject, i::Int) = sol.u TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2 diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index f0cf4a3c0..e091fa98e 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -106,7 +106,7 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/ [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes). """ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, S, - AC <: Union{Nothing, Vector{Int}}} <: + AC <: Union{Nothing, Vector{Int}}, R, O} <: AbstractODESolution{T, N, uType} u::uType u_analytic::uType2 @@ -121,6 +121,8 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, S, stats::S alg_choice::AC retcode::ReturnCode.T + resid::R + original::O end Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Symbol) @@ -133,13 +135,15 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Sy return getfield(x, s) end +# FIXME: Remove the defaults for resid and original on a breaking release function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense, - tslocation, stats, alg_choice, retcode) where {T, N} + tslocation, stats, alg_choice, retcode, resid = nothing, + original = nothing) where {T, N} return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t), typeof(k), typeof(prob), typeof(alg), typeof(interp), - typeof(stats), - typeof(alg_choice)}(u, u_analytic, errors, t, k, prob, alg, interp, - dense, tslocation, stats, alg_choice, retcode) + typeof(stats), typeof(alg_choice), typeof(resid), + typeof(original)}(u, u_analytic, errors, t, k, prob, alg, interp, + dense, tslocation, stats, alg_choice, retcode, resid, original) end function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing, @@ -232,6 +236,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, alg_choice = nothing, interp = LinearInterpolation(t, u), retcode = ReturnCode.Default, destats = missing, stats = nothing, + resid = nothing, original = nothing, kwargs...) T = eltype(eltype(u)) @@ -271,7 +276,9 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, 0, stats, alg_choice, - retcode) + retcode, + resid, + original) if calculate_error calculate_solution_errors!(sol; timeseries_errors = timeseries_errors, dense_errors = dense_errors) @@ -289,7 +296,9 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, 0, stats, alg_choice, - retcode) + retcode, + resid, + original) end end @@ -346,7 +355,9 @@ function build_solution(sol::ODESolution{T, N}, u_analytic, errors) where {T, N} sol.tslocation, sol.stats, sol.alg_choice, - sol.retcode) + sol.retcode, + sol.resid, + sol.original) end function solution_new_retcode(sol::ODESolution{T, N}, retcode) where {T, N} @@ -362,7 +373,9 @@ function solution_new_retcode(sol::ODESolution{T, N}, retcode) where {T, N} sol.tslocation, sol.stats, sol.alg_choice, - retcode) + retcode, + sol.resid, + sol.original) end function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N} @@ -378,7 +391,9 @@ function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N tslocation, sol.stats, sol.alg_choice, - sol.retcode) + sol.retcode, + sol.resid, + sol.original) end function solution_slice(sol::ODESolution{T, N}, I) where {T, N} @@ -394,7 +409,9 @@ function solution_slice(sol::ODESolution{T, N}, I) where {T, N} sol.tslocation, sol.stats, sol.alg_choice, - sol.retcode) + sol.retcode, + sol.resid, + sol.original) end function sensitivity_solution(sol::ODESolution, u, t) @@ -414,5 +431,5 @@ function sensitivity_solution(sol::ODESolution, u, t) sol.k, sol.prob, sol.alg, interp, sol.dense, sol.tslocation, - sol.stats, sol.alg_choice, sol.retcode) + sol.stats, sol.alg_choice, sol.retcode, sol.resid, sol.original) end From 01a789bdfded2de361c40fde8fc65ac502de64be Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 21 Dec 2023 14:48:46 -0500 Subject: [PATCH 5/7] Add a function to check nlls trait --- src/problems/bvp_problems.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index f470d7070..01fdd2a6c 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -176,6 +176,16 @@ struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <: end end +""" + isnonlinearleastsquares(prob::BVProblem) + +Returns `true` if the underlying problem is a nonlinear least squares problem. +""" +@inline function isnonlinearleastsquares(::BVProblem{uType, + tType, iip, nlls}) where {uType, tType, iip, nlls} + return nlls +end + struct FakeSolutionObject{U} u::U end From 0b4eeedf88fd36b301be299a327b535c468f278e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 21 Dec 2023 17:32:48 -0500 Subject: [PATCH 6/7] Dont infer what is hard to infer --- src/problems/bvp_problems.jl | 35 +++-------------------- src/remake.jl | 25 +++++++++------- src/scimlfunctions.jl | 16 ++++------- src/solutions/ode_solutions.jl | 4 +-- test/downstream/modelingtoolkit_remake.jl | 22 +++++++------- test/remake_tests.jl | 23 +++++++-------- test/traits.jl | 2 +- 7 files changed, 48 insertions(+), 79 deletions(-) diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index 01fdd2a6c..e84073adf 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -110,9 +110,8 @@ every solve call. - `nlls`: Specify that the BVP is a nonlinear least squares problem. Use `Val(true)` or `Val(false)` for type stability. By default this is automatically inferred based on the size of the input and outputs, however this is type unstable for any array type that - doesn't store array size as part of type information. Note that if problem is inplace - and `bcresid_prototype` in BVPFunction is not specified, then `nlls` is assumed to be - `false`. + doesn't store array size as part of type information. If we can't reliably infer this, + we set it to `Nothing`. Downstreams solvers must be setup to deal with this case. """ struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <: AbstractBVProblem{uType, tType, isinplace, nlls} @@ -124,7 +123,7 @@ struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <: kwargs::K @add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, u0, tspan, - p = NullParameters(); problem_type=nothing, nlls=nothing, + p = NullParameters(); problem_type = nothing, nlls = nothing, kwargs...) where {iip, TP} _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) @@ -156,11 +155,7 @@ struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <: if f.bcresid_prototype !== nothing _nlls = length(f.bcresid_prototype) != length(_u0) else - if iip - _nlls = false # Should we assume `true` instead? - else - _nlls = length(f.bc(FakeSolutionObject(u0), p, tspan)) != length(_u0) - end + _nlls = Nothing # Cannot reliably infer end end else @@ -176,28 +171,6 @@ struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <: end end -""" - isnonlinearleastsquares(prob::BVProblem) - -Returns `true` if the underlying problem is a nonlinear least squares problem. -""" -@inline function isnonlinearleastsquares(::BVProblem{uType, - tType, iip, nlls}) where {uType, tType, iip, nlls} - return nlls -end - -struct FakeSolutionObject{U} - u::U -end - -(sol::FakeSolutionObject)(t) = sol.u -Base.length(::FakeSolutionObject) = 1 -Base.firstindex(::FakeSolutionObject) = 1 -Base.lastindex(::FakeSolutionObject) = 1 -Base.getindex(sol::FakeSolutionObject, i::Int) = sol.u - -TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2 - function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...) iip = isinplace(f, 4) return BVProblem{iip}(BVPFunction{iip}(f, bc), u0, tspan, p; kwargs...) diff --git a/src/remake.jl b/src/remake.jl index 66b5c3ec8..ff64ad592 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -45,7 +45,8 @@ function isrecompile(prob::ODEProblem{iip}) where {iip} (prob.f isa ODEFunction) ? !isfunctionwrapper(prob.f.f) : true end -function remake(prob::AbstractSciMLProblem; u0 = missing, p = missing, interpret_symbolicmap = true, kwargs...) +function remake(prob::AbstractSciMLProblem; u0 = missing, + p = missing, interpret_symbolicmap = true, kwargs...) u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) _remake_internal(prob; kwargs..., u0, p) end @@ -54,7 +55,8 @@ function remake(prob::AbstractNoiseProblem; kwargs...) _remake_internal(prob; kwargs...) end -function remake(prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, kwargs...) +function remake( + prob::AbstractIntegralProblem; p = missing, interpret_symbolicmap = true, kwargs...) p = updated_p(prob, p; interpret_symbolicmap) _remake_internal(prob; kwargs..., p) end @@ -129,8 +131,8 @@ end Remake the given `BVProblem`. """ function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = missing, - u0 = missing, tspan = missing, p = missing, kwargs = missing, problem_type = missing, - _kwargs...) where {uType, tType, iip, nlls} + u0 = missing, tspan = missing, p = missing, kwargs = missing, problem_type = missing, + interpret_symbolicmap = true, _kwargs...) where {uType, tType, iip, nlls} if tspan === missing tspan = prob.tspan end @@ -169,10 +171,11 @@ function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = miss end if kwargs === missing - BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, nlls=Val(nlls), prob.kwargs..., + BVProblem{iip}( + _f, bc, u0, tspan, p; problem_type, nlls = Val(nlls), prob.kwargs..., _kwargs...) else - BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, nlls=Val(nlls), kwargs...) + BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, nlls = Val(nlls), kwargs...) end end @@ -254,7 +257,6 @@ function remake(prob::OptimizationProblem; kwargs = missing, interpret_symbolicmap = true, _kwargs...) - u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap) if f === missing f = prob.f @@ -393,10 +395,11 @@ function updated_p(prob, p; interpret_symbolicmap = true) end if eltype(p) <: Pair if interpret_symbolicmap - has_sys(prob.f) || throw(ArgumentError("This problem does not support symbolic maps with " * - "`remake`, i.e. it does not have a symbolic origin. Please use `remake`" * - "with the `p` keyword argument as a vector of values (paying attention to" * - "parameter order) or pass `interpret_symbolicmap = false` as a keyword argument")) + has_sys(prob.f) || + throw(ArgumentError("This problem does not support symbolic maps with " * + "`remake`, i.e. it does not have a symbolic origin. Please use `remake`" * + "with the `p` keyword argument as a vector of values (paying attention to" * + "parameter order) or pass `interpret_symbolicmap = false` as a keyword argument")) else return p end diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 567f53abf..0df138559 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2032,8 +2032,8 @@ For more details on this argument, see the ODEFunction documentation. The fields of the BVPFunction type directly match the names of the inputs. """ struct BVPFunction{iip, specialize, twopoint, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, - JP, BCJP, BCRP, SP, TW, TWt, TPJ, O, TCV, BCTCV, - SYS} <: AbstractBVPFunction{iip, twopoint} + JP, BCJP, BCRP, SP, TW, TWt, TPJ, O, TCV, BCTCV, + SYS} <: AbstractBVPFunction{iip, twopoint} f::F bc::BF mass_matrix::TMM @@ -2337,11 +2337,7 @@ function ODEFunction{iip, specialize}(f; typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, -<<<<<<< HEAD observed, _colorvec, sys, initializeprob, initializeprobmap) -======= - observed, _colorvec, sys) ->>>>>>> 65d8f530 (Add a nlls trait to BVProblem) else ODEFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2354,11 +2350,7 @@ function ODEFunction{iip, specialize}(f; typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, -<<<<<<< HEAD observed, _colorvec, sys, initializeprob, initializeprobmap) -======= - observed, _colorvec, sys) ->>>>>>> 65d8f530 (Add a nlls trait to BVProblem) end end @@ -3921,7 +3913,9 @@ end function sys_or_symbolcache(sys, syms, paramsyms, indepsym = nothing) if sys === nothing && (syms !== nothing || paramsyms !== nothing || indepsym !== nothing) - Base.depwarn("The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead.", :syms) + Base.depwarn( + "The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead.", + :syms) sys = SymbolCache(syms, paramsyms, indepsym) end return sys diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index e091fa98e..1e3c8d816 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -137,8 +137,8 @@ end # FIXME: Remove the defaults for resid and original on a breaking release function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense, - tslocation, stats, alg_choice, retcode, resid = nothing, - original = nothing) where {T, N} + tslocation, stats, alg_choice, retcode, resid = nothing, + original = nothing) where {T, N} return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t), typeof(k), typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(alg_choice), typeof(resid), diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index f23648fd6..fdc883f71 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -101,11 +101,11 @@ sprob3 = remake(sprob; u0 = [x => 3.0], p = [σ => 30.0]) # partial update # NonlinearProblem @named ns = NonlinearSystem( - [0 ~ σ*(y-x), - 0 ~ x*(ρ-z)-y, - 0 ~ x*y - β*z], - [x,y,z], - [σ,ρ,β] + [0 ~ σ * (y - x), + 0 ~ x * (ρ - z) - y, + 0 ~ x * y - β * z], + [x, y, z], + [σ, ρ, β] ) ns = complete(ns) nlprob = NonlinearProblem(ns, u0, p) @@ -131,14 +131,14 @@ nlprob3 = remake(nlprob; u0 = [x => 3.0], p = [σ => 30.0]) # partial update @parameters β γ @variables S(t) I(t) R(t) -rate₁ = β*S*I +rate₁ = β * S * I affect₁ = [S ~ S - 1, I ~ I + 1] -rate₂ = γ*I +rate₂ = γ * I affect₂ = [I ~ I - 1, R ~ R + 1] -j₁ = ConstantRateJump(rate₁,affect₁) -j₂ = ConstantRateJump(rate₂,affect₂) -j₃ = MassActionJump(2*β+γ, [R => 1], [S => 1, R => -1]) -@named js = JumpSystem([j₁,j₂,j₃], t, [S,I,R], [β,γ]) +j₁ = ConstantRateJump(rate₁, affect₁) +j₂ = ConstantRateJump(rate₂, affect₂) +j₃ = MassActionJump(2 * β + γ, [R => 1], [S => 1, R => -1]) +@named js = JumpSystem([j₁, j₂, j₃], t, [S, I, R], [β, γ]) js = complete(js) u₀map = [S => 999, I => 1, R => 0.0] parammap = [β => 0.1 / 1000, γ => 0.01] diff --git a/test/remake_tests.jl b/test/remake_tests.jl index 6b0937297..e5d6f6600 100644 --- a/test/remake_tests.jl +++ b/test/remake_tests.jl @@ -2,14 +2,14 @@ using SciMLBase using SymbolicIndexingInterface # ODE -function lorenz!(du,u,p,t) - du[1] = p[1] * (u[2]-u[1]) - du[2] = u[1]*(p[2]-u[3]) - u[2] - du[3] = u[1]*u[2] - p[3]*u[3] +function lorenz!(du, u, p, t) + du[1] = p[1] * (u[2] - u[1]) + du[2] = u[1] * (p[2] - u[3]) - u[2] + du[3] = u[1] * u[2] - p[3] * u[3] end -u0 = [1.0;0.0;0.0] -tspan = (0.0,100.0) -p = [10.0, 28.0, 8/3] +u0 = [1.0; 0.0; 0.0] +tspan = (0.0, 100.0) +p = [10.0, 28.0, 8 / 3] fn = ODEFunction(lorenz!; sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)) prob = ODEProblem(fn, u0, tspan, p) @@ -21,7 +21,7 @@ prob = ODEProblem(fn, u0, tspan, p) @test remake(prob; p = [:a => 11.0, :c => 13.0, :b => 12.0]).p == [11.0, 12.0, 13.0] @test remake(prob; p = (11.0, 12.0, 13)).p == (11.0, 12.0, 13) @test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, 0.0, 0.0] -@test remake(prob; p = [:b => 11.0]).p == [10.0, 11.0, 8/3] +@test remake(prob; p = [:b => 11.0]).p == [10.0, 11.0, 8 / 3] # BVP g = 9.81 @@ -39,7 +39,7 @@ function bc1!(residual, u, p, t) end u0 = [pi / 2, pi / 2] p = [g, L] -fn = BVPFunction(simplependulum!, bc1!; sys = SymbolCache([:x, :y], [:a, :b], :t) ) +fn = BVPFunction(simplependulum!, bc1!; sys = SymbolCache([:x, :y], [:a, :b], :t)) prob = BVProblem(fn, u0, tspan, p) @test remake(prob).u0 == u0 @@ -80,11 +80,11 @@ prob = SDEProblem(fn, u0, tspan, p) # OptimizationProblem function loss(u, p) - return (p[1] - u[1]) ^ 2 + p[2] * (u[2] - u[1] ^ 2) ^ 2 + return (p[1] - u[1])^2 + p[2] * (u[2] - u[1]^2)^2 end u0 = [1.0, 2.0] p = [1.0, 100.0] -fn = OptimizationFunction(loss; sys = SymbolCache([:x, :y], [:a, :b], :t) ) +fn = OptimizationFunction(loss; sys = SymbolCache([:x, :y], [:a, :b], :t)) prob = OptimizationProblem(fn, u0, p) @test remake(prob).u0 == u0 @test remake(prob).p == p @@ -128,4 +128,3 @@ prob = NonlinearLeastSquaresProblem(fn, u0, p) @test remake(prob; p = (11.0, 12.0, 13)).p == (11.0, 12.0, 13) @test remake(prob; u0 = [:x => 2.0]).u0 == [2.0, 0.0, 0.0] @test remake(prob; p = [:b => 11.0]).p == [10.0, 11.0, 26.0] - diff --git a/test/traits.jl b/test/traits.jl index 8b6ebe46a..b08e13a72 100644 --- a/test/traits.jl +++ b/test/traits.jl @@ -10,7 +10,7 @@ using ModelingToolkit: t_nounits as t, D_nounits as D @test !SciMLBase.Tables.isrowtable(SciMLBase.QuadratureSolution) @test !SciMLBase.Tables.isrowtable(SciMLBase.OptimizationSolution) -@variables x(t)=1 +@variables x(t) = 1 eqs = [D(x) ~ -x] @named sys = ODESystem(eqs, t) sys = complete(sys) From 7ff95fcfdb7debe512d3c21d8289b5653971178f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 24 Mar 2024 13:39:35 -0400 Subject: [PATCH 7/7] Specialize handling of functions as initial condition --- src/problems/bvp_problems.jl | 18 ++++++++++++++---- src/solutions/ode_solutions.jl | 4 ++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index e84073adf..0c84fb4b9 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -138,22 +138,32 @@ struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <: end if nlls === nothing + if !hasmethod(length, Tuple{typeof(_u0)}) + # If _u0 is a function for initial guess we won't be able to infer + __u0 = _u0 isa Function ? + (hasmethod(_u0, Tuple{typeof(p), typeof(first(_tspan))}) ? + _u0(p, first(_tspan)) : _u0(first(_tspan))) : nothing + else + __u0 = _u0 + end # Try to infer it - if problem_type isa TwoPointBVProblem + if __u0 isa Nothing + _nlls = Nothing + elseif problem_type isa TwoPointBVProblem if f.bcresid_prototype !== nothing l1, l2 = length(f.bcresid_prototype[1]), length(f.bcresid_prototype[2]) - _nlls = l1 + l2 != length(_u0) + _nlls = l1 + l2 != length(__u0) else # iip without bcresid_prototype is not possible if !iip l1 = length(f.bc[1](u0, p)) l2 = length(f.bc[2](u0, p)) - _nlls = l1 + l2 != length(_u0) + _nlls = l1 + l2 != length(__u0) end end else if f.bcresid_prototype !== nothing - _nlls = length(f.bcresid_prototype) != length(_u0) + _nlls = length(f.bcresid_prototype) != length(__u0) else _nlls = Nothing # Cannot reliably infer end diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 1e3c8d816..6ad75a3be 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -242,6 +242,10 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, if prob.u0 === nothing N = 2 + elseif prob isa BVProblem && !hasmethod(size, Tuple{typeof(prob.u0)}) + __u0 = hasmethod(prob.u0, Tuple{typeof(prob.p), typeof(first(prob.tspan))}) ? + prob.u0(prob.p, first(prob.tspan)) : prob.u0(first(prob.tspan)) + N = length((size(__u0)..., length(u))) else N = length((size(prob.u0)..., length(u))) end