Skip to content

Commit

Permalink
Merge pull request #418 from ReactiveBayes/transition-rules
Browse files Browse the repository at this point in the history
Relax rules for structured VMP in Transition node
  • Loading branch information
wouterwln authored Oct 3, 2024
2 parents dac637e + 929d4f8 commit 99a00a6
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
6 changes: 3 additions & 3 deletions src/rules/transition/in.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import Base.Broadcast: BroadcastFunction

@rule Transition(:in, Marginalisation) (m_out::Categorical, m_a::PointMass) = begin
@rule Transition(:in, Marginalisation) (m_out::Union{DiscreteNonParametric, PointMass}, m_a::PointMass) = begin
@logscale log(sum(mean(A)' * probvec(m_out)))
p = mean(m_a)' * probvec(m_out)
normalize!(p, 1)
Expand All @@ -12,12 +12,12 @@ end
return Categorical(a ./ sum(a))
end

@rule Transition(:in, Marginalisation) (m_out::Categorical, q_a::MatrixDirichlet) = begin
@rule Transition(:in, Marginalisation) (m_out::Union{DiscreteNonParametric, PointMass}, q_a::MatrixDirichlet) = begin
a = clamp.(exp.(mean(BroadcastFunction(log), q_a))' * probvec(m_out), tiny, Inf)
return Categorical(a ./ sum(a))
end

@rule Transition(:in, Marginalisation) (m_out::Categorical, q_a::PointMass, meta::Any) = begin
@rule Transition(:in, Marginalisation) (m_out::Union{DiscreteNonParametric, PointMass}, q_a::PointMass, meta::Any) = begin
return @call_rule Transition(:in, Marginalisation) (m_out = m_out, m_a = q_a, meta = meta)
end

Expand Down
8 changes: 4 additions & 4 deletions src/rules/transition/out.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import Base.Broadcast: BroadcastFunction
# Belief Propagation #
# --------------------------------- #

@rule Transition(:out, Marginalisation) (m_in::Categorical, m_a::PointMass) = begin
@rule Transition(:out, Marginalisation) (m_in::Union{PointMass, DiscreteNonParametric}, m_a::PointMass) = begin
@logscale 0
p = mean(m_a) * probvec(m_in)
normalize!(p, 1)
Expand All @@ -20,17 +20,17 @@ end
# Variational #
# --------------------------------- #

@rule Transition(:out, Marginalisation) (q_in::Categorical, q_a::Any) = begin
@rule Transition(:out, Marginalisation) (q_in::DiscreteNonParametric, q_a::Any) = begin
a = clamp.(exp.(mean(BroadcastFunction(log), q_a) * probvec(q_in)), tiny, Inf)
return Categorical(a ./ sum(a))
end

@rule Transition(:out, Marginalisation) (m_in::Categorical, q_a::ContinuousMatrixDistribution) = begin
@rule Transition(:out, Marginalisation) (m_in::DiscreteNonParametric, q_a::ContinuousMatrixDistribution) = begin
a = clamp.(exp.(mean(BroadcastFunction(log), q_a)) * probvec(m_in), tiny, Inf)
return Categorical(a ./ sum(a))
end

@rule Transition(:out, Marginalisation) (m_in::DiscreteNonParametric, q_a::PointMass, meta::Any) = begin
@rule Transition(:out, Marginalisation) (m_in::Union{PointMass, DiscreteNonParametric}, q_a::PointMass, meta::Any) = begin
@logscale 0
return @call_rule Transition(:out, Marginalisation) (m_in = m_in, m_a = q_a, meta = meta, addons = getaddons())
end
7 changes: 7 additions & 0 deletions test/rules/transition/out_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,11 @@
)
]
end

@testset "Variational Bayes: (m_in::PointMass, q_a::PointMass)" begin
@test_rules [check_type_promotion = false] Transition(:out, Marginalisation) [
(input = (m_in = PointMass([0, 1, 0]), q_a = PointMass([0.2 0.1 0.7; 0.4 0.3 0.3; 0.1 0.6 0.3])), output = Categorical([0.1, 0.3, 0.6])),
(input = (m_in = PointMass([1, 0, 0]), q_a = PointMass([0.1 0.8 0.1; 0.6 0.3 0.1; 0.2 0.4 0.4])), output = Categorical([0.1 / 0.9, 0.6 / 0.9, 0.2 / 0.9]))
]
end
end

0 comments on commit 99a00a6

Please sign in to comment.