Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add getgradlogpartition function for Poisson #151

Merged
merged 8 commits into from
Dec 15, 2023
2 changes: 2 additions & 0 deletions docs/src/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ isproper
getbasemeasure
getsufficientstatistics
getlogpartition
getgradlogpartition
getfisherinformation
getsupport
basemeasure
sufficientstatistics
logpartition
gradlogpartition
fisherinformation
isbasemeasureconstant
ConstantBaseMeasure
Expand Down
5 changes: 5 additions & 0 deletions src/distributions/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ getlogpartition(::NaturalParametersSpace, ::Type{Exponential}) = (η) -> begin
return -log(-η₁)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Exponential}) = (η) -> begin
(η₁,) = unpack_parameters(Exponential, η)
return SA[-1/η₁]
end

getfisherinformation(::NaturalParametersSpace, ::Type{Exponential}) = (η) -> begin
(η₁,) = unpack_parameters(Exponential, η)
SA[inv(η₁^2);;]
Expand Down
7 changes: 7 additions & 0 deletions src/distributions/normal_family/normal_family.jl
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,13 @@ getlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanCovariance}) = (η)
return (dot(η₁, Cinv, η₁) / 2 - (k * log(2) + l)) / 2
end

getgradlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanCovariance}) =
(η) -> begin
(η₁, η₂) = unpack_parameters(MvNormalMeanCovariance, η)
Cinv, _ = cholinv_logdet(-η₂)
return pack_parameters(MvNormalMeanCovariance, (0.5 * Cinv * η₁, 0.25 * Cinv * η₁ * η₁' * Cinv + 0.5 * Cinv))
end

getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanCovariance}) =
(η) -> begin
(η₁, η₂) = unpack_parameters(MvNormalMeanCovariance, η)
Expand Down
5 changes: 5 additions & 0 deletions src/distributions/poisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ getlogpartition(::NaturalParametersSpace, ::Type{Poisson}) = (η) -> begin
return exp(η1)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Poisson}) = (η) -> begin
(η1,) = unpack_parameters(Poisson, η)
return SA[exp(η1)]
end

getfisherinformation(::NaturalParametersSpace, ::Type{Poisson}) = (η) -> begin
(η1,) = unpack_parameters(Poisson, η)
SA[exp(η1);;]
Expand Down
6 changes: 5 additions & 1 deletion src/distributions/von_mises.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ isbasemeasureconstant(::Type{VonMises}) = ConstantBaseMeasure()

getbasemeasure(::Type{VonMises}, _) = (x) -> inv(twoπ)
getsufficientstatistics(::Type{VonMises}, _) = (cos, sin)

getgradlogpartition(::NaturalParametersSpace, ::Type{VonMises}, _) = (η) -> begin
u = sqrt(dot(η, η))
same_part = besseli(1, u) / (u * besseli(0, u))
return SA[η[1] * same_part, η[2] * same_part]
end
getlogpartition(::NaturalParametersSpace, ::Type{VonMises}, _) = (η) -> begin
return log(besseli(0, sqrt(dot(η, η))))
end
Expand Down
38 changes: 36 additions & 2 deletions src/exponential_family.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ export ExponentialFamilyDistribution

export ExponentialFamilyDistribution, ExponentialFamilyDistributionAttributes, getnaturalparameters, getattributes
export MeanToNatural, NaturalToMean, MeanParametersSpace, NaturalParametersSpace
export getbasemeasure, getsufficientstatistics, getlogpartition, getfisherinformation, getsupport, getmapping, getconditioner
export basemeasure, sufficientstatistics, logpartition, fisherinformation, insupport, isproper
export getbasemeasure, getsufficientstatistics, getlogpartition, getgradlogpartition, getfisherinformation, getsupport, getmapping, getconditioner
export basemeasure, sufficientstatistics, logpartition, gradlogpartition, fisherinformation, insupport, isproper
export isbasemeasureconstant, ConstantBaseMeasure, NonConstantBaseMeasure

using LoopVectorization
Expand Down Expand Up @@ -301,6 +301,18 @@ function logpartition(ef::ExponentialFamilyDistribution, η = getnaturalparamete
return getlogpartition(ef)(η)
end

"""
gradlogpartition(::ExponentialFamilyDistribution, η)

Return the computed value of `gradlogpartition` of the exponential family distribution at the point `η`.
By default `η = getnaturalparameters(ef)`.

See also: [`getgradlogpartition`](@ref)
"""
function gradlogpartition(ef::ExponentialFamilyDistribution, η = getnaturalparameters(ef))
return getgradlogpartition(ef)(η)
end

"""
fisherinformation(distribution, η)

Expand Down Expand Up @@ -329,6 +341,12 @@ getlogpartition(::Nothing, ef::ExponentialFamilyDistribution{T}) where {T} = get
getlogpartition(attributes::ExponentialFamilyDistributionAttributes, ::ExponentialFamilyDistribution) =
getlogpartition(attributes)

getgradlogpartition(ef::ExponentialFamilyDistribution) = getgradlogpartition(ef.attributes, ef)
getgradlogpartition(::Nothing, ef::ExponentialFamilyDistribution{T}) where {T} =
getgradlogpartition(T, getconditioner(ef))
getgradlogpartition(attributes::ExponentialFamilyDistributionAttributes, ::ExponentialFamilyDistribution) =
error("TODO: not implemented. Should we use monte-carlo estimator here: the mean of the sufficient statistics here?")

getfisherinformation(ef::ExponentialFamilyDistribution) = getfisherinformation(ef.attributes, ef)
getfisherinformation(::Nothing, ef::ExponentialFamilyDistribution{T}) where {T} =
getfisherinformation(T, getconditioner(ef))
Expand Down Expand Up @@ -423,6 +441,22 @@ getlogpartition(
::Nothing
) where {T <: Distribution} = getlogpartition(space, T)

"""
getgradlogpartition([ space = NaturalParametersSpace() ], ::Type{T}, [ conditioner ]) where { T <: Distribution }

A specific verion of `getgradlogpartition` defined particularly for distribution types from `Distributions.jl` package.
Does not require an instance of the `ExponentialFamilyDistribution` and can be called directly with a specific distribution type instead.
Optionally, accepts the `space` parameter, which defines the parameters space.
For conditional exponential family distributions requires an extra `conditioner` argument.
"""
getgradlogpartition(::Type{T}, conditioner = nothing) where {T <: Distribution} =
getgradlogpartition(NaturalParametersSpace(), T, conditioner)
getgradlogpartition(
space::Union{MeanParametersSpace, NaturalParametersSpace},
::Type{T},
::Nothing
) where {T <: Distribution} = getgradlogpartition(space, T)

"""
getfisherinformation([ space = NaturalParametersSpace() ], ::Type{T}) where { T <: Distribution }

Expand Down
20 changes: 19 additions & 1 deletion test/distributions/distributions_setuptests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ function test_exponentialfamily_interface(distribution;
test_packing_unpacking = true,
test_isproper = true,
test_basic_functions = true,
test_gradlogpartition_against_expectation = true,
test_fisherinformation_properties = true,
test_fisherinformation_against_hessian = true,
test_fisherinformation_against_jacobian = true,
option_assume_no_allocations = false,
option_assume_no_allocations = false
)
T = ExponentialFamily.exponential_family_typetag(distribution)

Expand All @@ -71,6 +72,7 @@ function test_exponentialfamily_interface(distribution;
test_packing_unpacking && run_test_packing_unpacking(distribution)
test_isproper && run_test_isproper(distribution; assume_no_allocations = option_assume_no_allocations)
test_basic_functions && run_test_basic_functions(distribution; assume_no_allocations = option_assume_no_allocations)
test_gradlogpartition_against_expectation && run_test_gradlogpartition_against_expectation(distribution)
test_fisherinformation_properties && run_test_fisherinformation_properties(distribution)
test_fisherinformation_against_hessian && run_test_fisherinformation_against_hessian(distribution; assume_no_allocations = option_assume_no_allocations)
test_fisherinformation_against_jacobian && run_test_fisherinformation_against_jacobian(distribution; assume_no_allocations = option_assume_no_allocations)
Expand Down Expand Up @@ -302,6 +304,22 @@ function run_test_fisherinformation_properties(distribution; test_properties_in_
end
end

function run_test_gradlogpartition_against_expectation(distribution; nsamples = 5000)
ef = @inferred(convert(ExponentialFamilyDistribution, distribution))

(η, conditioner) = (getnaturalparameters(ef), getconditioner(ef))

samples = rand(distribution, nsamples)
_, samples = ExponentialFamily.check_logpdf(variate_form(typeof(ef)), typeof(samples), eltype(samples), ef, samples)
sample_sufficient_statistics = map((s) -> ExponentialFamily.pack_parameters(ExponentialFamily.sufficientstatistics(ef, s)), samples)
expectation_of_sufficient_statistics = mean(sample_sufficient_statistics)
gradient = gradlogpartition(ef)
inverse_fisher = cholinv(fisherinformation(ef))
@test length(gradient) === length(η)
@test dot(gradient - expectation_of_sufficient_statistics, inverse_fisher, gradient - expectation_of_sufficient_statistics) ≈ 0 atol = 0.01
0.01
end

function run_test_fisherinformation_against_hessian(distribution; assume_ours_faster = true, assume_no_allocations = true)
T = ExponentialFamily.exponential_family_typetag(distribution)

Expand Down
Loading