Skip to content

Commit

Permalink
Tests green.
Browse files Browse the repository at this point in the history
  • Loading branch information
ivan-bocharov committed Aug 28, 2018
1 parent c212930 commit 07655c4
Show file tree
Hide file tree
Showing 11 changed files with 24 additions and 20 deletions.
2 changes: 1 addition & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
julia 0.6
julia 0.7
SpecialFunctions
8 changes: 6 additions & 2 deletions src/dependency_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,19 @@ Optional keyword arguments:
This function can be used to generate message passing schedules
if `graph` is a dependency graph.
"""
function find_vertex_indexes(vertex::V, graph::DependencyGraph{V}) where V
return something(findfirst(isequal(vertex), graph.vertices), 0)
end

function children( vertices::Vector{V},
graph::DependencyGraph{V};
allow_cycles::Bool=false,
breaker_sites::Set{V}=Set{V}(),
restrict_to::Set{V}=Set{V}()) where V

# Find vertex indexes of breaker_sites
breaker_vertices = Set{Int}(map((v) -> something(findfirst(isequal(v), graph.vertices), 0), breaker_sites))
restrict_to_vertices = Set{Int}(map((v) -> something(findfirst(isequal(v), graph.vertices), 0), restrict_to))
breaker_vertices = Set{Int}((find_vertex_indexes(v, graph) for v=breaker_sites))
restrict_to_vertices = Set{Int}((find_vertex_indexes(v, graph) for v=restrict_to))

visited = Int[] # Hold topologically sorted list of indices of child vertices

Expand Down
2 changes: 1 addition & 1 deletion src/factor_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ end

edges(graph::FactorGraph=currentGraph()) = Set{Edge}(graph.edges)
edges(node::FactorNode) = Set{Edge}([intf.edge for intf in node.interfaces])
edges(nodeset::Set{FactorNode}) = union(map(edges, nodeset)...)
edges(nodeset::Set{FactorNode}) = union(Set((edges(node) for node=nodeset))...)

"""
Description:
Expand Down
2 changes: 1 addition & 1 deletion src/factor_nodes/gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ end
function prod!(
x::ProbabilityDistribution{Multivariate, F1},
y::ProbabilityDistribution{Multivariate, F2},
z::ProbabilityDistribution{Multivariate, GaussianWeightedMeanPrecision}=ProbabilityDistribution(Multivariate, GaussianWeightedMeanPrecision, xi=[NaN], w=[NaN].')) where {F1<:Gaussian, F2<:Gaussian}
z::ProbabilityDistribution{Multivariate, GaussianWeightedMeanPrecision}=ProbabilityDistribution(Multivariate, GaussianWeightedMeanPrecision, xi=[NaN], w=transpose([NaN]))) where {F1<:Gaussian, F2<:Gaussian}

z.params[:xi] = unsafeWeightedMean(x) + unsafeWeightedMean(y)
z.params[:w] = unsafePrecision(x) + unsafePrecision(y)
Expand Down
2 changes: 1 addition & 1 deletion src/factor_nodes/gaussian_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ format(dist::ProbabilityDistribution{V, GaussianMeanPrecision}) where V<:Variate

ProbabilityDistribution(::Type{Univariate}, ::Type{GaussianMeanPrecision}; m=0.0, w=1.0) = ProbabilityDistribution{Univariate, GaussianMeanPrecision}(Dict(:m=>m, :w=>w))
ProbabilityDistribution(::Type{GaussianMeanPrecision}; m::Number=0.0, w::Number=1.0) = ProbabilityDistribution{Univariate, GaussianMeanPrecision}(Dict(:m=>m, :w=>w))
ProbabilityDistribution(::Type{Multivariate}, ::Type{GaussianMeanPrecision}; m=[0.0], w=[1.0].') = ProbabilityDistribution{Multivariate, GaussianMeanPrecision}(Dict(:m=>m, :w=>w))
ProbabilityDistribution(::Type{Multivariate}, ::Type{GaussianMeanPrecision}; m=[0.0], w=transpose([1.0])) = ProbabilityDistribution{Multivariate, GaussianMeanPrecision}(Dict(:m=>m, :w=>w))

dims(dist::ProbabilityDistribution{V, GaussianMeanPrecision}) where V<:VariateType = length(dist.params[:m])

Expand Down
2 changes: 1 addition & 1 deletion src/factor_nodes/gaussian_weighted_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ format(dist::ProbabilityDistribution{V, GaussianWeightedMeanPrecision}) where V<

ProbabilityDistribution(::Type{Univariate}, ::Type{GaussianWeightedMeanPrecision}; xi=0.0, w=1.0) = ProbabilityDistribution{Univariate, GaussianWeightedMeanPrecision}(Dict(:xi=>xi, :w=>w))
ProbabilityDistribution(::Type{GaussianWeightedMeanPrecision}; xi::Number=0.0, w::Number=1.0) = ProbabilityDistribution{Univariate, GaussianWeightedMeanPrecision}(Dict(:xi=>xi, :w=>w))
ProbabilityDistribution(::Type{Multivariate}, ::Type{GaussianWeightedMeanPrecision}; xi=[0.0], w=[1.0].') = ProbabilityDistribution{Multivariate, GaussianWeightedMeanPrecision}(Dict(:xi=>xi, :w=>w))
ProbabilityDistribution(::Type{Multivariate}, ::Type{GaussianWeightedMeanPrecision}; xi=[0.0], w=transpose([1.0])) = ProbabilityDistribution{Multivariate, GaussianWeightedMeanPrecision}(Dict(:xi=>xi, :w=>w))

dims(dist::ProbabilityDistribution{V, GaussianWeightedMeanPrecision}) where V<:VariateType = length(dist.params[:xi])

Expand Down
2 changes: 1 addition & 1 deletion src/probability_distribution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ dims(dist::ProbabilityDistribution{MatrixVariate, PointMass}) = size(dist.params
ProbabilityDistribution(::Type{Univariate}, ::Type{PointMass}; m::Number=1.0) = ProbabilityDistribution{Univariate, PointMass}(Dict(:m=>m))
ProbabilityDistribution(::Type{PointMass}; m::Number=1.0) = ProbabilityDistribution{Univariate, PointMass}(Dict(:m=>m))
ProbabilityDistribution(::Type{Multivariate}, ::Type{PointMass}; m::Vector=[1.0]) = ProbabilityDistribution{Multivariate, PointMass}(Dict(:m=>m))
ProbabilityDistribution(::Type{MatrixVariate}, ::Type{PointMass}; m::AbstractMatrix=[1.0].') = ProbabilityDistribution{MatrixVariate, PointMass}(Dict(:m=>m))
ProbabilityDistribution(::Type{MatrixVariate}, ::Type{PointMass}; m::AbstractMatrix=transpose([1.0])) = ProbabilityDistribution{MatrixVariate, PointMass}(Dict(:m=>m))

unsafeMean(dist::ProbabilityDistribution{T, PointMass}) where T<:VariateType = deepcopy(dist.params[:m])

Expand Down
2 changes: 1 addition & 1 deletion src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,6 @@ end
Collect all edges corresponding with variable(s)
"""
edges(variable::Variable) = Set{Edge}(variable.edges)
edges(variables::Set{Variable}) = union(map(edges, variables)...)
edges(variables::Set{Variable}) = union(Set((edges(v) for v=variables))...)

Base.isless(v1::Variable, v2::Variable) = isless("$(v1.id)", "$(v2.id)")
10 changes: 5 additions & 5 deletions test/factor_nodes/test_wishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ end

@testset "isProper" begin
@test isProper(ProbabilityDistribution(MatrixVariate, Wishart, v=mat(1.0), nu=1.0)) == true
@test isProper(ProbabilityDistribution(MatrixVariate, Wishart, v=[-1.0].', nu=2.0)) == false
@test isProper(ProbabilityDistribution(MatrixVariate, Wishart, v=transpose([-1.0]), nu=2.0)) == false
@test isProper(ProbabilityDistribution(MatrixVariate, Wishart, v=mat(1.0), nu=0.0)) == false
end

@testset "prod!" begin
@test ProbabilityDistribution(MatrixVariate, Wishart, v=mat(1.0), nu=2.0) * ProbabilityDistribution(MatrixVariate, Wishart, v=mat(1.0),nu=2.0) == ProbabilityDistribution(MatrixVariate, Wishart, v=[0.4999999999999999].',nu=2.0)
@test ProbabilityDistribution(MatrixVariate, Wishart, v=mat(1.0), nu=2.0) * ProbabilityDistribution(MatrixVariate, Wishart, v=mat(1.0),nu=2.0) == ProbabilityDistribution(MatrixVariate, Wishart, v=transpose([0.4999999999999999]),nu=2.0)
@test ProbabilityDistribution(MatrixVariate, Wishart, v=mat(1.0), nu=2.0) * ProbabilityDistribution(MatrixVariate, PointMass, m=mat(1.0)) == ProbabilityDistribution(MatrixVariate, PointMass, m=mat(1.0))
@test ProbabilityDistribution(MatrixVariate, PointMass, m=mat(1.0)) * ProbabilityDistribution(MatrixVariate, Wishart, v=mat(1.0), nu=2.0) == ProbabilityDistribution(MatrixVariate, PointMass, m=mat(1.0))
@test_throws Exception ProbabilityDistribution(MatrixVariate, PointMass, m=[-1.0].') * ProbabilityDistribution(MatrixVariate, Wishart, v=mat(1.0), nu=2.0)
@test_throws Exception ProbabilityDistribution(MatrixVariate, PointMass, m=transpose([-1.0])) * ProbabilityDistribution(MatrixVariate, Wishart, v=mat(1.0), nu=2.0)
end

@testset "unsafe mean and variance" begin
Expand All @@ -45,7 +45,7 @@ end
@test outboundType(SPWishartOutVPP) == Message{Wishart}
@test isApplicable(SPWishartOutVPP, [Nothing, Message{PointMass}, Message{PointMass}])

@test ruleSPWishartOutVPP(nothing, Message(MatrixVariate, PointMass, m=[1.0].'), Message(Univariate, PointMass, m=2.0)) == Message(MatrixVariate, Wishart, v=[1.0].', nu=2.0)
@test ruleSPWishartOutVPP(nothing, Message(MatrixVariate, PointMass, m=transpose([1.0])), Message(Univariate, PointMass, m=2.0)) == Message(MatrixVariate, Wishart, v=transpose([1.0]), nu=2.0)
end

@testset "VBWishartOut" begin
Expand All @@ -54,7 +54,7 @@ end
@test isApplicable(VBWishartOut, [Nothing, ProbabilityDistribution, ProbabilityDistribution])
@test !isApplicable(VBWishartOut, [ProbabilityDistribution, ProbabilityDistribution, Nothing])

@test ruleVBWishartOut(nothing, ProbabilityDistribution(MatrixVariate, PointMass, m=[1.5].'), ProbabilityDistribution(Univariate, PointMass, m=3.0)) == Message(MatrixVariate, Wishart, v=[1.5].', nu=3.0)
@test ruleVBWishartOut(nothing, ProbabilityDistribution(MatrixVariate, PointMass, m=transpose([1.5])), ProbabilityDistribution(Univariate, PointMass, m=3.0)) == Message(MatrixVariate, Wishart, v=transpose([1.5]), nu=3.0)
end

@testset "averageEnergy and differentialEntropy" begin
Expand Down
2 changes: 1 addition & 1 deletion test/test_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import LinearAlgebra: Diagonal, isposdef, I, Hermitian
@testset "diageye" begin
# should be shorthand for Diagonal(eye(M))
M = diageye(3)
@test typeof(M) == Diagonal{Float64}
@test typeof(M) == Diagonal{Float64, Array{Float64,1}}
@test M == Diagonal(ones(3))
end

Expand Down
10 changes: 5 additions & 5 deletions test/test_probability_distribution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,22 @@ end
@test isProper(point_mass)
@test mean(point_mass) == [0.0]
@test var(point_mass) == [0.0]
@test cov(point_mass) == [0.0].'
@test cov(point_mass) == transpose([0.0])
end

@testset "MatrixVariate" begin
point_mass = ProbabilityDistribution(MatrixVariate, PointMass, m=[0.0].')
point_mass = ProbabilityDistribution(MatrixVariate, PointMass, m=transpose([0.0]))
@test isa(point_mass, ProbabilityDistribution{MatrixVariate, PointMass})
@test point_mass.params == Dict(:m=>[0.0].')
@test point_mass.params == Dict(:m=>transpose([0.0]))
@test isProper(point_mass)
@test mean(point_mass) == [0.0].'
@test mean(point_mass) == transpose([0.0])
end

@testset "PointMass ProbabilityDistribution and Message construction" begin
@test ProbabilityDistribution(Univariate, PointMass, m=0.2) == ProbabilityDistribution{Univariate, PointMass}(Dict(:m=>0.2))
@test_throws Exception ProbabilityDistribution(Multivariate, PointMass, m=0.2)
@test ProbabilityDistribution(Multivariate, PointMass, m=[0.2]) == ProbabilityDistribution{Multivariate, PointMass}(Dict(:m=>[0.2]))
@test ProbabilityDistribution(MatrixVariate, PointMass, m=[0.2].') == ProbabilityDistribution{MatrixVariate, PointMass}(Dict(:m=>[0.2].'))
@test ProbabilityDistribution(MatrixVariate, PointMass, m=transpose([0.2])) == ProbabilityDistribution{MatrixVariate, PointMass}(Dict(:m=>transpose([0.2])))
@test ProbabilityDistribution(PointMass, m=0.2) == ProbabilityDistribution{Univariate, PointMass}(Dict(:m=>0.2))
@test ProbabilityDistribution(Univariate, PointMass) == ProbabilityDistribution{Univariate, PointMass}(Dict(:m=>1.0))
@test ProbabilityDistribution(PointMass) == ProbabilityDistribution{Univariate, PointMass}(Dict(:m=>1.0))
Expand Down

0 comments on commit 07655c4

Please sign in to comment.