diff --git a/src/solve.jl b/src/solve.jl index 90f2611d..93b0d6f9 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -247,12 +247,6 @@ Base.getindex(sol::CosmologySolution, i::Colon, j::SymbolicIndex, k = :) = sol[1 function (sol::CosmologySolution)(ts::AbstractArray, is::AbstractArray) return permutedims(sol.th(ts, idxs=is)[:, :]) end -function (sol::CosmologySolution)(ts, is) - ts_arr, ts_outi = ts isa Number ? ([ts], 1) : (ts, Colon()) - is_arr, is_outi = is isa Number ? ([is], 1) : (is, Colon()) - out = sol(ts_arr, is_arr) # convert to all-array call - return out[ts_outi, is_outi] # pick out dimensions for scalar ks/ts/is -end function get_neighboring_wavenumber_indices(sol::CosmologySolution, k) i2 = max(searchsortedfirst(sol.ks, k), 2) # index above target k (or 2nd index if k == kmin) @@ -298,12 +292,14 @@ function (sol::CosmologySolution)(ks::AbstractArray, ts::AbstractArray, is::Abst return out end -function (sol::CosmologySolution)(ks, ts, is) - ks_arr, ks_outi = ks isa Number ? ([ks], 1) : (ks, Colon()) - ts_arr, ts_outi = ts isa Number ? ([ts], 1) : (ts, Colon()) - is_arr, is_outi = is isa Number ? ([is], 1) : (is, Colon()) - out = sol(ks_arr, ts_arr, is_arr) # convert to all-array call - return out[ks_outi, ts_outi, is_outi] # pick out dimensions for scalar ks/ts/is + +# Handle (ts, is) or (ks, ts, is) of arbitrary 0-dimensional and 1-dimensional combinations +function (sol::CosmologySolution)(args...) + # Please read this function with a pirate's voice + args_arr = [arg isa Number ? [arg] : arg for arg in args] + args_outi = [arg isa Number ? 1 : Colon() for arg in args] + out = sol(args_arr...) # convert to all-array call + return out[args_outi...] # pick out dimensions for scalar ks/ts/is end function (sol::CosmologySolution)(tvar::Num, t, idxs)