From 9d4f60f4a37cb65339fa91c4570e4295af5dc9f0 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Fri, 27 Sep 2024 17:57:41 +0200 Subject: [PATCH] support forwardiff of states --- src/GraphDynamics.jl | 27 ++++++-- src/graph_solve.jl | 10 --- src/problems.jl | 34 ++++++++-- src/subsystems.jl | 58 +++++++++++++--- test/Project.toml | 2 + test/particle_osc_example.jl | 125 ++++++++++++++++++++--------------- 6 files changed, 173 insertions(+), 83 deletions(-) diff --git a/src/GraphDynamics.jl b/src/GraphDynamics.jl index 65a95e1..07bd3c3 100644 --- a/src/GraphDynamics.jl +++ b/src/GraphDynamics.jl @@ -41,6 +41,7 @@ export GraphSystem, ODEGraphSystem, SDEGraphSystem, + get_tag, get_states, get_params, ODEProblem, @@ -69,7 +70,8 @@ using SciMLBase: CallbackSet, VectorContinuousCallback, ContinuousCallback, - DiscreteCallback + DiscreteCallback, + remake using RecursiveArrayTools: ArrayPartition @@ -101,20 +103,31 @@ include("utils.jl") #---------------------------------------------------------- # API functions to be implemented by new Systems -struct SubsystemStates{Name, Eltype, States <: NamedTuple} <: AbstractVector{Eltype} +struct SubsystemStates{T, Eltype, States <: NamedTuple} <: AbstractVector{Eltype} states::States end -struct SubsystemParams{Name, Params <: NamedTuple} +struct SubsystemParams{T, Params <: NamedTuple} params::Params end -struct Subsystem{Name, Eltype, States, Params} - states::SubsystemStates{Name, Eltype, States} - params::SubsystemParams{Name, Params} +""" + Subsystem{T, Eltype, StateNT, ParamNT} + +A `Subsystem` struct describes a complete subcomponent to an `GraphSystem`. This stores a `SubsystemStates` to describe the continuous dynamical state of the subsystem, and a `GraphSystemParams` which describes various non-dynamical parameters of the subsystem. The type parameter `T` is the subsystem's \"tag\" which labels what sort of subsystem it is. + +See also `subsystem_differential`, `SubsystemStates`, `SubsystemParams`. + +For example, if we wanted to describe a system where one sub-component is a billiard ball, + + +""" +struct Subsystem{T, Eltype, States, Params} + states::SubsystemStates{T, Eltype, States} + params::SubsystemParams{T, Params} end -function get_name end +function get_tag end function get_params end function get_states end diff --git a/src/graph_solve.jl b/src/graph_solve.jl index 372bd9b..4a7cf5a 100644 --- a/src/graph_solve.jl +++ b/src/graph_solve.jl @@ -299,7 +299,6 @@ tany(f, coll; kwargs...) = tmapreduce(f, |, coll; kwargs...) @nexprs $Len k -> begin M = connection_matrices[nc][k, i] if has_discrete_events(eltype(M)) - #tany(foo(Val(k), Val(i), Val(NConn), M, t), eachindex(states_partitioned[i])) && return true for j ∈ eachindex(states_partitioned[i]) for (l, Mlj) ∈ maybe_sparse_enumerate_col(M, j) discrete_event_condition(Mlj, t) && return true @@ -313,15 +312,6 @@ tany(f, coll; kwargs...) = tmapreduce(f, |, coll; kwargs...) end end -function foo(::Val{k}, ::Val{i}, ::Val{NConn}, M, t) where {i, NConn, k} - function f(j) - for (l, Mlj) ∈ maybe_sparse_enumerate_col(M, j) - discrete_event_condition(Mlj, t) && return true - end - false - end -end - function discrete_affect!(integrator) (;params_partitioned, state_types_val, connection_matrices) = integrator.p state_data = integrator.u.x diff --git a/src/problems.jl b/src/problems.jl index 02ea918..c09d8fb 100644 --- a/src/problems.jl +++ b/src/problems.jl @@ -6,7 +6,7 @@ end function SciMLBase.ODEProblem(g::ODEGraphSystem, u0map, tspan, param_map=[]; scheduler=SerialScheduler(), tstops=Float64[], allow_nonconcrete=false, kwargs...) - nt = _problem(g, tspan; scheduler, allow_nonconcrete) + nt = _problem(g, tspan; scheduler, allow_nonconcrete, u0map, param_map) (; f, u, tspan, p, callback) = nt tstops = vcat(tstops, nt.tstops) prob = ODEProblem(f, u, tspan, p; callback, tstops, kwargs...) @@ -21,11 +21,18 @@ end function SciMLBase.SDEProblem(g::SDEGraphSystem, u0map, tspan, param_map=[]; scheduler=SerialScheduler(), tstops=Float64[], allow_nonconcrete=false, kwargs...) - nt = _problem(g, tspan; scheduler, allow_nonconcrete) + nt = _problem(g, tspan; scheduler, allow_nonconcrete, u0map, param_map) (; f, u, tspan, p, callback) = nt noise_rate_prototype = nothing # zeros(length(u)) # this'll need to change once we support correlated noise - SDEProblem(f, graph_noise!, u, tspan, p; callback, noise_rate_prototype, tstops = vcat(tstops, nt.tstops), kwargs...) + prob = SDEProblem(f, graph_noise!, u, tspan, p; callback, noise_rate_prototype, tstops = vcat(tstops, nt.tstops), kwargs...) + for (k, v) ∈ u0map + setu(prob, k)(prob, v) + end + for (k, v) ∈ param_map + setp(prob, k)(prob, v) + end + prob end Base.@kwdef struct GraphSystemParameters{PP, CM, S, STV} @@ -35,7 +42,7 @@ Base.@kwdef struct GraphSystemParameters{PP, CM, S, STV} state_types_val::STV end -function _problem(g::GraphSystem, tspan; scheduler, allow_nonconcrete) +function _problem(g::GraphSystem, tspan; scheduler, allow_nonconcrete, u0map, param_map) (; states_partitioned, params_partitioned, connection_matrices, @@ -43,6 +50,25 @@ function _problem(g::GraphSystem, tspan; scheduler, allow_nonconcrete) composite_discrete_events_partitioned, composite_continuous_events_partitioned,) = g + total_eltype = let + states_eltype = mapreduce(promote_type, states_partitioned) do v + eltype(eltype(v)) + end + u0map_eltype = mapreduce(promote_type, u0map; init=Union{}) do (k, v) + typeof(v) + end + promote_type(states_eltype, u0map_eltype) + end + + re_eltype(s::SubsystemStates{T}) where {T} = convert(SubsystemStates{T, total_eltype}, s) + states_partitioned = map(states_partitioned) do v + if eltype(eltype(v)) <: total_eltype && eltype(eltype(v)) !== Union{} + v + else + re_eltype.(v) + end + end + length(states_partitioned) == length(params_partitioned) || error("Incompatible state and parameter lengths") for i ∈ eachindex(states_partitioned, params_partitioned) diff --git a/src/subsystems.jl b/src/subsystems.jl index 55a56fe..0b5c3cb 100644 --- a/src/subsystems.jl +++ b/src/subsystems.jl @@ -21,8 +21,8 @@ function ConstructionBase.setproperties(s::SubsystemParams{T}, patch::NamedTuple SubsystemParams{T}(props′) end -get_name(::SubsystemParams{Name}) where {Name} = Name -get_name(::Type{<:SubsystemParams{Name}}) where {Name} = Name +get_tag(::SubsystemParams{Name}) where {Name} = Name +get_tag(::Type{<:SubsystemParams{Name}}) where {Name} = Name Base.NamedTuple(p::SubsystemParams) = getfield(p, :params) Base.Tuple(s::SubsystemParams) = Tuple(getfield(s, :params)) Base.getproperty(p::SubsystemParams, prop::Symbol) = getproperty(NamedTuple(p), prop) @@ -43,6 +43,10 @@ end function SubsystemStates{Name}(nt::NamedTuple{state_names, NTuple{N, Eltype}}) where {Name, state_names, N, Eltype} SubsystemStates{Name, Eltype, typeof(nt)}(nt) end +function SubsystemStates{Name}(nt::NamedTuple{state_names, <:NTuple{N, Any}}) where {Name, state_names, N} + nt_promoted = NamedTuple{state_names}(promote(nt...)) + SubsystemStates{Name}(nt_promoted) +end function SubsystemStates{Name}(nt::NamedTuple{(), Tuple{}}) where {Name} SubsystemStates{Name, Union{}, NamedTuple{(), Tuple{}}}(nt) end @@ -81,8 +85,8 @@ function ConstructionBase.setproperties(s::SubsystemStates{T}, patch::NamedTuple SubsystemStates{T}(props′) end -get_name(::SubsystemStates{Name}) where {Name} = Name -get_name(::Type{<:SubsystemStates{Name}}) where {Name} = Name +get_tag(::SubsystemStates{Name}) where {Name} = Name +get_tag(::Type{<:SubsystemStates{Name}}) where {Name} = Name Base.NamedTuple(s::SubsystemStates) = getfield(s, :states) Base.Tuple(s::SubsystemStates) = Tuple(getfield(s, :states)) Base.getproperty(s::SubsystemStates, prop::Symbol) = getproperty(NamedTuple(s), prop) @@ -93,12 +97,25 @@ function state_ind(::Type{SubsystemStates{Name, Eltype, NamedTuple{names, Tup}}} i = findfirst(==(s), names) end +function Base.convert(::Type{SubsystemStates{Name, Eltype, NT}}, s::SubsystemStates{Name}) where {Name, Eltype, NT} + SubsystemStates{Name}(convert(NT, NamedTuple(s))) +end +function Base.convert(::Type{SubsystemStates{Name, Eltype}}, + s::SubsystemStates{Name, <:Any, <:NamedTuple{state_names}}) where {Name, Eltype, state_names} + nt = NamedTuple{state_names}(convert.(Eltype, Tuple(s))) + SubsystemStates{Name, Eltype, typeof(nt)}(nt) +end + #------------------------------------------------------------ # Subsystem function Subsystem{T}(;states, params) where {T} - ET = eltype(states) - Subsystem{T, ET, typeof(states), typeof(params)}(SubsystemStates{T}(states), SubsystemParams{T}(params)) + Subsystem{T}(SubsystemStates{T}(states), SubsystemParams{T}(params)) +end +function Subsystem{T}(states::SubsystemStates{T, Eltype, States}, + params::SubsystemParams{T, Params}) where {T, Eltype, States, Params} + Subsystem{T, Eltype, States, Params}(states, params) end + function Base.show(io::IO, sys::Subsystem{Name, Eltype}) where {Name, Eltype} print(io, "$Subsystem{$Name, $Eltype}(states = ", @@ -124,15 +141,40 @@ function ConstructionBase.setproperties(s::Subsystem{T, Eltype, States, Params}, Subsystem{T, Eltype, States, Params}(SubsystemStates{T}(states′), SubsystemParams{T}(params′)) end +function Base.convert(::Type{Subsystem{Name, Eltype, SNT, PNT}}, s::Subsystem{Name}) where {Name, Eltype, SNT, PNT} + Subsystem{Name}(convert(SubsystemStates{Name, Eltype, SNT}, get_states(s)), + convert(SubsystemParams{Name, PNT}, get_params(s))) +end +function Base.convert(::Type{Subsystem{Name, Eltype}}, s::Subsystem{Name}) where {Name, Eltype} + Subsystem{Name}(convert(SubsystemStates{Name, Eltype}, get_states(s)), get_params(s)) +end + +@generated function promote_nt_type(::Type{NamedTuple{names, T1}}, + ::Type{NamedTuple{names, T2}}) where {names, T1, T2} + NamedTuple{names, Tuple{(promote_type(T1.parameters[i], T2.parameters[i]) for i ∈ eachindex(names))...}} +end +function Base.promote_rule(::Type{SubsystemParams{Name, NT1}}, + ::Type{SubsystemParams{Name, NT2}}) where {Name, NT1, NT2} + SubsystemParams{Name, promote_nt_type(NT1, NT2)} +end +function Base.promote_rule(::Type{SubsystemStates{Name, ET1, NT1}}, + ::Type{SubsystemStates{Name, ET2, NT2}}) where {Name, ET1, ET2, NT1, NT2} + SubsystemStates{Name, promote_type(ET1, ET2), promote_nt_type(NT1, NT2)} +end + +function Base.promote_rule(::Type{Subsystem{Name, ET1, SNT1, PNT1}}, + ::Type{Subsystem{Name, ET2, SNT2, PNT2}}) where {Name, ET1, SNT1, PNT1, ET2, SNT2, PNT2} + Subsystem{Name, promote_type(ET1, ET2), promote_nt_type(SNT1, SNT2), promote_nt_type(PNT1, PNT2)} +end get_states(s::Subsystem) = getfield(s, :states) get_params(s::Subsystem) = getfield(s, :params) -get_name(::Subsystem{Name}) where {Name} = Name +get_tag(::Subsystem{Name}) where {Name} = Name -get_name(::Type{<:Subsystem{Name}}) where {Name} = Name +get_tag(::Type{<:Subsystem{Name}}) where {Name} = Name function Base.getproperty(s::Subsystem{<:Any, States, Params}, diff --git a/test/Project.toml b/test/Project.toml index 2242f59..8e9cb2d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,7 @@ [deps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GraphDynamics = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/particle_osc_example.jl b/test/particle_osc_example.jl index 36a583f..245f381 100644 --- a/test/particle_osc_example.jl +++ b/test/particle_osc_example.jl @@ -1,4 +1,4 @@ -using GraphDynamics, OrdinaryDiffEq, Test +using GraphDynamics, OrdinaryDiffEq, Test, ForwardDiff struct Particle end function GraphDynamics.subsystem_differential(sys::Subsystem{Particle}, F, t) @@ -34,56 +34,73 @@ function ((;fac)::Coulomb)(a, b) -fac * a.q * b.q * sign(a.x - b.x)/(abs(a.x - b.x) + 1e-10)^2 end -# put some garbage values in here for states and params, but we'll set them to reasonable values later with the -# u0map and param_map -subsystems_partitioned = ([Subsystem{Particle}(states=(;x= NaN, v=0.0), params=(;m=1.0, q=1.0)), - Subsystem{Particle}(states=(;x=-1.0, v=Inf), params=(;m=2.0, q=1.0))], - [Subsystem{Oscillator}(states=(;x=-Inf, v=1.0), params=(;x₀=0.0, m=-3000.0, k=1.0, q=1.0))]) - -states_partitioned = map(v -> map(get_states, v), subsystems_partitioned) -params_partitioned = map(v -> map(get_params, v), subsystems_partitioned) -names_partitioned = ([:particle1, :particle2], [:osc]) - -spring_conns_par_par = NotConnected() -spring_conns_par_osc = [Spring(1) - Spring(1);;] -spring_conns_osc_par = [Spring(1) Spring(1)] -spring_conns_osc_osc = NotConnected() - -spring_conns = ConnectionMatrix(((spring_conns_par_par, spring_conns_par_osc), - (spring_conns_osc_par, spring_conns_osc_osc))) - -# Spring[⎡. .⎤ ⎡2⎤ -# ⎣. .⎦ ⎣0⎦ -# [2 0] [.]] - -coulomb_conns_par_par = [Coulomb(0) Coulomb(.05) - Coulomb(.05) Coulomb(0)] -coulomb_conns_par_osc = [Coulomb(.05) - Coulomb(.05);;] -coulomb_conns_osc_par = [Coulomb(.05) Coulomb(.05)] -coulomb_conns_osc_osc = NotConnected() - -coulomb_conns = ConnectionMatrix(((coulomb_conns_par_par, coulomb_conns_par_osc), - (coulomb_conns_osc_par, coulomb_conns_osc_osc))) - -# Coulomb[⎡0 1⎤ ⎡1⎤ -# ⎣1 0⎦ ⎣1⎦ -# [0 0] [.]] - -connection_matrices = ConnectionMatrices((spring_conns, coulomb_conns)) - -sys = ODEGraphSystem(;connection_matrices, states_partitioned, params_partitioned, names_partitioned) -tspan = (0.0, 20.0) - -prob = ODEProblem(sys, - # Fix the garbage state values - [:particle1₊x => 1.0, :particle2₊v => 0.0, :osc₊x => 0.0], - tspan, - # fix the garbage param values - [:osc₊m => 3.0]) -sol = solve(prob, Tsit5()) - -@test sol[:particle1₊x][end] ≈ 1.4923823131014389 -@test sol[:particle2₊x][end] ≈ -0.11189010002787175 -@test sol[:osc₊x][end] ≈ 1.3175449091469553 + +function solve_particle_osc(;x1, x2) + # put some garbage values in here for states and params, but we'll set them to reasonable values later with the + # u0map and param_map + subsystems_partitioned = ([Subsystem{Particle}(states=(;x= NaN, v=0.0), params=(;m=1.0, q=1.0)), + Subsystem{Particle}(states=(;x=-1.0, v=Inf), params=(;m=2.0, q=1.0))], + [Subsystem{Oscillator}(states=(;x=-Inf, v=1.0), params=(;x₀=0.0, m=-3000.0, k=1.0, q=1.0))]) + + states_partitioned = map(v -> map(get_states, v), subsystems_partitioned) + params_partitioned = map(v -> map(get_params, v), subsystems_partitioned) + names_partitioned = ([:particle1, :particle2], [:osc]) + + spring_conns_par_par = NotConnected() + spring_conns_par_osc = [Spring(1) + Spring(1);;] + spring_conns_osc_par = [Spring(1) Spring(1)] + spring_conns_osc_osc = NotConnected() + + spring_conns = ConnectionMatrix(((spring_conns_par_par, spring_conns_par_osc), + (spring_conns_osc_par, spring_conns_osc_osc))) + + # Spring[⎡. .⎤ ⎡2⎤ + # ⎣. .⎦ ⎣0⎦ + # [2 0] [.]] + + coulomb_conns_par_par = [Coulomb(0) Coulomb(.05) + Coulomb(.05) Coulomb(0)] + coulomb_conns_par_osc = [Coulomb(.05) + Coulomb(.05);;] + coulomb_conns_osc_par = [Coulomb(.05) Coulomb(.05)] + coulomb_conns_osc_osc = NotConnected() + + coulomb_conns = ConnectionMatrix(((coulomb_conns_par_par, coulomb_conns_par_osc), + (coulomb_conns_osc_par, coulomb_conns_osc_osc))) + + # Coulomb[⎡0 1⎤ ⎡1⎤ + # ⎣1 0⎦ ⎣1⎦ + # [0 0] [.]] + + connection_matrices = ConnectionMatrices((spring_conns, coulomb_conns)) + + sys = ODEGraphSystem(;connection_matrices, states_partitioned, params_partitioned, names_partitioned) + tspan = (0.0, 20.0) + + prob = ODEProblem(sys, + # Fix the garbage state values + [:particle1₊x => x1, :particle2₊x => x2, :particle2₊v => 0.0, :osc₊x => 0.0], + tspan, + # fix the garbage param values + [:osc₊m => 3.0]) + sol = solve(prob, Tsit5()) +end + +@testset "solutions" begin + sol = solve_particle_osc(;x1=1.0, x2=-1.0) + @test sol[:particle1₊x][end] ≈ 1.4923823131014389 rtol=1e-7 + @test sol[:particle2₊x][end] ≈ -0.11189010002787175 rtol=1e-7 + @test sol[:osc₊x][end] ≈ 1.3175449091469553 rtol=1e-7 +end + +@testset "sensitivies" begin + jac = ForwardDiff.jacobian([1.0, -1.0]) do (x1, x2) + sol = solve_particle_osc(;x1, x2) + [sol[:particle1₊x][end], sol[:particle2₊x][end], sol[:osc₊x][end]] + end + @test jac ≈ [0.498565 -0.0161443 + -1.92556 3.14649 + -0.249007 0.808641] rtol=1e-5 + +end