diff --git a/Project.toml b/Project.toml index 5a5429aa..341f3699 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TemporalGPs" uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" authors = ["willtebbutt "] -version = "0.5.1" +version = "0.5.2" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" diff --git a/src/gp/lti_sde.jl b/src/gp/lti_sde.jl index 2d950c75..c26ea0e4 100644 --- a/src/gp/lti_sde.jl +++ b/src/gp/lti_sde.jl @@ -246,7 +246,25 @@ Zygote.@adjoint function stationary_distribution(k::Matern52Kernel, storage_type return stationary_distribution(k, storage_type), Δ->(nothing, nothing) end +# Constant +function TemporalGPs.to_sde(k::ConstantKernel, ::SArrayStorage{T}) where {T<:Real} + F = SMatrix{1, 1, T}(0) + q = convert(T, 0) + H = SVector{1, T}(1) + return F, q, H +end + +function TemporalGPs.stationary_distribution(k::ConstantKernel, ::SArrayStorage{T}) where {T<:Real} + return TemporalGPs.Gaussian( + SVector{1, T}(0), + SMatrix{1, 1, T}( T(only(k.c)) ), + ) +end + +Zygote.@adjoint function to_sde(k::ConstantKernel, storage_type) + return to_sde(k, storage_type), Δ->(nothing, nothing) +end # Scaled diff --git a/src/util/zygote_rules.jl b/src/util/zygote_rules.jl index a0a7b9a6..0ad8797b 100644 --- a/src/util/zygote_rules.jl +++ b/src/util/zygote_rules.jl @@ -38,11 +38,11 @@ end return SMatrix{D1, D2}(X), SMatrix_pullback end -@adjoint function SMatrix{1, 1}(a) - SMatrix_pullback(Δ::AbstractMatrix) = (first(Δ), ) +function Zygote._pullback(::AContext, ::Type{<:SMatrix{1, 1}}, a) + SMatrix_pullback(::Nothing) = nothing + SMatrix_pullback(Δ::AbstractMatrix) = (nothing, first(Δ), ) return SMatrix{1, 1}(a), SMatrix_pullback end - # Implementation of the matrix exponential that assumes one doesn't require access to the # gradient w.r.t. `A`, only `t`. The former is a bit compute-intensive to get at, while the # latter is very cheap. diff --git a/test/gp/lti_sde.jl b/test/gp/lti_sde.jl index 44479e0e..3e22dd67 100644 --- a/test/gp/lti_sde.jl +++ b/test/gp/lti_sde.jl @@ -26,7 +26,7 @@ println("lti_sde:") # (name="static storage Float32", val=SArrayStorage(Float32)), ) - kernels = [Matern12Kernel(), Matern32Kernel(), Matern52Kernel()] + kernels = [Matern12Kernel(), Matern32Kernel(), Matern52Kernel(), ConstantKernel(c=1.5)] @testset "$kernel, $(storage.name)" for kernel in kernels, storage in storages F, q, H = TemporalGPs.to_sde(kernel, storage.val)