Skip to content

Commit

Permalink
Use general function for interpolating background/perturbations with …
Browse files Browse the repository at this point in the history
…scalar arguments
  • Loading branch information
hersle committed Dec 19, 2024
1 parent 52645da commit de5f72b
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,6 @@ Base.getindex(sol::CosmologySolution, i::Colon, j::SymbolicIndex, k = :) = sol[1
function (sol::CosmologySolution)(ts::AbstractArray, is::AbstractArray)
return permutedims(sol.th(ts, idxs=is)[:, :])
end
function (sol::CosmologySolution)(ts, is)
ts_arr, ts_outi = ts isa Number ? ([ts], 1) : (ts, Colon())
is_arr, is_outi = is isa Number ? ([is], 1) : (is, Colon())
out = sol(ts_arr, is_arr) # convert to all-array call
return out[ts_outi, is_outi] # pick out dimensions for scalar ks/ts/is
end

function get_neighboring_wavenumber_indices(sol::CosmologySolution, k)
i2 = max(searchsortedfirst(sol.ks, k), 2) # index above target k (or 2nd index if k == kmin)
Expand Down Expand Up @@ -298,12 +292,14 @@ function (sol::CosmologySolution)(ks::AbstractArray, ts::AbstractArray, is::Abst

return out
end
function (sol::CosmologySolution)(ks, ts, is)
ks_arr, ks_outi = ks isa Number ? ([ks], 1) : (ks, Colon())
ts_arr, ts_outi = ts isa Number ? ([ts], 1) : (ts, Colon())
is_arr, is_outi = is isa Number ? ([is], 1) : (is, Colon())
out = sol(ks_arr, ts_arr, is_arr) # convert to all-array call
return out[ks_outi, ts_outi, is_outi] # pick out dimensions for scalar ks/ts/is

# Handle (ts, is) or (ks, ts, is) of arbitrary 0-dimensional and 1-dimensional combinations
function (sol::CosmologySolution)(args...)
# Please read this function with a pirate's voice
args_arr = [arg isa Number ? [arg] : arg for arg in args]
args_outi = [arg isa Number ? 1 : Colon() for arg in args]
out = sol(args_arr...) # convert to all-array call
return out[args_outi...] # pick out dimensions for scalar ks/ts/is
end

function (sol::CosmologySolution)(tvar::Num, t, idxs)
Expand Down

0 comments on commit de5f72b

Please sign in to comment.