Skip to content

Commit

Permalink
support forwardiff of states
Browse files Browse the repository at this point in the history
  • Loading branch information
MasonProtter committed Sep 27, 2024
1 parent 983cada commit 9d4f60f
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 83 deletions.
27 changes: 20 additions & 7 deletions src/GraphDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export
GraphSystem,
ODEGraphSystem,
SDEGraphSystem,
get_tag,
get_states,
get_params,
ODEProblem,
Expand Down Expand Up @@ -69,7 +70,8 @@ using SciMLBase:
CallbackSet,
VectorContinuousCallback,
ContinuousCallback,
DiscreteCallback
DiscreteCallback,
remake

using RecursiveArrayTools: ArrayPartition

Expand Down Expand Up @@ -101,20 +103,31 @@ include("utils.jl")
#----------------------------------------------------------
# API functions to be implemented by new Systems

struct SubsystemStates{Name, Eltype, States <: NamedTuple} <: AbstractVector{Eltype}
struct SubsystemStates{T, Eltype, States <: NamedTuple} <: AbstractVector{Eltype}
states::States
end

struct SubsystemParams{Name, Params <: NamedTuple}
struct SubsystemParams{T, Params <: NamedTuple}
params::Params
end

struct Subsystem{Name, Eltype, States, Params}
states::SubsystemStates{Name, Eltype, States}
params::SubsystemParams{Name, Params}
"""
Subsystem{T, Eltype, StateNT, ParamNT}
A `Subsystem` struct describes a complete subcomponent to an `GraphSystem`. This stores a `SubsystemStates` to describe the continuous dynamical state of the subsystem, and a `GraphSystemParams` which describes various non-dynamical parameters of the subsystem. The type parameter `T` is the subsystem's \"tag\" which labels what sort of subsystem it is.
See also `subsystem_differential`, `SubsystemStates`, `SubsystemParams`.
For example, if we wanted to describe a system where one sub-component is a billiard ball,
"""
struct Subsystem{T, Eltype, States, Params}
states::SubsystemStates{T, Eltype, States}
params::SubsystemParams{T, Params}
end

function get_name end
function get_tag end
function get_params end
function get_states end

Expand Down
10 changes: 0 additions & 10 deletions src/graph_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ tany(f, coll; kwargs...) = tmapreduce(f, |, coll; kwargs...)
@nexprs $Len k -> begin
M = connection_matrices[nc][k, i]
if has_discrete_events(eltype(M))
#tany(foo(Val(k), Val(i), Val(NConn), M, t), eachindex(states_partitioned[i])) && return true
for j eachindex(states_partitioned[i])
for (l, Mlj) maybe_sparse_enumerate_col(M, j)
discrete_event_condition(Mlj, t) && return true
Expand All @@ -313,15 +312,6 @@ tany(f, coll; kwargs...) = tmapreduce(f, |, coll; kwargs...)
end
end

function foo(::Val{k}, ::Val{i}, ::Val{NConn}, M, t) where {i, NConn, k}
function f(j)
for (l, Mlj) maybe_sparse_enumerate_col(M, j)
discrete_event_condition(Mlj, t) && return true
end
false
end
end

function discrete_affect!(integrator)
(;params_partitioned, state_types_val, connection_matrices) = integrator.p
state_data = integrator.u.x
Expand Down
34 changes: 30 additions & 4 deletions src/problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ end
function SciMLBase.ODEProblem(g::ODEGraphSystem, u0map, tspan, param_map=[];
scheduler=SerialScheduler(), tstops=Float64[],
allow_nonconcrete=false, kwargs...)
nt = _problem(g, tspan; scheduler, allow_nonconcrete)
nt = _problem(g, tspan; scheduler, allow_nonconcrete, u0map, param_map)
(; f, u, tspan, p, callback) = nt
tstops = vcat(tstops, nt.tstops)
prob = ODEProblem(f, u, tspan, p; callback, tstops, kwargs...)
Expand All @@ -21,11 +21,18 @@ end
function SciMLBase.SDEProblem(g::SDEGraphSystem, u0map, tspan, param_map=[];
scheduler=SerialScheduler(), tstops=Float64[],
allow_nonconcrete=false, kwargs...)
nt = _problem(g, tspan; scheduler, allow_nonconcrete)
nt = _problem(g, tspan; scheduler, allow_nonconcrete, u0map, param_map)
(; f, u, tspan, p, callback) = nt

noise_rate_prototype = nothing # zeros(length(u)) # this'll need to change once we support correlated noise
SDEProblem(f, graph_noise!, u, tspan, p; callback, noise_rate_prototype, tstops = vcat(tstops, nt.tstops), kwargs...)
prob = SDEProblem(f, graph_noise!, u, tspan, p; callback, noise_rate_prototype, tstops = vcat(tstops, nt.tstops), kwargs...)
for (k, v) u0map
setu(prob, k)(prob, v)
end
for (k, v) param_map
setp(prob, k)(prob, v)
end
prob
end

Base.@kwdef struct GraphSystemParameters{PP, CM, S, STV}
Expand All @@ -35,14 +42,33 @@ Base.@kwdef struct GraphSystemParameters{PP, CM, S, STV}
state_types_val::STV
end

function _problem(g::GraphSystem, tspan; scheduler, allow_nonconcrete)
function _problem(g::GraphSystem, tspan; scheduler, allow_nonconcrete, u0map, param_map)
(; states_partitioned,
params_partitioned,
connection_matrices,
tstops,
composite_discrete_events_partitioned,
composite_continuous_events_partitioned,) = g

total_eltype = let
states_eltype = mapreduce(promote_type, states_partitioned) do v
eltype(eltype(v))
end
u0map_eltype = mapreduce(promote_type, u0map; init=Union{}) do (k, v)
typeof(v)
end
promote_type(states_eltype, u0map_eltype)
end

re_eltype(s::SubsystemStates{T}) where {T} = convert(SubsystemStates{T, total_eltype}, s)
states_partitioned = map(states_partitioned) do v
if eltype(eltype(v)) <: total_eltype && eltype(eltype(v)) !== Union{}
v
else
re_eltype.(v)
end
end

length(states_partitioned) == length(params_partitioned) ||
error("Incompatible state and parameter lengths")
for i eachindex(states_partitioned, params_partitioned)
Expand Down
58 changes: 50 additions & 8 deletions src/subsystems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ function ConstructionBase.setproperties(s::SubsystemParams{T}, patch::NamedTuple
SubsystemParams{T}(props′)
end

get_name(::SubsystemParams{Name}) where {Name} = Name
get_name(::Type{<:SubsystemParams{Name}}) where {Name} = Name
get_tag(::SubsystemParams{Name}) where {Name} = Name
get_tag(::Type{<:SubsystemParams{Name}}) where {Name} = Name
Base.NamedTuple(p::SubsystemParams) = getfield(p, :params)
Base.Tuple(s::SubsystemParams) = Tuple(getfield(s, :params))
Base.getproperty(p::SubsystemParams, prop::Symbol) = getproperty(NamedTuple(p), prop)
Expand All @@ -43,6 +43,10 @@ end
function SubsystemStates{Name}(nt::NamedTuple{state_names, NTuple{N, Eltype}}) where {Name, state_names, N, Eltype}
SubsystemStates{Name, Eltype, typeof(nt)}(nt)
end
function SubsystemStates{Name}(nt::NamedTuple{state_names, <:NTuple{N, Any}}) where {Name, state_names, N}
nt_promoted = NamedTuple{state_names}(promote(nt...))
SubsystemStates{Name}(nt_promoted)
end
function SubsystemStates{Name}(nt::NamedTuple{(), Tuple{}}) where {Name}
SubsystemStates{Name, Union{}, NamedTuple{(), Tuple{}}}(nt)
end
Expand Down Expand Up @@ -81,8 +85,8 @@ function ConstructionBase.setproperties(s::SubsystemStates{T}, patch::NamedTuple
SubsystemStates{T}(props′)
end

get_name(::SubsystemStates{Name}) where {Name} = Name
get_name(::Type{<:SubsystemStates{Name}}) where {Name} = Name
get_tag(::SubsystemStates{Name}) where {Name} = Name
get_tag(::Type{<:SubsystemStates{Name}}) where {Name} = Name
Base.NamedTuple(s::SubsystemStates) = getfield(s, :states)
Base.Tuple(s::SubsystemStates) = Tuple(getfield(s, :states))
Base.getproperty(s::SubsystemStates, prop::Symbol) = getproperty(NamedTuple(s), prop)
Expand All @@ -93,12 +97,25 @@ function state_ind(::Type{SubsystemStates{Name, Eltype, NamedTuple{names, Tup}}}
i = findfirst(==(s), names)
end

function Base.convert(::Type{SubsystemStates{Name, Eltype, NT}}, s::SubsystemStates{Name}) where {Name, Eltype, NT}
SubsystemStates{Name}(convert(NT, NamedTuple(s)))
end
function Base.convert(::Type{SubsystemStates{Name, Eltype}},
s::SubsystemStates{Name, <:Any, <:NamedTuple{state_names}}) where {Name, Eltype, state_names}
nt = NamedTuple{state_names}(convert.(Eltype, Tuple(s)))
SubsystemStates{Name, Eltype, typeof(nt)}(nt)
end

#------------------------------------------------------------
# Subsystem
function Subsystem{T}(;states, params) where {T}
ET = eltype(states)
Subsystem{T, ET, typeof(states), typeof(params)}(SubsystemStates{T}(states), SubsystemParams{T}(params))
Subsystem{T}(SubsystemStates{T}(states), SubsystemParams{T}(params))
end
function Subsystem{T}(states::SubsystemStates{T, Eltype, States},
params::SubsystemParams{T, Params}) where {T, Eltype, States, Params}
Subsystem{T, Eltype, States, Params}(states, params)
end

function Base.show(io::IO, sys::Subsystem{Name, Eltype}) where {Name, Eltype}
print(io,
"$Subsystem{$Name, $Eltype}(states = ",
Expand All @@ -124,15 +141,40 @@ function ConstructionBase.setproperties(s::Subsystem{T, Eltype, States, Params},
Subsystem{T, Eltype, States, Params}(SubsystemStates{T}(states′), SubsystemParams{T}(params′))
end

function Base.convert(::Type{Subsystem{Name, Eltype, SNT, PNT}}, s::Subsystem{Name}) where {Name, Eltype, SNT, PNT}
Subsystem{Name}(convert(SubsystemStates{Name, Eltype, SNT}, get_states(s)),
convert(SubsystemParams{Name, PNT}, get_params(s)))
end
function Base.convert(::Type{Subsystem{Name, Eltype}}, s::Subsystem{Name}) where {Name, Eltype}
Subsystem{Name}(convert(SubsystemStates{Name, Eltype}, get_states(s)), get_params(s))
end

@generated function promote_nt_type(::Type{NamedTuple{names, T1}},
::Type{NamedTuple{names, T2}}) where {names, T1, T2}
NamedTuple{names, Tuple{(promote_type(T1.parameters[i], T2.parameters[i]) for i eachindex(names))...}}
end

function Base.promote_rule(::Type{SubsystemParams{Name, NT1}},
::Type{SubsystemParams{Name, NT2}}) where {Name, NT1, NT2}
SubsystemParams{Name, promote_nt_type(NT1, NT2)}
end
function Base.promote_rule(::Type{SubsystemStates{Name, ET1, NT1}},
::Type{SubsystemStates{Name, ET2, NT2}}) where {Name, ET1, ET2, NT1, NT2}
SubsystemStates{Name, promote_type(ET1, ET2), promote_nt_type(NT1, NT2)}
end

function Base.promote_rule(::Type{Subsystem{Name, ET1, SNT1, PNT1}},
::Type{Subsystem{Name, ET2, SNT2, PNT2}}) where {Name, ET1, SNT1, PNT1, ET2, SNT2, PNT2}
Subsystem{Name, promote_type(ET1, ET2), promote_nt_type(SNT1, SNT2), promote_nt_type(PNT1, PNT2)}
end

get_states(s::Subsystem) = getfield(s, :states)
get_params(s::Subsystem) = getfield(s, :params)
get_name(::Subsystem{Name}) where {Name} = Name
get_tag(::Subsystem{Name}) where {Name} = Name



get_name(::Type{<:Subsystem{Name}}) where {Name} = Name
get_tag(::Type{<:Subsystem{Name}}) where {Name} = Name


function Base.getproperty(s::Subsystem{<:Any, States, Params},
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GraphDynamics = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
125 changes: 71 additions & 54 deletions test/particle_osc_example.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using GraphDynamics, OrdinaryDiffEq, Test
using GraphDynamics, OrdinaryDiffEq, Test, ForwardDiff

struct Particle end
function GraphDynamics.subsystem_differential(sys::Subsystem{Particle}, F, t)
Expand Down Expand Up @@ -34,56 +34,73 @@ function ((;fac)::Coulomb)(a, b)
-fac * a.q * b.q * sign(a.x - b.x)/(abs(a.x - b.x) + 1e-10)^2
end

# put some garbage values in here for states and params, but we'll set them to reasonable values later with the
# u0map and param_map
subsystems_partitioned = ([Subsystem{Particle}(states=(;x= NaN, v=0.0), params=(;m=1.0, q=1.0)),
Subsystem{Particle}(states=(;x=-1.0, v=Inf), params=(;m=2.0, q=1.0))],
[Subsystem{Oscillator}(states=(;x=-Inf, v=1.0), params=(;x₀=0.0, m=-3000.0, k=1.0, q=1.0))])

states_partitioned = map(v -> map(get_states, v), subsystems_partitioned)
params_partitioned = map(v -> map(get_params, v), subsystems_partitioned)
names_partitioned = ([:particle1, :particle2], [:osc])

spring_conns_par_par = NotConnected()
spring_conns_par_osc = [Spring(1)
Spring(1);;]
spring_conns_osc_par = [Spring(1) Spring(1)]
spring_conns_osc_osc = NotConnected()

spring_conns = ConnectionMatrix(((spring_conns_par_par, spring_conns_par_osc),
(spring_conns_osc_par, spring_conns_osc_osc)))

# Spring[⎡. .⎤ ⎡2⎤
# ⎣. .⎦ ⎣0⎦
# [2 0] [.]]

coulomb_conns_par_par = [Coulomb(0) Coulomb(.05)
Coulomb(.05) Coulomb(0)]
coulomb_conns_par_osc = [Coulomb(.05)
Coulomb(.05);;]
coulomb_conns_osc_par = [Coulomb(.05) Coulomb(.05)]
coulomb_conns_osc_osc = NotConnected()

coulomb_conns = ConnectionMatrix(((coulomb_conns_par_par, coulomb_conns_par_osc),
(coulomb_conns_osc_par, coulomb_conns_osc_osc)))

# Coulomb[⎡0 1⎤ ⎡1⎤
# ⎣1 0⎦ ⎣1⎦
# [0 0] [.]]

connection_matrices = ConnectionMatrices((spring_conns, coulomb_conns))

sys = ODEGraphSystem(;connection_matrices, states_partitioned, params_partitioned, names_partitioned)
tspan = (0.0, 20.0)

prob = ODEProblem(sys,
# Fix the garbage state values
[:particle1₊x => 1.0, :particle2₊v => 0.0, :osc₊x => 0.0],
tspan,
# fix the garbage param values
[:osc₊m => 3.0])
sol = solve(prob, Tsit5())

@test sol[:particle1₊x][end] 1.4923823131014389
@test sol[:particle2₊x][end] -0.11189010002787175
@test sol[:osc₊x][end] 1.3175449091469553

function solve_particle_osc(;x1, x2)
# put some garbage values in here for states and params, but we'll set them to reasonable values later with the
# u0map and param_map
subsystems_partitioned = ([Subsystem{Particle}(states=(;x= NaN, v=0.0), params=(;m=1.0, q=1.0)),
Subsystem{Particle}(states=(;x=-1.0, v=Inf), params=(;m=2.0, q=1.0))],
[Subsystem{Oscillator}(states=(;x=-Inf, v=1.0), params=(;x₀=0.0, m=-3000.0, k=1.0, q=1.0))])

states_partitioned = map(v -> map(get_states, v), subsystems_partitioned)
params_partitioned = map(v -> map(get_params, v), subsystems_partitioned)
names_partitioned = ([:particle1, :particle2], [:osc])

spring_conns_par_par = NotConnected()
spring_conns_par_osc = [Spring(1)
Spring(1);;]
spring_conns_osc_par = [Spring(1) Spring(1)]
spring_conns_osc_osc = NotConnected()

spring_conns = ConnectionMatrix(((spring_conns_par_par, spring_conns_par_osc),
(spring_conns_osc_par, spring_conns_osc_osc)))

# Spring[⎡. .⎤ ⎡2⎤
# ⎣. .⎦ ⎣0⎦
# [2 0] [.]]

coulomb_conns_par_par = [Coulomb(0) Coulomb(.05)
Coulomb(.05) Coulomb(0)]
coulomb_conns_par_osc = [Coulomb(.05)
Coulomb(.05);;]
coulomb_conns_osc_par = [Coulomb(.05) Coulomb(.05)]
coulomb_conns_osc_osc = NotConnected()

coulomb_conns = ConnectionMatrix(((coulomb_conns_par_par, coulomb_conns_par_osc),
(coulomb_conns_osc_par, coulomb_conns_osc_osc)))

# Coulomb[⎡0 1⎤ ⎡1⎤
# ⎣1 0⎦ ⎣1⎦
# [0 0] [.]]

connection_matrices = ConnectionMatrices((spring_conns, coulomb_conns))

sys = ODEGraphSystem(;connection_matrices, states_partitioned, params_partitioned, names_partitioned)
tspan = (0.0, 20.0)

prob = ODEProblem(sys,
# Fix the garbage state values
[:particle1₊x => x1, :particle2₊x => x2, :particle2₊v => 0.0, :osc₊x => 0.0],
tspan,
# fix the garbage param values
[:osc₊m => 3.0])
sol = solve(prob, Tsit5())
end

@testset "solutions" begin
sol = solve_particle_osc(;x1=1.0, x2=-1.0)
@test sol[:particle1₊x][end] 1.4923823131014389 rtol=1e-7
@test sol[:particle2₊x][end] -0.11189010002787175 rtol=1e-7
@test sol[:osc₊x][end] 1.3175449091469553 rtol=1e-7
end

@testset "sensitivies" begin
jac = ForwardDiff.jacobian([1.0, -1.0]) do (x1, x2)
sol = solve_particle_osc(;x1, x2)
[sol[:particle1₊x][end], sol[:particle2₊x][end], sol[:osc₊x][end]]
end
@test jac [0.498565 -0.0161443
-1.92556 3.14649
-0.249007 0.808641] rtol=1e-5

end

0 comments on commit 9d4f60f

Please sign in to comment.