Skip to content

Commit

Permalink
support forwardiff of states (#4)
Browse files Browse the repository at this point in the history
* support forwardiff of states

* bump version
  • Loading branch information
MasonProtter authored Sep 27, 2024
1 parent 983cada commit 8d79d33
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 84 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "GraphDynamics"
uuid = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c"
version = "0.1.1"
version = "0.1.2"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
27 changes: 20 additions & 7 deletions src/GraphDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export
GraphSystem,
ODEGraphSystem,
SDEGraphSystem,
get_tag,
get_states,
get_params,
ODEProblem,
Expand Down Expand Up @@ -69,7 +70,8 @@ using SciMLBase:
CallbackSet,
VectorContinuousCallback,
ContinuousCallback,
DiscreteCallback
DiscreteCallback,
remake

using RecursiveArrayTools: ArrayPartition

Expand Down Expand Up @@ -101,20 +103,31 @@ include("utils.jl")
#----------------------------------------------------------
# API functions to be implemented by new Systems

struct SubsystemStates{Name, Eltype, States <: NamedTuple} <: AbstractVector{Eltype}
struct SubsystemStates{T, Eltype, States <: NamedTuple} <: AbstractVector{Eltype}
states::States
end

struct SubsystemParams{Name, Params <: NamedTuple}
struct SubsystemParams{T, Params <: NamedTuple}
params::Params
end

struct Subsystem{Name, Eltype, States, Params}
states::SubsystemStates{Name, Eltype, States}
params::SubsystemParams{Name, Params}
"""
Subsystem{T, Eltype, StateNT, ParamNT}
A `Subsystem` struct describes a complete subcomponent to an `GraphSystem`. This stores a `SubsystemStates` to describe the continuous dynamical state of the subsystem, and a `GraphSystemParams` which describes various non-dynamical parameters of the subsystem. The type parameter `T` is the subsystem's \"tag\" which labels what sort of subsystem it is.
See also `subsystem_differential`, `SubsystemStates`, `SubsystemParams`.
For example, if we wanted to describe a system where one sub-component is a billiard ball,
"""
struct Subsystem{T, Eltype, States, Params}
states::SubsystemStates{T, Eltype, States}
params::SubsystemParams{T, Params}
end

function get_name end
function get_tag end
function get_params end
function get_states end

Expand Down
10 changes: 0 additions & 10 deletions src/graph_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ tany(f, coll; kwargs...) = tmapreduce(f, |, coll; kwargs...)
@nexprs $Len k -> begin
M = connection_matrices[nc][k, i]
if has_discrete_events(eltype(M))
#tany(foo(Val(k), Val(i), Val(NConn), M, t), eachindex(states_partitioned[i])) && return true
for j eachindex(states_partitioned[i])
for (l, Mlj) maybe_sparse_enumerate_col(M, j)
discrete_event_condition(Mlj, t) && return true
Expand All @@ -313,15 +312,6 @@ tany(f, coll; kwargs...) = tmapreduce(f, |, coll; kwargs...)
end
end

function foo(::Val{k}, ::Val{i}, ::Val{NConn}, M, t) where {i, NConn, k}
function f(j)
for (l, Mlj) maybe_sparse_enumerate_col(M, j)
discrete_event_condition(Mlj, t) && return true
end
false
end
end

function discrete_affect!(integrator)
(;params_partitioned, state_types_val, connection_matrices) = integrator.p
state_data = integrator.u.x
Expand Down
34 changes: 30 additions & 4 deletions src/problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ end
function SciMLBase.ODEProblem(g::ODEGraphSystem, u0map, tspan, param_map=[];
scheduler=SerialScheduler(), tstops=Float64[],
allow_nonconcrete=false, kwargs...)
nt = _problem(g, tspan; scheduler, allow_nonconcrete)
nt = _problem(g, tspan; scheduler, allow_nonconcrete, u0map, param_map)
(; f, u, tspan, p, callback) = nt
tstops = vcat(tstops, nt.tstops)
prob = ODEProblem(f, u, tspan, p; callback, tstops, kwargs...)
Expand All @@ -21,11 +21,18 @@ end
function SciMLBase.SDEProblem(g::SDEGraphSystem, u0map, tspan, param_map=[];
scheduler=SerialScheduler(), tstops=Float64[],
allow_nonconcrete=false, kwargs...)
nt = _problem(g, tspan; scheduler, allow_nonconcrete)
nt = _problem(g, tspan; scheduler, allow_nonconcrete, u0map, param_map)
(; f, u, tspan, p, callback) = nt

noise_rate_prototype = nothing # zeros(length(u)) # this'll need to change once we support correlated noise
SDEProblem(f, graph_noise!, u, tspan, p; callback, noise_rate_prototype, tstops = vcat(tstops, nt.tstops), kwargs...)
prob = SDEProblem(f, graph_noise!, u, tspan, p; callback, noise_rate_prototype, tstops = vcat(tstops, nt.tstops), kwargs...)
for (k, v) u0map
setu(prob, k)(prob, v)
end
for (k, v) param_map
setp(prob, k)(prob, v)
end
prob
end

Base.@kwdef struct GraphSystemParameters{PP, CM, S, STV}
Expand All @@ -35,14 +42,33 @@ Base.@kwdef struct GraphSystemParameters{PP, CM, S, STV}
state_types_val::STV
end

function _problem(g::GraphSystem, tspan; scheduler, allow_nonconcrete)
function _problem(g::GraphSystem, tspan; scheduler, allow_nonconcrete, u0map, param_map)
(; states_partitioned,
params_partitioned,
connection_matrices,
tstops,
composite_discrete_events_partitioned,
composite_continuous_events_partitioned,) = g

total_eltype = let
states_eltype = mapreduce(promote_type, states_partitioned) do v
eltype(eltype(v))
end
u0map_eltype = mapreduce(promote_type, u0map; init=Union{}) do (k, v)
typeof(v)
end
promote_type(states_eltype, u0map_eltype)
end

re_eltype(s::SubsystemStates{T}) where {T} = convert(SubsystemStates{T, total_eltype}, s)
states_partitioned = map(states_partitioned) do v
if eltype(eltype(v)) <: total_eltype && eltype(eltype(v)) !== Union{}
v
else
re_eltype.(v)
end
end

length(states_partitioned) == length(params_partitioned) ||
error("Incompatible state and parameter lengths")
for i eachindex(states_partitioned, params_partitioned)
Expand Down
58 changes: 50 additions & 8 deletions src/subsystems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ function ConstructionBase.setproperties(s::SubsystemParams{T}, patch::NamedTuple
SubsystemParams{T}(props′)
end

get_name(::SubsystemParams{Name}) where {Name} = Name
get_name(::Type{<:SubsystemParams{Name}}) where {Name} = Name
get_tag(::SubsystemParams{Name}) where {Name} = Name
get_tag(::Type{<:SubsystemParams{Name}}) where {Name} = Name
Base.NamedTuple(p::SubsystemParams) = getfield(p, :params)
Base.Tuple(s::SubsystemParams) = Tuple(getfield(s, :params))
Base.getproperty(p::SubsystemParams, prop::Symbol) = getproperty(NamedTuple(p), prop)
Expand All @@ -43,6 +43,10 @@ end
function SubsystemStates{Name}(nt::NamedTuple{state_names, NTuple{N, Eltype}}) where {Name, state_names, N, Eltype}
SubsystemStates{Name, Eltype, typeof(nt)}(nt)
end
function SubsystemStates{Name}(nt::NamedTuple{state_names, <:NTuple{N, Any}}) where {Name, state_names, N}
nt_promoted = NamedTuple{state_names}(promote(nt...))
SubsystemStates{Name}(nt_promoted)
end
function SubsystemStates{Name}(nt::NamedTuple{(), Tuple{}}) where {Name}
SubsystemStates{Name, Union{}, NamedTuple{(), Tuple{}}}(nt)
end
Expand Down Expand Up @@ -81,8 +85,8 @@ function ConstructionBase.setproperties(s::SubsystemStates{T}, patch::NamedTuple
SubsystemStates{T}(props′)
end

get_name(::SubsystemStates{Name}) where {Name} = Name
get_name(::Type{<:SubsystemStates{Name}}) where {Name} = Name
get_tag(::SubsystemStates{Name}) where {Name} = Name
get_tag(::Type{<:SubsystemStates{Name}}) where {Name} = Name
Base.NamedTuple(s::SubsystemStates) = getfield(s, :states)
Base.Tuple(s::SubsystemStates) = Tuple(getfield(s, :states))
Base.getproperty(s::SubsystemStates, prop::Symbol) = getproperty(NamedTuple(s), prop)
Expand All @@ -93,12 +97,25 @@ function state_ind(::Type{SubsystemStates{Name, Eltype, NamedTuple{names, Tup}}}
i = findfirst(==(s), names)
end

function Base.convert(::Type{SubsystemStates{Name, Eltype, NT}}, s::SubsystemStates{Name}) where {Name, Eltype, NT}
SubsystemStates{Name}(convert(NT, NamedTuple(s)))
end
function Base.convert(::Type{SubsystemStates{Name, Eltype}},
s::SubsystemStates{Name, <:Any, <:NamedTuple{state_names}}) where {Name, Eltype, state_names}
nt = NamedTuple{state_names}(convert.(Eltype, Tuple(s)))
SubsystemStates{Name, Eltype, typeof(nt)}(nt)
end

#------------------------------------------------------------
# Subsystem
function Subsystem{T}(;states, params) where {T}
ET = eltype(states)
Subsystem{T, ET, typeof(states), typeof(params)}(SubsystemStates{T}(states), SubsystemParams{T}(params))
Subsystem{T}(SubsystemStates{T}(states), SubsystemParams{T}(params))
end
function Subsystem{T}(states::SubsystemStates{T, Eltype, States},
params::SubsystemParams{T, Params}) where {T, Eltype, States, Params}
Subsystem{T, Eltype, States, Params}(states, params)
end

function Base.show(io::IO, sys::Subsystem{Name, Eltype}) where {Name, Eltype}
print(io,
"$Subsystem{$Name, $Eltype}(states = ",
Expand All @@ -124,15 +141,40 @@ function ConstructionBase.setproperties(s::Subsystem{T, Eltype, States, Params},
Subsystem{T, Eltype, States, Params}(SubsystemStates{T}(states′), SubsystemParams{T}(params′))
end

function Base.convert(::Type{Subsystem{Name, Eltype, SNT, PNT}}, s::Subsystem{Name}) where {Name, Eltype, SNT, PNT}
Subsystem{Name}(convert(SubsystemStates{Name, Eltype, SNT}, get_states(s)),
convert(SubsystemParams{Name, PNT}, get_params(s)))
end
function Base.convert(::Type{Subsystem{Name, Eltype}}, s::Subsystem{Name}) where {Name, Eltype}
Subsystem{Name}(convert(SubsystemStates{Name, Eltype}, get_states(s)), get_params(s))
end

@generated function promote_nt_type(::Type{NamedTuple{names, T1}},
::Type{NamedTuple{names, T2}}) where {names, T1, T2}
NamedTuple{names, Tuple{(promote_type(T1.parameters[i], T2.parameters[i]) for i eachindex(names))...}}
end

function Base.promote_rule(::Type{SubsystemParams{Name, NT1}},
::Type{SubsystemParams{Name, NT2}}) where {Name, NT1, NT2}
SubsystemParams{Name, promote_nt_type(NT1, NT2)}
end
function Base.promote_rule(::Type{SubsystemStates{Name, ET1, NT1}},
::Type{SubsystemStates{Name, ET2, NT2}}) where {Name, ET1, ET2, NT1, NT2}
SubsystemStates{Name, promote_type(ET1, ET2), promote_nt_type(NT1, NT2)}
end

function Base.promote_rule(::Type{Subsystem{Name, ET1, SNT1, PNT1}},
::Type{Subsystem{Name, ET2, SNT2, PNT2}}) where {Name, ET1, SNT1, PNT1, ET2, SNT2, PNT2}
Subsystem{Name, promote_type(ET1, ET2), promote_nt_type(SNT1, SNT2), promote_nt_type(PNT1, PNT2)}
end

get_states(s::Subsystem) = getfield(s, :states)
get_params(s::Subsystem) = getfield(s, :params)
get_name(::Subsystem{Name}) where {Name} = Name
get_tag(::Subsystem{Name}) where {Name} = Name



get_name(::Type{<:Subsystem{Name}}) where {Name} = Name
get_tag(::Type{<:Subsystem{Name}}) where {Name} = Name


function Base.getproperty(s::Subsystem{<:Any, States, Params},
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GraphDynamics = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
125 changes: 71 additions & 54 deletions test/particle_osc_example.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using GraphDynamics, OrdinaryDiffEq, Test
using GraphDynamics, OrdinaryDiffEq, Test, ForwardDiff

struct Particle end
function GraphDynamics.subsystem_differential(sys::Subsystem{Particle}, F, t)
Expand Down Expand Up @@ -34,56 +34,73 @@ function ((;fac)::Coulomb)(a, b)
-fac * a.q * b.q * sign(a.x - b.x)/(abs(a.x - b.x) + 1e-10)^2
end

# put some garbage values in here for states and params, but we'll set them to reasonable values later with the
# u0map and param_map
subsystems_partitioned = ([Subsystem{Particle}(states=(;x= NaN, v=0.0), params=(;m=1.0, q=1.0)),
Subsystem{Particle}(states=(;x=-1.0, v=Inf), params=(;m=2.0, q=1.0))],
[Subsystem{Oscillator}(states=(;x=-Inf, v=1.0), params=(;x₀=0.0, m=-3000.0, k=1.0, q=1.0))])

states_partitioned = map(v -> map(get_states, v), subsystems_partitioned)
params_partitioned = map(v -> map(get_params, v), subsystems_partitioned)
names_partitioned = ([:particle1, :particle2], [:osc])

spring_conns_par_par = NotConnected()
spring_conns_par_osc = [Spring(1)
Spring(1);;]
spring_conns_osc_par = [Spring(1) Spring(1)]
spring_conns_osc_osc = NotConnected()

spring_conns = ConnectionMatrix(((spring_conns_par_par, spring_conns_par_osc),
(spring_conns_osc_par, spring_conns_osc_osc)))

# Spring[⎡. .⎤ ⎡2⎤
# ⎣. .⎦ ⎣0⎦
# [2 0] [.]]

coulomb_conns_par_par = [Coulomb(0) Coulomb(.05)
Coulomb(.05) Coulomb(0)]
coulomb_conns_par_osc = [Coulomb(.05)
Coulomb(.05);;]
coulomb_conns_osc_par = [Coulomb(.05) Coulomb(.05)]
coulomb_conns_osc_osc = NotConnected()

coulomb_conns = ConnectionMatrix(((coulomb_conns_par_par, coulomb_conns_par_osc),
(coulomb_conns_osc_par, coulomb_conns_osc_osc)))

# Coulomb[⎡0 1⎤ ⎡1⎤
# ⎣1 0⎦ ⎣1⎦
# [0 0] [.]]

connection_matrices = ConnectionMatrices((spring_conns, coulomb_conns))

sys = ODEGraphSystem(;connection_matrices, states_partitioned, params_partitioned, names_partitioned)
tspan = (0.0, 20.0)

prob = ODEProblem(sys,
# Fix the garbage state values
[:particle1₊x => 1.0, :particle2₊v => 0.0, :osc₊x => 0.0],
tspan,
# fix the garbage param values
[:osc₊m => 3.0])
sol = solve(prob, Tsit5())

@test sol[:particle1₊x][end] 1.4923823131014389
@test sol[:particle2₊x][end] -0.11189010002787175
@test sol[:osc₊x][end] 1.3175449091469553

function solve_particle_osc(;x1, x2)
# put some garbage values in here for states and params, but we'll set them to reasonable values later with the
# u0map and param_map
subsystems_partitioned = ([Subsystem{Particle}(states=(;x= NaN, v=0.0), params=(;m=1.0, q=1.0)),
Subsystem{Particle}(states=(;x=-1.0, v=Inf), params=(;m=2.0, q=1.0))],
[Subsystem{Oscillator}(states=(;x=-Inf, v=1.0), params=(;x₀=0.0, m=-3000.0, k=1.0, q=1.0))])

states_partitioned = map(v -> map(get_states, v), subsystems_partitioned)
params_partitioned = map(v -> map(get_params, v), subsystems_partitioned)
names_partitioned = ([:particle1, :particle2], [:osc])

spring_conns_par_par = NotConnected()
spring_conns_par_osc = [Spring(1)
Spring(1);;]
spring_conns_osc_par = [Spring(1) Spring(1)]
spring_conns_osc_osc = NotConnected()

spring_conns = ConnectionMatrix(((spring_conns_par_par, spring_conns_par_osc),
(spring_conns_osc_par, spring_conns_osc_osc)))

# Spring[⎡. .⎤ ⎡2⎤
# ⎣. .⎦ ⎣0⎦
# [2 0] [.]]

coulomb_conns_par_par = [Coulomb(0) Coulomb(.05)
Coulomb(.05) Coulomb(0)]
coulomb_conns_par_osc = [Coulomb(.05)
Coulomb(.05);;]
coulomb_conns_osc_par = [Coulomb(.05) Coulomb(.05)]
coulomb_conns_osc_osc = NotConnected()

coulomb_conns = ConnectionMatrix(((coulomb_conns_par_par, coulomb_conns_par_osc),
(coulomb_conns_osc_par, coulomb_conns_osc_osc)))

# Coulomb[⎡0 1⎤ ⎡1⎤
# ⎣1 0⎦ ⎣1⎦
# [0 0] [.]]

connection_matrices = ConnectionMatrices((spring_conns, coulomb_conns))

sys = ODEGraphSystem(;connection_matrices, states_partitioned, params_partitioned, names_partitioned)
tspan = (0.0, 20.0)

prob = ODEProblem(sys,
# Fix the garbage state values
[:particle1₊x => x1, :particle2₊x => x2, :particle2₊v => 0.0, :osc₊x => 0.0],
tspan,
# fix the garbage param values
[:osc₊m => 3.0])
sol = solve(prob, Tsit5())
end

@testset "solutions" begin
sol = solve_particle_osc(;x1=1.0, x2=-1.0)
@test sol[:particle1₊x][end] 1.4923823131014389 rtol=1e-7
@test sol[:particle2₊x][end] -0.11189010002787175 rtol=1e-7
@test sol[:osc₊x][end] 1.3175449091469553 rtol=1e-7
end

@testset "sensitivies" begin
jac = ForwardDiff.jacobian([1.0, -1.0]) do (x1, x2)
sol = solve_particle_osc(;x1, x2)
[sol[:particle1₊x][end], sol[:particle2₊x][end], sol[:osc₊x][end]]
end
@test jac [0.498565 -0.0161443
-1.92556 3.14649
-0.249007 0.808641] rtol=1e-5

end

2 comments on commit 8d79d33

@MasonProtter
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/116148

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.2 -m "<description of version>" 8d79d33a73d7546e4f6b4539d6647bdf22d1a441
git push origin v0.1.2

Please sign in to comment.