Skip to content

Commit

Permalink
Merge pull request #338 from ReactiveBayes/dev-project-constraint-sup…
Browse files Browse the repository at this point in the history
…port-ef

Skip projection for ef types
  • Loading branch information
bvdmitri authored Aug 1, 2024
2 parents 318bd43 + 239ab98 commit 04ea9fc
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ext/ProjectionExt/ProjectionExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function ReactiveMP.prepare_context(constraint::ProjectedTo)
return ProjectionContext{T}(nothing)
end

function ReactiveMP.constrain_form(constraint::ProjectedTo, context::ProjectionContext, something::Distribution)
function ReactiveMP.constrain_form(constraint::ProjectedTo, context::ProjectionContext, something::Union{Distribution, ExponentialFamilyDistribution})
T = ExponentialFamilyProjection.get_projected_to_type(constraint)
D = ExponentialFamily.exponential_family_typetag(something)
if T === D
Expand Down
28 changes: 28 additions & 0 deletions test/ext/ProjectionExt/inference_with_projection_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,31 @@ end
plot(p1, p2, p3)
end
end

@testitem "Projection constraint should skip processing of `ExponentialFamilyDistribution` instances" begin
using BayesBase, ExponentialFamily, Distributions, ExponentialFamilyProjection

struct NodePrior end
struct NodeLikelihood end

@node NodePrior Stochastic [out, in]
@node NodeLikelihood Stochastic [out, in]

@rule NodePrior(:out, Marginalisation) (q_in::Any,) = NodePrior()
@rule NodeLikelihood(:in, Marginalisation) (q_out::Any,) = NodeLikelihood()

BayesBase.prod(::GenericProd, ::NodePrior, ::NodeLikelihood) = convert(ExponentialFamilyDistribution, Beta(1, 1))

@model function mymodel(y)
a ~ NodePrior(1)
y ~ NodeLikelihood(a)
end

constraints = @constraints begin
q(a)::ProjectedTo(Beta)
end

result = infer(model = mymodel(), data = (y = 1.0,), constraints = constraints)

@test result.posteriors[:a] == Beta(1, 1)
end

0 comments on commit 04ea9fc

Please sign in to comment.