Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/hersle/SymBoltz.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
hersle committed Jan 9, 2025
2 parents 588f70a + d2b7dd8 commit a781cdb
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 31 deletions.
6 changes: 3 additions & 3 deletions docs/src/comparison.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
```@example class
using SymBoltz
lmax = 6
M = SymBoltz.ΛCDM(; lmax) # TODO: fix perturbations when massive neutrinos are present
M = SymBoltz.ΛCDM(; lmax)
pars = SymBoltz.parameters_Planck18(M)
```
```@setup class
Expand Down Expand Up @@ -189,7 +189,6 @@ function plot_compare(xlabel, ylabels; lgx=false, lgy=false, errtype=:auto, errl
xlab(x) = lgx ? "lg(|$x|)" : x
ylab(y) = lgy ? "lg(|$y|)" : y
# TODO: relative or absolute comparison (of quantities close to 0)
p = plot(; layout=grid(2, 1, heights=(3/4, 1/4)), size = (800, 600))
plot!(p[1]; titlefontsize = 8, ylabel = join(ylab.(ylabels), ", "))
maxerr = 0.0
Expand All @@ -214,6 +213,8 @@ function plot_compare(xlabel, ylabels; lgx=false, lgy=false, errtype=:auto, errl
# TODO: use built-in CosmoloySolution interpolation
y1 = LinearInterpolation(y1, x1; extrapolate=true).(x)
y2 = LinearInterpolation(y2, x2; extrapolate=true).(x)
# Compare absolute error if quantity crosses zero, otherwise relative error (unless overridden)
abserr = (errtype == :abs) || (errtype == :auto && (any(y1 .<= 0) || any(y2 .<= 0)))
if abserr
err = y2 .- y1
Expand All @@ -232,7 +233,6 @@ function plot_compare(xlabel, ylabels; lgx=false, lgy=false, errtype=:auto, errl
a = ceil(maxerr / 10^b)
errlim = a * 10^b
end
#println("maxerr = $maxerr ≈ $a * 10^$b")
plot!(p[end], ylims = (-errlim, +errlim))
return p
Expand Down
2 changes: 1 addition & 1 deletion src/components/gravity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function general_relativity(g; acceleration = false, name = :G, kwargs...)
D(g.Φ) ~ -4*Num(π)/3*a^2/g.*δρ - k^2/(3*g.ℰ)*g.Φ - g.*g.Ψ
k^2 * (g.Φ - g.Ψ) ~ 12*Num(π) * a^2 * Π
] .|> O^1)
guesses ==> 1, D(a) => +1]
guesses ==> 0.1, D(a) => +1]
description = "General relativity gravity"
return ODESystem([eqs0; eqs1], t, vars, pars; initialization_eqs = ics0, guesses, name, description, kwargs...)
end
Expand Down
60 changes: 34 additions & 26 deletions src/solve.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import Base: nameof
import CommonSolve: solve
import SymbolicIndexingInterface: setp # all_variable_symbols, getname
import OhMyThreads: TaskLocalValue
Expand Down Expand Up @@ -34,12 +35,16 @@ struct CosmologyModel
end

function CosmologyModel(sys::ODESystem; initE = true, spline_thermo = true, debug = false)
if debug
sys = debugize(sys) # TODO: make work with massive neutrinos
end
bg = structural_simplify(background(sys))
th = structural_simplify(thermodynamics(sys))
pt = structural_simplify(perturbations(sys; spline_thermo))

if debug
bg = debug_system(bg)
th = debug_system(th)
pt = debug_system(pt)
end

sys = complete(sys; flatten = false)
return CosmologyModel(sys, bg, th, pt, initE, spline_thermo)
end
Expand All @@ -55,6 +60,7 @@ function Base.getproperty(M::CosmologyModel, prop::Symbol)
end

# Forward inspection functions to full system
nameof(M::CosmologyModel) = nameof(M.sys)
equations(M::CosmologyModel) = equations(M.sys)
observed(M::CosmologyModel) = observed(M.sys)
unknowns(M::CosmologyModel) = unknowns(M.sys)
Expand All @@ -70,22 +76,31 @@ struct CosmologySolution
th::ODESolution
ks::AbstractArray
pts::Union{EnsembleSolution, Nothing}
pars::Dict
end

solvername(alg) = string(nameof(typeof(alg)))
solvername(alg::CompositeAlgorithm) = join(solvername.(alg.algs), "+")

function Base.show(io::IO, sol::CosmologySolution)
print(io, "Cosmology solution with stages")
print(io, "Cosmology solution for model ")
printstyled(io, nameof(sol.M), '\n'; bold = true)

printstyled(io, "Parameters:\n"; bold = true)
for (par, val) in sol.pars
print(io, " $par = $val\n")
end

printstyled(io, "Stages:"; bold = true)
print(io, "\n 1. background: solved with $(solvername(sol.bg.alg)), $(length(sol.bg)) points")
if !isnothing(sol.th)
print(io, "\n 2. thermodynamics: solved with $(solvername(sol.th.alg)), $(length(sol.th)) points")
end
if !isnothing(sol.pts)
kmin, kmax = extrema(map(pt -> pt.prob.ps[SymBoltz.k], sol.pts))
nmin, nmax = extrema(map(pt -> length(pt), sol.pts))
kmin, kmax = extrema(sol.ks)
nmin, nmax = extrema(map(length, sol.pts))
n = length(sol.pts)
print(io, "\n 3. perturbations: solved with $(solvername(sol.pts[1].alg)), $nmin-$nmax points, x$n k ∈ [$kmin, $kmax] H₀/c (linear interpolation in-between)")
print(io, "\n 3. perturbations: solved with $(solvername(sol.pts[1].alg)), $nmin-$nmax points, x$n k ∈ [$kmin, $kmax] H₀/c (linear interpolation between log(k))")
end
end

Expand Down Expand Up @@ -166,10 +181,9 @@ function solve(M::CosmologyModel, pars; aini = 1e-8, solver = Rodas4P(), reltol
th_sol = bg_sol
end

return CosmologySolution(M, bg_sol, th_sol, [], nothing)
return CosmologySolution(M, bg_sol, th_sol, [], nothing, pars)
end

# TODO: pass background solution to avoid recomputing it
"""
solve(M::CosmologyModel, pars, ks; aini = 1e-8, solver = KenCarp4(), reltol = 1e-8, backwards = true, verbose = false, thread = true, jac = false, sparse = false, kwargs...)
Expand All @@ -188,17 +202,16 @@ function solve(M::CosmologyModel, pars, ks::AbstractArray; aini = 1e-8, solver =
th_sol = solve(M, pars; aini, backwards, jac, sparse, reltol = reltol_bg, kwargs...)
tini, tend = extrema(th_sol.th[t])


# TODO: can I exploit that the structure of the perturbation ODEs is ẏ = J * y with "constant" J?
kset! = setp(M.pt, M.k) # function that sets k on a problem
ics0 = unknowns(M.bg) .=> th_sol.bg[unknowns(M.bg)][backwards ? end : begin]
ics0 = filter(ic -> !contains(String(Symbol(ic.first)), "aˍt"), ics0) # remove D(a)
ics0 = Dict(ics0)
pars = merge(pars, Dict(k => NaN)) # must be set, even for the uninitialized problem
ode_prob0 = ODEProblem(M.pt, ics0, (tini, tend), pars; fully_determined = true, jac, sparse)
parsk = merge(pars, Dict(k => NaN)) # must be set, even for the uninitialized problem
ode_prob0 = ODEProblem(M.pt, ics0, (tini, tend), parsk; fully_determined = true, jac, sparse)

# If the thermodynamics solution should be splined,
# solve it again and update the spline parameters
# solve it again (if needed) and update the spline parameters
if M.spline_thermo
th_sol_spline = isempty(kwargs) ? th_sol : solve(M, pars; aini, backwards, jac, sparse, reltol = reltol_bg) # should solve again if given keyword arguments, like saveat
τspline = spline(th_sol_spline[M.b.rec.τ], th_sol_spline[M.t]) # TODO: when solving thermo with low reltol: even though the solution is correct, just taking its points for splining can be insufficient. should increase number of points, so it won't mess up the perturbations
Expand All @@ -209,8 +222,7 @@ function solve(M::CosmologyModel, pars, ks::AbstractArray; aini = 1e-8, solver =
ode_prob0 = remake(ode_prob0, u0 = newu0, p = newp)
end

# TODO: just copy p and u0: https://github.com/SciML/ModelingToolkit.jl/issues/3056
ode_prob_tlv = TaskLocalValue{ODEProblem}(() -> deepcopy(ode_prob0)) # https://discourse.julialang.org/t/solving-ensembleproblem-efficiently-for-large-systems-memory-issues/116146/11 # TODO: avoid copying whole problem
ode_prob_tlv = TaskLocalValue{ODEProblem}(() -> deepcopy(ode_prob0)) # prevent conflicts where different tasks modify same problem: https://discourse.julialang.org/t/solving-ensembleproblem-efficiently-for-large-systems-memory-issues/116146/11 (alternatively copy just p and u0: https://github.com/SciML/ModelingToolkit.jl/issues/3056)
ode_probs = EnsembleProblem(; safetycopy = false, prob = ode_prob0, prob_func = (_, i, _) -> begin
ode_prob = ode_prob_tlv[]
verbose && println("$i/$(length(ks)) k = $(ks[i]*k0) Mpc/h")
Expand All @@ -222,9 +234,8 @@ function solve(M::CosmologyModel, pars, ks::AbstractArray; aini = 1e-8, solver =
for i in 1:length(ode_sols)
check_solution(ode_sols[i].retcode)
end
return CosmologySolution(M, th_sol.bg, th_sol.th, ks, ode_sols)
return CosmologySolution(M, th_sol.bg, th_sol.th, ks, ode_sols, pars)
end

function solve(M::CosmologyModel, pars, k::Number; kwargs...)
return solve(M, pars, [k]; kwargs...)
end
Expand All @@ -239,7 +250,7 @@ end
const SymbolicIndex = Union{Num, AbstractArray{Num}}
function Base.getindex(sol::CosmologySolution, i::SymbolicIndex)
if ModelingToolkit.isparameter(i) && i !== t # don't catch independent variable as parameter
return sol.th.ps[i] # assume all parameters are in background/thermodynamics # TODO: index sol directly?
return sol.th.ps[i] # assume all parameters are in background/thermodynamics # TODO: index sol directly when this is fixed? https://github.com/SciML/ModelingToolkit.jl/issues/3267
else
return sol.th[i]
end
Expand All @@ -253,8 +264,8 @@ Base.getindex(sol::CosmologySolution, i::Colon, j::SymbolicIndex, k = :) = sol[1

function (sol::CosmologySolution)(ts::AbstractArray, is::AbstractArray)
tmin, tmax = extrema(sol.th.t[[begin, end]])
minimum(ts) >= tmin || throw("Requested time t = $(minimum(ts)) is before initial time $tmin")
maximum(ts) <= tmax || throw("Requested time t = $(maximum(ts)) is after final time $tmax")
minimum(ts) >= tmin || minimum(ts) tmin || throw("Requested time t = $(minimum(ts)) is before initial time $tmin")
maximum(ts) <= tmax || maximum(ts) tmax || throw("Requested time t = $(maximum(ts)) is after final time $tmax")
return permutedims(sol.th(ts, idxs=is)[:, :])
end

Expand Down Expand Up @@ -351,13 +362,10 @@ function timeseries(sol::CosmologySolution, k; kwargs...)
return timeseries(ts; kwargs...)
end
function timeseries(ts::AbstractArray; Nextra = 0)
if Nextra == 0
return ts
if Nextra > 0
ts = exp.(extend_array(log.(ts), Nextra))
end
ts_new = exp.(extend_array(log.(ts), Nextra))
ts_new[begin] = ts[begin] # ensure exp(log(t)) transformation
ts_new[end] = ts[end] # # leaves original endpoints intact (for bounds checking)
return ts_new
return ts
end
"""
timeseries(sol::CosmologySolution, var, val::Number)
Expand Down
2 changes: 1 addition & 1 deletion src/spectra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ function los_integrate(Ss::AbstractArray, ls::AbstractArray, ks::AbstractRange,
end
for il in eachindex(ls)
integrand = @view ∂I_∂lnt[il,:]
Is[ik,il] = integrate(lnts, integrand, integrator) # integrate over t # TODO: add starting I(tini) # TODO: calculate ∂Θ_∂logΘ and use Even() methods
Is[ik,il] = integrate(lnts, integrand, integrator) # integrate over t # TODO: add starting I(tini)
end
end

Expand Down

0 comments on commit a781cdb

Please sign in to comment.