From 71c08948ebdbe4e6d0790f939cf08d2f77358a66 Mon Sep 17 00:00:00 2001 From: Thijs van de Laar Date: Tue, 2 May 2023 15:11:47 +0200 Subject: [PATCH] fix poisson node --- src/factor_nodes/poisson.jl | 34 +++++++++++++++++++++++++------ test/factor_nodes/test_poisson.jl | 5 ++++- test/test_helpers.jl | 2 +- 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/factor_nodes/poisson.jl b/src/factor_nodes/poisson.jl index 1e56ddab..b25c333e 100644 --- a/src/factor_nodes/poisson.jl +++ b/src/factor_nodes/poisson.jl @@ -66,20 +66,42 @@ logPdf(V::Type{Univariate}, F::Type{Poisson}, x::Number; η::Vector) = -logfacto # ∑ [λ^k*log(k!)]/k! from k=0 to inf # Approximates the above sum for calculation of averageEnergy and differentialEntropy # @ref https://arxiv.org/pdf/1708.06394.pdf -function apprSum(l, j=100) - sum([(l)^(k)*logfactorial(k)/exp(logfactorial(k)) for k in collect(0:j)]) +function approximatePowerSum(l, j=150) + (l == 0.0) && return 0.0 + (l > 110.0) && error("Cannot approximate power sum for Poisson distribution with l>110") + + s = zero(BigFloat) + lk = one(BigFloat) + for k = 1:j + lk *= l + s += lk*loggamma(k + 1)/gamma(k + 1) + end + + return convert(Float64, s) end # Entropy functional # @ref https://en.wikipedia.org/wiki/Poisson_distribution function differentialEntropy(dist::Distribution{Univariate, Poisson}) - l = clamp(dist.params[:l], tiny, huge) - l*(1-log(l)) + exp(-l)*apprSum(l) + l = dist.params[:l] + (l == 0.0) && return 0.0 + + if l <= 50.0 + return l*(1-log(l)) + exp(-l)*approximatePowerSum(l) + else + return 0.5*log(2*pi*ℯ*l) - 1/(12*l) - 1/(24*l^2) - 19/(360*l^3) + end end -# Average energy functional +# Average energy functionals function averageEnergy(::Type{Poisson}, marg_out::Distribution{Univariate}, marg_l::Distribution{Univariate}) unsafeMean(marg_l) - unsafeMean(marg_out)*unsafeLogMean(marg_l) + - exp(-unsafeMean(marg_out))*apprSum(unsafeMean(marg_out)) + exp(-unsafeMean(marg_out))*approximatePowerSum(unsafeMean(marg_out)) end + +function averageEnergy(::Type{Poisson}, marg_out::Distribution{Univariate, PointMass}, marg_l::Distribution{Univariate}) + unsafeMean(marg_l) - + unsafeMean(marg_out)*unsafeLogMean(marg_l) + + sum(log.(1:unsafeMean(marg_out))) +end \ No newline at end of file diff --git a/test/factor_nodes/test_poisson.jl b/test/factor_nodes/test_poisson.jl index b7e3b0c2..3d16832e 100644 --- a/test/factor_nodes/test_poisson.jl +++ b/test/factor_nodes/test_poisson.jl @@ -99,7 +99,10 @@ end @testset "averageEnergy and differentialEntropy" begin @test isapprox(differentialEntropy(Distribution(Poisson, l=1.0)), averageEnergy(Poisson, Distribution(Poisson, l=1.0), Distribution(Univariate, PointMass, m=1.0))) @test isapprox(differentialEntropy(Distribution(Poisson, l=10.0)), averageEnergy(Poisson, Distribution(Poisson, l=10.0), Distribution(Univariate, PointMass, m=10.0))) - @test isapprox(differentialEntropy(Distribution(Poisson, l=100.0)), averageEnergy(Poisson, Distribution(Poisson, l=100.0), Distribution(Univariate, PointMass, m=100.0))) + @test isapprox(differentialEntropy(Distribution(Poisson, l=100.0)), averageEnergy(Poisson, Distribution(Poisson, l=100.0), Distribution(Univariate, PointMass, m=100.0)), atol=0.1) + + @test averageEnergy(Poisson, Distribution(PointMass, m=1.0), Distribution(Univariate, PointMass, m=1.0)) == 1.0 + @test averageEnergy(Poisson, Distribution(PointMass, m=2.0), Distribution(Univariate, PointMass, m=1.0)) == 1.0 + log(2) end end # module diff --git a/test/test_helpers.jl b/test/test_helpers.jl index 83762c0c..44e4d872 100644 --- a/test/test_helpers.jl +++ b/test/test_helpers.jl @@ -86,7 +86,7 @@ using LinearAlgebra: Diagonal, isposdef, I, Hermitian @test bar(1, Tuple{Float64, Float32}((1.0, 1.0f0))) === bar(Tuple{Float64, Float32}((1.0, 1.0f0)), 1) @test bar(1, Tuple{Float32, Float64}((1.0f0, 1.0))) === bar(Tuple{Float32, Float64}((1.0f0, 1.0)), 1) - @symmetrical function baz(a::Int, b::Float64, c::String) where A where B where C + @symmetrical function baz(a::Int, b::Float64, c::String) return 1 end