Skip to content

Commit

Permalink
Remove redundant rrules (#36)
Browse files Browse the repository at this point in the history
* Remove redundant code + add new tests

* Delete dead code

* Remove redundant tests

* Move correctness tests

* Move Zygote tests around

* Move tests around

* Add Zygote hacks

* Update checkpointing

* Bump patch and fix test/Project.toml
  • Loading branch information
willtebbutt authored Dec 4, 2020
1 parent 4bdda0d commit c5e7cfa
Show file tree
Hide file tree
Showing 15 changed files with 408 additions and 1,657 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["willtebbutt <[email protected]>"]
version = "0.3.7"
version = "0.3.8"

[deps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Expand Down
6 changes: 2 additions & 4 deletions src/TemporalGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module TemporalGPs

using FillArrays: AbstractFill
using Kronecker: KroneckerProduct
using Zygote: _pullback

import Stheno: mean, cov, pairwise, logpdf, AV, AM, FiniteGP, AbstractGP

Expand All @@ -32,10 +33,7 @@ module TemporalGPs

include(joinpath("models", "immutable_inference.jl"))
include(joinpath("models", "immutable_inference_pullbacks.jl"))
include(joinpath("models", "checkpointed_immutable_pullbacks.jl"))

include(joinpath("models", "mutable_inference.jl"))
include(joinpath("models", "mutable_inference_pullbacks.jl"))
include(joinpath("models", "checkpointed_immutable_pullbacks.jl"))

include(joinpath("models", "scalar_lgssm.jl"))

Expand Down
10 changes: 6 additions & 4 deletions src/models/checkpointed_immutable_pullbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ for (foo, step_foo, foo_pullback, step_foo_pullback) in [
# Grabs the penultimate filtering distribution, xs[end].
Δys = Vector{eltype(ys)}(undef, T)
(Δα, Δx__) = get_pb(f)(last(Δvs))
_, pullback_last = $step_foo_pullback(model[T], xs[end], ys[T])
Δmodel_at_T, Δx, Δy = pullback_last((Δlml, Δα, Δx__))
_, pullback_last = _pullback(NoContext(), $step_foo, model[T], xs[end], ys[T])
_, Δmodel_at_T, Δx, Δy = pullback_last((Δlml, Δα, Δx__))
Δmodel = get_adjoint_storage(model, Δmodel_at_T)
Δys[T] = Δy

Expand All @@ -144,8 +144,10 @@ for (foo, step_foo, foo_pullback, step_foo_pullback) in [
if t != T
Δα, Δx__ = get_pb(f)(Δvs[t])
Δx_ = Zygote.accum(Δx, Δx__)
_, pullback_t = $step_foo_pullback(model[t], xs_block[c], ys[t])
Δmodel_at_t, Δx, Δy = pullback_t((Δlml, Δα, Δx_))
_, pullback_t = _pullback(
NoContext(), $step_foo, model[t], xs_block[c], ys[t],
)
_, Δmodel_at_t, Δx, Δy = pullback_t((Δlml, Δα, Δx_))
Δmodel = _accum_at(Δmodel, t, Δmodel_at_t)
Δys[t] = Δy
end
Expand Down
31 changes: 6 additions & 25 deletions src/models/immutable_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,26 +72,6 @@ end
return A * mf + a, (A * Pf) * A' + Q
end

# # Immutable inference for heap-allocated arrays.
# @inline function predict(
# mf::StridedVector{T},
# Pf::StridedMatrix{T},
# A::StridedMatrix{T},
# a::StridedVector{T},
# Q::StridedMatrix{T},
# ) where {T<:Real}

# # Compute filtering mean vector.
# mp = A * mf + a

# # Compute filtering covariance matrix.
# Pp = similar(Pf)
# BLAS.copy!(Pp, Q)
# mul!(Pp, A * Symmetric(Pf), A', one(T), one(T))

# return mp, Pp
# end

@inline function update_decorrelate(
mp::AV{T}, Pp::AM{T}, H::AM{T}, h::AV{T}, Σ::AM{T}, y::AV{T},
) where {T<:Real}
Expand Down Expand Up @@ -126,8 +106,9 @@ end

_compute_Pf(Pp::AM{T}, B::AM{T}) where {T<:Real} = Pp - B'B

function _compute_Pf(Pp::Matrix{T}, B::Matrix{T}) where {T<:Real}
# Copy of Pp is necessary to ensure that the memory isn't modified.
# return BLAS.syrk!('U', 'T', -one(T), B, one(T), copy(Pp))
return LinearAlgebra.copytri!(BLAS.syrk!('U', 'T', -one(T), B, one(T), copy(Pp)), 'U')
end
# function _compute_Pf(Pp::Matrix{T}, B::Matrix{T}) where {T<:Real}
# # Copy of Pp is necessary to ensure that the memory isn't modified.
# # return BLAS.syrk!('U', 'T', -one(T), B, one(T), copy(Pp))
# # I probably _do_ need a custom adjoint for this...
# return LinearAlgebra.copytri!(BLAS.syrk!('U', 'T', -one(T), B, one(T), copy(Pp)), 'U')
# end
Loading

2 comments on commit c5e7cfa

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/25825

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.8 -m "<description of version>" c5e7cfa188cdde1633348c3864e41e91ac856862
git push origin v0.3.8

Please sign in to comment.