Skip to content

Commit

Permalink
Also update pseudo_point dtc
Browse files Browse the repository at this point in the history
  • Loading branch information
theogf committed Apr 4, 2023
1 parent e93cf85 commit 8505dc8
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/space_time/pseudo_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,12 @@ function dtc_post_emissions(k::ScaledKernel, x_new::AbstractVector, storage::Sto
end

function dtc_post_emissions(k::KernelSum, x_new::AbstractVector, storage::StorageType)
(Cs_l, cs_l, Hs_l, hs_l), Σs_l = dtc_post_emissions(k.kernels[1], x_new, storage)
(Cs_r, cs_r, Hs_r, hs_r), Σs_r = dtc_post_emissions(k.kernels[2], x_new, storage)
Cs = _map(vcat, Cs_l, Cs_r)
cs = cs_l + cs_r
Hs = _map(block_diagonal, Hs_l, Hs_r)
hs = _map(vcat, hs_l, hs_r)
return (Cs, cs, Hs, hs), _map(+, Σs_l, Σs_r)
post_emissions = dtc_post_emissions.(k.kernels, Ref(x_new), Ref(storage))
Cs_cs_Hs_hs = getindex.(post_emissions, 1)
Σs = getindex.(post_emissions, 2)
Cs = _map(vcat, getindex.(Cs_cs_Hs_hs, 1)...)
cs = sum(getindex.(Cs_cs_Hs_hs, 2))
Hs = _map(block_diagonal, getindex.(Cs_cs_Hs_hs, 3)...)
hs = _map(vcat, getindex.(Cs_cs_Hs_hs, 4)...)
return (Cs, cs, Hs, hs), sum(Σs)
end

0 comments on commit 8505dc8

Please sign in to comment.