Skip to content

Commit

Permalink
Fix min/max time/wavenumber tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hersle committed Dec 19, 2024
1 parent e73e350 commit b0aa505
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
15 changes: 9 additions & 6 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,9 @@ Base.getindex(sol::CosmologySolution, i, j::SymbolicIndex, k = :) = [stack(sol[_
Base.getindex(sol::CosmologySolution, i::Colon, j::SymbolicIndex, k = :) = sol[1:length(sol.pts), j, k]

function (sol::CosmologySolution)(ts::AbstractArray, is::AbstractArray)
minimum(ts) >= sol.th.t[begin] || throw("Requested time t = $(minimum(ts)) is before initial time $(sol.th.t[begin])")
maximum(ts) <= sol.th.t[end] || throw("Requested time t = $(maximum(ts)) is before final time $(sol.th.t[end])")
tmin, tmax = extrema(sol.th.t)
minimum(ts) >= tmin || throw("Requested time t = $(minimum(ts)) is below minimum solved time $tmin")
maximum(ts) <= tmax || throw("Requested time t = $(maximum(ts)) is above maximum solved time $tmin")
return permutedims(sol.th(ts, idxs=is)[:, :])
end

Expand All @@ -273,10 +274,12 @@ end
function (sol::CosmologySolution)(ks::AbstractArray, ts::AbstractArray, is::AbstractArray)
ks = k_dimensionless.(ks, sol.bg.ps[:h])
isempty(sol.ks) && throw(error("No perturbations solved for. Pass ks to solve()."))
minimum(ks) >= sol.ks[begin] || throw("Requested wavenumber k = $(minimum(ks)) is outside solved range k ≥ $(sol.ks[begin])")
maximum(ks) <= sol.ks[end] || throw("Requested wavenumber k = $(maximum(ks)) is outside solved range k ≤ $(sol.ks[end])")
minimum(ts) >= sol.th.t[begin] || throw("Requested time t = $(minimum(ts)) is before initial time $(sol.th.t[begin])")
maximum(ts) <= sol.th.t[end] || throw("Requested time t = $(maximum(ts)) is before final time $(sol.th.t[end])")
kmin, kmax = extrema(sol.ks)
minimum(ks) >= kmin || throw("Requested wavenumber k = $(minimum(ks)) is below the minimum solved wavenumber $kmin")
maximum(ks) <= kmax || throw("Requested wavenumber k = $(maximum(ks)) is above the maximum solved wavenumber $kmax")
tmin, tmax = extrema(sol.th.t)
minimum(ts) >= tmin || throw("Requested time t = $(minimum(ts)) is below minimum solved time $tmin")
maximum(ts) <= tmax || throw("Requested time t = $(maximum(ts)) is above maximum solved time $tmin")

# Pre-allocate intermediate and output arrays
T = eltype(sol.pts[1])
Expand Down
12 changes: 6 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ pars = SymBoltz.parameters_Planck18(M)
@test size(sol(ts, is[1])) == (nt,)
@test size(sol(ts, is)) == (nt, ni)
@test_throws "No perturbations" sol(ks, ts, is)
@test_throws "before initial time" sol(sol[M.t][begin]-1, is)
@test_throws "after final time" sol(sol[M.t][end]+1, is)
@test_throws "below minimum solved time" sol(sol[M.t][begin]-1, is)
@test_throws "above maximum solved time" sol(sol[M.t][end]+1, is)

sol = solve(M, pars, ks)
@test size(sol(ks[1], ts[1], is[1])) == () # perturbations
Expand All @@ -29,10 +29,10 @@ pars = SymBoltz.parameters_Planck18(M)
@test size(sol(ks, ts[1], is)) == (nk, ni)
@test size(sol(ks, ts, is[1])) == (nk, nt)
@test size(sol(ks, ts, is)) == (nk, nt, ni)
@test_throws "outside range" sol(ks[begin]-1, ts, is)
@test_throws "outside range" sol(ks[end]+1, ts, is)
@test_throws "before initial time" sol(ks[1], sol[M.t][begin]-1, is)
@test_throws "after final time" sol(ks[1], sol[M.t][end]+1, is)
@test_throws "below minimum solved wavenumber" sol(ks[begin]-1, ts, is)
@test_throws "above maximum solved wavenumber" sol(ks[end]+1, ts, is)
@test_throws "below minimum solved time" sol(ks[1], sol[M.t][begin]-1, is)
@test_throws "above maximum solved time" sol(ks[1], sol[M.t][end]+1, is)

# TODO: also test array indexing
end
Expand Down

0 comments on commit b0aa505

Please sign in to comment.