Skip to content

Commit

Permalink
Fix multi-threading
Browse files Browse the repository at this point in the history
  • Loading branch information
hersle committed Nov 27, 2024
1 parent 4dc6875 commit 8e26de7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
NumericalIntegration = "e7bfaba1-d571-5449-8927-abc22e82249b"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PeriodicTable = "7b2266bf-644c-5ea3-82d8-af4bbd25a884"
PhysicalConstants = "5ad8b20f-a522-5ce9-bfc9-ddf1d5bda6ab"
Expand All @@ -26,5 +27,6 @@ UnitfulAstro = "6112ee07-acf9-5e0f-b108-d242c714bf9f"

[compat]
ModelingToolkit = "9.50.0"
OhMyThreads = "0.7.0"
OrdinaryDiffEq = "6.90.1"
Symbolics = "6.19.0"
2 changes: 1 addition & 1 deletion docs/src/comparison.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ ks = sol1["P"]["k(h/Mpc)"] * h # 1/Mpc
Ps_class = sol1["P"]["P(Mpc/h)^3"] / h^3
Ps_class = Ps_class[ks .> 1e-3]
ks = ks[ks .> 1e-3]
Ps = power_spectrum(M, pars, ks / u"Mpc") / u"Mpc^3"
Ps = power_spectrum(M, pars, ks / u"Mpc"; verbose=true) / u"Mpc^3"
sol = merge(sol, Dict(
"k" => (ks, ks),
"P" => (Ps_class, Ps)
Expand Down
15 changes: 12 additions & 3 deletions src/solve.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import CommonSolve: solve
import SymbolicIndexingInterface: setp # all_variable_symbols, getname
import OhMyThreads: TaskLocalValue

function background(sys)
sys = thermodynamics(sys)
Expand Down Expand Up @@ -149,7 +150,7 @@ function solve(M::CosmologyModel, pars; aini = 1e-7, solver = Rodas5P(), reltol
return CosmologySolution(bg_sol, th_sol, [], nothing)
end

function solve(M::CosmologyModel, pars, ks::AbstractArray; aini = 1e-7, solver = KenCarp4(), reltol = 1e-11, backwards = true, verbose = false, kwargs...)
function solve(M::CosmologyModel, pars, ks::AbstractArray; aini = 1e-7, solver = KenCarp4(), reltol = 1e-11, backwards = true, verbose = false, thread = true, kwargs...)
ks = k_dimensionless.(ks, pars[M.g.h])

!issorted(ks) && throw(error("ks = $ks are not sorted in ascending order"))
Expand All @@ -164,18 +165,26 @@ function solve(M::CosmologyModel, pars, ks::AbstractArray; aini = 1e-7, solver =
))
end

if Threads.nthreads() == 1 && thread
@warn "Multi-threading was requested, but Julia is running with 1 thread."

Check warning on line 169 in src/solve.jl

View workflow job for this annotation

GitHub Actions / Documentation

Multi-threading was requested, but Julia is running with 1 thread.

Check warning on line 169 in src/solve.jl

View workflow job for this annotation

GitHub Actions / Documentation

Multi-threading was requested, but Julia is running with 1 thread.
thread = false
end

kset! = setp(M.pt, M.k) # function that sets k on a problem
ics0 = unknowns(M.bg) .=> th_sol.bg[unknowns(M.bg)][backwards ? end : begin]
ics0 = filter(ic -> !contains(String(Symbol(ic.first)), "aˍt"), ics0) # remove D(a)
ics0 = Dict(ics0)
push!(pars, k => ks[1])
ode_prob0 = ODEProblem(M.pt, ics0, (tini, tend), pars; fully_determined = true) # TODO: why do I need this???
ode_probs = EnsembleProblem(; safetycopy = false, prob = ode_prob0, prob_func = (ode_prob, i, _) -> begin
ode_prob_tlv = TaskLocalValue{ODEProblem}(() -> deepcopy(ode_prob0)) # https://discourse.julialang.org/t/solving-ensembleproblem-efficiently-for-large-systems-memory-issues/116146/11 # TODO: avoid copying whole problem
ode_probs = EnsembleProblem(; safetycopy = false, prob = ode_prob0, prob_func = (_, i, _) -> begin
ode_prob = ode_prob_tlv[]
verbose && println("$i/$(length(ks)) k = $(ks[i]*k0) Mpc/h")
kset!(ode_prob, ks[i])
return ode_prob
end)
ode_sols = solve(ode_probs, solver, EnsembleThreads(), trajectories = length(ks); reltol, kwargs...) # TODO: test GPU parallellization
alg = thread ? EnsembleThreads() : EnsembleSerial()
ode_sols = solve(ode_probs, solver, alg, trajectories = length(ks); reltol, kwargs...) # TODO: test GPU parallellization
for i in 1:length(ode_sols)
check_solution(ode_sols[i].retcode)
end
Expand Down

0 comments on commit 8e26de7

Please sign in to comment.