Skip to content

Commit

Permalink
Merge pull request #147 from biaslab/fast-logpdf
Browse files Browse the repository at this point in the history
Faster logpdf implementation for container based inputs
  • Loading branch information
bvdmitri authored Dec 4, 2023
2 parents 12b7797 + 784b7e0 commit 7348a2c
Show file tree
Hide file tree
Showing 15 changed files with 194 additions and 25 deletions.
4 changes: 4 additions & 0 deletions docs/src/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ fisherinformation
isbasemeasureconstant
ConstantBaseMeasure
NonConstantBaseMeasure
ExponentialFamily._logpdf
ExponentialFamily.check_logpdf
ExponentialFamily.PointBasedLogpdfCall
ExponentialFamily.MapBasedLogpdfCall
```

## Interfacing with Distributions Defined in the `Distributions.jl` Package
Expand Down
4 changes: 2 additions & 2 deletions src/distributions/gamma_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ end
getfisherinformation(::NaturalParametersSpace, ::Type{GammaInverse}) =
(η) -> begin
(η₁, η₂) = unpack_parameters(GammaInverse, η)
return SA[polygamma(1, -η₁ - one(η₁)) -inv(-η₂); -inv(-η₂) (-η₁-one(η₁))/(η₂^2) ]
return SA[polygamma(1, -η₁ - one(η₁)) -inv(-η₂); -inv(-η₂) (-η₁-one(η₁))/(η₂^2)]
end

# Mean parametrization
Expand All @@ -79,5 +79,5 @@ end
getfisherinformation(::MeanParametersSpace, ::Type{GammaInverse}) =
(θ) -> begin
(shape, scale) = unpack_parameters(Gamma, θ)
return SA[polygamma(1, shape) -inv(scale); -inv(scale) shape/(scale^2) ]
return SA[polygamma(1, shape) -inv(scale); -inv(scale) shape/(scale^2)]
end
21 changes: 21 additions & 0 deletions src/distributions/mv_normal_wishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,24 @@ getfisherinformation(::MeanParametersSpace, ::Type{MvNormalWishart}) = (θ) -> b

# return blockdiag(sparse(ν*κ*T) , sparse((ν/2)*kronT) , sparse(Diagonal([d/(2κ^2), mvtrigamma(d, ν/2)/4])))
end

function _logpdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x)
# TODO: Think of what to do with this assert
@assert insupport(ef, x)
_logpartition = logpartition(ef)
return _logpdf(ef, x, _logpartition)
end

function _logpdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x, logpartition)
# TODO: Think of what to do with this assert
@assert insupport(ef, x)
η = getnaturalparameters(ef)
# Use `_` to avoid name collisions with the actual functions
_statistics = sufficientstatistics(ef, x)
_basemeasure = basemeasure(ef, x)
return log(_basemeasure) + dot(η, flatten_parameters(MvNormalWishart, _statistics)) - logpartition
end

function _pdf(ef::ExponentialFamilyDistribution{MvNormalWishart}, x)
exp(_logpdf(ef, x))
end
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ BayesBase.std(dist::MvNormalMeanCovariance) = cholsqrt(cov(dist))
BayesBase.logdetcov(dist::MvNormalMeanCovariance) = chollogdet(cov(dist))
BayesBase.params(dist::MvNormalMeanCovariance) = (mean(dist), cov(dist))

function Distributions.sqmahal(dist::MvNormalMeanCovariance, x::AbstractVector)
function Distributions.sqmahal(dist::MvNormalMeanCovariance, x::AbstractVector)
T = promote_type(eltype(x), paramfloattype(dist))
return sqmahal!(similar(x, T), dist, x)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ BayesBase.std(dist::MvNormalMeanPrecision) = cholsqrt(cov(dist))
BayesBase.logdetcov(dist::MvNormalMeanPrecision) = -chollogdet(invcov(dist))
BayesBase.params(dist::MvNormalMeanPrecision) = (mean(dist), invcov(dist))

function Distributions.sqmahal(dist::MvNormalMeanPrecision, x::AbstractVector)
function Distributions.sqmahal(dist::MvNormalMeanPrecision, x::AbstractVector)
T = promote_type(eltype(x), paramfloattype(dist))
return sqmahal!(similar(x, T), dist, x)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ BayesBase.std(dist::MvNormalWeightedMeanPrecision) = cholsqrt(cov(dist))
BayesBase.logdetcov(dist::MvNormalWeightedMeanPrecision) = -chollogdet(invcov(dist))
BayesBase.params(dist::MvNormalWeightedMeanPrecision) = (weightedmean(dist), invcov(dist))

function Distributions.sqmahal(dist::MvNormalWeightedMeanPrecision, x::AbstractVector)
function Distributions.sqmahal(dist::MvNormalWeightedMeanPrecision, x::AbstractVector)
T = promote_type(eltype(x), paramfloattype(dist))
return sqmahal!(similar(x, T), dist, x)
end
Expand Down
8 changes: 4 additions & 4 deletions src/distributions/normal_gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ BayesBase.location(d::NormalGamma) = first(params(d))
BayesBase.scale(d::NormalGamma) = getindex(params(d), 2)
BayesBase.shape(d::NormalGamma) = getindex(params(d), 3)
BayesBase.rate(d::NormalGamma) = getindex(params(d), 4)
BayesBase.mean(d::NormalGamma) = [d.μ, d.α / d.β]
BayesBase.var(d::NormalGamma) = d.α > one(d.α) ? [d.β / (d.λ * (d.α - one(d.α))), d.α / (d.β^2)] : error("`var` of `NormalGamma` is not defined for `α < 1`")
BayesBase.cov(d::NormalGamma) = d.α > one(d.α) ? [d.β/(d.λ*(d.α-one(d.α))) 0.0; 0.0 d.α/(d.β^2)] : error("`cov` of `NormalGamma` is not defined for `α < 1`")
BayesBase.std(d::NormalGamma) = d.α > one(d.α) ? sqrt.(var(d)) : error("`std` of `NormalGamma` is not defined for `α < 1`")
BayesBase.mean(d::NormalGamma) = [d.μ, d.α / d.β]
BayesBase.var(d::NormalGamma) = d.α > one(d.α) ? [d.β / (d.λ * (d.α - one(d.α))), d.α / (d.β^2)] : error("`var` of `NormalGamma` is not defined for `α < 1`")
BayesBase.cov(d::NormalGamma) = d.α > one(d.α) ? [d.β/(d.λ*(d.α-one(d.α))) 0.0; 0.0 d.α/(d.β^2)] : error("`cov` of `NormalGamma` is not defined for `α < 1`")
BayesBase.std(d::NormalGamma) = d.α > one(d.α) ? sqrt.(var(d)) : error("`std` of `NormalGamma` is not defined for `α < 1`")

function BayesBase.rand!(rng::AbstractRNG, dist::NormalGamma, container::AbstractVector)
container[2] = rand(rng, GammaShapeRate(dist.α, dist.β))
Expand Down
6 changes: 5 additions & 1 deletion src/distributions/pareto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ end
BayesBase.default_prod_rule(::Type{<:ExponentialFamilyDistribution{T}}, ::Type{<:ExponentialFamilyDistribution{T}}) where {T <: Pareto} =
PreserveTypeProd(ExponentialFamilyDistribution{Pareto})

function BayesBase.prod!(container::ExponentialFamilyDistribution{Pareto}, left::ExponentialFamilyDistribution{Pareto}, right::ExponentialFamilyDistribution{Pareto})
function BayesBase.prod!(
container::ExponentialFamilyDistribution{Pareto},
left::ExponentialFamilyDistribution{Pareto},
right::ExponentialFamilyDistribution{Pareto}
)
(η_container, conditioner_container) = (getnaturalparameters(container), getconditioner(container))
(η_left, conditioner_left) = (getnaturalparameters(left), getconditioner(left))
(η_right, conditioner_right) = (getnaturalparameters(right), getconditioner(right))
Expand Down
150 changes: 142 additions & 8 deletions src/exponential_family.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,26 +475,160 @@ isbasemeasureconstant(::Function) = NonConstantBaseMeasure()
Evaluates and returns the log-density of the exponential family distribution for the input `x`.
"""
function BayesBase.logpdf(ef::ExponentialFamilyDistribution{T}, x) where {T}
# TODO: Think of what to do with this assert
@assert insupport(ef, x)
function BayesBase.logpdf(ef::ExponentialFamilyDistribution, x)
return _logpdf(ef, x)
end

η = getnaturalparameters(ef)
"""
A trait object, signifying that the _logpdf method should treat it second argument as one point from the distrubution domain.
"""
struct PointBasedLogpdfCall end

"""
A trait object, signifying that the _logpdf method should treat it second argument as a container of points from the distrubution domain.
"""
struct MapBasedLogpdfCall end

function _logpdf(::PointBasedLogpdfCall, ef, x)
_plogpdf(ef, x)
end

function _logpdf(::MapBasedLogpdfCall, ef, container)
_vlogpdf(ef, container)
end

"""
_logpdf(ef::ExponentialFamilyDistribution, x)
Evaluates and returns the log-density of the exponential family distribution for the input `x`.
This inner function dispatches to the appropriate version of `_logpdf` based on the types of `x` and `ef`, utilizing the `check_logpdf` function. The dispatch mechanism ensures that `_logpdf` correctly handles the input `x`, whether it is a single point or a container of points, according to the nature of the exponential family distribution and `x`.
For instance, with a `Univariate` distribution, `_logpdf` evaluates the log-density for a single point if `x` is a `Number`, and for a container of points if `x` is an `AbstractVector`.
### Examples
Evaluate the log-density of a Gamma distribution at a single point:
```jldoctest
using ExponentialFamily, Distributions;
gamma = convert(ExponentialFamilyDistribution, Gamma(1, 1))
ExponentialFamily._logpdf(gamma, 1.0)
# output
-1.0
```
Evaluate the log-density of a Gamma distribution at multiple points:
```jldoctest
using ExponentialFamily, Distributions
gamma = convert(ExponentialFamilyDistribution, Gamma(1, 1))
ExponentialFamily._logpdf(gamma, [1, 2, 3])
# output
3-element Vector{Float64}:
-1.0
-2.0
-3.0
```
For details on the dispatch mechanism of `_logpdf`, refer to the `check_logpdf` function.
See also: [`check_logpdf`](@ref)
"""
function _logpdf(ef::ExponentialFamilyDistribution{T}, x) where {T}
vartype, _x = check_logpdf(variate_form(typeof(ef)), typeof(x), eltype(x), ef, x)
_logpdf(vartype, ef, _x)
end

# Use `_` to avoid name collisions with the actual functions
function _plogpdf(ef, x)
@assert insupport(ef, x) "Point $(x) does not belong to the support of $(ef)"
return _plogpdf(ef, x, logpartition(ef))
end

_scalarproduct(::Type{T}, η, statistics) where {T} = _scalarproduct(variate_form(T), T, η, statistics)
_scalarproduct(::Type{Univariate}, η, statistics) = dot(η, flatten_parameters(statistics))
_scalarproduct(::Type{Univariate}, ::Type{T}, η, statistics) where {T} = dot(η, flatten_parameters(T, statistics))
_scalarproduct(_, ::Type{T}, η, statistics) where {T} = dot(η, pack_parameters(T, statistics))

function _plogpdf(ef::ExponentialFamilyDistribution{T}, x, logpartition) where {T}
# TODO: Think of what to do with this assert
@assert insupport(ef, x) "Point $(x) does not belong to the support of $(ef)"
η = getnaturalparameters(ef)
_statistics = sufficientstatistics(ef, x)
_basemeasure = basemeasure(ef, x)
_logpartition = logpartition(ef)
return log(_basemeasure) + _scalarproduct(T, η, _statistics) - logpartition
end

"""
check_logpdf(variate_form, typeof(x), eltype(x), ef, x)
Determines an appropriate strategy of evaluation of `_logpdf` (`PointBasedLogpdfCall` or `MapBasedLogpdfCall`) to use based on the types of `x` and `ef`. This function employs a dispatch mechanism that adapts to the input `x`, whether it is a single point or a container of points, in accordance with the characteristics of the exponential family distribution (`ef`) and the variate form of `x`.
### Strategies
- For a `Univariate` distribution:
- If `x` is a `Number`, `_logpdf` is invoked with `PointBasedLogpdfCall()`.
- If `x` is an `AbstractVector` containing `Number`s, `_logpdf` is invoked with `MapBasedLogpdfCall()`.
- For a `Multivariate` distribution:
- If `x` is an `AbstractVector` containing `Number`s, `_logpdf` is invoked with `PointBasedLogpdfCall()`.
- If `x` is an `AbstractVector` containing `AbstractVector`s, `_logpdf` is invoked with `MapBasedLogpdfCall()`.
- If `x` is an `AbstractMatrix` containing `Number`s, `_logpdf` is invoked with `MapBasedLogpdfCall()`, transforming `x` to `eachcol(x)`.
- For a `Matrixvariate` distribution:
- If `x` is an `AbstractMatrix` containing `Number`s, `_logpdf` is invoked with `PointBasedLogpdfCall()`.
- If `x` is an `AbstractVector` containing `AbstractMatrix`s, `_logpdf` is invoked with `MapBasedLogpdfCall()`.
### Examples
```jldoctest
using ExponentialFamily
ExponentialFamily.check_logpdf(Univariate, typeof(1.0), eltype(1.0), Gamma(1, 1), 1.0)
# output
(ExponentialFamily.PointBasedLogpdfCall(), 1.0)
```
```jldoctest
using ExponentialFamily
ExponentialFamily.check_logpdf(Univariate, typeof([1.0, 2.0, 3.0]), eltype([1.0, 2.0, 3.0]), Gamma(1, 1), [1.0, 2.0, 3.0])
# output
(ExponentialFamily.MapBasedLogpdfCall(), [1.0, 2.0, 3.0])
```
return log(_basemeasure) + dot(η, flatten_parameters(T, _statistics)) - _logpartition
See also: [`_logpdf`](@ref) [`PointBasedLogpdfCall`](@ref) [`MapBasedLogpdfCall`](@ref)
"""
function check_logpdf end

check_logpdf(::Type{Univariate}, ::Type{<:Number}, ::Type{<:Number}, ef, x) = (PointBasedLogpdfCall(), x)
check_logpdf(::Type{Multivariate}, ::Type{<:AbstractVector}, ::Type{<:Number}, ef, x) = (PointBasedLogpdfCall(), x)
check_logpdf(::Type{Matrixvariate}, ::Type{<:AbstractMatrix}, ::Type{<:Number}, ef, x) = (PointBasedLogpdfCall(), x)

function _vlogpdf(ef, container)
_logpartition = logpartition(ef)
return map(x -> _plogpdf(ef, x, _logpartition), container)
end

check_logpdf(::Type{Univariate}, ::Type{<:AbstractVector}, ::Type{<:Number}, ef, container) = (MapBasedLogpdfCall(), container)
check_logpdf(::Type{Multivariate}, ::Type{<:AbstractVector}, ::Type{<:AbstractVector}, ef, container) = (MapBasedLogpdfCall(), container)
check_logpdf(::Type{Multivariate}, ::Type{<:AbstractMatrix}, ::Type{<:Number}, ef, container) = (MapBasedLogpdfCall(), eachcol(container))
check_logpdf(::Type{Matrixvariate}, ::Type{<:AbstractVector}, ::Type{<:AbstractMatrix}, ef, container) = (MapBasedLogpdfCall(), container)

"""
pdf(ef::ExponentialFamilyDistribution, x)
Evaluates and returns the probability density function of the exponential family distribution for the input `x`.
"""
BayesBase.pdf(ef::ExponentialFamilyDistribution, x) = exp(logpdf(ef, x))
BayesBase.pdf(ef::ExponentialFamilyDistribution, x) = _pdf(ef, x)

function _pdf(ef, x)
vartype, _x = check_logpdf(variate_form(typeof(ef)), typeof(x), eltype(x), ef, x)
_pdf(vartype, ef, _x)
end

function _pdf(::PointBasedLogpdfCall, ef, x)
exp(logpdf(ef, x))
end

function _pdf(::MapBasedLogpdfCall, ef, x)
exp.(logpdf(ef, x))
end

"""
cdf(ef::ExponentialFamilyDistribution{D}, x) where { D <: Distribution }
Expand Down
3 changes: 1 addition & 2 deletions test/common_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@
@test all(ForwardDiff.hessian((x) -> dot3arg(x, A, x), x) .!== 0)
@test all(ForwardDiff.hessian((x) -> dot3arg(x, A, x), y) .!== 0)
end

end
end
9 changes: 7 additions & 2 deletions test/distributions/distributions_setuptests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function test_exponentialfamily_interface(distribution;
test_fisherinformation_properties = true,
test_fisherinformation_against_hessian = true,
test_fisherinformation_against_jacobian = true,
option_assume_no_allocations = false
option_assume_no_allocations = false,
)
T = ExponentialFamily.exponential_family_typetag(distribution)

Expand Down Expand Up @@ -192,7 +192,7 @@ function run_test_isproper(distribution; assume_no_allocations = true)
end
end

function run_test_basic_functions(distribution; nsamples = 10, test_gradients = true, assume_no_allocations = true)
function run_test_basic_functions(distribution; nsamples = 10, test_gradients = true, test_samples_logpdf = true, assume_no_allocations = true)
T = ExponentialFamily.exponential_family_typetag(distribution)

ef = @inferred(convert(ExponentialFamilyDistribution, distribution))
Expand Down Expand Up @@ -266,6 +266,11 @@ function run_test_basic_functions(distribution; nsamples = 10, test_gradients =
@test @allocated(sufficientstatistics(ef, x)) === 0
end
end

if test_samples_logpdf
@test @inferred(logpdf(ef, samples)) map((s) -> logpdf(distribution, s), samples)
@test @inferred(pdf(ef, samples)) map((s) -> pdf(distribution, s), samples)
end
end

function run_test_fisherinformation_properties(distribution; test_properties_in_natural_space = true, test_properties_in_mean_space = true)
Expand Down
2 changes: 1 addition & 1 deletion test/distributions/matrix_dirichlet_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ end

@testitem "MatrixDirichlet: mean(::typeof(log))" begin
include("distributions_setuptests.jl")

import Base.Broadcast: BroadcastFunction

@test mean(BroadcastFunction(log), MatrixDirichlet([1.0 1.0; 1.0 1.0; 1.0 1.0])) [
Expand Down
3 changes: 3 additions & 0 deletions test/distributions/mv_normal_wishart_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ end
ef = test_exponentialfamily_interface(
d;
option_assume_no_allocations = false,
test_basic_functions = false,
test_fisherinformation_against_hessian = false,
test_fisherinformation_against_jacobian = false
)

run_test_basic_functions(ef; assume_no_allocations = false, test_samples_logpdf = false)
end
end
end
Expand Down
1 change: 0 additions & 1 deletion test/exponential_family_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ end
@test all(@inferred(sufficientstatistics(_similar, 2.0)) .≈ (2.0, log(2.0)))
@test @inferred(logpartition(_similar, η)) 0.25
@test @inferred(getsupport(_similar)) == RealInterval(0, Inf)

end
end

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Aqua, CpuId, ReTestItems, ExponentialFamily

# `ambiguities = false` - there are quite some ambiguities, but these should be normal and should not be encountered under normal circumstances
# `piracy = false` - we extend/add some of the methods to the objects defined in the Distributions.jl
Aqua.test_all(ExponentialFamily, ambiguities = false, piracy = false)
Aqua.test_all(ExponentialFamily, ambiguities = false, deps_compat = (; check_extras = false, check_weakdeps = true), piracy = false)

nthreads = max(cputhreads(), 1)
ncores = max(cpucores(), 1)
Expand Down

0 comments on commit 7348a2c

Please sign in to comment.