diff --git a/src/util/zygote_rules.jl b/src/util/zygote_rules.jl index b52324b5..a0a7b9a6 100644 --- a/src/util/zygote_rules.jl +++ b/src/util/zygote_rules.jl @@ -253,7 +253,8 @@ function Zygote._pullback( ctx::AContext, ::Type{Symmetric}, X::StridedMatrix{<:Real}, uplo=:U, ) function Symmetric_pullback(Δ) - return nothing, _symmetric_back(Δ, uplo), nothing + ΔX = Δ === nothing ? nothing : _symmetric_back(Δ, uplo) + return nothing, ΔX, nothing end return Symmetric(X, uplo), Symmetric_pullback end diff --git a/test/models/lgssm.jl b/test/models/lgssm.jl index c8de690a..fb426e6e 100644 --- a/test/models/lgssm.jl +++ b/test/models/lgssm.jl @@ -123,7 +123,6 @@ println("lgssm:") max_primal_allocs=10, max_forward_allocs=35, max_backward_allocs=50, - # check_allocs=false, check_allocs=storage.val isa SArrayStorage, ) end diff --git a/test/test_util.jl b/test/test_util.jl index 94128636..2a80d485 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -418,21 +418,20 @@ function test_interface( x_val = rand(rng, x) y = conditional_rand(rng, conditional, x_val) - # @testset "rand" begin - # @test length(y) == dim_out(conditional) - # args = (conditional, x_val) - # @code_warntype conditional_rand(y, args...) - # check_infers && @inferred conditional_rand(rng, args...) - # if check_adjoints - # adjoint_test( - # (f, x) -> conditional_rand(MersenneTwister(123456), f, x), args; - # check_infers=check_infers, kwargs..., - # ) - # end - # if check_allocs - # check_adjoint_allocations(conditional_rand, (rng, args...); kwargs...) - # end - # end + @testset "rand" begin + @test length(y) == dim_out(conditional) + args = (conditional, x_val) + check_infers && @inferred conditional_rand(rng, args...) + if check_adjoints + adjoint_test( + (f, x) -> conditional_rand(MersenneTwister(123456), f, x), args; + check_infers=check_infers, kwargs..., + ) + end + if check_allocs + check_adjoint_allocations(conditional_rand, (rng, args...); kwargs...) + end + end @testset "predict" begin @test predict(x, conditional) isa Gaussian @@ -450,26 +449,26 @@ function test_interface( @test cov(pred_marg) isa Diagonal end - # @testset "posterior_and_lml" begin - # args = (x, conditional, y) - # @test posterior_and_lml(args...) isa Tuple{Gaussian, Real} - # check_infers && @inferred posterior_and_lml(args...) - # if check_adjoints - # (Δx, Δlml) = rand_zygote_tangent(posterior_and_lml(args...)) - # ∂args = map(rand_tangent, args) - # adjoint_test(posterior_and_lml, (Δx, Δlml), args, ∂args) - # adjoint_test(posterior_and_lml, (Δx, nothing), args, ∂args) - # adjoint_test(posterior_and_lml, (nothing, Δlml), args, ∂args) - # adjoint_test(posterior_and_lml, (nothing, nothing), args, ∂args) - # end - # if check_allocs - # (Δx, Δlml) = rand_zygote_tangent(posterior_and_lml(args...)) - # check_adjoint_allocations(posterior_and_lml, (Δx, Δlml), args; kwargs...) - # check_adjoint_allocations(posterior_and_lml, (nothing, Δlml), args; kwargs...) - # check_adjoint_allocations(posterior_and_lml, (Δx, nothing), args; kwargs...) - # check_adjoint_allocations(posterior_and_lml, (nothing, nothing), args; kwargs...) - # end - # end + @testset "posterior_and_lml" begin + args = (x, conditional, y) + @test posterior_and_lml(args...) isa Tuple{Gaussian, Real} + check_infers && @inferred posterior_and_lml(args...) + if check_adjoints + (Δx, Δlml) = rand_zygote_tangent(posterior_and_lml(args...)) + ∂args = map(rand_tangent, args) + adjoint_test(posterior_and_lml, (Δx, Δlml), args, ∂args) + adjoint_test(posterior_and_lml, (Δx, nothing), args, ∂args) + adjoint_test(posterior_and_lml, (nothing, Δlml), args, ∂args) + adjoint_test(posterior_and_lml, (nothing, nothing), args, ∂args) + end + if check_allocs + (Δx, Δlml) = rand_zygote_tangent(posterior_and_lml(args...)) + check_adjoint_allocations(posterior_and_lml, (Δx, Δlml), args; kwargs...) + check_adjoint_allocations(posterior_and_lml, (nothing, Δlml), args; kwargs...) + check_adjoint_allocations(posterior_and_lml, (Δx, nothing), args; kwargs...) + check_adjoint_allocations(posterior_and_lml, (nothing, nothing), args; kwargs...) + end + end end """