Skip to content

Commit

Permalink
get setp/setu and u0map working
Browse files Browse the repository at this point in the history
  • Loading branch information
MasonProtter committed Sep 19, 2024
1 parent 61c341d commit b824c91
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 28 deletions.
15 changes: 9 additions & 6 deletions src/GraphDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,14 @@ using SciMLBase:
using RecursiveArrayTools: ArrayPartition

using SymbolicIndexingInterface:
SymbolicIndexingInterface
SymbolicIndexingInterface,
setu,
setp

using Accessors:
Accessors,
@set
@set,
@reset

using ConstructionBase:
ConstructionBase
Expand All @@ -114,14 +117,14 @@ include("utils.jl")
#----------------------------------------------------------
# API functions to be implemented by new Systems

struct SubsystemParams{Name, Params <: NamedTuple}
params::Params
end

struct SubsystemStates{Name, Eltype, States <: NamedTuple} <: AbstractVector{Eltype}
states::States
end

struct SubsystemParams{Name, Params <: NamedTuple}
params::Params
end

struct Subsystem{Name, Eltype, States, Params}
states::SubsystemStates{Name, Eltype, States}
params::SubsystemParams{Name, Params}
Expand Down
28 changes: 19 additions & 9 deletions src/problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,25 @@ function (::Type{T})(g::Gsys, args...; kwargs...) where {T <: SciMLBase.Abstract
throw(ArgumentError("GraphSystems.jl does not yet support the use of $(Gsys.name.wrapper) in $T.\nThis is either a feature that is not yet supported, or you may have accidentally done something incorrect such as passing an ODEGraphSystem to an SDEProblem."))
end

function SciMLBase.ODEProblem(g::ODEGraphSystem, u0map, tspan, param_map=[]; scheduler=SerialScheduler(), tstops=Float64[], kwargs...)
nt = _problem(g, u0map, tspan, param_map; scheduler)
function SciMLBase.ODEProblem(g::ODEGraphSystem, u0map, tspan, param_map=[];
scheduler=SerialScheduler(), tstops=Float64[],
allow_nonconcrete=false, kwargs...)
nt = _problem(g, tspan; scheduler, allow_nonconcrete)
(; f, u, tspan, p, callback) = nt
tstops = vcat(tstops, nt.tstops)
ODEProblem(f, u, tspan, p; callback, tstops, kwargs...)
prob = ODEProblem(f, u, tspan, p; callback, tstops, kwargs...)
for (k, v) u0map
setu(prob, k)(prob, v)
end
for (k, v) param_map
setp(prob, k)(prob, v)
end
prob
end
function SciMLBase.SDEProblem(g::SDEGraphSystem, u0map, tspan, param_map=[];
scheduler=SerialScheduler(), tstops=Float64[], kwargs...)
nt = _problem(g, u0map, tspan, param_map; scheduler)
scheduler=SerialScheduler(), tstops=Float64[],
allow_nonconcrete=false, kwargs...)
nt = _problem(g, tspan; scheduler, allow_nonconcrete)
(; f, u, tspan, p, callback) = nt

noise_rate_prototype = nothing # zeros(length(u)) # this'll need to change once we support correlated noise
Expand All @@ -25,10 +35,7 @@ Base.@kwdef struct GraphSystemParameters{PP, CM, S, STV}
state_types_val::STV
end

function _problem(g::GraphSystem, u0map, tspan, param_map=[]; scheduler=SerialScheduler)
isempty(u0map) || error("Specifying a state map is not yet implemented")
isempty(param_map) || error("Specifying a parameter map is not yet implemented")

function _problem(g::GraphSystem, tspan; scheduler, allow_nonconcrete)
(; states_partitioned,
params_partitioned,
connection_matrices,
Expand Down Expand Up @@ -70,6 +77,9 @@ function _problem(g::GraphSystem, u0map, tspan, param_map=[]; scheduler=SerialSc
state_types_val = Val(Tuple{map(eltype, states_partitioned)...})

u = ArrayPartition(map(v -> stack(v), states_partitioned))
if !allow_nonconcrete && !isconcretetype(eltype(u)) && !all(isconcretetype eltype, states_partitioned)
error(ArgumentError("The provided subsystem states do not have a concrete eltype. All partitions must contain the same eltype. Got `eltype(u) = $(eltype(u))`."))
end

ce = nce > 0 ? VectorContinuousCallback(continuous_condition, continuous_affect!, nce) : nothing
de = nde > 0 ? DiscreteCallback(discrete_condition, discrete_affect!) : nothing
Expand Down
7 changes: 7 additions & 0 deletions src/subsystems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ Base.NamedTuple(p::SubsystemParams) = getfield(p, :params)
Base.Tuple(s::SubsystemParams) = Tuple(getfield(s, :params))
Base.getproperty(p::SubsystemParams, prop::Symbol) = getproperty(NamedTuple(p), prop)
Base.propertynames(p::SubsystemParams) = propertynames(NamedTuple(p))
function Base.setindex(p::SubsystemParams{Name}, val, param) where {Name}
SubsystemParams{Name}(Base.setindex(NamedTuple(p), val, param))
end
function Base.convert(::Type{SubsystemParams{Name, NT}}, p::SubsystemParams{Name}) where {Name, NT}
SubsystemParams{Name}(convert(NT, NamedTuple(p)))
end


#------------------------------------------------------------
# Subsystem states
Expand Down
30 changes: 25 additions & 5 deletions src/symbolic_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct ParamIndex #todo: this'll require some generalization to support weight p
prop::Symbol
end

function compute_namemap(names_partitioned, states_partitioned::Tuple{Vararg{<:AbstractVector{<:SubsystemStates}}})
function compute_namemap(names_partitioned, states_partitioned::Tuple{Vararg{AbstractVector{<:SubsystemStates}}})
state_namemap = Dict{Symbol, StateIndex}()
for i eachindex(names_partitioned, states_partitioned)
for j eachindex(names_partitioned[i], states_partitioned[i])
Expand All @@ -27,7 +27,7 @@ function compute_namemap(names_partitioned, states_partitioned::Tuple{Vararg{<:A
end
state_namemap
end
function compute_namemap(names_partitioned, params_partitioned::Tuple{Vararg{<:AbstractVector{<:SubsystemParams}}})
function compute_namemap(names_partitioned, params_partitioned::Tuple{Vararg{AbstractVector{<:SubsystemParams}}})
param_namemap = Dict{Symbol, ParamIndex}()
for i eachindex(names_partitioned, params_partitioned)
for j eachindex(names_partitioned[i], params_partitioned[i])
Expand All @@ -49,15 +49,32 @@ function Base.getindex(u::Tuple, (;tup_index, v_index, state_index)::StateIndex)
u[tup_index][state_index, v_index]
end

function Base.setindex!(u::ArrayPartition, val, idx::StateIndex)
setindex!(u.x, val, idx)
end
function Base.setindex!(u::Tuple, val, (;tup_index, v_index, state_index)::StateIndex)
setindex!(u[tup_index], val, state_index, v_index)
end

function Base.getindex(u::GraphSystemParameters, p::ParamIndex)
u.params_partitioned[p]
end
function Base.getindex(u::Tuple, (;tup_index, v_index, prop)::ParamIndex)
getproperty(u[tup_index][v_index], prop)
end
function Base.getindex(u::GraphSystemParameters, p::ParamIndex)
u.subsystem_params[p]


function Base.setindex!(u::GraphSystemParameters, val, p::ParamIndex)
setindex!(u.params_partitioned, val, p)
end
function Base.setindex!(u::Tuple, val, (;tup_index, v_index, prop)::ParamIndex)
params = u[tup_index][v_index]
@reset params[prop] = val
setindex!(u[tup_index], params, v_index)
end



function SymbolicIndexingInterface.is_variable(g::GraphSystem, sym)
haskey(g.state_namemap, sym)
end
Expand All @@ -76,8 +93,11 @@ function SymbolicIndexingInterface.parameter_index(g::GraphSystem, sym)
g.param_namemap[sym]
end

function SymbolicIndexingInterface.parameter_values(p::GraphSystemParameters)
p.params_partitioned
end
function SymbolicIndexingInterface.parameter_values(p::GraphSystemParameters, i::ParamIndex)
p.subsystem_params[i]
p.params_partitioned[i]
end

function SymbolicIndexingInterface.parameter_symbols(g::GraphSystem)
Expand Down
21 changes: 13 additions & 8 deletions test/particle_osc_example.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using GraphDynamics, OrdinaryDiffEq
using GraphDynamics, OrdinaryDiffEq, Test

struct Particle end
function GraphDynamics.subsystem_differential(sys::Subsystem{Particle}, F, t)
Expand Down Expand Up @@ -34,17 +34,16 @@ function ((;fac)::Coulomb)(a, b)
-fac * a.q * b.q * sign(a.x - b.x)/(abs(a.x - b.x) + 1e-10)^2
end

subsystems_partitioned = ([Subsystem{Particle}(states=(;x= 1.0, v=0.0), params=(;m=1.0, q=1.0)),
Subsystem{Particle}(states=(;x=-1.0, v=0.0), params=(;m=2.0, q=1.0))],
[Subsystem{Oscillator}(states=(;x=0.0, v=1.0), params=(;x₀=0.0, m=3.0, k=1.0, q=1.0))])
# put some garbage values in here for states and params, but we'll set them to reasonable values later with the
# u0map and param_map
subsystems_partitioned = ([Subsystem{Particle}(states=(;x= NaN, v=0.0), params=(;m=1.0, q=1.0)),
Subsystem{Particle}(states=(;x=-1.0, v=Inf), params=(;m=2.0, q=1.0))],
[Subsystem{Oscillator}(states=(;x=-Inf, v=1.0), params=(;x₀=0.0, m=-3000.0, k=1.0, q=1.0))])

states_partitioned = map(v -> map(get_states, v), subsystems_partitioned)
params_partitioned = map(v -> map(get_params, v), subsystems_partitioned)
names_partitioned = ([:particle1, :particle2], [:osc])




spring_conns_par_par = NotConnected()
spring_conns_par_osc = [Spring(1)
Spring(1);;]
Expand Down Expand Up @@ -76,7 +75,13 @@ connection_matrices = ConnectionMatrices((spring_conns, coulomb_conns))

sys = ODEGraphSystem(;connection_matrices, states_partitioned, params_partitioned, names_partitioned)
tspan = (0.0, 20.0)
prob = ODEProblem(sys, [], tspan)

prob = ODEProblem(sys,
# Fix the garbage state values
[:particle1₊x => 1.0, :particle2₊v => 0.0, :osc₊x => 0.0],
tspan,
# fix the garbage param values
[:osc₊m => 3.0])
sol = solve(prob, Tsit5())

@test sol[:particle1₊x][end] 1.4923823131014389
Expand Down

0 comments on commit b824c91

Please sign in to comment.