Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop delta fn eus fix indexed call rule #219

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1125a33
Change default meta spec from `Any` to `Nothing`
bvdmitri Sep 15, 2022
952c22b
style: make format
bvdmitri Sep 15, 2022
471f706
Update test_rule.jl
bvdmitri Sep 15, 2022
ecaca38
Fix meta spec in various rules
bvdmitri Sep 16, 2022
6732440
Set default meta of average energy to `Nothing`
bvdmitri Sep 16, 2022
9f71f48
style: make format
bvdmitri Sep 16, 2022
b6f1ddc
Update gcv.jl
bvdmitri Sep 16, 2022
c3760af
Merge pull request #206 from biaslab/dev-issue-205
bvdmitri Sep 16, 2022
ce0326a
update: Bump version to 2.5.0
bvdmitri Sep 16, 2022
cb44b8d
Add average energy for uniform dist
albertpod Sep 16, 2022
cc8e31b
Update AE uniform
albertpod Sep 18, 2022
a41b1b2
add GPRegression by SSM demo
HoangMHNguyen Sep 19, 2022
e85f918
Update
albertpod Sep 19, 2022
5b3041a
update coin flip demo
bartvanerp Sep 19, 2022
ee4dd00
Merge branch 'master' into ae_uniform
bvdmitri Sep 19, 2022
423bf5f
Merge branch 'master' into develop-gp2
bvdmitri Sep 19, 2022
f95dbe4
Update
albertpod Sep 19, 2022
0c34f6a
Merge branch 'ae_uniform' of https://github.com/biaslab/ReactiveMP.jl…
albertpod Sep 19, 2022
7630990
update GPR_by_SSM notebook
HoangMHNguyen Sep 20, 2022
e70b673
delete old notebooks
HoangMHNguyen Sep 20, 2022
a356f32
update demo with Plots.jl
bartvanerp Sep 22, 2022
d12f388
update docs
bartvanerp Sep 22, 2022
371a23d
Fix example path
bvdmitri Sep 22, 2022
23a7c95
Merge pull request #214 from biaslab/fix-pyplot
bvdmitri Sep 22, 2022
e1d504a
change true process f_true
HoangMHNguyen Sep 22, 2022
f63174d
Merge branch 'master' into develop-gp2
bvdmitri Sep 22, 2022
5c18073
Merge branch 'master' into ae_uniform
bvdmitri Sep 22, 2022
6c22cb9
Merge pull request #209 from biaslab/develop-gp2
bvdmitri Sep 22, 2022
836b1ff
Merge pull request #208 from biaslab/ae_uniform
bvdmitri Sep 22, 2022
3a53428
update: Bump version to 2.5.1
bvdmitri Sep 23, 2022
215cc74
Update rule.jl
bvdmitri Sep 27, 2022
0d31a3f
merge call rule-fix
bvdmitri Sep 28, 2022
3b01a35
Add backward test for the delta node
bvdmitri Sep 28, 2022
e17cc51
Update Project.toml
bvdmitri Sep 28, 2022
8f4b89b
style: make format
bvdmitri Sep 28, 2022
bf288d3
style: make format
bvdmitri Sep 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReactiveMP"
uuid = "a194aa59-28ba-4574-a09c-4a745416d6e3"
authors = ["Dmitry Bagaev <[email protected]>", "Albert Podusenko <[email protected]>", "Bart van Erp <[email protected]>", "Ismail Senoz <[email protected]>"]
version = "2.4.1"
version = "2.5.1"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -25,7 +25,6 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8"

Expand Down
757 changes: 7 additions & 750 deletions demo/Coin Flip Example.ipynb

Large diffs are not rendered by default.

388 changes: 388 additions & 0 deletions demo/GPRegression by SSM.ipynb

Large diffs are not rendered by default.

840 changes: 840 additions & 0 deletions demo/Invertible Neural Network Tutorial.ipynb

Large diffs are not rendered by default.

1,095 changes: 0 additions & 1,095 deletions demo/Normalizing Flow Tutorial.ipynb

This file was deleted.

2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ makedocs(
"Hidden Markov Model" => "examples/hidden_markov_model.md",
"Hierarchical Gaussian Filter" => "examples/hierarchical_gaussian_filter.md",
"Autoregressive Model" => "examples/autoregressive.md",
"Normalizing Flows Tutorial" => "examples/flow_tutorial.md",
"Invertible Neural Networks" => "examples/invertible_neural_network_tutorial.md",
"Univariate Normal Mixture" => "examples/univariate_normal_mixture.md",
"Multivariate Normal Mixture" => "examples/multivariate_normal_mixture.md",
"Gamma Mixture" => "examples/gamma_mixture.md",
Expand Down

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/src/examples/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ We are going to perform an exact inference to assess the skills of a student giv
- [Hidden Markov Model](@ref examples-hidden-markov-model): An example of structured variational Bayesian inference in Hidden Markov Model with unknown transition and observational matrices.
- [Hierarchical Gaussian Filter](@ref examples-hgf): An example of online inference procedure for Hierarchical Gaussian Filter with univariate noisy observations using Variational Message Passing algorithm. Reference: [Ismail Senoz, Online Message Passing-based Inference in the Hierarchical Gaussian Filter](https://ieeexplore.ieee.org/document/9173980).
- [Autoregressive Model](@ref examples-autoregressive): An example of variational Bayesian Inference on full graph for Autoregressive model. Reference: [Albert Podusenko, Message Passing-Based Inference for Time-Varying Autoregressive Models](https://www.mdpi.com/1099-4300/23/6/683).
- [Normalising Flows](@ref examples-flow): An example of variational Bayesian Inference with Normalizing Flows. Reference: Bard van Erp, Hybrid Inference with Invertible Neural Networks in Factor Graphs (submitted).
- [Invertible Neural Networks](@ref examples-inn): An example of variational Bayesian Inference with invertible neural networks. Reference: Bart van Erp, Hybrid Inference with Invertible Neural Networks in Factor Graphs (accepted).
- [Univariate Gaussian Mixture](@ref examples-univariate-gaussian-mixture): This example implements variational Bayesian inference in a univariate Gaussian mixture model with mean-field assumption.
- [Multivariate Gaussian Mixture](@ref examples-multivariate-gaussian-mixture): This example implements variational Bayesian inference in a multivariate Gaussian mixture model with mean-field assumption.
- [Gamma Mixture](@ref examples-gamma-mixture): This example implements one of the experiments outlined in https://biaslab.github.io/publication/mp-based-inference-in-gmm/ .
Expand Down
1 change: 1 addition & 0 deletions src/algebra/correction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct NoCorrection <: AbstractCorrection end

correction!(::NoCorrection, value::Real) = value
correction!(::NoCorrection, matrix::AbstractMatrix) = matrix
correction!(::Nothing, something) = correction!(NoCorrection(), something)

"""
TinyCorrection
Expand Down
8 changes: 7 additions & 1 deletion src/nodes/gcv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ const DefaultGCVNodeMetadata = GCVMetadata(GaussHermiteCubature(20))

default_meta(::Type{GCV}) = DefaultGCVNodeMetadata

@average_energy GCV (q_y_x::MultivariateNormalDistributionsFamily, q_z::NormalDistributionsFamily, q_κ::Any, q_ω::Any) =
@average_energy GCV (
q_y_x::MultivariateNormalDistributionsFamily,
q_z::NormalDistributionsFamily,
q_κ::Any,
q_ω::Any,
meta::Union{<:GCVMetadata, Nothing}
) =
begin
y_x_mean, y_x_cov = mean_cov(q_y_x)
z_mean, z_var = mean_var(q_z)
Expand Down
5 changes: 5 additions & 0 deletions src/nodes/uniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ function prod(::ProdAnalytical, left::Uniform, right::Beta)
# The special case for `Uniform(0, 1)` which is essentially `p(x) = 1` and does not change anything
return right
end

@average_energy Uniform (q_out::Beta, q_a::PointMass, q_b::PointMass) = begin
@assert (mean(q_a), mean(q_b)) == (0.0, 1.0) "a and b must be equal to 0 and 1 respectively"
0.0
end
34 changes: 27 additions & 7 deletions src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ function rule_macro_parse_on_tag(on)
return :(Type{Val{$(QuoteNode(name))}}), nothing, nothing
elseif @capture(on, (:name_, index_Symbol))
return :(Tuple{Val{$(QuoteNode(name))}, Int}), index, :($index = on[2])
elseif @capture(on, (:name_, k_ = index_Int))
return :(Tuple{Val{$(QuoteNode(name))}, Int}),
index,
:(error(
"`k = ...` syntax in the edge specification is only allowed in the `@call_rule` and `@call_marginalrule` macros"
))
else
error(
"Error in macro. `on` specification is incorrect: $(on). Must be either a quoted symbol expression (e.g. `:out` or `:mean`) or tuple expression with quoted symbol and index identifier (e.g. `(:m, k)` or `(:w, k)`)"
Expand Down Expand Up @@ -197,6 +203,17 @@ function call_rule_macro_parse_fn_args(inputs; specname, prefix, proxy)
return names_arg, values_arg
end

call_rule_macro_construct_on_arg(on_type, on_index::Nothing) = MacroHelpers.bottom_type(on_type)

function call_rule_macro_construct_on_arg(on_type, on_index::Int)
bottomtype = MacroHelpers.bottom_type(on_type)
if @capture(bottomtype, Tuple{Val{R_}, Int})
return :((Val($R), $on_index))
else
error("Internal indexed call rule error: Invalid `on_type` in the `call_rule_macro_construct_on_arg` function.")
end
end

function rule_function_expression(
body::Function,
fuppertype,
Expand Down Expand Up @@ -273,7 +290,7 @@ macro rule(fform, lambda)
fuppertype = MacroHelpers.upper_type(fformtype)
on_type, on_index, on_index_init = rule_macro_parse_on_tag(on)
whereargs = whereargs === nothing ? [] : whereargs
metatype = metatype === nothing ? :Any : metatype
metatype = metatype === nothing ? :Nothing : metatype

options = map(options) do option
@capture(option, name_ = value_) || error("Error in macro. Option specification '$(option)' is incorrect")x
Expand Down Expand Up @@ -418,7 +435,7 @@ macro call_rule(fform, args)
q_names_arg, q_values_arg =
call_rule_macro_parse_fn_args(inputs, specname = :marginals, prefix = :q_, proxy = :(ReactiveMP.Marginal))

on_arg = MacroHelpers.bottom_type(on_type)
on_arg = call_rule_macro_construct_on_arg(on_type, on_index)

output = quote
ReactiveMP.rule(
Expand Down Expand Up @@ -589,7 +606,7 @@ macro marginalrule(fform, lambda)
fuppertype = MacroHelpers.upper_type(fformtype)
on_type, on_index, on_index_init = rule_macro_parse_on_tag(on)
whereargs = whereargs === nothing ? [] : whereargs
metatype = metatype === nothing ? :Any : metatype
metatype = metatype === nothing ? :Nothing : metatype

inputs = map(inputs) do input
@capture(input, iname_::itype_) || error("Error in macro. Input $(input) is incorrect")
Expand Down Expand Up @@ -647,7 +664,7 @@ macro call_marginalrule(fform, args)
q_names_arg, q_values_arg =
call_rule_macro_parse_fn_args(inputs, specname = :marginals, prefix = :q_, proxy = :(ReactiveMP.Marginal))

on_arg = MacroHelpers.bottom_type(on_type)
on_arg = call_rule_macro_construct_on_arg(on_type, on_index)

output = quote
ReactiveMP.marginalrule(
Expand Down Expand Up @@ -824,8 +841,11 @@ end
rule_method_error_extract_fform(f::Function) = string("typeof(", f, ")")
rule_method_error_extract_fform(f) = string(f)

rule_method_error_extract_on(::Type{Val{T}}) where {T} = T
rule_method_error_extract_on(on::Tuple{Val{T}, Int}) where {T} = string("(:", rule_method_error_extract_on(typeof(on[1])), ", k)")
rule_method_error_extract_on(::Type{Val{T}}) where {T} = string(":", T)
rule_method_error_extract_on(::Type{Tuple{Val{T}, Int}}) where {T} = string("(", rule_method_error_extract_on(Val{T}), ", k)")
rule_method_error_extract_on(::Type{Tuple{Val{T}, N}}) where {T, N} = string("(", rule_method_error_extract_on(Val{T}), ", ", convert(Int, N), ")")
rule_method_error_extract_on(::Tuple{Val{T}, Int}) where {T} = string("(", rule_method_error_extract_on(Val{T}), ", k)")
rule_method_error_extract_on(::Tuple{Val{T}, N}) where {T, N} = string("(", rule_method_error_extract_on(Val{T}), ", ", convert(Int, N), ")")

rule_method_error_extract_vconstraint(something) = typeof(something)

Expand Down Expand Up @@ -894,7 +914,7 @@ function Base.showerror(io::IO, error::RuleMethodError)
meta_spec = rule_method_error_extract_meta(error.meta)

possible_fix_definition = """
@rule $(spec_fform)(:$spec_on, $spec_vconstraint) ($arguments_spec, $meta_spec) = begin
@rule $(spec_fform)($spec_on, $spec_vconstraint) ($arguments_spec, $meta_spec) = begin
return ...
end
"""
Expand Down
2 changes: 1 addition & 1 deletion src/rules/and/in2.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
@rule AND(:in2, Marginalisation) (m_out::Bernoulli, m_in1::Bernoulli, meta::Any) = begin
@rule AND(:in2, Marginalisation) (m_out::Bernoulli, m_in1::Bernoulli) = begin
return @call_rule AND(:in1, Marginalisation) (m_out = m_out, m_in2 = m_in1, meta = meta)
end
4 changes: 2 additions & 2 deletions src/rules/delta/extended/in.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
m = A * μ_out + b
V = A * Σ_out * A'

F = val(m, Number) ? Univariate : Multivariate
F = isa(m, Number) ? Univariate : Multivariate

return convert(promote_variate_type(F, NormalMeanVariance), m, V)
end
Expand All @@ -26,7 +26,7 @@
m = A * μ_in + b
V = A * Σ_in * A'

F = val(m, Number) ? Univariate : Multivariate
F = isa(m, Number) ? Univariate : Multivariate

return convert(promote_variate_type(F, NormalMeanVariance), m, V)
end
Expand Down
2 changes: 1 addition & 1 deletion src/rules/delta/helpers/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ end
"""
Return the marginalized statistics of the Gaussian corresponding to an inbound inx
"""
function marginalizeGaussianMV(m::Vector{T}, V::AbstractMatrix, ds::Vector, inx::Int64) where T<:Real
function marginalizeGaussianMV(m::Vector{T}, V::AbstractMatrix, ds::Vector, inx::Int64) where {T <: Real}
if ds[inx] == () # Univariate original
return (m[inx], V[inx, inx]) # Return scalars
else # Multivariate original
Expand Down
9 changes: 8 additions & 1 deletion src/rules/gcv/marginals.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@

@marginalrule GCV(:y_x) (m_y::UniNormalOrExpLinQuad, m_x::UniNormalOrExpLinQuad, q_z::Any, q_κ::Any, q_ω::Any) = begin
@marginalrule GCV(:y_x) (
m_y::UniNormalOrExpLinQuad,
m_x::UniNormalOrExpLinQuad,
q_z::Any,
q_κ::Any,
q_ω::Any,
meta::Union{<:GCVMetadata, Nothing}
) = begin
y_mean, y_precision = mean_precision(m_y)
x_mean, x_precision = mean_precision(m_x)

Expand Down
10 changes: 8 additions & 2 deletions src/rules/gcv/x.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
export rule

@rule GCV(:x, Marginalisation) (m_y::UniNormalOrExpLinQuad, q_z::Any, q_κ::Any, q_ω::Any) = begin
@rule GCV(:x, Marginalisation) (
m_y::UniNormalOrExpLinQuad,
q_z::Any,
q_κ::Any,
q_ω::Any,
meta::Union{<:GCVMetadata, Nothing}
) = begin
y_mean, y_var = mean_var(m_y)
z_mean, z_var = mean_var(q_z)
κ_mean, κ_var = mean_var(q_κ)
Expand All @@ -13,7 +19,7 @@ export rule
return NormalMeanVariance(y_mean, y_var + inv(A * B))
end

@rule GCV(:x, Marginalisation) (q_y::Any, q_z::Any, q_κ::Any, q_ω::Any) = begin
@rule GCV(:x, Marginalisation) (q_y::Any, q_z::Any, q_κ::Any, q_ω::Any, meta::Union{<:GCVMetadata, Nothing}) = begin
y_mean = mean(q_y)
z_mean, z_var = mean_var(q_z)
κ_mean, κ_var = mean_var(q_κ)
Expand Down
10 changes: 8 additions & 2 deletions src/rules/gcv/y.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
export rule

@rule GCV(:y, Marginalisation) (m_x::UniNormalOrExpLinQuad, q_z::Any, q_κ::Any, q_ω::Any) = begin
@rule GCV(:y, Marginalisation) (
m_x::UniNormalOrExpLinQuad,
q_z::Any,
q_κ::Any,
q_ω::Any,
meta::Union{<:GCVMetadata, Nothing}
) = begin
x_mean, x_var = mean_var(m_x)
z_mean, z_var = mean_var(q_z)
κ_mean, κ_var = mean_var(q_κ)
Expand All @@ -13,7 +19,7 @@ export rule
return NormalMeanVariance(x_mean, x_var + inv(A * B))
end

@rule GCV(:y, Marginalisation) (q_x::Any, q_z::Any, q_κ::Any, q_ω::Any) = begin
@rule GCV(:y, Marginalisation) (q_x::Any, q_z::Any, q_κ::Any, q_ω::Any, meta::Union{<:GCVMetadata, Nothing}) = begin
x_mean = mean(q_x)
z_mean, z_var = mean_var(q_z)
κ_mean, κ_var = mean_var(q_κ)
Expand Down
19 changes: 12 additions & 7 deletions src/rules/multiplication/A.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@

@rule typeof(*)(:A, Marginalisation) (m_out::PointMass, m_in::PointMass) = PointMass(mean(m_in) \ mean(m_out))
@rule typeof(*)(:A, Marginalisation) (m_out::PointMass, m_in::PointMass, meta::Union{<:AbstractCorrection, Nothing}) =
PointMass(mean(m_in) \ mean(m_out))

@rule typeof(*)(:A, Marginalisation) (m_out::GammaDistributionsFamily, m_in::PointMass{<:Real}) = begin
@rule typeof(*)(:A, Marginalisation) (
m_out::GammaDistributionsFamily,
m_in::PointMass{<:Real},
meta::Union{<:AbstractCorrection, Nothing}
) = begin
return GammaShapeRate(shape(m_out), rate(m_out) * mean(m_in))
end

# if A is a matrix, then the result is multivariate
@rule typeof(*)(:A, Marginalisation) (
m_out::MultivariateNormalDistributionsFamily,
m_in::PointMass{<:AbstractMatrix},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_in)
ξ_out, W_out = weightedmean_precision(m_out)
Expand All @@ -22,7 +27,7 @@ end
@rule typeof(*)(:A, Marginalisation) (
m_out::MultivariateNormalDistributionsFamily,
m_in::PointMass{<:AbstractVector},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_in)
ξ_out, W_out = weightedmean_precision(m_out)
Expand All @@ -34,7 +39,7 @@ end
@rule typeof(*)(:A, Marginalisation) (
m_out::F,
m_in::PointMass{<:Real},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) where {F <: NormalDistributionsFamily} = begin
A = mean(m_in)
ξ_out, W_out = weightedmean_precision(m_out)
Expand All @@ -46,7 +51,7 @@ end
@rule typeof(*)(:A, Marginalisation) (
m_out::MvNormalMeanCovariance,
m_in::PointMass{<:AbstractMatrix},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_in)
μ_out, Σ_out = mean_cov(m_out)
Expand All @@ -61,7 +66,7 @@ end
@rule typeof(*)(:A, Marginalisation) (
m_out::MvNormalMeanCovariance,
m_in::PointMass{<:AbstractVector},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_in)
μ_out, Σ_out = mean_cov(m_out)
Expand Down
19 changes: 12 additions & 7 deletions src/rules/multiplication/in.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@

@rule typeof(*)(:in, Marginalisation) (m_out::PointMass, m_A::PointMass) = PointMass(mean(m_A) \ mean(m_out))
@rule typeof(*)(:in, Marginalisation) (m_out::PointMass, m_A::PointMass, meta::Union{<:AbstractCorrection, Nothing}) =
PointMass(mean(m_A) \ mean(m_out))

@rule typeof(*)(:in, Marginalisation) (m_out::GammaDistributionsFamily, m_A::PointMass{<:Real}) = begin
@rule typeof(*)(:in, Marginalisation) (
m_out::GammaDistributionsFamily,
m_A::PointMass{<:Real},
meta::Union{<:AbstractCorrection, Nothing}
) = begin
return GammaShapeRate(shape(m_out), rate(m_out) * mean(m_A))
end

# if A is a matrix, then the result is multivariate
@rule typeof(*)(:in, Marginalisation) (
m_out::MultivariateNormalDistributionsFamily,
m_A::PointMass{<:AbstractMatrix},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_A)
ξ_out, W_out = weightedmean_precision(m_out)
Expand All @@ -22,7 +27,7 @@ end
@rule typeof(*)(:in, Marginalisation) (
m_out::MultivariateNormalDistributionsFamily,
m_A::PointMass{<:AbstractVector},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_A)
ξ_out, W_out = weightedmean_precision(m_out)
Expand All @@ -34,7 +39,7 @@ end
@rule typeof(*)(:in, Marginalisation) (
m_out::F,
m_A::PointMass{<:Real},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) where {F <: NormalDistributionsFamily} = begin
A = mean(m_A)
ξ_out, W_out = weightedmean_precision(m_out)
Expand All @@ -46,7 +51,7 @@ end
@rule typeof(*)(:in, Marginalisation) (
m_out::MvNormalMeanCovariance,
m_A::PointMass{<:AbstractMatrix},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_A)
μ_out, Σ_out = mean_cov(m_out)
Expand All @@ -61,7 +66,7 @@ end
@rule typeof(*)(:in, Marginalisation) (
m_out::MvNormalMeanCovariance,
m_A::PointMass{<:AbstractVector},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_A)
μ_out, Σ_out = mean_cov(m_out)
Expand Down
4 changes: 2 additions & 2 deletions src/rules/multiplication/marginals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
m_out::NormalDistributionsFamily,
m_A::PointMass,
m_in::NormalDistributionsFamily,
meta::Any
meta::Union{<:AbstractCorrection, Nothing}
) = begin
b_in = @call_rule typeof(*)(:in, Marginalisation) (m_out = m_out, m_A = m_A, meta = meta)
q_in = prod(ProdAnalytical(), b_in, m_in)
Expand All @@ -17,7 +17,7 @@ end
m_out::UnivariateNormalDistributionsFamily,
m_A::UnivariateNormalDistributionsFamily,
m_in::PointMass{<:Real},
meta::Any
meta::Union{<:AbstractCorrection, Nothing}
) = begin
return @call_marginalrule typeof(*)(:A_in) (m_out = m_out, m_A = m_in, m_in = m_A, meta = meta)
end
Loading