Skip to content

Commit

Permalink
Merge pull request #219 from biaslab/develop-delta-fn-eus-fix-indexed…
Browse files Browse the repository at this point in the history
…-call-rule

Develop delta fn eus fix indexed call rule
  • Loading branch information
bvdmitri authored Sep 28, 2022
2 parents 7786cbb + bf288d3 commit 9aa50b4
Show file tree
Hide file tree
Showing 35 changed files with 1,624 additions and 2,082 deletions.
3 changes: 1 addition & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReactiveMP"
uuid = "a194aa59-28ba-4574-a09c-4a745416d6e3"
authors = ["Dmitry Bagaev <[email protected]>", "Albert Podusenko <[email protected]>", "Bart van Erp <[email protected]>", "Ismail Senoz <[email protected]>"]
version = "2.4.1"
version = "2.5.1"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -25,7 +25,6 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8"

Expand Down
757 changes: 7 additions & 750 deletions demo/Coin Flip Example.ipynb

Large diffs are not rendered by default.

388 changes: 388 additions & 0 deletions demo/GPRegression by SSM.ipynb

Large diffs are not rendered by default.

840 changes: 840 additions & 0 deletions demo/Invertible Neural Network Tutorial.ipynb

Large diffs are not rendered by default.

1,095 changes: 0 additions & 1,095 deletions demo/Normalizing Flow Tutorial.ipynb

This file was deleted.

2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ makedocs(
"Hidden Markov Model" => "examples/hidden_markov_model.md",
"Hierarchical Gaussian Filter" => "examples/hierarchical_gaussian_filter.md",
"Autoregressive Model" => "examples/autoregressive.md",
"Normalizing Flows Tutorial" => "examples/flow_tutorial.md",
"Invertible Neural Networks" => "examples/invertible_neural_network_tutorial.md",
"Univariate Normal Mixture" => "examples/univariate_normal_mixture.md",
"Multivariate Normal Mixture" => "examples/multivariate_normal_mixture.md",
"Gamma Mixture" => "examples/gamma_mixture.md",
Expand Down

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/src/examples/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ We are going to perform an exact inference to assess the skills of a student giv
- [Hidden Markov Model](@ref examples-hidden-markov-model): An example of structured variational Bayesian inference in Hidden Markov Model with unknown transition and observational matrices.
- [Hierarchical Gaussian Filter](@ref examples-hgf): An example of online inference procedure for Hierarchical Gaussian Filter with univariate noisy observations using Variational Message Passing algorithm. Reference: [Ismail Senoz, Online Message Passing-based Inference in the Hierarchical Gaussian Filter](https://ieeexplore.ieee.org/document/9173980).
- [Autoregressive Model](@ref examples-autoregressive): An example of variational Bayesian Inference on full graph for Autoregressive model. Reference: [Albert Podusenko, Message Passing-Based Inference for Time-Varying Autoregressive Models](https://www.mdpi.com/1099-4300/23/6/683).
- [Normalising Flows](@ref examples-flow): An example of variational Bayesian Inference with Normalizing Flows. Reference: Bard van Erp, Hybrid Inference with Invertible Neural Networks in Factor Graphs (submitted).
- [Invertible Neural Networks](@ref examples-inn): An example of variational Bayesian Inference with invertible neural networks. Reference: Bart van Erp, Hybrid Inference with Invertible Neural Networks in Factor Graphs (accepted).
- [Univariate Gaussian Mixture](@ref examples-univariate-gaussian-mixture): This example implements variational Bayesian inference in a univariate Gaussian mixture model with mean-field assumption.
- [Multivariate Gaussian Mixture](@ref examples-multivariate-gaussian-mixture): This example implements variational Bayesian inference in a multivariate Gaussian mixture model with mean-field assumption.
- [Gamma Mixture](@ref examples-gamma-mixture): This example implements one of the experiments outlined in https://biaslab.github.io/publication/mp-based-inference-in-gmm/ .
Expand Down
1 change: 1 addition & 0 deletions src/algebra/correction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct NoCorrection <: AbstractCorrection end

correction!(::NoCorrection, value::Real) = value
correction!(::NoCorrection, matrix::AbstractMatrix) = matrix
correction!(::Nothing, something) = correction!(NoCorrection(), something)

"""
TinyCorrection
Expand Down
8 changes: 7 additions & 1 deletion src/nodes/gcv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ const DefaultGCVNodeMetadata = GCVMetadata(GaussHermiteCubature(20))

default_meta(::Type{GCV}) = DefaultGCVNodeMetadata

@average_energy GCV (q_y_x::MultivariateNormalDistributionsFamily, q_z::NormalDistributionsFamily, q_κ::Any, q_ω::Any) =
@average_energy GCV (
q_y_x::MultivariateNormalDistributionsFamily,
q_z::NormalDistributionsFamily,
q_κ::Any,
q_ω::Any,
meta::Union{<:GCVMetadata, Nothing}
) =
begin
y_x_mean, y_x_cov = mean_cov(q_y_x)
z_mean, z_var = mean_var(q_z)
Expand Down
5 changes: 5 additions & 0 deletions src/nodes/uniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ function prod(::ProdAnalytical, left::Uniform, right::Beta)
# The special case for `Uniform(0, 1)` which is essentially `p(x) = 1` and does not change anything
return right
end

@average_energy Uniform (q_out::Beta, q_a::PointMass, q_b::PointMass) = begin
@assert (mean(q_a), mean(q_b)) == (0.0, 1.0) "a and b must be equal to 0 and 1 respectively"
0.0
end
34 changes: 27 additions & 7 deletions src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ function rule_macro_parse_on_tag(on)
return :(Type{Val{$(QuoteNode(name))}}), nothing, nothing
elseif @capture(on, (:name_, index_Symbol))
return :(Tuple{Val{$(QuoteNode(name))}, Int}), index, :($index = on[2])
elseif @capture(on, (:name_, k_ = index_Int))
return :(Tuple{Val{$(QuoteNode(name))}, Int}),
index,
:(error(
"`k = ...` syntax in the edge specification is only allowed in the `@call_rule` and `@call_marginalrule` macros"
))
else
error(
"Error in macro. `on` specification is incorrect: $(on). Must be either a quoted symbol expression (e.g. `:out` or `:mean`) or tuple expression with quoted symbol and index identifier (e.g. `(:m, k)` or `(:w, k)`)"
Expand Down Expand Up @@ -197,6 +203,17 @@ function call_rule_macro_parse_fn_args(inputs; specname, prefix, proxy)
return names_arg, values_arg
end

call_rule_macro_construct_on_arg(on_type, on_index::Nothing) = MacroHelpers.bottom_type(on_type)

function call_rule_macro_construct_on_arg(on_type, on_index::Int)
bottomtype = MacroHelpers.bottom_type(on_type)
if @capture(bottomtype, Tuple{Val{R_}, Int})
return :((Val($R), $on_index))
else
error("Internal indexed call rule error: Invalid `on_type` in the `call_rule_macro_construct_on_arg` function.")
end
end

function rule_function_expression(
body::Function,
fuppertype,
Expand Down Expand Up @@ -273,7 +290,7 @@ macro rule(fform, lambda)
fuppertype = MacroHelpers.upper_type(fformtype)
on_type, on_index, on_index_init = rule_macro_parse_on_tag(on)
whereargs = whereargs === nothing ? [] : whereargs
metatype = metatype === nothing ? :Any : metatype
metatype = metatype === nothing ? :Nothing : metatype

options = map(options) do option
@capture(option, name_ = value_) || error("Error in macro. Option specification '$(option)' is incorrect")x
Expand Down Expand Up @@ -418,7 +435,7 @@ macro call_rule(fform, args)
q_names_arg, q_values_arg =
call_rule_macro_parse_fn_args(inputs, specname = :marginals, prefix = :q_, proxy = :(ReactiveMP.Marginal))

on_arg = MacroHelpers.bottom_type(on_type)
on_arg = call_rule_macro_construct_on_arg(on_type, on_index)

output = quote
ReactiveMP.rule(
Expand Down Expand Up @@ -589,7 +606,7 @@ macro marginalrule(fform, lambda)
fuppertype = MacroHelpers.upper_type(fformtype)
on_type, on_index, on_index_init = rule_macro_parse_on_tag(on)
whereargs = whereargs === nothing ? [] : whereargs
metatype = metatype === nothing ? :Any : metatype
metatype = metatype === nothing ? :Nothing : metatype

inputs = map(inputs) do input
@capture(input, iname_::itype_) || error("Error in macro. Input $(input) is incorrect")
Expand Down Expand Up @@ -647,7 +664,7 @@ macro call_marginalrule(fform, args)
q_names_arg, q_values_arg =
call_rule_macro_parse_fn_args(inputs, specname = :marginals, prefix = :q_, proxy = :(ReactiveMP.Marginal))

on_arg = MacroHelpers.bottom_type(on_type)
on_arg = call_rule_macro_construct_on_arg(on_type, on_index)

output = quote
ReactiveMP.marginalrule(
Expand Down Expand Up @@ -824,8 +841,11 @@ end
rule_method_error_extract_fform(f::Function) = string("typeof(", f, ")")
rule_method_error_extract_fform(f) = string(f)

rule_method_error_extract_on(::Type{Val{T}}) where {T} = T
rule_method_error_extract_on(on::Tuple{Val{T}, Int}) where {T} = string("(:", rule_method_error_extract_on(typeof(on[1])), ", k)")
rule_method_error_extract_on(::Type{Val{T}}) where {T} = string(":", T)
rule_method_error_extract_on(::Type{Tuple{Val{T}, Int}}) where {T} = string("(", rule_method_error_extract_on(Val{T}), ", k)")
rule_method_error_extract_on(::Type{Tuple{Val{T}, N}}) where {T, N} = string("(", rule_method_error_extract_on(Val{T}), ", ", convert(Int, N), ")")
rule_method_error_extract_on(::Tuple{Val{T}, Int}) where {T} = string("(", rule_method_error_extract_on(Val{T}), ", k)")
rule_method_error_extract_on(::Tuple{Val{T}, N}) where {T, N} = string("(", rule_method_error_extract_on(Val{T}), ", ", convert(Int, N), ")")

rule_method_error_extract_vconstraint(something) = typeof(something)

Expand Down Expand Up @@ -894,7 +914,7 @@ function Base.showerror(io::IO, error::RuleMethodError)
meta_spec = rule_method_error_extract_meta(error.meta)

possible_fix_definition = """
@rule $(spec_fform)(:$spec_on, $spec_vconstraint) ($arguments_spec, $meta_spec) = begin
@rule $(spec_fform)($spec_on, $spec_vconstraint) ($arguments_spec, $meta_spec) = begin
return ...
end
"""
Expand Down
2 changes: 1 addition & 1 deletion src/rules/and/in2.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
@rule AND(:in2, Marginalisation) (m_out::Bernoulli, m_in1::Bernoulli, meta::Any) = begin
@rule AND(:in2, Marginalisation) (m_out::Bernoulli, m_in1::Bernoulli) = begin
return @call_rule AND(:in1, Marginalisation) (m_out = m_out, m_in2 = m_in1, meta = meta)
end
4 changes: 2 additions & 2 deletions src/rules/delta/extended/in.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
m = A * μ_out + b
V = A * Σ_out * A'

F = val(m, Number) ? Univariate : Multivariate
F = isa(m, Number) ? Univariate : Multivariate

return convert(promote_variate_type(F, NormalMeanVariance), m, V)
end
Expand All @@ -26,7 +26,7 @@
m = A * μ_in + b
V = A * Σ_in * A'

F = val(m, Number) ? Univariate : Multivariate
F = isa(m, Number) ? Univariate : Multivariate

return convert(promote_variate_type(F, NormalMeanVariance), m, V)
end
Expand Down
2 changes: 1 addition & 1 deletion src/rules/delta/helpers/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ end
"""
Return the marginalized statistics of the Gaussian corresponding to an inbound inx
"""
function marginalizeGaussianMV(m::Vector{T}, V::AbstractMatrix, ds::Vector, inx::Int64) where T<:Real
function marginalizeGaussianMV(m::Vector{T}, V::AbstractMatrix, ds::Vector, inx::Int64) where {T <: Real}
if ds[inx] == () # Univariate original
return (m[inx], V[inx, inx]) # Return scalars
else # Multivariate original
Expand Down
9 changes: 8 additions & 1 deletion src/rules/gcv/marginals.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@

@marginalrule GCV(:y_x) (m_y::UniNormalOrExpLinQuad, m_x::UniNormalOrExpLinQuad, q_z::Any, q_κ::Any, q_ω::Any) = begin
@marginalrule GCV(:y_x) (
m_y::UniNormalOrExpLinQuad,
m_x::UniNormalOrExpLinQuad,
q_z::Any,
q_κ::Any,
q_ω::Any,
meta::Union{<:GCVMetadata, Nothing}
) = begin
y_mean, y_precision = mean_precision(m_y)
x_mean, x_precision = mean_precision(m_x)

Expand Down
10 changes: 8 additions & 2 deletions src/rules/gcv/x.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
export rule

@rule GCV(:x, Marginalisation) (m_y::UniNormalOrExpLinQuad, q_z::Any, q_κ::Any, q_ω::Any) = begin
@rule GCV(:x, Marginalisation) (
m_y::UniNormalOrExpLinQuad,
q_z::Any,
q_κ::Any,
q_ω::Any,
meta::Union{<:GCVMetadata, Nothing}
) = begin
y_mean, y_var = mean_var(m_y)
z_mean, z_var = mean_var(q_z)
κ_mean, κ_var = mean_var(q_κ)
Expand All @@ -13,7 +19,7 @@ export rule
return NormalMeanVariance(y_mean, y_var + inv(A * B))
end

@rule GCV(:x, Marginalisation) (q_y::Any, q_z::Any, q_κ::Any, q_ω::Any) = begin
@rule GCV(:x, Marginalisation) (q_y::Any, q_z::Any, q_κ::Any, q_ω::Any, meta::Union{<:GCVMetadata, Nothing}) = begin
y_mean = mean(q_y)
z_mean, z_var = mean_var(q_z)
κ_mean, κ_var = mean_var(q_κ)
Expand Down
10 changes: 8 additions & 2 deletions src/rules/gcv/y.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
export rule

@rule GCV(:y, Marginalisation) (m_x::UniNormalOrExpLinQuad, q_z::Any, q_κ::Any, q_ω::Any) = begin
@rule GCV(:y, Marginalisation) (
m_x::UniNormalOrExpLinQuad,
q_z::Any,
q_κ::Any,
q_ω::Any,
meta::Union{<:GCVMetadata, Nothing}
) = begin
x_mean, x_var = mean_var(m_x)
z_mean, z_var = mean_var(q_z)
κ_mean, κ_var = mean_var(q_κ)
Expand All @@ -13,7 +19,7 @@ export rule
return NormalMeanVariance(x_mean, x_var + inv(A * B))
end

@rule GCV(:y, Marginalisation) (q_x::Any, q_z::Any, q_κ::Any, q_ω::Any) = begin
@rule GCV(:y, Marginalisation) (q_x::Any, q_z::Any, q_κ::Any, q_ω::Any, meta::Union{<:GCVMetadata, Nothing}) = begin
x_mean = mean(q_x)
z_mean, z_var = mean_var(q_z)
κ_mean, κ_var = mean_var(q_κ)
Expand Down
19 changes: 12 additions & 7 deletions src/rules/multiplication/A.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@

@rule typeof(*)(:A, Marginalisation) (m_out::PointMass, m_in::PointMass) = PointMass(mean(m_in) \ mean(m_out))
@rule typeof(*)(:A, Marginalisation) (m_out::PointMass, m_in::PointMass, meta::Union{<:AbstractCorrection, Nothing}) =
PointMass(mean(m_in) \ mean(m_out))

@rule typeof(*)(:A, Marginalisation) (m_out::GammaDistributionsFamily, m_in::PointMass{<:Real}) = begin
@rule typeof(*)(:A, Marginalisation) (
m_out::GammaDistributionsFamily,
m_in::PointMass{<:Real},
meta::Union{<:AbstractCorrection, Nothing}
) = begin
return GammaShapeRate(shape(m_out), rate(m_out) * mean(m_in))
end

# if A is a matrix, then the result is multivariate
@rule typeof(*)(:A, Marginalisation) (
m_out::MultivariateNormalDistributionsFamily,
m_in::PointMass{<:AbstractMatrix},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_in)
ξ_out, W_out = weightedmean_precision(m_out)
Expand All @@ -22,7 +27,7 @@ end
@rule typeof(*)(:A, Marginalisation) (
m_out::MultivariateNormalDistributionsFamily,
m_in::PointMass{<:AbstractVector},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_in)
ξ_out, W_out = weightedmean_precision(m_out)
Expand All @@ -34,7 +39,7 @@ end
@rule typeof(*)(:A, Marginalisation) (
m_out::F,
m_in::PointMass{<:Real},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) where {F <: NormalDistributionsFamily} = begin
A = mean(m_in)
ξ_out, W_out = weightedmean_precision(m_out)
Expand All @@ -46,7 +51,7 @@ end
@rule typeof(*)(:A, Marginalisation) (
m_out::MvNormalMeanCovariance,
m_in::PointMass{<:AbstractMatrix},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_in)
μ_out, Σ_out = mean_cov(m_out)
Expand All @@ -61,7 +66,7 @@ end
@rule typeof(*)(:A, Marginalisation) (
m_out::MvNormalMeanCovariance,
m_in::PointMass{<:AbstractVector},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_in)
μ_out, Σ_out = mean_cov(m_out)
Expand Down
19 changes: 12 additions & 7 deletions src/rules/multiplication/in.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@

@rule typeof(*)(:in, Marginalisation) (m_out::PointMass, m_A::PointMass) = PointMass(mean(m_A) \ mean(m_out))
@rule typeof(*)(:in, Marginalisation) (m_out::PointMass, m_A::PointMass, meta::Union{<:AbstractCorrection, Nothing}) =
PointMass(mean(m_A) \ mean(m_out))

@rule typeof(*)(:in, Marginalisation) (m_out::GammaDistributionsFamily, m_A::PointMass{<:Real}) = begin
@rule typeof(*)(:in, Marginalisation) (
m_out::GammaDistributionsFamily,
m_A::PointMass{<:Real},
meta::Union{<:AbstractCorrection, Nothing}
) = begin
return GammaShapeRate(shape(m_out), rate(m_out) * mean(m_A))
end

# if A is a matrix, then the result is multivariate
@rule typeof(*)(:in, Marginalisation) (
m_out::MultivariateNormalDistributionsFamily,
m_A::PointMass{<:AbstractMatrix},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_A)
ξ_out, W_out = weightedmean_precision(m_out)
Expand All @@ -22,7 +27,7 @@ end
@rule typeof(*)(:in, Marginalisation) (
m_out::MultivariateNormalDistributionsFamily,
m_A::PointMass{<:AbstractVector},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_A)
ξ_out, W_out = weightedmean_precision(m_out)
Expand All @@ -34,7 +39,7 @@ end
@rule typeof(*)(:in, Marginalisation) (
m_out::F,
m_A::PointMass{<:Real},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) where {F <: NormalDistributionsFamily} = begin
A = mean(m_A)
ξ_out, W_out = weightedmean_precision(m_out)
Expand All @@ -46,7 +51,7 @@ end
@rule typeof(*)(:in, Marginalisation) (
m_out::MvNormalMeanCovariance,
m_A::PointMass{<:AbstractMatrix},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_A)
μ_out, Σ_out = mean_cov(m_out)
Expand All @@ -61,7 +66,7 @@ end
@rule typeof(*)(:in, Marginalisation) (
m_out::MvNormalMeanCovariance,
m_A::PointMass{<:AbstractVector},
meta::AbstractCorrection
meta::Union{<:AbstractCorrection, Nothing}
) = begin
A = mean(m_A)
μ_out, Σ_out = mean_cov(m_out)
Expand Down
4 changes: 2 additions & 2 deletions src/rules/multiplication/marginals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
m_out::NormalDistributionsFamily,
m_A::PointMass,
m_in::NormalDistributionsFamily,
meta::Any
meta::Union{<:AbstractCorrection, Nothing}
) = begin
b_in = @call_rule typeof(*)(:in, Marginalisation) (m_out = m_out, m_A = m_A, meta = meta)
q_in = prod(ProdAnalytical(), b_in, m_in)
Expand All @@ -17,7 +17,7 @@ end
m_out::UnivariateNormalDistributionsFamily,
m_A::UnivariateNormalDistributionsFamily,
m_in::PointMass{<:Real},
meta::Any
meta::Union{<:AbstractCorrection, Nothing}
) = begin
return @call_marginalrule typeof(*)(:A_in) (m_out = m_out, m_A = m_in, m_in = m_A, meta = meta)
end
Loading

0 comments on commit 9aa50b4

Please sign in to comment.