Skip to content

Commit

Permalink
Add analytic KL divergence of diagonal Normal from CirculantNormal.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Feb 27, 2025
1 parent 9f6dee2 commit c6e5092
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
22 changes: 22 additions & 0 deletions numpyro/distributions/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from numpyro.distributions.continuous import (
Beta,
CirculantNormal,
Dirichlet,
Gamma,
Kumaraswamy,
Expand Down Expand Up @@ -183,6 +184,27 @@ def _shapes_are_broadcastable(first_shape, second_shape):
return 0.5 * (tr + t1 - D - log_det_ratio)


@dispatch(Independent, CirculantNormal)
def kl_divergence(p: Independent, q: CirculantNormal):
# We can only calculate the KL divergence if the base distribution is normal.
if not isinstance(p.base_dist, Normal) or p.reinterpreted_batch_ndims != 1:
raise NotImplementedError

residual = q.mean - p.mean
n = residual.shape[-1]
log_covariance_rfft = jnp.log(q.covariance_rfft)
return (
jnp.vecdot(
residual, jnp.fft.irfft(jnp.fft.rfft(residual) / q.covariance_rfft, n)
)
+ jnp.fft.irfft(1 / q.covariance_rfft, n)[..., 0] * jnp.sum(p.variance, axis=-1)
+ log_covariance_rfft.sum(axis=-1)
+ log_covariance_rfft[..., 1 : (n + 1) // 2].sum(axis=-1)
- jnp.log(p.variance).sum(axis=-1)
- n
) / 2


@dispatch(Beta, Beta)
def kl_divergence(p, q):
# From https://en.wikipedia.org/wiki/Beta_distribution#Quantities_of_information_(entropy)
Expand Down
19 changes: 19 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3185,6 +3185,25 @@ def make_dist(dist_class):
assert_allclose(actual, expected, rtol=0.05)


@pytest.mark.parametrize("shape", [(3, 2, 10), (3, 2, 11), (10,), (11,)], ids=str)
def test_kl_circulant_normal_consistency(shape: tuple) -> None:
key1, key2, key3, key4 = random.split(random.key(9), 4)
p = dist.Normal(random.normal(key1, shape), random.gamma(key2, 3, shape)).to_event(
1
)
# covariance_rfft = jnp.exp(-jnp.arange(shape[-1] // 2 + 1))
covariance_rfft = random.gamma(key4, 10, shape[:-1] + (shape[-1] // 2 + 1,)) / 10
q = dist.CirculantNormal(
random.normal(key3, shape), covariance_rfft=covariance_rfft
)
actual = kl_divergence(p, q)
expected = kl_divergence(
dist.MultivariateNormal(p.mean, jnp.eye(shape[-1]) * p.variance[..., None]),
dist.MultivariateNormal(q.mean, q.covariance_matrix),
)
assert_allclose(actual, expected, rtol=1e-6)


@pytest.mark.parametrize("shape", [(4,), (2, 3)], ids=str)
def test_kl_dirichlet_dirichlet(shape):
p = dist.Dirichlet(np.exp(np.random.normal(size=shape)))
Expand Down

0 comments on commit c6e5092

Please sign in to comment.