diff --git a/Project.toml b/Project.toml index 00ac75b3..cd15d2a5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TemporalGPs" uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" authors = ["willtebbutt "] -version = "0.3.7" +version = "0.3.8" [deps] BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" diff --git a/src/TemporalGPs.jl b/src/TemporalGPs.jl index fdf5d855..cd361c1c 100644 --- a/src/TemporalGPs.jl +++ b/src/TemporalGPs.jl @@ -14,6 +14,7 @@ module TemporalGPs using FillArrays: AbstractFill using Kronecker: KroneckerProduct + using Zygote: _pullback import Stheno: mean, cov, pairwise, logpdf, AV, AM, FiniteGP, AbstractGP @@ -32,10 +33,7 @@ module TemporalGPs include(joinpath("models", "immutable_inference.jl")) include(joinpath("models", "immutable_inference_pullbacks.jl")) - include(joinpath("models", "checkpointed_immutable_pullbacks.jl")) - - include(joinpath("models", "mutable_inference.jl")) - include(joinpath("models", "mutable_inference_pullbacks.jl")) + include(joinpath("models", "checkpointed_immutable_pullbacks.jl")) include(joinpath("models", "scalar_lgssm.jl")) diff --git a/src/models/checkpointed_immutable_pullbacks.jl b/src/models/checkpointed_immutable_pullbacks.jl index d9085c8c..da3c7887 100644 --- a/src/models/checkpointed_immutable_pullbacks.jl +++ b/src/models/checkpointed_immutable_pullbacks.jl @@ -121,8 +121,8 @@ for (foo, step_foo, foo_pullback, step_foo_pullback) in [ # Grabs the penultimate filtering distribution, xs[end]. Δys = Vector{eltype(ys)}(undef, T) (Δα, Δx__) = get_pb(f)(last(Δvs)) - _, pullback_last = $step_foo_pullback(model[T], xs[end], ys[T]) - Δmodel_at_T, Δx, Δy = pullback_last((Δlml, Δα, Δx__)) + _, pullback_last = _pullback(NoContext(), $step_foo, model[T], xs[end], ys[T]) + _, Δmodel_at_T, Δx, Δy = pullback_last((Δlml, Δα, Δx__)) Δmodel = get_adjoint_storage(model, Δmodel_at_T) Δys[T] = Δy @@ -144,8 +144,10 @@ for (foo, step_foo, foo_pullback, step_foo_pullback) in [ if t != T Δα, Δx__ = get_pb(f)(Δvs[t]) Δx_ = Zygote.accum(Δx, Δx__) - _, pullback_t = $step_foo_pullback(model[t], xs_block[c], ys[t]) - Δmodel_at_t, Δx, Δy = pullback_t((Δlml, Δα, Δx_)) + _, pullback_t = _pullback( + NoContext(), $step_foo, model[t], xs_block[c], ys[t], + ) + _, Δmodel_at_t, Δx, Δy = pullback_t((Δlml, Δα, Δx_)) Δmodel = _accum_at(Δmodel, t, Δmodel_at_t) Δys[t] = Δy end diff --git a/src/models/immutable_inference.jl b/src/models/immutable_inference.jl index 9d00a21f..b5360ed7 100644 --- a/src/models/immutable_inference.jl +++ b/src/models/immutable_inference.jl @@ -72,26 +72,6 @@ end return A * mf + a, (A * Pf) * A' + Q end -# # Immutable inference for heap-allocated arrays. -# @inline function predict( -# mf::StridedVector{T}, -# Pf::StridedMatrix{T}, -# A::StridedMatrix{T}, -# a::StridedVector{T}, -# Q::StridedMatrix{T}, -# ) where {T<:Real} - -# # Compute filtering mean vector. -# mp = A * mf + a - -# # Compute filtering covariance matrix. -# Pp = similar(Pf) -# BLAS.copy!(Pp, Q) -# mul!(Pp, A * Symmetric(Pf), A', one(T), one(T)) - -# return mp, Pp -# end - @inline function update_decorrelate( mp::AV{T}, Pp::AM{T}, H::AM{T}, h::AV{T}, Σ::AM{T}, y::AV{T}, ) where {T<:Real} @@ -126,8 +106,9 @@ end _compute_Pf(Pp::AM{T}, B::AM{T}) where {T<:Real} = Pp - B'B -function _compute_Pf(Pp::Matrix{T}, B::Matrix{T}) where {T<:Real} - # Copy of Pp is necessary to ensure that the memory isn't modified. - # return BLAS.syrk!('U', 'T', -one(T), B, one(T), copy(Pp)) - return LinearAlgebra.copytri!(BLAS.syrk!('U', 'T', -one(T), B, one(T), copy(Pp)), 'U') -end +# function _compute_Pf(Pp::Matrix{T}, B::Matrix{T}) where {T<:Real} +# # Copy of Pp is necessary to ensure that the memory isn't modified. +# # return BLAS.syrk!('U', 'T', -one(T), B, one(T), copy(Pp)) +# # I probably _do_ need a custom adjoint for this... +# return LinearAlgebra.copytri!(BLAS.syrk!('U', 'T', -one(T), B, one(T), copy(Pp)), 'U') +# end diff --git a/src/models/immutable_inference_pullbacks.jl b/src/models/immutable_inference_pullbacks.jl index 0d44086e..02e91a8c 100644 --- a/src/models/immutable_inference_pullbacks.jl +++ b/src/models/immutable_inference_pullbacks.jl @@ -1,17 +1,8 @@ # -# This file contains pullbacks for stuff in generic.jl. These are purely performance -# optimisations for algorithmic differentiation, and in no way important for understanding -# the structure of the package, or its functionality. +# This file contains pullbacks for stuff in immutable_inference.jl. There is no good reason +# to understand what's going on here. # -function Zygote.accum(a::UpperTriangular, b::UpperTriangular) - return UpperTriangular(Zygote.accum(a.data, b.data)) -end - -function Zygote.accum(D::Diagonal{<:Real}, U::UpperTriangular{<:Real, <:SMatrix}) - return UpperTriangular(D + U.data) -end - # # Objects in which to storage / accumulate the adjoint w.r.t. the hypers. # @@ -81,9 +72,9 @@ end get_pb(::typeof(pick_last)) = Δ->(nothing, Δ) -for (foo, step_foo, foo_pullback, step_foo_pullback) in [ - (:correlate, :step_correlate, :correlate_pullback, :step_correlate_pullback), - (:decorrelate, :step_decorrelate, :decorrelate_pullback, :step_decorrelate_pullback), +for (foo, step_foo, foo_pullback) in [ + (:correlate, :step_correlate, :correlate_pullback), + (:decorrelate, :step_decorrelate, :decorrelate_pullback), ] @eval @adjoint function $foo( ::Immutable, @@ -121,21 +112,21 @@ for (foo, step_foo, foo_pullback, step_foo_pullback) in [ vs[t] = f(α, x) end - function foo_pullback(Δ::Tuple{Any, Nothing}) - return foo_pullback((first(Δ), Fill(nothing, T))) - end + # function foo_pullback(Δ::Tuple{Any, Nothing}) + # return foo_pullback((Δ[1], Fill(nothing, T))) + # end - function foo_pullback(Δ::Tuple{Any, AbstractVector}) + function foo_pullback(Δ::Tuple{Any, Union{AbstractVector, Nothing}}) Δlml = Δ[1] - Δvs = Δ[2] + Δvs = Δ[2] isa Nothing ? Fill(nothing, T) : Δ[2] # Compute the pullback through the last element of the chain to get # initialisations for cotangents to accumulate. Δys = Vector{eltype(ys)}(undef, T) (Δα, Δx__) = get_pb(f)(last(Δvs)) - _, pullback_last = $step_foo_pullback(model[T], xs[T], ys[T]) - Δmodel_at_T, Δx, Δy = pullback_last((Δlml, Δα, Δx__)) + _, pullback_last = _pullback(NoContext(), $step_foo, model[T], xs[T], ys[T]) + _, Δmodel_at_T, Δx, Δy = pullback_last((Δlml, Δα, Δx__)) Δmodel = get_adjoint_storage(model, Δmodel_at_T) Δys[T] = Δy @@ -143,8 +134,8 @@ for (foo, step_foo, foo_pullback, step_foo_pullback) in [ for t in reverse(1:T-1) Δα, Δx__ = get_pb(f)(Δvs[t]) Δx_ = Zygote.accum(Δx, Δx__) - _, pullback_t = $step_foo_pullback(model[t], xs[t], ys[t]) - Δmodel_at_t, Δx, Δy = pullback_t((Δlml, Δα, Δx_)) + _, pullback_t = _pullback(NoContext(), $step_foo, model[t], xs[t], ys[t]) + _, Δmodel_at_t, Δx, Δy = pullback_t((Δlml, Δα, Δx_)) Δmodel = _accum_at(Δmodel, t, Δmodel_at_t) Δys[t] = Δy end @@ -161,333 +152,3 @@ for (foo, step_foo, foo_pullback, step_foo_pullback) in [ return (lml, vs), foo_pullback end end - - - -# -# AD-free pullbacks for a few things. These are primitives that will be used to write the -# gradients. -# - -function cholesky_pullback(Σ::Symmetric{<:Real, <:StridedMatrix}) - C = cholesky(Σ) - return C, function(Δ::NamedTuple) - U, Ū = C.U, Δ.factors - Σ̄ = Ū * U' - Σ̄ = LinearAlgebra.copytri!(Σ̄, 'U') - Σ̄ = ldiv!(U, Σ̄) - BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄) - - @inbounds for n in diagind(Σ̄) - Σ̄[n] /= 2 - end - return (UpperTriangular(Σ̄),) - end -end - -function cholesky_pullback(S::Symmetric{<:Real, <:StaticMatrix{N, N}}) where {N} - C = cholesky(S) - return C, function(Δ::NamedTuple) - U, ΔU = C.U, Δ.factors - ΔS = U \ (U \ SMatrix{N, N}(Symmetric(ΔU * U')))' - ΔS = ΔS - Diagonal(ΔS ./ 2) - return (UpperTriangular(ΔS),) - end -end - -function logdet_pullback(C::Cholesky) - return logdet(C), function(Δ) - return ((uplo=nothing, info=nothing, factors=Diagonal(2 .* Δ ./ diag(C.factors))),) - end -end - -AtA_pullback(A::AbstractMatrix{<:Real}) = A'A, Δ->(A * (Δ + Δ'),) - - - -# -# substantial pullbacks -# - -@adjoint predict(m, P, A, a, Q) = predict_pullback(m, P, A, a, Q) - -function predict_pullback( - m::AbstractVector{T}, - P::AbstractMatrix{T}, - A::AbstractMatrix{T}, - a::AbstractVector{T}, - Q::AbstractMatrix{T}, -) where {T<:Real} - mp = A * m + a # 1 - tmp = A * P # 2 - Pp = tmp * A' + Q # 3 - return (mp, Pp), function(Δ) - Δmp = Δ[1] - ΔPp = Δ[2] - - # 3 - ΔQ = ΔPp - ΔA = ΔPp' * tmp - ΔT = ΔPp * A - - # 2 - ΔA += ΔT * P' - ΔP = A'ΔT - - # 1 - ΔA += Δmp * m' - Δm = A'Δmp - Δa = Δmp - - return Δm, ΔP, ΔA, Δa, ΔQ - end -end - -@adjoint function step_decorrelate(model, x::Gaussian, y::AV{<:Real}) - return step_decorrelate_pullback(model, x, y) -end - -function step_decorrelate_pullback( - model::NamedTuple{(:gmm, :Σ)}, - x::Gaussian, - y::AV{<:Real}, -) - # Evaluate function, keeping track of derivatives. - gmm = model.gmm - (mp, Pp), predict_pb = predict_pullback(x.m, x.P, gmm.A, gmm.a, gmm.Q) - (mf, Pf, lml, α), update_decorrelate_pb = - update_decorrelate_pullback(mp, Pp, gmm.H, gmm.h, model.Σ, y) - - return (lml, α, Gaussian(mf, Pf)), function(Δ) - - # Unpack stuff. - Δlml, Δα, Δx = Δ - Δmf = Δx === nothing ? zero(mp) : Δx.m - ΔPf = Δx === nothing ? zero(Pp) : Δx.P - - # Backprop through stuff. - Δmp, ΔPp, ΔH, Δh, ΔΣ, Δy = update_decorrelate_pb((Δmf, ΔPf, Δlml, Δα)) - Δmf, ΔPf, ΔA, Δa, ΔQ = predict_pb((Δmp, ΔPp)) - - Δx = (m=Δmf, P=ΔPf) - Δmodel = ( - gmm = (A=ΔA, a=Δa, Q=ΔQ, H=ΔH, h=Δh), - Σ=ΔΣ, - ) - return Δmodel, Δx, Δy - end -end - -@adjoint function update_decorrelate(m, P, H, h, Σ, y) - return update_decorrelate_pullback(m, P, H, h, Σ, y) -end - -function update_decorrelate_pullback( - mp::AbstractVector{T}, - Pp::AbstractMatrix{T}, - H::AbstractMatrix{T}, - h::AbstractVector{T}, - Σ::AbstractMatrix{T}, - y::AbstractVector{T}, -) where {T<:Real} - - V = H * Pp # 1 - S_1 = V * H' + Σ # 2 - S, S_pb = cholesky_pullback(Symmetric(S_1)) # 2.1 - U = S.U # 3 - B = U' \ V # 4 - η = y - H * mp - h # 5 - α = U' \ η # 6 - - mf = mp + B'α # 7 - BtB, BtB_pb = AtA_pullback(B) # 8 - Pf = Pp - BtB # 9 - - logdet_S, logdet_S_pb = logdet_pullback(S) # 10 - lml = -(length(y) * T(log(2π)) + logdet_S + α'α) / 2 # 11 - - return (mf, Pf, lml, α), function(Δ) - Δmf, ΔPf, Δlml, Δα = Δ - - Δlml = Δlml === nothing ? zero(lml) : Δlml - Δα = Δα === nothing ? zero(α) : Δα - - # 11 - Δα = Δα .- Δlml * α - Δlogdet_S = -Δlml / 2 - - # 10 - ΔS = first(logdet_S_pb(Δlogdet_S)) - - # 9 - ΔPp = ΔPf - ΔBtB = -ΔPf - - # 8 - ΔB = first(BtB_pb(ΔBtB)) - - # 7 - Δmp = Δmf - Δα += B * Δmf - ΔB += α * Δmf' - - # 6 - Δη = U \ Δα - ΔU = -α * Δη' - - # 5 - Δy = Δη - ΔH = -Δη * mp' - Δmp += -H'Δη - Δh = -Δη - - # 4 - ΔV = U \ ΔB - ΔU += -B * ΔV' - - # 3 - ΔS = (uplo=nothing, info=nothing, factors=get_ΔS(ΔS.factors, UpperTriangular(ΔU))) - - # 2.1 - ΔS_1 = first(S_pb(ΔS)) - - # 2 - ΔV += ΔS_1 * H - ΔH += ΔS_1'V - ΔΣ = my_collect(ΔS_1) - - # 1 - ΔH += ΔV * Pp' - ΔPp += H'ΔV - - return Δmp, ΔPp, ΔH, Δh, ΔΣ, Δy - end -end - -get_ΔS(A, B) = A + B - -function get_ΔS( - A::Diagonal{<:Any, <:SVector{D}}, - B::UpperTriangular{<:Any, <:SMatrix{D, D}}, -) where {D} - return SMatrix{D, D}(A) + SMatrix{D, D}(B) -end - -@adjoint function step_correlate(model, x::Gaussian, α::AV{<:Real}) - return step_correlate_pullback(model, x, α) -end - -function step_correlate_pullback(model, x::Gaussian, α::AV{<:Real}) - - # Evaluate function, keeping track of derivatives. - gmm = model.gmm - (mp, Pp), predict_pb = predict_pullback(x.m, x.P, gmm.A, gmm.a, gmm.Q) - (mf, Pf, lml, y), update_decorrelate_pb = - update_correlate_pullback(mp, Pp, gmm.H, gmm.h, model.Σ, α) - - return (lml, y, Gaussian(mf, Pf)), function(Δ) - - # Unpack stuff. - Δlml, Δy, Δx = Δ - Δmf = Δx === nothing ? zero(mp) : Δx.m - ΔPf = Δx === nothing ? zero(Pp) : Δx.P - - # Backprop through stuff. - Δmp, ΔPp, ΔH, Δh, ΔΣ, Δα = update_decorrelate_pb((Δmf, ΔPf, Δlml, Δy)) - Δmf, ΔPf, ΔA, Δa, ΔQ = predict_pb((Δmp, ΔPp)) - - Δx = (m=Δmf, P=ΔPf) - Δmodel = ( - gmm = (A=ΔA, a=Δa, Q=ΔQ, H=ΔH, h=Δh), - Σ=ΔΣ, - ) - return Δmodel, Δx, Δα - end -end - -@adjoint function update_correlate(mp, Pp, H, h, Σ, α) - return update_correlate_pullback(mp, Pp, H, h, Σ, α) -end - -function update_correlate_pullback( - mp::AbstractVector{T}, - Pp::AbstractMatrix{T}, - H::AbstractMatrix{T}, - h::AbstractVector{T}, - Σ::AbstractMatrix{T}, - α::AbstractVector{T}, -) where {T<:Real} - - V = H * Pp # 1 - S_1 = V * H' + Σ # 2 - S, S_pb = cholesky_pullback(Symmetric(S_1)) # 2.1 - U = S.U # 3 - B = U' \ V # 4 - y = U'α + H * mp + h # 5 - - mf = mp + B'α # 6 - BtB, BtB_pb = AtA_pullback(B) # 7 - Pf = Pp - BtB # 8 - - logdet_S, logdet_S_pb = logdet_pullback(S) # 9 - lml = -(length(y) * T(log(2π)) + logdet_S + α'α) / 2 # 10 - - return (mf, Pf, lml, y), function(Δ) - Δmf, ΔPf, Δlml, Δy = Δ - - Δlml = Δlml === nothing ? zero(lml) : Δlml - Δy = Δy === nothing ? zero(y) : Δy - - # 10 - Δα = (-Δlml) * α - Δlogdet_S = -Δlml / 2 - - # 9 - ΔS = first(logdet_S_pb(Δlogdet_S)) - - # 8 - ΔPp = ΔPf - ΔBtB = -ΔPf - - # 7 - ΔB = first(BtB_pb(ΔBtB)) - - # 6 - Δmp = Δmf - Δα += B * Δmf - ΔB += α * Δmf' - - # 5 - Δα += U * Δy - ΔU = α * Δy' - ΔH = Δy * mp' - Δmp += H'Δy - Δh = Δy - - # 4 - ΔV = U \ ΔB - ΔU += -B * ΔV' - - # 3 - ΔS = Zygote.accum(ΔS, (uplo=nothing, info=nothing, factors=UpperTriangular(ΔU),)) - - # 2.1 - ΔS_1 = my_collect(first(S_pb(ΔS))) - - # 2 - ΔV += ΔS_1 * H - ΔH += ΔS_1'V - ΔΣ = ΔS_1 - - # 1 - ΔH += ΔV * Pp' - ΔPp += H'ΔV - - return Δmp, ΔPp, ΔH, Δh, ΔΣ, Δα - end -end - -my_collect(A::AbstractMatrix) = collect(A) -function my_collect(A::UpperTriangular{T, <:SMatrix{D, D, T}}) where {T<:Real, D} - return SMatrix{D, D}(A) -end diff --git a/src/models/mutable_inference.jl b/src/models/mutable_inference.jl deleted file mode 100644 index b7c2a77f..00000000 --- a/src/models/mutable_inference.jl +++ /dev/null @@ -1,148 +0,0 @@ -""" - decorrelate(::Mutable, model::LGSSM, ys::AbstractVector, f=copy_first) - -Version of decorrelate used by `LGSSM`s whose `StorageType` is `Mutable`, as defined by -`mutability`. -""" -function decorrelate(::Mutable, model::LGSSM, ys::AbstractVector, f=copy_first) - @assert length(model) == length(ys) - - # Pre-allocate for intermediates. - α = Vector{eltype(first(ys))}(undef, length(first(ys))) - x0 = model.gmm.x0 - x0_sym = Gaussian(x0.m, x0.P) - mf = copy(x0.m) - Pf = copy(x0.P) - x = Gaussian(mf, Pf) - - # Process first latent. - (lml, α, x) = step_decorrelate!(α, x, x0_sym, model[1], ys[1]) - v = f(α, x) - vs = Vector{typeof(v)}(undef, length(model)) - vs[1] = v - - # Process remaining latents. - @inbounds for t in 2:length(model) - lml_, α, x = step_decorrelate!(α, x, x, model[t], ys[t]) - lml += lml_ - vs[t] = f(α, x) - end - return lml, vs -end - -""" - function step_decorrelate!( - α::Vector{T}, - x_filter_next::Gaussian{Vector{T}, Matrix{T}}, - x_filter::Gaussian{Vector{T}, Matrix{T}}, - model::NamedTuple{(:gmm, :Σ)}, - y::Vector{<:Real}, - ) where {T<:Real} - -Mutating version of `step_decorrelate`. Mutates both `α` and `x_filter_next`. -""" -function step_decorrelate!( - α::Vector{T}, - x_filter_next::Gaussian{Vector{T}, Matrix{T}}, - x_filter::Gaussian{Vector{T}, Matrix{T}}, - model::NamedTuple{(:gmm, :Σ)}, - y::Vector{<:Real}, -) where {T<:Real} - - # Preallocate for predictive distribution. - mp = Vector{T}(undef, dim(x_filter)) - Pp = Matrix{T}(undef, dim(x_filter), dim(x_filter)) - x_predict = Gaussian(mp, Pp) - - # Compute next filtering distribution. - gmm = model.gmm - x_predict = predict!(x_predict, x_filter, gmm.A, gmm.a, gmm.Q) - x_filter_next, lml, α = update_decorrelate!( - α, x_filter_next, x_predict, gmm.H, gmm.h, model.Σ, y, - ) - return lml, α, x_filter_next -end - -""" - predict!( - x_predict::Gaussian{Vector{T}, Matrix{T}}, - x_filter::Gaussian{Vector{T}, Matrix{T}}, - A::Matrix{T}, - a::Vector{T}, - Q::Matrix{T}, - ) where {T<:Real} - -Mutatiing version of `predict`. Modifies `x_predict`. -""" -function predict!( - x_predict::Gaussian{Vector{T}, Matrix{T}}, - x_filter::Gaussian{Vector{T}, Matrix{T}}, - A::Matrix{T}, - a::AbstractVector{T}, - Q::Matrix{T}, -) where {T<:Real} - - # Compute predictive mean. - x_predict.m .= a - mul!(x_predict.m, A, x_filter.m, one(T), one(T)) - - # Compute predictive covariance. - APf = Matrix{T}(undef, size(A)) - mul!(APf, A, x_filter.P) - - x_predict.P .= Q - mul!(x_predict.P, APf, A', one(T), one(T)) - - return x_predict -end - -""" - update_decorrelate!( - α::Vector{T}, - x_filter::Gaussian{Vector{T}, Matrix{T}}, - x_predict::Gaussian{Vector{T}, Matrix{T}}, - H::Matrix{T}, - h::Vector{T}, - Σ::Matrix{T}, - y::AbstractVector{T}, - ) where {T<:Real} - -Mutating version of `update_decorrelate`. Modifies `α` and `x_filter`. -""" -function update_decorrelate!( - α::Vector{T}, - x_filter::Gaussian{Vector{T}, Matrix{T}}, - x_predict::Gaussian{Vector{T}, Matrix{T}}, - H::Matrix{T}, - h::AbstractVector{T}, - Σ::AbstractMatrix{T}, - y::AbstractVector{T}, -) where {T<:Real} - - V = H * x_predict.P - S_1 = V * H' + Σ - S = cholesky(Symmetric(S_1)) - U = S.U - B = U' \ V - - # α = U' \ (y - H * x_predict.m - h) - α = ldiv!(α, U', y - H * x_predict.m - h) - - # x_filter.m .= x_predict.m + B'α - x_filter.m .= x_predict.m - mul!(x_filter.m, B', α, one(T), one(T)) - - # x_filter.P .= x_predict.P - B'B - x_filter.P .= x_predict.P - _compute_Pf!(x_filter.P, B) - - # Compute log marginal probablilty of observation `y`. - lml = -(length(y) * T(log(2π)) + logdet(S) + α'α) / 2 - - return x_filter, lml, α -end - -# Old method. Considering removing. -function _compute_Pf!(Pp::Matrix{T}, B::Matrix{T}) where {T<:Real} - LinearAlgebra.copytri!(BLAS.syrk!('U', 'T', -1.0, B, 1.0, Pp), 'U') -end diff --git a/src/models/mutable_inference_pullbacks.jl b/src/models/mutable_inference_pullbacks.jl deleted file mode 100644 index e69de29b..00000000 diff --git a/src/util/zygote_rules.jl b/src/util/zygote_rules.jl index 3d83307a..8dbd73a1 100644 --- a/src/util/zygote_rules.jl +++ b/src/util/zygote_rules.jl @@ -1,4 +1,24 @@ -using Zygote: @adjoint, accum +using Zygote: @adjoint, accum, AContext + + +# This context doesn't allow any globals. +struct NoContext <: Zygote.AContext end + +# Stupid implementation to obtain type-stability. +Zygote.cache(cx::NoContext) = (cache_fields=nothing) + +# Stupid implementation. +Base.haskey(cx::NoContext, x) = false + +Zygote.accum_param(::NoContext, x, Δ) = Δ + +function context_free_gradient(f, args...) + _, pb = Zygote._pullback(NoContext(), f, args...) + return pb(1.0) +end + + +Zygote.accum(as::Tuple...) = map(accum, as...) # Not a rule, but a helpful utility. show_grad_type(x, S) = Zygote.hook(x̄ -> ((@show S, typeof(x̄)); x̄), x) @@ -144,3 +164,104 @@ end end return x[n], getindex_FillArray end + + +# +# AD-free pullbacks for a few things. These are primitives that will be used to write the +# gradients. +# + +function cholesky_pullback(Σ::Symmetric{<:Real, <:StridedMatrix}) + C = cholesky(Σ) + return C, function(Δ::NamedTuple) + U, Ū = C.U, Δ.factors + Σ̄ = Ū * U' + Σ̄ = LinearAlgebra.copytri!(Σ̄, 'U') + Σ̄ = ldiv!(U, Σ̄) + BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄) + + @inbounds for n in diagind(Σ̄) + Σ̄[n] /= 2 + end + return (UpperTriangular(Σ̄),) + end +end + +function cholesky_pullback(S::Symmetric{<:Real, <:StaticMatrix{N, N}}) where {N} + C = cholesky(S) + return C, function(Δ::NamedTuple) + U, ΔU = C.U, Δ.factors + ΔS = U \ (U \ SMatrix{N, N}(Symmetric(ΔU * U')))' + ΔS = ΔS - Diagonal(ΔS ./ 2) + return (UpperTriangular(ΔS),) + end +end + +@adjoint function cholesky(S::Symmetric{<:Real, <:StaticMatrix{N, N}}) where {N} + return cholesky_pullback(S) +end + +function logdet_pullback(C::Cholesky) + return logdet(C), function(Δ) + return ((uplo=nothing, info=nothing, factors=Diagonal(2 .* Δ ./ diag(C.factors))),) + end +end + +AtA_pullback(A::AbstractMatrix{<:Real}) = A'A, Δ->(A * (Δ + Δ'),) + + +function Zygote.accum(a::UpperTriangular, b::UpperTriangular) + return UpperTriangular(Zygote.accum(a.data, b.data)) +end + +function Zygote.accum(D::Diagonal{<:Real}, U::UpperTriangular{<:Real, <:SMatrix}) + return UpperTriangular(D + U.data) +end + +function Zygote.accum(a::Diagonal, b::UpperTriangular) + return UpperTriangular(a + b.data) +end + +Zygote.accum(a::UpperTriangular, b::Diagonal) = Zygote.accum(b, a) + +Zygote._symmetric_back(Δ::UpperTriangular{<:Any, <:SArray}, uplo) = Δ + + +# Temporary hacks. + +using Zygote: literal_getproperty, literal_indexed_iterate, literal_getindex + +function Zygote._pullback(::NoContext, ::typeof(literal_getproperty), C::Cholesky, ::Val{:U}) + function literal_getproperty_pullback(Δ) + return (nothing, (uplo=nothing, info=nothing, factors=UpperTriangular(Δ))) + end + literal_getproperty_pullback(Δ::Nothing) = nothing + return literal_getproperty(C, Val(:U)), literal_getproperty_pullback +end + + +function Zygote._pullback(cx::AContext, ::typeof(literal_indexed_iterate), xs::Tuple, ::Val{i}) where i + y, b = Zygote._pullback(cx, literal_getindex, xs, Val(i)) + back(::Nothing) = nothing + back(ȳ) = b(ȳ[1]) + (y, i+1), back +end + +function Zygote._pullback(cx::AContext, ::typeof(literal_indexed_iterate), xs::Tuple, ::Val{i}, st) where i + y, b = Zygote._pullback(cx, literal_getindex, xs, Val(i)) + back(::Nothing) = nothing + back(ȳ) = (b(ȳ[1])..., nothing) + (y, i+1), back +end + +Zygote._pullback(cx::AContext, ::typeof(getproperty), x, f::Symbol) = + Zygote._pullback(cx, Zygote.literal_getproperty, x, Val(f)) + +Zygote._pullback(cx::AContext, ::typeof(getfield), x, f::Symbol) = + Zygote._pullback(cx, Zygote.literal_getproperty, x, Val(f)) + +Zygote._pullback(cx::AContext, ::typeof(literal_getindex), x::NamedTuple, ::Val{f}) where f = + Zygote._pullback(cx, Zygote.literal_getproperty, x, Val(f)) + +Zygote._pullback(cx::AContext, ::typeof(literal_getproperty), x::Tuple, ::Val{f}) where f = + Zygote._pullback(cx, Zygote.literal_getindex, x, Val(f)) diff --git a/test/models/immutable_inference.jl b/test/models/immutable_inference.jl index f5f7c7ed..7c895e55 100644 --- a/test/models/immutable_inference.jl +++ b/test/models/immutable_inference.jl @@ -1,4 +1,15 @@ -using TemporalGPs: predict +using TemporalGPs: + NoContext, + predict, + update_decorrelate, + update_correlate, + step_decorrelate, + step_correlate, + decorrelate, + correlate, + Immutable, + copy_first +using Zygote: _pullback naive_predict(mf, Pf, A, a, Q) = A * mf + a, (A * Pf) * A' + Q @@ -7,63 +18,172 @@ println("immutable inference:") @testset "immutable_inference" begin rng = MersenneTwister(123456) Dlats = [1, 3] + Dobss = [1, 2] Ts = [ # (T=Float32, atol=1e-5, rtol=1e-5), (T=Float64, atol=1e-9, rtol=1e-9), ] - @testset "$Dlat, $(T.T)" for Dlat in Dlats, T in Ts - - @testset "predict" begin - - # Construct a Gauss-Markov model and pull out the relevant parameters. - gmm = random_tv_gmm(rng, Dlat, 1, 1, SArrayStorage(T.T)) - A = first(gmm.A) - a = first(gmm.a) - Q = first(gmm.Q) - mf = gmm.x0.m - Pf = gmm.x0.P - - # Check agreement with the naive implementation. - mp, Pp = predict(mf, Pf, A, a, Q) - mp_naive, Pp_naive = naive_predict(mf, Pf, A, a, Q) - @test mp ≈ mp_naive - @test Pp ≈ Pp_naive - - # Verify approximate numerical correctness of pullback. - U_Pf = cholesky(Symmetric(Pf)).U - U_Q = cholesky(Symmetric(Q)).U - Δmp = SVector{Dlat}(randn(rng, T.T, Dlat)) - ΔPp = SMatrix{Dlat, Dlat}(randn(rng, T.T, Dlat, Dlat)) - adjoint_test( - (mf, U_Pf, A, a, U_Q) -> begin - U_Q = UpperTriangular(U_Q) - U_Pf = UpperTriangular(U_Pf) - return predict(mf, U_Pf'U_Pf, A, a, U_Q'U_Q) - end, - (Δmp, ΔPp), - mf, U_Pf, A, a, U_Q; - rtol=T.rtol, atol=T.atol - ) - - # Evaluate and pullback. - (mp, Pp), back = pullback(predict, mf, Pf, A, a, Q) - (Δmf, ΔPf, ΔA, Δa, ΔQ) = back((Δmp, ΔPp)) - - # Verify correct output types have been produced. - @test mp isa SVector{Dlat, T.T} - @test Pp isa SMatrix{Dlat, Dlat, T.T} - - # Verify the adjoints w.r.t. the inputs are of the correct type. - @test Δmf isa SVector{Dlat, T.T} - @test ΔPf isa SMatrix{Dlat, Dlat, T.T} - @test ΔA isa SMatrix{Dlat, Dlat, T.T} - @test Δa isa SVector{Dlat, T.T} - @test ΔQ isa SMatrix{Dlat, Dlat, T.T} - - # Check that pullback doesn't allocate because StaticArrays. - @test allocs(@benchmark pullback(predict, $mf, $Pf, $A, $a, $Q)) == 0 - @test allocs(@benchmark $back(($Δmp, $ΔPp))) == 0 + @testset "$Dlat, $Dobs, $(T.T)" for Dlat in Dlats, Dobs in Dobss, T in Ts + + # Construct a Gauss-Markov model and pull out the relevant parameters. + gmm = random_tv_gmm(rng, Dlat, Dobs, 1, SArrayStorage(T.T)) + A = first(gmm.A) + a = first(gmm.a) + Q = first(gmm.Q) + mf = gmm.x0.m + Pf = gmm.x0.P + + # Check agreement with the naive implementation. + mp, Pp = predict(mf, Pf, A, a, Q) + mp_naive, Pp_naive = naive_predict(mf, Pf, A, a, Q) + @test mp ≈ mp_naive + @test Pp ≈ Pp_naive + + # Verify approximate numerical correctness of pullback. + U_Pf = cholesky(Symmetric(Pf)).U + U_Q = cholesky(Symmetric(Q)).U + Δmp = SVector{Dlat}(randn(rng, T.T, Dlat)) + ΔPp = SMatrix{Dlat, Dlat}(randn(rng, T.T, Dlat, Dlat)) + adjoint_test( + (mf, U_Pf, A, a, U_Q) -> begin + U_Q = UpperTriangular(U_Q) + U_Pf = UpperTriangular(U_Pf) + return predict(mf, U_Pf'U_Pf, A, a, U_Q'U_Q) + end, + (Δmp, ΔPp), + mf, U_Pf, A, a, U_Q; + rtol=T.rtol, atol=T.atol + ) + + # Evaluate and pullback. + (mp, Pp), back = pullback(predict, mf, Pf, A, a, Q) + (Δmf, ΔPf, ΔA, Δa, ΔQ) = back((Δmp, ΔPp)) + + # Verify correct output types have been produced. + @test mp isa SVector{Dlat, T.T} + @test Pp isa SMatrix{Dlat, Dlat, T.T} + + # Verify the adjoints w.r.t. the inputs are of the correct type. + @test Δmf isa SVector{Dlat, T.T} + @test ΔPf isa SMatrix{Dlat, Dlat, T.T} + @test ΔA isa SMatrix{Dlat, Dlat, T.T} + @test Δa isa SVector{Dlat, T.T} + @test ΔQ isa SMatrix{Dlat, Dlat, T.T} + + @testset "predict AD infers" begin + (mp, Pp), pb = _pullback(NoContext(), predict, mf, Pf, A, a, Q) + @inferred _pullback(NoContext(), predict, mf, Pf, A, a, Q) + @inferred pb((Δmp, ΔPp)) + end + + @testset "predict doesn't allocate" begin + _, pb = _pullback(NoContext(), predict, mf, Pf, A, a, Q) + @test allocs(@benchmark( + _pullback(NoContext(), predict, $mf, $Pf, $A, $a, $Q), + samples=1, + evals=1, + )) == 0 + @test allocs(@benchmark $pb(($Δmp, $ΔPp)) samples=1 evals=1) == 0 + end + + H = first(gmm.H) + h = first(gmm.h) + Σ = random_nice_psd_matrix(rng, Dobs, SArrayStorage(T.T)) + y = random_vector(rng, Dobs, SArrayStorage(T.T)) + + Δmf = random_vector(rng, Dlat, SArrayStorage(T.T)) + ΔPf = random_matrix(rng, Dlat, Dlat, SArrayStorage(T.T)) + Δlml = randn(rng) + Δα = random_vector(rng, Dobs, SArrayStorage(T.T)) + + x = random_gaussian(rng, Dlat, SArrayStorage(T.T)) + lgssm = random_tv_lgssm(rng, Dlat, Dobs, 1_000, SArrayStorage(T.T)) + ys = rand(rng, lgssm) + αs = rand(rng, lgssm) + + @testset "$name performance" for (name, f, update_f, step_f) in [ + (:decorrelate, decorrelate, update_decorrelate, step_decorrelate), + (:correlate, correlate, update_correlate, step_correlate), + ] + + @testset "update_$name AD infers" begin + _, pb = _pullback(NoContext(), update_f, mp, Pp, H, h, Σ, y) + @inferred _pullback(NoContext(), update_f, mp, Pp, H, h, Σ, y) + @inferred pb((Δmf, ΔPf, Δlml, Δα)) + end + + @testset "update_$name doesn't allocate" begin + _, pb = _pullback(NoContext(), update_f, mp, Pp, H, h, Σ, y) + @test allocs(@benchmark( + _pullback(NoContext(), $update_f, $mp, $Pp, $H, $h, $Σ, $y), + samples=1, + evals=1, + )) == 0 + @test allocs(@benchmark $pb(($Δmf, $ΔPf, $Δlml, $Δα)) samples=1 evals=1) == 0 + end + + @testset "step_$name AD infers" begin + model = (gmm=lgssm.gmm[1], Σ=lgssm.Σ[1]) + Δ = (Δlml, Δα, (m=Δmf, P=ΔPf)) + out, pb = _pullback(NoContext(), step_f, model, x, y) + @inferred _pullback(NoContext(), step_f, model, x, y) + @inferred pb(Δ) + end + + @testset "step_$name doesn't allocate" begin + model = (gmm=lgssm.gmm[1], Σ=lgssm.Σ[1]) + Δ = (Δlml, Δα, (m=Δmf, P=ΔPf)) + _, pb = _pullback(NoContext(), step_f, model, x, y) + @test allocs(@benchmark( + _pullback(NoContext(), $step_f, $model, $x, $y), + samples=1, + evals=1, + )) == 0 + @test allocs(@benchmark $pb($Δ) samples=1 evals=1) == 0 + end + + @testset "$name infers" begin + _, pb = _pullback(NoContext(), f, Immutable(), lgssm, ys) + @inferred f(Immutable(), lgssm, ys, copy_first) + @inferred _pullback(NoContext(), f, Immutable(), lgssm, ys, copy_first) + @inferred pb((randn(), αs)) + end + + # These tests should pick up on any substantial changes in allocations. It's + # possible that they'll need to be modified in future / for different versions + # of Julia. + @testset "$name allocations are independent of length" begin + _, pb = _pullback(NoContext(), f, Immutable(), lgssm, ys, copy_first) + + @test allocs( + @benchmark( + $f(Immutable(), $lgssm, $ys, copy_first); + samples=1, evals=1, + ), + ) < 5 + @test allocs( + @benchmark( + _pullback(NoContext(), $f, Immutable(), $lgssm, $ys, copy_first); + samples=1, evals=1, + ), + ) < 10 + @test allocs(@benchmark($pb((randn(), $αs)); samples=1, evals=1)) < 20 + end + + # @testset "benchmarking $name" begin + # @show Dlat, Dobs, name, T.T + # _, pb = _pullback(NoContext(), f, Immutable(), lgssm, ys, copy_first) + + # display(@benchmark($f(Immutable(), $lgssm, $ys, copy_first))) + # println() + # display(@benchmark( + # _pullback(NoContext(), $f, Immutable(), $lgssm, $ys, copy_first), + # )) + # println() + # display(@benchmark($pb((randn(), $αs)))) + # println() + # end end end end diff --git a/test/models/immutable_inference_pullbacks.jl b/test/models/immutable_inference_pullbacks.jl index 8e2cf0bb..b44b30cc 100644 --- a/test/models/immutable_inference_pullbacks.jl +++ b/test/models/immutable_inference_pullbacks.jl @@ -1,282 +1,12 @@ -# This file contains a collection of optimisations for use with reveerse-mode AD. -# Consequently, it is not necessary to understand the contents of this file to understand -# the package as a whole. - -using TemporalGPs: - is_of_storage_type, - Gaussian, - cholesky_pullback, - logdet_pullback, - update_correlate, - update_correlate_pullback, - step_correlate, - step_correlate_pullback, - correlate, - correlate_pullback, - update_decorrelate, - update_decorrelate_pullback, - step_decorrelate, - step_decorrelate_pullback, - decorrelate, - decorrelate_pullback - -naive_predict(mf, Pf, A, a, Q) = A * mf + a, (A * Pf) * A' + Q - -function verify_pullback(f_pullback, input, Δoutput, storage) - output, _pb = f_pullback(input...) - Δinput = _pb(Δoutput) - - @test is_of_storage_type(input, storage.val) - @test is_of_storage_type(output, storage.val) - @test is_of_storage_type(Δinput, storage.val) - @test is_of_storage_type(Δoutput, storage.val) - - if storage.val isa SArrayStorage - @test allocs(@benchmark $f_pullback($input...)) == 0 - @test allocs(@benchmark $_pb($Δoutput)) == 0 - end -end +using TemporalGPs: is_of_storage_type, correlate, decorrelate @testset "immutable_inference_pullbacks" begin - @testset "$N, $T" for N in [1, 2, 3], T in [Float32, Float64] - - rng = MersenneTwister(123456) - - # Do dense stuff. - S_ = randn(rng, T, N, N) - S = S_ * S_' + I - C = cholesky(S) - Ss = SMatrix{N, N, T}(S) - Cs = cholesky(Ss) - - @testset "cholesky" begin - C_fwd, pb = cholesky_pullback(Symmetric(S)) - Cs_fwd, pbs = cholesky_pullback(Symmetric(Ss)) - - @test eltype(C_fwd) == T - @test eltype(Cs_fwd) == T - - ΔC = randn(rng, T, N, N) - ΔCs = SMatrix{N, N, T}(ΔC) - - @test C.U ≈ Cs.U - @test Cs.U ≈ Cs_fwd.U - - ΔS, = pb((factors=ΔC, )) - ΔSs, = pbs((factors=ΔCs, )) - - @test ΔS ≈ ΔSs - @test eltype(ΔS) == T - @test eltype(ΔSs) == T - - @test allocs(@benchmark cholesky(Symmetric($Ss))) == 0 - @test allocs(@benchmark cholesky_pullback(Symmetric($Ss))) == 0 - @test allocs(@benchmark $pbs((factors=$ΔCs,))) == 0 - end - @testset "logdet" begin - @test logdet(Cs) ≈ logdet(C) - C_fwd, pb = logdet_pullback(C) - Cs_fwd, pbs = logdet_pullback(Cs) - - @test eltype(C_fwd) == T - @test eltype(Cs_fwd) == T - - @test logdet(Cs) ≈ Cs_fwd - - Δ = randn(rng, T) - ΔC = first(pb(Δ)).factors - ΔCs = first(pbs(Δ)).factors - - @test ΔC ≈ ΔCs - @test eltype(ΔC) == T - @test eltype(ΔCs) == T - - @test allocs(@benchmark logdet($Cs)) == 0 - @test allocs(@benchmark logdet_pullback($Cs)) == 0 - @test allocs(@benchmark $pbs($Δ)) == 0 - end - end - - @testset "step pullbacks" begin - Dlats = [3] - Dobss = [2] - storages = [ - (name="heap - Float64", val=ArrayStorage(Float64)), - (name="stack - Float64", val=SArrayStorage(Float64)), - (name="heap - Float32", val=ArrayStorage(Float32)), - (name="stack - Float32", val=SArrayStorage(Float32)), - ] - - @testset "storage=$(storage.name), Dlat=$Dlat, Dobs=$Dobs" for - Dlat in Dlats, - Dobs in Dobss, - storage in storages - - rng = MersenneTwister(123456) - - # Specify LGSSM dynamics. - A = random_matrix(rng, Dlat, Dlat, storage.val) - a = random_vector(rng, Dlat, storage.val) - Q = random_nice_psd_matrix(rng, Dlat, storage.val) - U_Q = cholesky(Q).U - H = random_matrix(rng, Dobs, Dlat, storage.val) - h = random_vector(rng, Dobs, storage.val) - S = random_nice_psd_matrix(rng, Dobs, storage.val) - U_S = cholesky(S).U - - # Specify LGSSM initial state distribution. - m = random_vector(rng, Dlat, storage.val) - P = random_nice_psd_matrix(rng, Dlat, storage.val) - U_P = cholesky(P).U - - # Specify input-output pairs. - α = random_vector(rng, Dobs, storage.val) - y = random_vector(rng, Dobs, storage.val) - - @testset "update_correlate" begin - - # Specify adjoints for outputs. - Δmf = random_vector(rng, Dlat, storage.val) - ΔPf = random_matrix(rng, Dlat, Dlat, storage.val) - Δlml = randn(rng, eltype(storage.val)) - Δy = random_vector(rng, Dobs, storage.val) - Δoutput = (Δmf, ΔPf, Δlml, Δy) - - # Check reverse-mode agrees with finite differences. - if eltype(storage.val) == Float64 - adjoint_test( - (mp, U_Pp, H, h, U_S, α) -> begin - U_Pp = UpperTriangular(U_Pp) - Pp = U_Pp'U_Pp - - U_S = UpperTriangular(U_S) - S = U_S'U_S - - return update_correlate(mp, Pp, H, h, S, α) - end, - Δoutput, - m, U_P, H, h, U_S, α; - atol=1e-6, rtol=1e-6, - ) - end - - # Check that appropriate typoes are produced, and allocations are correct. - input = (m, P, H, h, S, α) - verify_pullback(update_correlate_pullback, input, Δoutput, storage) - end - @testset "update_decorrelate" begin - - # Specify adjoints for outputs. - Δmf = random_vector(rng, Dlat, storage.val) - ΔPf = random_matrix(rng, Dlat, Dlat, storage.val) - Δlml = randn(rng, eltype(storage.val)) - Δα = random_vector(rng, Dobs, storage.val) - Δoutput = (Δmf, ΔPf, Δlml, Δα) - - - # Check reverse-mode agrees with finite differences. - if eltype(storage.val) == Float64 - adjoint_test( - (mp, U_Pp, H, h, U_S, y) -> begin - U_Pp = UpperTriangular(U_Pp) - Pp = U_Pp'U_Pp - - U_S = UpperTriangular(U_S) - _S = U_S'U_S - - return update_decorrelate(mp, Pp, H, h, _S, y) - end, - Δoutput, - m, U_P, H, h, U_S, y; - atol=1e-6, rtol=1e-6, - ) - end - - # Check that appropriate typoes are produced, and allocations are correct. - input = (m, P, H, h, S, y) - verify_pullback(update_decorrelate_pullback, input, Δoutput, storage) - end - @testset "step_correlate" begin - - # Specify adjoints for outputs. - Δlml = randn(rng, eltype(storage.val)) - Δy = random_vector(rng, Dobs, storage.val) - Δx = ( - m = random_vector(rng, Dlat, storage.val), - P = random_matrix(rng, Dlat, Dlat, storage.val), - ) - Δoutput = (Δlml, Δy, Δx) - - # Check reverse-mode agress with finite differences. - if eltype(storage.val) == Float64 - adjoint_test( - (A, a, U_Q, H, h, U_S, mf, U_Pf, α) -> begin - U_Q = UpperTriangular(Q) - U_S = UpperTriangular(U_S) - U_Pf = UpperTriangular(U_Pf) - - model = ( - gmm = (A=A, a=a, Q=U_Q'U_Q, H=H, h=h), - Σ=U_S'U_S, - ) - x_ = Gaussian(mf, U_Pf'U_Pf) - return step_correlate(model, x_, α) - end, - Δoutput, - A, a, U_Q, H, h, U_S, m, U_P, α; - rtol=1e-6, atol=1e-6, - ) - end - - # Check that appropriate typoes are produced, and allocations are correct. - model = (gmm=(A=A, a=a, Q=Q, H=H, h=h), Σ=S) - x = Gaussian(m, P) - input = (model, x, α) - verify_pullback(step_correlate_pullback, input, Δoutput, storage) - end - @testset "step_decorrelate" begin - - # Specify adjoints for outputs. - Δlml = randn(rng, eltype(storage.val)) - Δα = random_vector(rng, Dobs, storage.val) - Δx = ( - m = random_vector(rng, Dlat, storage.val), - P = random_matrix(rng, Dlat, Dlat, storage.val), - ) - Δoutput = (Δlml, Δα, Δx) - - # Check reverse-mode agress with finite differences. - if eltype(storage.val) == Float64 - adjoint_test( - (A, a, U_Q, H, h, U_S, mf, U_Pf, y) -> begin - U_Q = UpperTriangular(Q) - Q = U_Q'U_Q - - U_S = UpperTriangular(U_S) - S = U_S'U_S - - U_Pf = UpperTriangular(U_Pf) - Pf = U_Pf'U_Pf - - model = (gmm=(A=A, a=a, Q=Q, H=H, h=h), Σ=S) - x = Gaussian(mf, Pf) - return step_decorrelate(model, x, y) - end, - Δoutput, - A, a, U_Q, H, h, U_S, m, U_P, y; - atol=1e-6, rtol=1e-6, - ) - end - - # Check that appropriate typoes are produced, and allocations are correct. - model = (gmm=(A=A, a=a, Q=Q, H=H, h=h), Σ=S) - x = Gaussian(m, P) - input = (model, x, y) - verify_pullback(step_decorrelate_pullback, input, Δoutput, storage) - end - end - end + # AD correctness testing. + fs = [ + (name="decorrelate", f=decorrelate), + (name="correlate", f=correlate), + ] Dlats = [3] Dobss = [2] storages = [ @@ -290,59 +20,36 @@ end (name = "time-invariant", build_model = random_ti_lgssm), ] - @testset "correlate: Dlat=$Dlat, Dobs=$Dobs, storage=$(storage.name), tv=$(tv.name)" for + @testset "$(f.name): Dlat=$Dlat, Dobs=$Dobs, storage=$(storage.name), tv=$(tv.name)" for + f in fs, Dlat in Dlats, Dobs in Dobss, storage in storages, tv in tvs - N_correctness = 10 - N_performance = 1_000 - - @testset "correctness" begin - rng = MersenneTwister(123456) - - model = tv.build_model(rng, Dlat, Dobs, N_correctness, storage.val) - - # We don't care about the statistical properties of the thing that correlate - # is applied to, just that it's the correct size / type, for which rand - # suffices. - α = rand(rng, model) - - input = (model, α) - Δoutput = (randn(rng, eltype(storage.val)), rand(rng, model)) - - output, _pb = Zygote.pullback(correlate, input...) - Δinput = _pb(Δoutput) - - @test is_of_storage_type(input, storage.val) - @test is_of_storage_type(output, storage.val) - @test is_of_storage_type(Δinput, storage.val) - @test is_of_storage_type(Δoutput, storage.val) + rng = MersenneTwister(123456) - # Only verify accuracy with Float64s. - if eltype(storage.val) == Float64 && storage.val isa SArrayStorage - adjoint_test(correlate, Δoutput, input...) - end - end + model = tv.build_model(rng, Dlat, Dobs, 10, storage.val) - # Only verify performance if StaticArrays are being used. - if storage.val isa SArrayStorage - @testset "performance" begin + # We don't care about the statistical properties of the thing that correlate + # is applied to, just that it's the correct size / type, for which rand + # suffices. + α = rand(rng, model) - rng = MersenneTwister(123456) - model = tv.build_model(rng, Dlat, Dobs, N_performance, storage.val) + input = (model, α) + Δoutput = (randn(rng, eltype(storage.val)), rand(rng, model)) - α = rand(rng, model) + output, _pb = Zygote.pullback(f.f, input...) + Δinput = _pb(Δoutput) - input = (model, α) - Δoutput = (randn(rng, eltype(storage.val)), rand(rng, model)) + @test is_of_storage_type(input, storage.val) + @test is_of_storage_type(output, storage.val) + @test is_of_storage_type(Δinput, storage.val) + @test is_of_storage_type(Δoutput, storage.val) - primal, forwards, pb = adjoint_allocs(correlate, Δoutput, input...) - @test primal < 100 - @test forwards < 100 - @test pb < 3 * N_performance - end + # Only verify accuracy with Float64s. + if eltype(storage.val) == Float64 && storage.val isa SArrayStorage + adjoint_test(f.f, Δoutput, input...) end end end diff --git a/test/models/mutable_inference.jl b/test/models/mutable_inference.jl deleted file mode 100644 index 10018ebc..00000000 --- a/test/models/mutable_inference.jl +++ /dev/null @@ -1,381 +0,0 @@ -using TemporalGPs: is_of_storage_type - -@testset "mutable_inference" begin - rng = MersenneTwister(123456) - Dlats = [1, 3] - Dobss = [1, 2] - Ts = [ - # (T=Float32, atol=1e-5, rtol=1e-5), - (storage=ArrayStorage(Float64), atol=1e-9, rtol=1e-9), - ] - - @testset "Matrix - $Dlat, $Dobs, $(T.storage)" for Dlat in Dlats, Dobs in Dobss, T in Ts - - storage = T.storage - - # Generate parameters for a transition model. - A = random_matrix(rng, Dlat, Dlat, storage) - a = random_vector(rng, Dlat, storage) - Q = random_nice_psd_matrix(rng, Dlat, storage) - - mf = random_vector(rng, Dlat, storage) - Pf = random_nice_psd_matrix(rng, Dlat, storage) - xf = Gaussian(mf, Pf) - - mp = random_vector(rng, Dlat, storage) - Pp = random_nice_psd_matrix(rng, Dlat, storage) - xp = Gaussian(mp, Pp) - - # Generate parameters for emission model. - α = random_vector(rng, Dobs, storage) - H = random_matrix(rng, Dobs, Dlat, storage) - h = random_vector(rng, Dobs, storage) - Σ = random_nice_psd_matrix(rng, Dobs, storage) - y = random_vector(rng, Dobs, storage) - - model = (gmm=(A=A, a=a, Q=Q, H=H, h=h), Σ=Σ) - - @testset "predict!" begin - mp_naive, Pp_naive = TemporalGPs.predict(mf, Pf, A, a, Q) - xp = TemporalGPs.predict!(xp, xf, A, a, Q) - - @test xp.m ≈ mp_naive - @test xp.P ≈ Pp_naive - @test is_of_storage_type(xp, storage) - end - - @testset "update_decorrelate!" begin - mf′_naive, Pf′_naive, lml_naive, α_naive = TemporalGPs.update_decorrelate( - mp, Pp, H, h, Σ, y, - ) - - xf′, lml, α = TemporalGPs.update_decorrelate!(α, copy(xf), xp, H, h, Σ, y) - - @test xf′.m ≈ mf′_naive - @test xf′.P ≈ Pf′_naive - @test is_of_storage_type(xf′, storage) - @test α ≈ α_naive - @test lml ≈ lml_naive - end - - @testset "step_decorrelate!" begin - lml_naive, α_naive, x_filter_next_naive = TemporalGPs.step_decorrelate( - model, xp, y, - ) - - x_filter_next = random_gaussian(rng, Dlat, storage) - lml, α, x_filter_next = TemporalGPs.step_decorrelate!( - α, x_filter_next, xp, model, y, - ) - - @test lml_naive ≈ lml - @test α_naive ≈ α - @test x_filter_next_naive.m ≈ x_filter_next.m - @test x_filter_next_naive.P ≈ x_filter_next.P - end - - @testset "decorrelate - mutable" begin - - model_lgssm = random_ti_lgssm(rng, Dlat, Dobs, 5, storage) - ys = rand(rng, model_lgssm) - - lml_naive, vs_naive = TemporalGPs.decorrelate( - TemporalGPs.Immutable(), model_lgssm, ys, - ) - lml, vs = TemporalGPs.decorrelate(TemporalGPs.Mutable(), model_lgssm, ys) - - @test lml_naive ≈ lml - @test all(vs_naive .≈ vs) - end - - @testset "decorrelate - mutable - scalar" begin - - model_lgssm = random_ti_scalar_lgssm(rng, Dlat, 5, storage) - ys = rand(rng, model_lgssm) - - lml_naive, vs_naive = TemporalGPs.decorrelate( - TemporalGPs.Immutable(), model_lgssm, ys, - ) - lml, vs = TemporalGPs.decorrelate(TemporalGPs.Mutable(), model_lgssm, ys) - - @test lml_naive ≈ lml - @test all(vs_naive .≈ vs) - end - - # @testset "predict" begin - - # # Check agreement with the naive implementation. - # mp, Pp = predict(mf, Pf, A, a, Q) - # mp_naive, Pp_naive = naive_predict(mf, Pf, A, a, Q) - # @test mp ≈ mp_naive - # @test Pp ≈ Pp_naive - # @test mp isa Vector{T.T} - # @test Pp isa Matrix{T.T} - - # # Verify approximate numerical correctness of pullback. - # U_Pf = cholesky(Symmetric(Pf)).U - # U_Q = cholesky(Symmetric(Q)).U - # Δmp = randn(rng, T.T, Dlat) - # ΔPp = randn(rng, T.T, Dlat, Dlat) - # adjoint_test( - # (mf, U_Pf, A, a, U_Q) -> begin - # U_Q = UpperTriangular(U_Q) - # U_Pf = UpperTriangular(U_Pf) - # return predict(mf, Symmetric(U_Pf'U_Pf), A, a, U_Q'U_Q) - # end, - # (Δmp, ΔPp), - # mf, U_Pf, A, a, U_Q; - # rtol=T.rtol, atol=T.atol - # ) - - # # Evaluate and pullback. - # (mp, Pp), back = pullback(predict, mf, Pf, A, a, Q) - # (Δmf, ΔPf, ΔA, Δa, ΔQ) = back((Δmp, ΔPp)) - - # # Verify correct output types have been produced. - # @test mp isa Vector{T.T} - # @test Pp isa Matrix{T.T} - - # # Verify the adjoints w.r.t. the inputs are of the correct type. - # @test Δmf isa Vector{T.T} - # @test ΔPf isa Matrix{T.T} - # @test ΔA isa Matrix{T.T} - # @test Δa isa Vector{T.T} - # @test ΔQ isa Matrix{T.T} - # end - end - - # n_blockss = [1, 3] - # @testset "BlockDiagonal - $Dlat_block, $(T.T), $n_blocks" for - # Dlat_block in Dlats, - # T in Ts, - # n_blocks in n_blockss - - # rng = MersenneTwister(123456) - - # # Compute the total number of dimensions. - # Dlat = n_blocks * Dlat_block - - # # Generate block-diagonal transition dynamics. - # As = map(_ -> randn(rng, T.T, Dlat_block, Dlat_block), 1:n_blocks) - # A = BlockDiagonal(As) - - # a = randn(rng, T.T, Dlat) - - # Qs = map( - # _ -> random_nice_psd_matrix(rng, Dlat_block, ArrayStorage(T.T)), - # 1:n_blocks, - # ) - # Q = BlockDiagonal(Qs) - - # # Generate filtering (input) distribution. - # mf = randn(rng, T.T, Dlat) - # Pf = Symmetric(random_nice_psd_matrix(rng, Dlat, ArrayStorage(T.T))) - - # # Check that predicting twice gives exactly the same answer. - # let - # mf_c = copy(mf) - # Pf_c = copy(Pf) - # A_c = BlockDiagonal(map(copy, As)) - # a_c = copy(a) - # Q_c = BlockDiagonal(map(copy, Qs)) - - # m1, P1 = predict(mf_c, Pf_c, A_c, a_c, Q_c) - # m2, P2 = predict(mf_c, Pf_c, A_c, a_c, Q_c) - - # @test m1 == m2 - # @test P1 == P2 - - # @test mf_c == mf - # @test Pf_c == Pf - # @test A_c == A - # @test a_c == a - # @test Q_c == Q - # end - - # # Generate corresponding dense dynamics. - # A_dense = collect(A) - # Q_dense = collect(Q) - - # # Check agreement with dense implementation. - # mp, Pp = predict(mf, Pf, A, a, Q) - # mp_dense_dynamics, Pp_dense_dynamics = predict(mf, Pf, A_dense, a, Q_dense) - # @test mp ≈ mp_dense_dynamics - # @test Symmetric(Pp) ≈ Symmetric(Pp_dense_dynamics) - # @test mp isa Vector{T.T} - # @test Pp isa Matrix{T.T} - - # # Verify approximate numerical correctness of pullback. - # U_Pf = collect(cholesky(Symmetric(Pf)).U) - # U_Q = map(Q -> collect(cholesky(Symmetric(Q)).U), Qs) - # Δmp = randn(rng, T.T, Dlat) - # ΔPp = randn(rng, T.T, Dlat, Dlat) - - # adjoint_test( - # (mf, U_Pf, A, a, U_Q) -> begin - # Qs = map(U -> UpperTriangular(U)'UpperTriangular(U), U_Q) - # Q = BlockDiagonal(Qs) - # U_Pf = UpperTriangular(U_Pf) - # return predict(mf, Symmetric(U_Pf'U_Pf), A, a, Q) - # end, - # (Δmp, ΔPp), - # mf, U_Pf, A, a, U_Q; - # rtol=T.rtol, atol=T.atol, - # ) - # end - - # Ns = [1, 2] - # Ds = [2, 3] - - # @testset "KroneckerProduct - $N, $D, $(T.T)" for N in Ns, D in Ds, T in Ts - - # rng = MersenneTwister(123456) - # storage = ArrayStorage(T.T) - - # # Compute the total number of dimensions. - # Dlat = N * D - - # # Generate Kronecker-Product transition dynamics. - # A_D = randn(rng, T.T, D, D) - # A = Eye{T.T}(N) ⊗ A_D - - # a = randn(rng, T.T, Dlat) - - # K_N = random_nice_psd_matrix(rng, N, storage) - # Q_D = random_nice_psd_matrix(rng, D, storage) - # Q = collect(K_N ⊗ Q_D) - - # # Generate filtering (input) distribution. - # mf = randn(rng, T.T, Dlat) - # Pf = Symmetric(random_nice_psd_matrix(rng, Dlat, storage)) - - # # Generate corresponding dense dynamics. - # A_dense = collect(A) - - # # Check agreement with dense implementation. - # mp, Pp = predict(mf, Pf, A, a, Q) - # mp_dense_dynamics, Pp_dense_dynamics = predict(mf, Pf, A_dense, a, Q) - # @test mp ≈ mp_dense_dynamics - # @test Symmetric(Pp) ≈ Symmetric(Pp_dense_dynamics) - # @test mp isa Vector{T.T} - # @test Pp isa Matrix{T.T} - - # # Check that predicting twice gives exactly the same answer. - # let - # mf_c = copy(mf) - # Pf_c = copy(Pf) - # A_D_c = copy(A_D) - # A_c = Eye(N) ⊗ A_D - # a_c = copy(a) - # Q_c = copy(Q) - - # m1, P1 = predict(mf_c, Pf_c, A_c, a_c, Q_c) - # m2, P2 = predict(mf_c, Pf_c, A_c, a_c, Q_c) - - # @test m1 == m2 - # @test P1 == P2 - - # @test mf_c == mf - # @test Pf_c == Pf - # @test A_c == A - # @test a_c == a - # @test Q_c == Q - - # (m3, P3), back = Zygote.pullback(predict, mf_c, Pf_c, A_c, a_c, Q_c) - # @test m1 == m3 - # @test P1 == P3 - - # back((m3, P3)) - - # @test mf_c == mf - # @test Pf_c == Pf - # @test A_c == A - # @test a_c == a - # @test Q_c == Q - # end - - # # Verify approximate numerical correctness of pullback. - # U_Pf = collect(cholesky(Symmetric(Pf)).U) - # U_Q = collect(cholesky(Symmetric(Q)).U) - # Δmp = randn(rng, T.T, Dlat) - # ΔPp = randn(rng, T.T, Dlat, Dlat) - - # adjoint_test( - # (mf, U_Pf, A_D, a, U_Q) -> begin - # U_Q = UpperTriangular(U_Q) - # Q = collect(Symmetric(U_Q'U_Q)) - # U_Pf = UpperTriangular(U_Pf) - # A = Eye{T.T}(N) ⊗ A_D - # return predict(mf, Symmetric(U_Pf'U_Pf), A, a, Q) - # end, - # (Δmp, ΔPp), - # mf, U_Pf, A_D, a, U_Q; - # rtol=T.rtol, atol=T.atol, - # ) - # end - - # Ns = [1, 2, 3] - # Ds = [1, 2, 3] - # N_blockss = [1, 2, 3] - - # @testset "BlockDiagonal of KroneckerProduct - $N, $D, $N_blocks, $(T.T)" for - # N in Ns, - # D in Ds, - # N_blocks in N_blockss, - # T in Ts - - # rng = MersenneTwister(123456) - # storage = ArrayStorage(T.T) - - # Dlat = N * D * N_blocks - - # # Generate BlockDiagonal-KroneckerProduct transition dynamics. - # A_Ds = [randn(rng, T.T, D, D) for _ in 1:N_blocks] - # As = [Eye{T.T}(N) ⊗ A_Ds[n] for n in 1:N_blocks] - # A = BlockDiagonal(As) - - # a = randn(rng, T.T, N * D * N_blocks) - - # Qs = [random_nice_psd_matrix(rng, N * D, storage) for _ in 1:N_blocks] - # Q = BlockDiagonal(Qs) - - # # Generate filtering (input) distribution. - # mf = randn(rng, T.T, Dlat) - # Pf = Symmetric(random_nice_psd_matrix(rng, Dlat, storage)) - - # # Generate corresponding dense dynamics. - # A_dense = collect(A) - # Q_dense = collect(Q) - - # # Check agreement with dense implementation. - # mp, Pp = predict(mf, Pf, A, a, Q) - # mp_dense_dynamics, Pp_dense_dynamics = predict(mf, Pf, A_dense, a, Q_dense) - # @test mp ≈ mp_dense_dynamics - # @test Symmetric(Pp) ≈ Symmetric(Pp_dense_dynamics) atol=1e-6 rtol=1e-6 - - # @test A_dense == A - # @test Q_dense == Q - - # @test mp isa Vector{T.T} - # @test Pp isa Matrix{T.T} - - # # Verify approximate numerical correctness of pullback. - # U_Pf = collect(cholesky(Symmetric(Pf)).U) - # U_Q = map(Q -> collect(cholesky(Symmetric(Q)).U), Qs) - # Δmp = randn(rng, T.T, Dlat) - # ΔPp = randn(rng, T.T, Dlat, Dlat) - - # adjoint_test( - # (mf, U_Pf, A_Ds, a, U_Q) -> begin - # Qs = map(U -> UpperTriangular(U)'UpperTriangular(U), U_Q) - # Q = BlockDiagonal(Qs) - # U_Pf = UpperTriangular(U_Pf) - # A = BlockDiagonal(map(A_D -> Eye{T.T}(N) ⊗ A_D, A_Ds)) - # return predict(mf, Symmetric(U_Pf'U_Pf), A, a, Q) - # end, - # (Δmp, ΔPp), - # mf, U_Pf, A_Ds, a, U_Q; - # rtol=T.rtol, atol=T.atol, - # ) - # end -end diff --git a/test/models/mutable_inference_pullbacks.jl b/test/models/mutable_inference_pullbacks.jl deleted file mode 100644 index 02f92975..00000000 --- a/test/models/mutable_inference_pullbacks.jl +++ /dev/null @@ -1,3 +0,0 @@ -@testset "mutable_inference_pullbacks" begin - -end diff --git a/test/models/predict.jl b/test/models/predict.jl deleted file mode 100644 index 4bfe5dc4..00000000 --- a/test/models/predict.jl +++ /dev/null @@ -1,362 +0,0 @@ -using TemporalGPs: predict - -naive_predict(mf, Pf, A, a, Q) = A * mf + a, (A * Pf) * A' + Q - -println("predict:") -@testset "predict" begin - - @testset "StaticArrays" begin - rng = MersenneTwister(123456) - Dlats = [1, 3] - Ts = [ - # (T=Float32, atol=1e-5, rtol=1e-5), - (T=Float64, atol=1e-9, rtol=1e-9), - ] - - @testset "$Dlat, $(T.T)" for Dlat in Dlats, T in Ts - - # Construct a Gauss-Markov model and pull out the relevant paramters. - gmm = random_tv_gmm(rng, Dlat, 1, 1, SArrayStorage(T.T)) - A = first(gmm.A) - a = first(gmm.a) - Q = first(gmm.Q) - mf = gmm.x0.m - Pf = gmm.x0.P - - # Check agreement with the naive implementation. - mp, Pp = predict(mf, Pf, A, a, Q) - mp_naive, Pp_naive = naive_predict(mf, Pf, A, a, Q) - @test mp ≈ mp_naive - @test Pp ≈ Pp_naive - - # Verify approximate numerical correctness of pullback. - U_Pf = cholesky(Symmetric(Pf)).U - U_Q = cholesky(Symmetric(Q)).U - Δmp = SVector{Dlat}(randn(rng, T.T, Dlat)) - ΔPp = SMatrix{Dlat, Dlat}(randn(rng, T.T, Dlat, Dlat)) - adjoint_test( - (mf, U_Pf, A, a, U_Q) -> begin - U_Q = UpperTriangular(U_Q) - U_Pf = UpperTriangular(U_Pf) - return predict(mf, U_Pf'U_Pf, A, a, U_Q'U_Q) - end, - (Δmp, ΔPp), - mf, U_Pf, A, a, U_Q; - rtol=T.rtol, atol=T.atol - ) - - # Evaluate and pullback. - (mp, Pp), back = pullback(predict, mf, Pf, A, a, Q) - (Δmf, ΔPf, ΔA, Δa, ΔQ) = back((Δmp, ΔPp)) - - # Verify correct output types have been produced. - @test mp isa SVector{Dlat, T.T} - @test Pp isa SMatrix{Dlat, Dlat, T.T} - - # Verify the adjoints w.r.t. the inputs are of the correct type. - @test Δmf isa SVector{Dlat, T.T} - @test ΔPf isa SMatrix{Dlat, Dlat, T.T} - @test ΔA isa SMatrix{Dlat, Dlat, T.T} - @test Δa isa SVector{Dlat, T.T} - @test ΔQ isa SMatrix{Dlat, Dlat, T.T} - - # Check that pullback doesn't allocate because StaticArrays. - @test allocs(@benchmark pullback(predict, $mf, $Pf, $A, $a, $Q)) == 0 - @test allocs(@benchmark $back(($Δmp, $ΔPp))) == 0 - end - end - - @testset "Dense" begin - - rng = MersenneTwister(123456) - Dlats = [1, 3] - Ts = [ - # (T=Float32, atol=1e-5, rtol=1e-5), - (T=Float64, atol=1e-9, rtol=1e-9), - ] - - @testset "Matrix - $Dlat, $(T.T)" for Dlat in Dlats, T in Ts - - # Generate parameters for a transition model. - storage = ArrayStorage(T.T) - A = randn(rng, T.T, Dlat, Dlat) - a = randn(rng, T.T, Dlat) - Q = random_nice_psd_matrix(rng, Dlat, storage) - mf = randn(rng, T.T, Dlat) - Pf = Symmetric(random_nice_psd_matrix(rng, Dlat, storage)) - - # Check agreement with the naive implementation. - mp, Pp = predict(mf, Pf, A, a, Q) - mp_naive, Pp_naive = naive_predict(mf, Pf, A, a, Q) - @test mp ≈ mp_naive - @test Pp ≈ Pp_naive - @test mp isa Vector{T.T} - @test Pp isa Matrix{T.T} - - # Verify approximate numerical correctness of pullback. - U_Pf = cholesky(Symmetric(Pf)).U - U_Q = cholesky(Symmetric(Q)).U - Δmp = randn(rng, T.T, Dlat) - ΔPp = randn(rng, T.T, Dlat, Dlat) - adjoint_test( - (mf, U_Pf, A, a, U_Q) -> begin - U_Q = UpperTriangular(U_Q) - U_Pf = UpperTriangular(U_Pf) - return predict(mf, Symmetric(U_Pf'U_Pf), A, a, U_Q'U_Q) - end, - (Δmp, ΔPp), - mf, U_Pf, A, a, U_Q; - rtol=T.rtol, atol=T.atol - ) - - # Evaluate and pullback. - (mp, Pp), back = pullback(predict, mf, Pf, A, a, Q) - (Δmf, ΔPf, ΔA, Δa, ΔQ) = back((Δmp, ΔPp)) - - # Verify correct output types have been produced. - @test mp isa Vector{T.T} - @test Pp isa Matrix{T.T} - - # Verify the adjoints w.r.t. the inputs are of the correct type. - @test Δmf isa Vector{T.T} - @test ΔPf isa Matrix{T.T} - @test ΔA isa Matrix{T.T} - @test Δa isa Vector{T.T} - @test ΔQ isa Matrix{T.T} - end - - n_blockss = [1, 3] - @testset "BlockDiagonal - $Dlat_block, $(T.T), $n_blocks" for - Dlat_block in Dlats, - T in Ts, - n_blocks in n_blockss - - rng = MersenneTwister(123456) - - # Compute the total number of dimensions. - Dlat = n_blocks * Dlat_block - - # Generate block-diagonal transition dynamics. - As = map(_ -> randn(rng, T.T, Dlat_block, Dlat_block), 1:n_blocks) - A = BlockDiagonal(As) - - a = randn(rng, T.T, Dlat) - - Qs = map( - _ -> random_nice_psd_matrix(rng, Dlat_block, ArrayStorage(T.T)), - 1:n_blocks, - ) - Q = BlockDiagonal(Qs) - - # Generate filtering (input) distribution. - mf = randn(rng, T.T, Dlat) - Pf = Symmetric(random_nice_psd_matrix(rng, Dlat, ArrayStorage(T.T))) - - # Check that predicting twice gives exactly the same answer. - let - mf_c = copy(mf) - Pf_c = copy(Pf) - A_c = BlockDiagonal(map(copy, As)) - a_c = copy(a) - Q_c = BlockDiagonal(map(copy, Qs)) - - m1, P1 = predict(mf_c, Pf_c, A_c, a_c, Q_c) - m2, P2 = predict(mf_c, Pf_c, A_c, a_c, Q_c) - - @test m1 == m2 - @test P1 == P2 - - @test mf_c == mf - @test Pf_c == Pf - @test A_c == A - @test a_c == a - @test Q_c == Q - end - - # Generate corresponding dense dynamics. - A_dense = collect(A) - Q_dense = collect(Q) - - # Check agreement with dense implementation. - mp, Pp = predict(mf, Pf, A, a, Q) - mp_dense_dynamics, Pp_dense_dynamics = predict(mf, Pf, A_dense, a, Q_dense) - @test mp ≈ mp_dense_dynamics - @test Symmetric(Pp) ≈ Symmetric(Pp_dense_dynamics) - @test mp isa Vector{T.T} - @test Pp isa Matrix{T.T} - - # Verify approximate numerical correctness of pullback. - U_Pf = collect(cholesky(Symmetric(Pf)).U) - U_Q = map(Q -> collect(cholesky(Symmetric(Q)).U), Qs) - Δmp = randn(rng, T.T, Dlat) - ΔPp = randn(rng, T.T, Dlat, Dlat) - - adjoint_test( - (mf, U_Pf, A, a, U_Q) -> begin - Qs = map(U -> UpperTriangular(U)'UpperTriangular(U), U_Q) - Q = BlockDiagonal(Qs) - U_Pf = UpperTriangular(U_Pf) - return predict(mf, Symmetric(U_Pf'U_Pf), A, a, Q) - end, - (Δmp, ΔPp), - mf, U_Pf, A, a, U_Q; - rtol=T.rtol, atol=T.atol, - ) - end - - Ns = [1, 2] - Ds = [2, 3] - - @testset "KroneckerProduct - $N, $D, $(T.T)" for N in Ns, D in Ds, T in Ts - - rng = MersenneTwister(123456) - storage = ArrayStorage(T.T) - - # Compute the total number of dimensions. - Dlat = N * D - - # Generate Kronecker-Product transition dynamics. - A_D = randn(rng, T.T, D, D) - A = Eye{T.T}(N) ⊗ A_D - - a = randn(rng, T.T, Dlat) - - K_N = random_nice_psd_matrix(rng, N, storage) - Q_D = random_nice_psd_matrix(rng, D, storage) - Q = collect(K_N ⊗ Q_D) - - # Generate filtering (input) distribution. - mf = randn(rng, T.T, Dlat) - Pf = Symmetric(random_nice_psd_matrix(rng, Dlat, storage)) - - # Generate corresponding dense dynamics. - A_dense = collect(A) - - # Check agreement with dense implementation. - mp, Pp = predict(mf, Pf, A, a, Q) - mp_dense_dynamics, Pp_dense_dynamics = predict(mf, Pf, A_dense, a, Q) - @test mp ≈ mp_dense_dynamics - @test Symmetric(Pp) ≈ Symmetric(Pp_dense_dynamics) - @test mp isa Vector{T.T} - @test Pp isa Matrix{T.T} - - # Check that predicting twice gives exactly the same answer. - let - mf_c = copy(mf) - Pf_c = copy(Pf) - A_D_c = copy(A_D) - A_c = Eye(N) ⊗ A_D - a_c = copy(a) - Q_c = copy(Q) - - m1, P1 = predict(mf_c, Pf_c, A_c, a_c, Q_c) - m2, P2 = predict(mf_c, Pf_c, A_c, a_c, Q_c) - - @test m1 == m2 - @test P1 == P2 - - @test mf_c == mf - @test Pf_c == Pf - @test A_c == A - @test a_c == a - @test Q_c == Q - - (m3, P3), back = Zygote.pullback(predict, mf_c, Pf_c, A_c, a_c, Q_c) - @test m1 == m3 - @test P1 == P3 - - back((m3, P3)) - - @test mf_c == mf - @test Pf_c == Pf - @test A_c == A - @test a_c == a - @test Q_c == Q - end - - # Verify approximate numerical correctness of pullback. - U_Pf = collect(cholesky(Symmetric(Pf)).U) - U_Q = collect(cholesky(Symmetric(Q)).U) - Δmp = randn(rng, T.T, Dlat) - ΔPp = randn(rng, T.T, Dlat, Dlat) - - adjoint_test( - (mf, U_Pf, A_D, a, U_Q) -> begin - U_Q = UpperTriangular(U_Q) - Q = collect(Symmetric(U_Q'U_Q)) - U_Pf = UpperTriangular(U_Pf) - A = Eye{T.T}(N) ⊗ A_D - return predict(mf, Symmetric(U_Pf'U_Pf), A, a, Q) - end, - (Δmp, ΔPp), - mf, U_Pf, A_D, a, U_Q; - rtol=T.rtol, atol=T.atol, - ) - end - - Ns = [1, 2, 3] - Ds = [1, 2, 3] - N_blockss = [1, 2, 3] - - @testset "BlockDiagonal of KroneckerProduct - $N, $D, $N_blocks, $(T.T)" for - N in Ns, - D in Ds, - N_blocks in N_blockss, - T in Ts - - rng = MersenneTwister(123456) - storage = ArrayStorage(T.T) - - Dlat = N * D * N_blocks - - # Generate BlockDiagonal-KroneckerProduct transition dynamics. - A_Ds = [randn(rng, T.T, D, D) for _ in 1:N_blocks] - As = [Eye{T.T}(N) ⊗ A_Ds[n] for n in 1:N_blocks] - A = BlockDiagonal(As) - - a = randn(rng, T.T, N * D * N_blocks) - - Qs = [random_nice_psd_matrix(rng, N * D, storage) for _ in 1:N_blocks] - Q = BlockDiagonal(Qs) - - # Generate filtering (input) distribution. - mf = randn(rng, T.T, Dlat) - Pf = Symmetric(random_nice_psd_matrix(rng, Dlat, storage)) - - # Generate corresponding dense dynamics. - A_dense = collect(A) - Q_dense = collect(Q) - - # Check agreement with dense implementation. - mp, Pp = predict(mf, Pf, A, a, Q) - mp_dense_dynamics, Pp_dense_dynamics = predict(mf, Pf, A_dense, a, Q_dense) - @test mp ≈ mp_dense_dynamics - @test Symmetric(Pp) ≈ Symmetric(Pp_dense_dynamics) atol=1e-6 rtol=1e-6 - - @test A_dense == A - @test Q_dense == Q - - @test mp isa Vector{T.T} - @test Pp isa Matrix{T.T} - - # Verify approximate numerical correctness of pullback. - U_Pf = collect(cholesky(Symmetric(Pf)).U) - U_Q = map(Q -> collect(cholesky(Symmetric(Q)).U), Qs) - Δmp = randn(rng, T.T, Dlat) - ΔPp = randn(rng, T.T, Dlat, Dlat) - - adjoint_test( - (mf, U_Pf, A_Ds, a, U_Q) -> begin - Qs = map(U -> UpperTriangular(U)'UpperTriangular(U), U_Q) - Q = BlockDiagonal(Qs) - U_Pf = UpperTriangular(U_Pf) - A = BlockDiagonal(map(A_D -> Eye{T.T}(N) ⊗ A_D, A_Ds)) - return predict(mf, Symmetric(U_Pf'U_Pf), A, a, Q) - end, - (Δmp, ΔPp), - mf, U_Pf, A_Ds, a, U_Q; - rtol=T.rtol, atol=T.atol, - ) - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 91ceabf0..1c9a1133 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,9 +34,6 @@ include("test_util.jl") include(joinpath("models", "immutable_inference_pullbacks.jl")) include(joinpath("models", "checkpointed_immutable_pullbacks.jl")) - include(joinpath("models", "mutable_inference.jl")) - include(joinpath("models", "mutable_inference_pullbacks.jl")) - include(joinpath("models", "scalar_lgssm.jl")) end diff --git a/test/util/zygote_rules.jl b/test/util/zygote_rules.jl index c7420ad4..b9cb2475 100644 --- a/test/util/zygote_rules.jl +++ b/test/util/zygote_rules.jl @@ -1,5 +1,5 @@ using StaticArrays -using TemporalGPs: time_exp +using TemporalGPs: time_exp, logdet_pullback, cholesky_pullback @testset "zygote_rules" begin @testset "SVector" begin @@ -106,4 +106,62 @@ using TemporalGPs: time_exp end adjoint_test(foo, ȳ, randn(rng), x1, x2) end + @testset "$N, $T" for N in [1, 2, 3], T in [Float32, Float64] + + rng = MersenneTwister(123456) + + # Do dense stuff. + S_ = randn(rng, T, N, N) + S = S_ * S_' + I + C = cholesky(S) + Ss = SMatrix{N, N, T}(S) + Cs = cholesky(Ss) + + @testset "cholesky" begin + C_fwd, pb = cholesky_pullback(Symmetric(S)) + Cs_fwd, pbs = cholesky_pullback(Symmetric(Ss)) + + @test eltype(C_fwd) == T + @test eltype(Cs_fwd) == T + + ΔC = randn(rng, T, N, N) + ΔCs = SMatrix{N, N, T}(ΔC) + + @test C.U ≈ Cs.U + @test Cs.U ≈ Cs_fwd.U + + ΔS, = pb((factors=ΔC, )) + ΔSs, = pbs((factors=ΔCs, )) + + @test ΔS ≈ ΔSs + @test eltype(ΔS) == T + @test eltype(ΔSs) == T + + @test allocs(@benchmark cholesky(Symmetric($Ss))) == 0 + @test allocs(@benchmark cholesky_pullback(Symmetric($Ss))) == 0 + @test allocs(@benchmark $pbs((factors=$ΔCs,))) == 0 + end + @testset "logdet" begin + @test logdet(Cs) ≈ logdet(C) + C_fwd, pb = logdet_pullback(C) + Cs_fwd, pbs = logdet_pullback(Cs) + + @test eltype(C_fwd) == T + @test eltype(Cs_fwd) == T + + @test logdet(Cs) ≈ Cs_fwd + + Δ = randn(rng, T) + ΔC = first(pb(Δ)).factors + ΔCs = first(pbs(Δ)).factors + + @test ΔC ≈ ΔCs + @test eltype(ΔC) == T + @test eltype(ΔCs) == T + + @test allocs(@benchmark logdet($Cs)) == 0 + @test allocs(@benchmark logdet_pullback($Cs)) == 0 + @test allocs(@benchmark $pbs($Δ)) == 0 + end + end end