From 8788418c62484f154ab7306d133cbaf928dcb752 Mon Sep 17 00:00:00 2001 From: marius Date: Thu, 5 Oct 2023 21:46:46 +0300 Subject: [PATCH] fix some BlockDiagIEB AD issues --- src/field_vectors.jl | 8 ++++---- src/specialops.jl | 33 +++++++++++++++++++++++---------- test/runtests.jl | 4 +++- 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/field_vectors.jl b/src/field_vectors.jl index 3740e327..45a9af71 100644 --- a/src/field_vectors.jl +++ b/src/field_vectors.jl @@ -66,25 +66,25 @@ promote_rule(::Type{F}, ::Type{<:Scalar}) where {F<:Field} = F end @auto_adjoint function sqrt(A::SA) where {SA<:StaticMatrix{2,2,<:DiagOp}} - a, c, b, d = A[1,1], A[2,1], A[2,1], A[2,2] + a, c, b, d = A[1,1], A[2,1], A[1,2], A[2,2] s = sqrt(a*d-b*c) t = pinv(sqrt(a+(d+2s))) SA([t*(a+s) t*b; t*c t*(d+s)]) end @auto_adjoint function det(A::StaticMatrix{2,2,<:DiagOp}) - a, c, b, d = A[1,1], A[2,1], A[2,1], A[2,2] + a, c, b, d = A[1,1], A[2,1], A[1,2], A[2,2] a*d-b*c end @auto_adjoint function pinv(A::SA) where {SA<:StaticMatrix{2,2,<:DiagOp}} - a, c, b, d = A[1,1], A[2,1], A[2,1], A[2,2] + a, c, b, d = A[1,1], A[2,1], A[1,2], A[2,2] idet = pinv(a*d-b*c) SA([d*idet -(b*idet); -(c*idet) a*idet]) end function pinv!(dst::StaticMatrix{2,2,<:DiagOp}, src::StaticMatrix{2,2,<:DiagOp}) - a, c, b, d = src[1,1], src[2,1], src[2,1], src[2,2] + a, c, b, d = src[1,1], src[2,1], src[1,2], src[2,2] det⁻¹ = pinv(@. a*d-b*c) @. dst[1,1] = det⁻¹ * d @. dst[1,2] = -det⁻¹ * b diff --git a/src/specialops.jl b/src/specialops.jl index c3258c82..2b05fd35 100644 --- a/src/specialops.jl +++ b/src/specialops.jl @@ -56,12 +56,14 @@ end # We store the 2x2 block as a 2x2 SMatrix, ΣTE, so that we can easily # call sqrt/inv on it, and the ΣBB block separately as ΣB. This type # is generic with regards to the field type, F. -struct BlockDiagIEB{P,T,D1,D2} <: ImplicitOp{T} - ΣTE :: SizedMatrix{2,2,D1,2,Matrix{D1}} - ΣB :: D2 - function BlockDiagIEB(ΣTE::AbstractMatrix{D1}, ΣB::D2) where {T1,T2,P,F1<:BaseFourier{P},F2<:BaseFourier{P},D1<:Diagonal{T1,F1},D2<:Diagonal{T2,F2}} - T = promote_type(T1, T2) - new{P,T,D1,D2}(ΣTE, ΣB) +struct BlockDiagIEB{P,T,DTE,DB} <: ImplicitOp{T} + ΣTE :: SizedMatrix{2,2,DTE,2,Matrix{DTE}} + ΣB :: DB + function BlockDiagIEB(ΣTE::AbstractMatrix, ΣB) + ΣTE = SizedMatrix{2,2}(ΣTE) + T = promote_type(map(d -> d isa Diagonal ? eltype(d) : Union{}, (ΣTE...,ΣB))...) + P = promote_type(map(d -> d isa Diagonal ? typeof(diag(d).metadata) : Union{}, (ΣTE...,ΣB))...) + new{P,T,eltype(ΣTE),typeof(ΣB)}(ΣTE, ΣB) end end @adjoint function BlockDiagIEB(ΣTE, ΣB) @@ -83,23 +85,24 @@ end size(L::BlockDiagIEB) = 3 .* size(L.ΣB) adjoint(L::BlockDiagIEB) = BlockDiagIEB(adjoint(L.ΣTE), adjoint(L.ΣB)) sqrt(L::BlockDiagIEB) = BlockDiagIEB(sqrt(L.ΣTE), sqrt(L.ΣB)) -@auto_adjoint pinv(L::BlockDiagIEB) = BlockDiagIEB(pinv(L.ΣTE), pinv(L.ΣB)) +pinv(L::BlockDiagIEB) = BlockDiagIEB(pinv(L.ΣTE), pinv(L.ΣB)) diag(L::BlockDiagIEB{P}) where {P} = BaseIEBFourier{P}(diag(L.ΣTE[1,1]), diag(L.ΣTE[2,2]), diag(L.ΣB)) similar(L::BlockDiagIEB) = BlockDiagIEB(similar.(L.ΣTE), similar(L.ΣB)) get_storage(L::BlockDiagIEB) = get_storage(L.ΣB) adapt_structure(storage, L::BlockDiagIEB) = BlockDiagIEB(adapt.(Ref(storage), L.ΣTE), adapt(storage, L.ΣB)) simulate(rng::AbstractRNG, L::BlockDiagIEB; Nbatch=()) = sqrt(L) * randn!(rng, similar(diag(L), Nbatch...)) -logdet(L::BlockDiagIEB) = logdet(det(L.ΣTE)) + logdet(L.ΣB) +@auto_adjoint logdet(L::BlockDiagIEB) = logdet(det(L.ΣTE)) + logdet(L.ΣB) # arithmetic *(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB([L.ΣTE[1,1]*D[:I] L.ΣTE[1,2]*D[:E]; L.ΣTE[2,1]*D[:I] L.ΣTE[2,2]*D[:E]], L.ΣB*D[:B]) *(D::DiagOp{<:BaseIEBFourier}, L::BlockDiagIEB) = BlockDiagIEB([L.ΣTE[1,1]*D[:I] L.ΣTE[1,2]*D[:I]; L.ΣTE[2,1]*D[:E] L.ΣTE[2,2]*D[:E]], L.ΣB*D[:B]) +(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB([L.ΣTE[1,1]+D[:I] L.ΣTE[1,2]; L.ΣTE[2,1] L.ΣTE[2,2]+D[:E]], L.ΣB+D[:B]) *(La::BlockDiagIEB, Lb::BlockDiagIEB) = BlockDiagIEB(La.ΣTE * Lb.ΣTE, La.ΣB * Lb.ΣB) +(La::BlockDiagIEB, Lb::BlockDiagIEB) = BlockDiagIEB(La.ΣTE + Lb.ΣTE, La.ΣB + Lb.ΣB) +-(L::BlockDiagIEB) = BlockDiagIEB(.-(L.ΣTE), -L.ΣB) +(L::BlockDiagIEB, U::UniformScaling{<:Scalar}) = BlockDiagIEB([(L.ΣTE[1,1]+U) L.ΣTE[1,2]; L.ΣTE[2,1] (L.ΣTE[2,2]+U)], L.ΣB+U) -*(L::BlockDiagIEB, λ::Scalar) = BlockDiagIEB(L.ΣTE * λ, L.ΣB * λ) +@auto_adjoint *(L::BlockDiagIEB, λ::Scalar) = BlockDiagIEB([L.ΣTE[1,1]*λ L.ΣTE[1,2]*λ; L.ΣTE[2,1]*λ L.ΣTE[2,2]*λ], L.ΣB*λ) +(U::UniformScaling{<:Scalar}, L::BlockDiagIEB) = L + U -*(λ::Scalar, L::BlockDiagIEB) = L * λ +@auto_adjoint *(λ::Scalar, L::BlockDiagIEB) = L * λ # indexing function getindex(L::BlockDiagIEB{P}, k::Symbol) where {P} @match k begin @@ -112,6 +115,16 @@ function getindex(L::BlockDiagIEB{P}, k::Symbol) where {P} _ => throw(ArgumentError("Invalid BlockDiagIEB index: $k")) end end +@adjoint function Base.getproperty(L::BlockDiagIEB, k::Symbol) + function BlockDiagIEB_getproperty_pullback(Δ) + if k == :ΣTE + (BlockDiagIEB(Δ, zero(getfield(L,:ΣB))), nothing) + else + (BlockDiagIEB(zero.(getfield(L,:ΣTE)), Δ), nothing) + end + end + return getfield(L, k), BlockDiagIEB_getproperty_pullback +end # hashing hash(L::BlockDiagIEB, h::UInt64) = foldr(hash, (typeof(L), L.ΣTE[1,1], L.ΣTE[1,2], L.ΣTE[2,2], L.ΣB), init=h) diff --git a/test/runtests.jl b/test/runtests.jl index 4113de83..898f25aa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -610,7 +610,9 @@ end atol = pol==:IP ? 30 : 3 @test_real_gradient(α -> logpdf( ds; f = f + α * δf, ϕ = ϕ + α * δϕ), 0, atol=atol) @test_real_gradient(α -> logpdf(Mixed(ds); f° = f° + α * δf, ϕ° = ϕ° + α * δϕ), 0, atol=atol) - + @test_real_gradient(r -> logpdf( ds; f, ϕ, θ=(;r)), T(0.1), atol=atol) + @test_real_gradient(r -> logpdf(Mixed(ds); f°, ϕ°, θ=(;r)), T(0.1), atol=atol) + end end