From f0057bc4892f0b213fb1f6de444f9a3cbc802f41 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Thu, 19 Sep 2024 22:49:23 +0200 Subject: [PATCH] Get `setp`/`setu` working, implement `u0map` for `ODE/SDEProblem` (#3) * get setp/setu and u0map working * add github workflow * dont test 1.9 (not supported) * make night precompilation happy --- .github/workflows/CI.yml | 37 ++++++++++++++++++++++++++++++++++++ src/GraphDynamics.jl | 33 ++++++++++---------------------- src/problems.jl | 28 ++++++++++++++++++--------- src/subsystems.jl | 7 +++++++ src/symbolic_indexing.jl | 30 ++++++++++++++++++++++++----- test/particle_osc_example.jl | 21 ++++++++++++-------- 6 files changed, 111 insertions(+), 45 deletions(-) create mode 100644 .github/workflows/CI.yml diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml new file mode 100644 index 0000000..3e6182c --- /dev/null +++ b/.github/workflows/CI.yml @@ -0,0 +1,37 @@ +name: CI +on: + - push + - pull_request +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - '1.10' + - 'nightly' + os: + - ubuntu-latest + arch: + - x64 + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: actions/cache@v1 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 \ No newline at end of file diff --git a/src/GraphDynamics.jl b/src/GraphDynamics.jl index 66e99b3..65a95e1 100644 --- a/src/GraphDynamics.jl +++ b/src/GraphDynamics.jl @@ -10,15 +10,10 @@ macro public(ex) end @public ( - Subsystem, - SubsystemParams, - SubsystemStates, - subsystem_differential, apply_subsystem_noise!, subsystem_differential_requires_inputs, - initialize_input, combine, @@ -32,22 +27,11 @@ end apply_discrete_event!, discrete_events_require_inputs, - must_run_before, isstochastic, - - GraphSystem, - ODEGraphSystem, - SDEGraphSystem, - - get_states, - get_params, - event_times, - ConnectionMatrices, - ConnectionMatrix, - NotConnected, + event_times, ) export @@ -90,11 +74,14 @@ using SciMLBase: using RecursiveArrayTools: ArrayPartition using SymbolicIndexingInterface: - SymbolicIndexingInterface + SymbolicIndexingInterface, + setu, + setp using Accessors: Accessors, - @set + @set, + @reset using ConstructionBase: ConstructionBase @@ -114,14 +101,14 @@ include("utils.jl") #---------------------------------------------------------- # API functions to be implemented by new Systems -struct SubsystemParams{Name, Params <: NamedTuple} - params::Params -end - struct SubsystemStates{Name, Eltype, States <: NamedTuple} <: AbstractVector{Eltype} states::States end +struct SubsystemParams{Name, Params <: NamedTuple} + params::Params +end + struct Subsystem{Name, Eltype, States, Params} states::SubsystemStates{Name, Eltype, States} params::SubsystemParams{Name, Params} diff --git a/src/problems.jl b/src/problems.jl index ba4472d..02ea918 100644 --- a/src/problems.jl +++ b/src/problems.jl @@ -3,15 +3,25 @@ function (::Type{T})(g::Gsys, args...; kwargs...) where {T <: SciMLBase.Abstract throw(ArgumentError("GraphSystems.jl does not yet support the use of $(Gsys.name.wrapper) in $T.\nThis is either a feature that is not yet supported, or you may have accidentally done something incorrect such as passing an ODEGraphSystem to an SDEProblem.")) end -function SciMLBase.ODEProblem(g::ODEGraphSystem, u0map, tspan, param_map=[]; scheduler=SerialScheduler(), tstops=Float64[], kwargs...) - nt = _problem(g, u0map, tspan, param_map; scheduler) +function SciMLBase.ODEProblem(g::ODEGraphSystem, u0map, tspan, param_map=[]; + scheduler=SerialScheduler(), tstops=Float64[], + allow_nonconcrete=false, kwargs...) + nt = _problem(g, tspan; scheduler, allow_nonconcrete) (; f, u, tspan, p, callback) = nt tstops = vcat(tstops, nt.tstops) - ODEProblem(f, u, tspan, p; callback, tstops, kwargs...) + prob = ODEProblem(f, u, tspan, p; callback, tstops, kwargs...) + for (k, v) ∈ u0map + setu(prob, k)(prob, v) + end + for (k, v) ∈ param_map + setp(prob, k)(prob, v) + end + prob end function SciMLBase.SDEProblem(g::SDEGraphSystem, u0map, tspan, param_map=[]; - scheduler=SerialScheduler(), tstops=Float64[], kwargs...) - nt = _problem(g, u0map, tspan, param_map; scheduler) + scheduler=SerialScheduler(), tstops=Float64[], + allow_nonconcrete=false, kwargs...) + nt = _problem(g, tspan; scheduler, allow_nonconcrete) (; f, u, tspan, p, callback) = nt noise_rate_prototype = nothing # zeros(length(u)) # this'll need to change once we support correlated noise @@ -25,10 +35,7 @@ Base.@kwdef struct GraphSystemParameters{PP, CM, S, STV} state_types_val::STV end -function _problem(g::GraphSystem, u0map, tspan, param_map=[]; scheduler=SerialScheduler) - isempty(u0map) || error("Specifying a state map is not yet implemented") - isempty(param_map) || error("Specifying a parameter map is not yet implemented") - +function _problem(g::GraphSystem, tspan; scheduler, allow_nonconcrete) (; states_partitioned, params_partitioned, connection_matrices, @@ -70,6 +77,9 @@ function _problem(g::GraphSystem, u0map, tspan, param_map=[]; scheduler=SerialSc state_types_val = Val(Tuple{map(eltype, states_partitioned)...}) u = ArrayPartition(map(v -> stack(v), states_partitioned)) + if !allow_nonconcrete && !isconcretetype(eltype(u)) && !all(isconcretetype ∘ eltype, states_partitioned) + error(ArgumentError("The provided subsystem states do not have a concrete eltype. All partitions must contain the same eltype. Got `eltype(u) = $(eltype(u))`.")) + end ce = nce > 0 ? VectorContinuousCallback(continuous_condition, continuous_affect!, nce) : nothing de = nde > 0 ? DiscreteCallback(discrete_condition, discrete_affect!) : nothing diff --git a/src/subsystems.jl b/src/subsystems.jl index 14a6f7c..55a56fe 100644 --- a/src/subsystems.jl +++ b/src/subsystems.jl @@ -27,6 +27,13 @@ Base.NamedTuple(p::SubsystemParams) = getfield(p, :params) Base.Tuple(s::SubsystemParams) = Tuple(getfield(s, :params)) Base.getproperty(p::SubsystemParams, prop::Symbol) = getproperty(NamedTuple(p), prop) Base.propertynames(p::SubsystemParams) = propertynames(NamedTuple(p)) +function Base.setindex(p::SubsystemParams{Name}, val, param) where {Name} + SubsystemParams{Name}(Base.setindex(NamedTuple(p), val, param)) +end +function Base.convert(::Type{SubsystemParams{Name, NT}}, p::SubsystemParams{Name}) where {Name, NT} + SubsystemParams{Name}(convert(NT, NamedTuple(p))) +end + #------------------------------------------------------------ # Subsystem states diff --git a/src/symbolic_indexing.jl b/src/symbolic_indexing.jl index 8131b75..f92a265 100644 --- a/src/symbolic_indexing.jl +++ b/src/symbolic_indexing.jl @@ -15,7 +15,7 @@ struct ParamIndex #todo: this'll require some generalization to support weight p prop::Symbol end -function compute_namemap(names_partitioned, states_partitioned::Tuple{Vararg{<:AbstractVector{<:SubsystemStates}}}) +function compute_namemap(names_partitioned, states_partitioned::Tuple{Vararg{AbstractVector{<:SubsystemStates}}}) state_namemap = Dict{Symbol, StateIndex}() for i ∈ eachindex(names_partitioned, states_partitioned) for j ∈ eachindex(names_partitioned[i], states_partitioned[i]) @@ -27,7 +27,7 @@ function compute_namemap(names_partitioned, states_partitioned::Tuple{Vararg{<:A end state_namemap end -function compute_namemap(names_partitioned, params_partitioned::Tuple{Vararg{<:AbstractVector{<:SubsystemParams}}}) +function compute_namemap(names_partitioned, params_partitioned::Tuple{Vararg{AbstractVector{<:SubsystemParams}}}) param_namemap = Dict{Symbol, ParamIndex}() for i ∈ eachindex(names_partitioned, params_partitioned) for j ∈ eachindex(names_partitioned[i], params_partitioned[i]) @@ -49,15 +49,32 @@ function Base.getindex(u::Tuple, (;tup_index, v_index, state_index)::StateIndex) u[tup_index][state_index, v_index] end +function Base.setindex!(u::ArrayPartition, val, idx::StateIndex) + setindex!(u.x, val, idx) +end +function Base.setindex!(u::Tuple, val, (;tup_index, v_index, state_index)::StateIndex) + setindex!(u[tup_index], val, state_index, v_index) +end +function Base.getindex(u::GraphSystemParameters, p::ParamIndex) + u.params_partitioned[p] +end function Base.getindex(u::Tuple, (;tup_index, v_index, prop)::ParamIndex) getproperty(u[tup_index][v_index], prop) end -function Base.getindex(u::GraphSystemParameters, p::ParamIndex) - u.subsystem_params[p] + + +function Base.setindex!(u::GraphSystemParameters, val, p::ParamIndex) + setindex!(u.params_partitioned, val, p) +end +function Base.setindex!(u::Tuple, val, (;tup_index, v_index, prop)::ParamIndex) + params = u[tup_index][v_index] + @reset params[prop] = val + setindex!(u[tup_index], params, v_index) end + function SymbolicIndexingInterface.is_variable(g::GraphSystem, sym) haskey(g.state_namemap, sym) end @@ -76,8 +93,11 @@ function SymbolicIndexingInterface.parameter_index(g::GraphSystem, sym) g.param_namemap[sym] end +function SymbolicIndexingInterface.parameter_values(p::GraphSystemParameters) + p.params_partitioned +end function SymbolicIndexingInterface.parameter_values(p::GraphSystemParameters, i::ParamIndex) - p.subsystem_params[i] + p.params_partitioned[i] end function SymbolicIndexingInterface.parameter_symbols(g::GraphSystem) diff --git a/test/particle_osc_example.jl b/test/particle_osc_example.jl index 13d4f19..36a583f 100644 --- a/test/particle_osc_example.jl +++ b/test/particle_osc_example.jl @@ -1,4 +1,4 @@ -using GraphDynamics, OrdinaryDiffEq +using GraphDynamics, OrdinaryDiffEq, Test struct Particle end function GraphDynamics.subsystem_differential(sys::Subsystem{Particle}, F, t) @@ -34,17 +34,16 @@ function ((;fac)::Coulomb)(a, b) -fac * a.q * b.q * sign(a.x - b.x)/(abs(a.x - b.x) + 1e-10)^2 end -subsystems_partitioned = ([Subsystem{Particle}(states=(;x= 1.0, v=0.0), params=(;m=1.0, q=1.0)), - Subsystem{Particle}(states=(;x=-1.0, v=0.0), params=(;m=2.0, q=1.0))], - [Subsystem{Oscillator}(states=(;x=0.0, v=1.0), params=(;x₀=0.0, m=3.0, k=1.0, q=1.0))]) +# put some garbage values in here for states and params, but we'll set them to reasonable values later with the +# u0map and param_map +subsystems_partitioned = ([Subsystem{Particle}(states=(;x= NaN, v=0.0), params=(;m=1.0, q=1.0)), + Subsystem{Particle}(states=(;x=-1.0, v=Inf), params=(;m=2.0, q=1.0))], + [Subsystem{Oscillator}(states=(;x=-Inf, v=1.0), params=(;x₀=0.0, m=-3000.0, k=1.0, q=1.0))]) states_partitioned = map(v -> map(get_states, v), subsystems_partitioned) params_partitioned = map(v -> map(get_params, v), subsystems_partitioned) names_partitioned = ([:particle1, :particle2], [:osc]) - - - spring_conns_par_par = NotConnected() spring_conns_par_osc = [Spring(1) Spring(1);;] @@ -76,7 +75,13 @@ connection_matrices = ConnectionMatrices((spring_conns, coulomb_conns)) sys = ODEGraphSystem(;connection_matrices, states_partitioned, params_partitioned, names_partitioned) tspan = (0.0, 20.0) -prob = ODEProblem(sys, [], tspan) + +prob = ODEProblem(sys, + # Fix the garbage state values + [:particle1₊x => 1.0, :particle2₊v => 0.0, :osc₊x => 0.0], + tspan, + # fix the garbage param values + [:osc₊m => 3.0]) sol = solve(prob, Tsit5()) @test sol[:particle1₊x][end] ≈ 1.4923823131014389