Skip to content

Commit

Permalink
Type-safe when calling solution with single wavenumber/time/index
Browse files Browse the repository at this point in the history
  • Loading branch information
hersle committed Jan 30, 2025
1 parent b21fb6b commit b66b51c
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ function (sol::CosmologySolution)(ts::AbstractArray, is::AbstractArray)
maximum(ts) <= tmax || maximum(ts) tmax || throw("Requested time t = $(maximum(ts)) is after final time $tmax")
return permutedims(sol.th(ts, idxs=is)[:, :])
end
(sol::CosmologySolution)(ts::AbstractArray, i) = sol(ts, [i])[:, 1]
(sol::CosmologySolution)(t, is::AbstractArray) = sol([t], is)[1, :]
(sol::CosmologySolution)(t, i) = sol([t], [i])[1, 1]

function neighboring_modes_indices(sol::CosmologySolution, k)
k = k_dimensionless.(k, sol.bg.ps[:h])
Expand Down Expand Up @@ -401,6 +404,13 @@ function (sol::CosmologySolution)(ks::AbstractArray, ts::AbstractArray, is::Abst

return out
end
(sol::CosmologySolution)(k::Number, ts::AbstractArray, is::AbstractArray) = sol([k], ts, is)[1, :, :]
(sol::CosmologySolution)(ks::AbstractArray, t::Number, is::AbstractArray) = sol(ks, [t], is)[:, 1, :]
(sol::CosmologySolution)(ks::AbstractArray, ts::AbstractArray, i) = sol(ks, ts, [i])[:, :, 1]
(sol::CosmologySolution)(k::Number, t::Number, is::AbstractArray) = sol([k], [t], is)[1, 1, :]
(sol::CosmologySolution)(k::Number, ts::AbstractArray, i) = sol([k], ts, [i])[1, :, 1]
(sol::CosmologySolution)(ks::AbstractArray, t::Number, i) = sol(ks, [t], [i])[:, 1, 1]
(sol::CosmologySolution)(k::Number, t::Number, i) = sol([k], [t], [i])[1, 1, 1]

# Handle (ts, is) or (ks, ts, is) of arbitrary 0-dimensional and 1-dimensional combinations
function (sol::CosmologySolution)(args...)
Expand Down

0 comments on commit b66b51c

Please sign in to comment.