From 258a793523e4b4bf4b101034874ec419e447001e Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 29 Jul 2024 14:08:35 +0200 Subject: [PATCH 1/2] skip projection for ef types --- ext/ProjectionExt/ProjectionExt.jl | 2 +- .../inference_with_projection_tests.jl | 28 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/ext/ProjectionExt/ProjectionExt.jl b/ext/ProjectionExt/ProjectionExt.jl index 86e3821a6..f7e455da1 100644 --- a/ext/ProjectionExt/ProjectionExt.jl +++ b/ext/ProjectionExt/ProjectionExt.jl @@ -14,7 +14,7 @@ function ReactiveMP.prepare_context(constraint::ProjectedTo) return ProjectionContext{T}(nothing) end -function ReactiveMP.constrain_form(constraint::ProjectedTo, context::ProjectionContext, something::Distribution) +function ReactiveMP.constrain_form(constraint::ProjectedTo, context::ProjectionContext, something::Union{Distribution, ExponentialFamilyDistribution}) T = ExponentialFamilyProjection.get_projected_to_type(constraint) D = ExponentialFamily.exponential_family_typetag(something) if T === D diff --git a/test/ext/ProjectionExt/inference_with_projection_tests.jl b/test/ext/ProjectionExt/inference_with_projection_tests.jl index 16b6cbcf2..eb8e24d6b 100644 --- a/test/ext/ProjectionExt/inference_with_projection_tests.jl +++ b/test/ext/ProjectionExt/inference_with_projection_tests.jl @@ -494,3 +494,31 @@ end plot(p1, p2, p3) end end + +@testitem "Projection constraint should skip processing of `ExponentialFamilyDistribution` instances" begin + using BayesBase, ExponentialFamily, Distributions, ExponentialFamilyProjection + + struct NodePrior end + struct NodeLikelihood end + + @node NodePrior Stochastic [out, in] + @node NodeLikelihood Stochastic [out, in] + + @rule NodePrior(:out, Marginalisation) (q_in::Any,) = NodePrior() + @rule NodeLikelihood(:in, Marginalisation) (q_out::Any,) = NodeLikelihood() + + BayesBase.prod(::GenericProd, ::NodePrior, ::NodeLikelihood) = convert(ExponentialFamilyDistribution, Beta(1, 1)) + + @model function mymodel(y) + a ~ NodePrior(1) + y ~ NodeLikelihood(a) + end + + constraints = @constraints begin + q(a)::ProjectedTo(Beta) + end + + result = infer(model = mymodel(), data = (y = 1.0,), constraints = constraints) + + @test result.posteriors[:a] == Beta(1, 1) +end \ No newline at end of file From 239ab9859051ee5ab368c7a59f1497230b2896c7 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Mon, 29 Jul 2024 14:18:51 +0200 Subject: [PATCH 2/2] make format --- test/ext/ProjectionExt/inference_with_projection_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ext/ProjectionExt/inference_with_projection_tests.jl b/test/ext/ProjectionExt/inference_with_projection_tests.jl index eb8e24d6b..99ce1ac27 100644 --- a/test/ext/ProjectionExt/inference_with_projection_tests.jl +++ b/test/ext/ProjectionExt/inference_with_projection_tests.jl @@ -521,4 +521,4 @@ end result = infer(model = mymodel(), data = (y = 1.0,), constraints = constraints) @test result.posteriors[:a] == Beta(1, 1) -end \ No newline at end of file +end