Skip to content

Commit

Permalink
Adds partials to ScalarField function to preserve subsystem namespaces
Browse files Browse the repository at this point in the history
  • Loading branch information
cadojo committed Dec 3, 2024
1 parent d0feae3 commit c407a55
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
23 changes: 17 additions & 6 deletions src/potentials.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
function ScalarField(value, t, u, p; name, simplify = true, kwargs...)
eqs::Vector{Equation} = vcat(
(Differential(t)^2).(u) .~ -ModelingToolkit.gradient(value, u, simplify = simplify)
)
return ODESystem(vcat~ value), t; name = name, kwargs...)
function ScalarField(value, t, u, p; name, gradient = true, simplify = true, kwargs...)
if gradient
symbols = [Symbol(:∂Φ∂, Symbol(first(split(string(x), "($(Symbolics.value(t)))"))))
for x in u]

∂Φ∂u = getfield.(
vcat((@variables($(x)(t)) for x in symbols)...),
:val
)

eqs = vcat~ value, ∂Φ∂u .~ Symbolics.gradient(value, u, simplify = simplify))
else
eqs =~ value]
end

return ODESystem(eqs, t; name = name, kwargs...)
end

"""
Expand Down Expand Up @@ -306,7 +317,7 @@ function Bovy2014(; name = :BovyMilkyWayPotential, kwargs...)
u = [x, y, z]
du = [ẋ, ẏ, ż]

grad(sys) = calculate_jacobian(sys; simplify = true)[(begin + 1):end] # TODO remove manual indexing
grad(sys) = [sys.∂Φ∂x, sys.∂Φ∂y, sys.∂Φ∂z]

eqs = vcat(
Φ ~ disk.Φ + bulge.Φ + halo.Φ,
Expand Down
7 changes: 2 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ end
mws = structural_simplify(mw)

u = randn(6)
p = randn(21)
t = 0.0

f = ODEFunction(mws)

@test f(u, p, t) isa AbstractVector
problem = ODEProblem(mws, u, (0.0, 10.0), [])
@test problem.f(u, problem.p, 0.0) isa AbstractVector
end

0 comments on commit c407a55

Please sign in to comment.