Skip to content

Commit

Permalink
add ForeachConnectedSubsystem for effects modifying downstream Subsys…
Browse files Browse the repository at this point in the history
…tems
  • Loading branch information
MasonProtter committed Nov 4, 2024
1 parent 8edbfe5 commit a58ebed
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 13 deletions.
4 changes: 3 additions & 1 deletion src/GraphDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ end

isstochastic,

event_times
event_times,
ForeachConnectedSubsystem
)

export
Expand Down Expand Up @@ -231,6 +232,7 @@ add methods to this function if a subsystem or connection type has a discrete ev
event_times(::Any) = ()

abstract type ConnectionRule end
Base.zero(::T) where {T <: ConnectionRule} = zero(T)
struct NotConnected <: ConnectionRule end
(::NotConnected)(l, r) = zero(promote_type(eltype(l), eltype(r)))
struct ConnectionMatrix{N, CR, Tup <: NTuple{N, NTuple{N, Union{NotConnected, AbstractMatrix{CR}}}}}
Expand Down
104 changes: 93 additions & 11 deletions src/graph_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,12 @@ function _continuous_affect!(integrator,
sview = @view states_partitioned[i][j]
pview = @view params_partitioned[i][j]
sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j])
F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices)
if continuous_events_require_inputs(sys)
input = calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices)
apply_continuous_event!(integrator, sview, pview, sys, input)
apply_continuous_event!(integrator, sview, pview, sys, F, input)
else
apply_continuous_event!(integrator, sview, pview, sys)
apply_continuous_event!(integrator, sview, pview, sys, F)
end
end
offset += N
Expand Down Expand Up @@ -326,34 +327,38 @@ end
t) where {Len, NConn}
quote
@nexprs $Len i -> begin
# First we apply events to the states
if has_discrete_events(eltype(states_partitioned[i]))
for j eachindex(states_partitioned[i])
sys_dst = Subsystem(states_partitioned[i][j], params_partitioned[i][j])
sview_dst = @view states_partitioned[i][j]
pview_dst = @view params_partitioned[i][j]
if discrete_event_condition(sys_dst, t)
if discrete_events_require_inputs(sys_dst)
@inbounds for j eachindex(states_partitioned[i])
sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j])
sview = @view states_partitioned[i][j]
pview = @view params_partitioned[i][j]
if discrete_event_condition(sys, t)
# println("helllllo")
F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices)
if discrete_events_require_inputs(sys)
input = calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices)
apply_discrete_event!(integrator, sview_dst, pview_dst, sys_dst, input)
apply_discrete_event!(integrator, sview, pview, sys, F, input)
else
apply_discrete_event!(integrator, sview_dst, pview_dst, sys_dst)
apply_discrete_event!(integrator, sview, pview, sys, F)
end
end
end
end
# Then we do the connection events
@nexprs $NConn nc -> begin
@nexprs $Len k -> begin
f = _discrete_connection_affect!(Val(i), Val(k), Val(nc), t,
states_partitioned, params_partitioned, connection_matrices,
integrator)
foreach(f, eachindex(states_partitioned[i]))

end
end
end
end
end


function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t,
states_partitioned::NTuple{Len, Any},
params_partitioned::NTuple{Len, Any},
Expand Down Expand Up @@ -397,3 +402,80 @@ function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t,
end
end
end


#-----------------------------------------------------------------------

"""
ForeachConnectedSubsystem
This is a callable struct which takes in a function, and then calls that function on each subsystem which has a connection leading to it
from some previously specified subsystem.
That is, writing
```julia
F = ForeachConnectedSubsystem{k}(l, states_partitioned, params_partitioned, connection_matrices)
F() do conn, sys_dst, states_view_dst, params_view_dst
[...]
end
```
is like a type stable version of writing
```
for i in eachindex(states_partitioned)
for nc in eachindex(connection_matrices)
M = connection_matrices[nc][i, k]
for j in eachindex(states_partitioned[k])
conn = M[l, j]
if !iszero(conn)
states_view_dst = @view states_partitioned[i][j]
params_view_dst = @view params_partitioned[i][j]
sys_dst = Subsystem(states_view_dst[], params_view_dst[])
[...] # <------- User code here
ends
end
end
end
```
"""
struct ForeachConnectedSubsystem{k, Len, NConn, S, P, CMs}
l::Int
states_partitioned::S
params_partitioned::P
connection_matrices::CMs
function ForeachConnectedSubsystem{k}(l,
states_partitioned::NTuple{Len, Any},
params_partitioned::NTuple{Len, Any},
connection_matrices::ConnectionMatrices{NConn}) where {k, Len, NConn}
S = typeof(states_partitioned)
P = typeof(params_partitioned)
CMs = typeof(connection_matrices)
new{k, Len, NConn, S, P, CMs}(l, states_partitioned, params_partitioned, connection_matrices)
end
end

@generated function ((;l,
states_partitioned,
params_partitioned,
connection_matrices)::ForeachConnectedSubsystem{k, Len, NConn})(f::F) where {k, Len, NConn, F}
quote
@nexprs $Len i -> begin
@nexprs $NConn nc -> begin
M = connection_matrices[nc][k, i]
if M isa NotConnected
nothing
else
for j eachindex(states_partitioned[i])
@inbounds conn = M[l, j]
if !iszero(conn)
@inbounds states_view_dst = @view states_partitioned[i][j]
@inbounds params_view_dst = @view params_partitioned[i][j]
sys_dst = Subsystem(states_view_dst[], params_view_dst[])
f(conn, sys_dst, states_view_dst, params_view_dst)
end
end
end
end
end
end
end
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ valueof(x) = x

# this just makes it so that I can easily replace all uses of `@inbounds ex` with just `ex`.
macro inbounds(ex)
# ex
#esc(ex)
esc(:($Base.@inbounds $ex))
end

Expand Down

0 comments on commit a58ebed

Please sign in to comment.