Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CirculantNormal distribution. #1988

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ Chi2
:show-inheritance:
:member-order: bysource

CirculantNormal
^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.continuous.CirculantNormal
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Dirichlet
^^^^^^^^^
.. autoclass:: numpyro.distributions.continuous.Dirichlet
Expand Down Expand Up @@ -998,6 +1006,14 @@ OrderedTransform
:show-inheritance:
:member-order: bysource

PackRealFastFourierCoefficientsTransform
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.transforms.PackRealFastFourierCoefficientsTransform
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

PermuteTransform
^^^^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.transforms.PermuteTransform
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ NumPyro documentation
tutorials/censoring
tutorials/hsgp_example
tutorials/other_samplers
tutorials/circulant_gp

.. nbgallery::
:maxdepth: 1
Expand Down
293 changes: 293 additions & 0 deletions notebooks/source/circulant_gp.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion notebooks/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@

if "READTHEDOCS" not in os.environ:
# if developing locally, use numpyro.__version__ as version
from numpyro import __version__ # noqaE402
from numpyro import __version__ # noqa: E402

version = __version__

Expand Down
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
BetaProportion,
Cauchy,
Chi2,
CirculantNormal,
Dirichlet,
EulerMaruyama,
Exponential,
Expand Down Expand Up @@ -132,6 +133,7 @@
"CategoricalProbs",
"Cauchy",
"Chi2",
"CirculantNormal",
"Delta",
"Dirichlet",
"DirichletMultinomial",
Expand Down
15 changes: 15 additions & 0 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"nonnegative_integer",
"positive",
"positive_definite",
"positive_definite_circulant_vector",
"positive_semidefinite",
"positive_integer",
"real",
Expand Down Expand Up @@ -642,6 +643,19 @@ def feasible_like(self, prototype):
)


class _PositiveDefiniteCirculantVector(_SingletonConstraint):
event_dim = 1

def __call__(self, x):
jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy
tol = 10 * jnp.finfo(x.dtype).eps
rfft = jnp.fft.rfft(x)
return (jnp.abs(rfft.imag) < tol) & (rfft.real > -tol)

def feasible_like(self, prototype):
return jnp.zeros_like(prototype).at[..., 0].set(1.0)


class _PositiveSemiDefinite(_SingletonConstraint):
event_dim = 2

Expand Down Expand Up @@ -792,6 +806,7 @@ def tree_flatten(self):
ordered_vector = _OrderedVector()
positive = _Positive()
positive_definite = _PositiveDefinite()
positive_definite_circulant_vector = _PositiveDefiniteCirculantVector()
positive_semidefinite = _PositiveSemiDefinite()
positive_integer = _IntegerPositive()
positive_ordered_vector = _PositiveOrderedVector()
Expand Down
107 changes: 106 additions & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import jax.nn as nn
import jax.numpy as jnp
import jax.random as random
from jax.scipy.linalg import cho_solve, solve_triangular
from jax.scipy.linalg import cho_solve, solve_triangular, toeplitz
from jax.scipy.special import (
betaln,
digamma,
Expand All @@ -59,7 +59,9 @@
CholeskyTransform,
CorrMatrixCholeskyTransform,
ExpTransform,
PackRealFastFourierCoefficientsTransform,
PowerTransform,
RealFastFourierTransform,
RecursiveLinearTransform,
SigmoidTransform,
ZeroSumTransform,
Expand Down Expand Up @@ -3068,3 +3070,106 @@ def entropy(self) -> ArrayLike:
return jnp.broadcast_to(
0.5 + 1.5 * jnp.euler_gamma + 0.5 * jnp.log(16 * jnp.pi), self.batch_shape
) + jnp.log(self.scale)


class CirculantNormal(TransformedDistribution):
"""
Multivariate normal distribution with circulant covariance matrix.
Args:
loc: Mean of the distribution.
covariance_row: First row of the circulant covariance matrix.
covariance_rfft: Real part of the Fourier transform of :code:`covariance_row`.
"""

arg_constraints = {
"loc": constraints.real_vector,
"covariance_row": constraints.positive_definite_circulant_vector,
"covariance_rfft": constraints.independent(constraints.positive, 1),
}
support = constraints.real_vector

def __init__(
self,
loc: jnp.ndarray,
covariance_row: jnp.ndarray = None,
covariance_rfft: jnp.ndarray = None,
*,
validate_args=None,
) -> None:
# We demand a one-dimensional input, because we cannot determine the event shape
# if only the covariance_rfft is given.
assert jnp.ndim(loc) > 0, "Location parameter must have at least one dimension."
n = jnp.shape(loc)[-1]
assert_one_of(covariance_row=covariance_row, covariance_rfft=covariance_rfft)

# Evaluate covariance_rfft if not provided and validate.
if covariance_rfft is None:
assert covariance_row.shape[-1] == n
covariance_rfft = jnp.fft.rfft(covariance_row).real
shape = jnp.broadcast_shapes(loc.shape, covariance_row.shape)
self.covariance_row = jnp.broadcast_to(covariance_row, shape)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm broadcasting to the full shape here instead of just promoting the shapes. That's to guarantee that covariance_row always has shape batch_shape + event_shape. If we only promoted shapes, we might end up with a shape (1, 1, n) if the input covariance_row has shape (n,) but the input loc has shape (a, b, n).

As an aside, this may also be relevant to MultivariateNormal distribution where covariance_matrix may not have the right batch dimensions.

>>> from jax import numpy as jnp
>>> from numpyro.distributions import MultivariateNormal
>>> 
>>> d = MultivariateNormal(jnp.zeros((3, 4, 5)), jnp.eye(5))
>>> d
<numpyro.distributions.continuous.MultivariateNormal object at 0x145cc7d00 with 
batch shape (3, 4) and event shape (5,)>
>>> d.covariance_matrix.shape
(1, 1, 5, 5)  # Expected (3, 4, 5, 5)

self.loc = loc
self.covariance_rfft = covariance_rfft

# Construct the base distribution.
n_real = n // 2 + 1
n_imag = n - n_real
assert self.covariance_rfft.shape[-1] == n_real
var_rfft = (n * covariance_rfft / 2).at[..., 0].mul(2)
if n % 2 == 0:
var_rfft = var_rfft.at[..., -1].mul(2)
var_rfft = jnp.concatenate([var_rfft, var_rfft[..., 1 : 1 + n_imag]], axis=-1)
assert var_rfft.shape[-1] == n
base_distribution = Normal(scale=jnp.sqrt(var_rfft)).to_event(1)

super().__init__(
base_distribution,
[
PackRealFastFourierCoefficientsTransform((n,)),
RealFastFourierTransform((n,)).inv,
AffineTransform(loc, scale=1.0),
],
validate_args=validate_args,
)

@property
def mean(self) -> jnp.ndarray:
return jnp.broadcast_to(self.loc, self.shape())

@lazy_property
def covariance_row(self) -> jnp.ndarray:
return jnp.broadcast_to(
jnp.fft.irfft(self.covariance_rfft, n=self.event_shape[-1]), self.shape()
)

@lazy_property
def covariance_matrix(self) -> jnp.ndarray:
if self.batch_shape:
# `toeplitz` flattens the input, and we need to broadcast manually.
(n,) = self.event_shape
return vmap(toeplitz)(self.covariance_row.reshape((-1, n))).reshape(
self.batch_shape + (n, n)
)
else:
return toeplitz(self.covariance_row)

@lazy_property
def variance(self) -> jnp.ndarray:
return jnp.broadcast_to(self.covariance_row[..., 0, None], self.shape())

@staticmethod
def infer_shapes(
loc: tuple = (), covariance_row: tuple = None, covariance_rfft: tuple = None
):
assert_one_of(covariance_row=covariance_row, covariance_rfft=covariance_rfft)
for cov in [covariance_rfft, covariance_row]:
if cov is not None:
batch_shape = jnp.broadcast_shapes(loc[:-1], cov[:-1])
event_shape = loc[-1:]
return batch_shape, event_shape

def entropy(self):
(n,) = self.event_shape
log_abs_det_jacobian = 2 * jnp.log(2) * ((n - 1) // 2) - jnp.log(n) * n
return self.base_dist.entropy() + log_abs_det_jacobian / 2
22 changes: 22 additions & 0 deletions numpyro/distributions/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from numpyro.distributions.continuous import (
Beta,
CirculantNormal,
Dirichlet,
Gamma,
Kumaraswamy,
Expand Down Expand Up @@ -183,6 +184,27 @@ def _shapes_are_broadcastable(first_shape, second_shape):
return 0.5 * (tr + t1 - D - log_det_ratio)


@dispatch(Independent, CirculantNormal)
def kl_divergence(p: Independent, q: CirculantNormal):
# We can only calculate the KL divergence if the base distribution is normal.
if not isinstance(p.base_dist, Normal) or p.reinterpreted_batch_ndims != 1:
raise NotImplementedError

residual = q.mean - p.mean
n = residual.shape[-1]
log_covariance_rfft = jnp.log(q.covariance_rfft)
return (
jnp.vecdot(
residual, jnp.fft.irfft(jnp.fft.rfft(residual) / q.covariance_rfft, n)
)
+ jnp.fft.irfft(1 / q.covariance_rfft, n)[..., 0] * jnp.sum(p.variance, axis=-1)
+ log_covariance_rfft.sum(axis=-1)
+ log_covariance_rfft[..., 1 : (n + 1) // 2].sum(axis=-1)
- jnp.log(p.variance).sum(axis=-1)
- n
) / 2


@dispatch(Beta, Beta)
def kl_divergence(p, q):
# From https://en.wikipedia.org/wiki/Beta_distribution#Quantities_of_information_(entropy)
Expand Down
78 changes: 76 additions & 2 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"LowerCholeskyTransform",
"ScaledUnitLowerCholeskyTransform",
"LowerCholeskyAffine",
"PackRealFastFourierCoefficientsTransform",
"PermuteTransform",
"PowerTransform",
"RealFastFourierTransform",
Expand Down Expand Up @@ -1311,10 +1312,15 @@ def inverse_shape(self, shape: tuple) -> tuple:
def log_abs_det_jacobian(
self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None
) -> jnp.ndarray:
shape = jnp.broadcast_shapes(
batch_shape = jnp.broadcast_shapes(
x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims]
)
return jnp.zeros_like(x, shape=shape)
event_shape = x.shape[-self.transform_ndims :]
size = math.prod(event_shape)
q = math.prod(2 - size % 2 for size in event_shape)
return jnp.broadcast_to(
(size * jnp.log(size) - jnp.log(2) * (size - q)) / 2, batch_shape
)

def tree_flatten(self):
aux_data = {
Expand All @@ -1339,6 +1345,74 @@ def __eq__(self, other):
)


class PackRealFastFourierCoefficientsTransform(Transform):
"""
Transform a real vector to complex coefficients of a real fast Fourier transform.

:param transform_shape: Shape of the real vector, defaults to the input size.
"""

domain = constraints.real_vector
codomain = constraints.independent(constraints.complex, 1)

def __init__(self, transform_shape: tuple = None) -> None:
assert transform_shape is None or len(transform_shape) == 1, (
"Packing Fourier coefficients is only implemented for vectors."
)
self.shape = transform_shape

def tree_flatten(self):
return (), ((), {"shape": self.shape})

def forward_shape(self, shape: tuple) -> tuple:
*batch_shape, n = shape
assert self.shape is None or self.shape == (n,), (
f"`shape` must be `None` or `{self.shape}. Got `{shape}`."
)
n_rfft = n // 2 + 1
return (*batch_shape, n_rfft)

def inverse_shape(self, shape: tuple) -> tuple:
*batch_shape, n_rfft = shape
assert self.shape is not None, (
"Shape must be specified in `__init__` for inverse transform."
)
(n,) = self.shape
assert n_rfft == n // 2 + 1
return (*batch_shape, n)

def log_abs_det_jacobian(
self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None
) -> jnp.ndarray:
shape = jnp.broadcast_shapes(x.shape[:-1], y.shape[:-1])
return jnp.zeros_like(x, shape=shape)

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
assert self.shape is None or self.shape == x.shape[-1:]
n = x.shape[-1]
n_real = n // 2 + 1
n_imag = n - n_real
complex_dtype = jnp.result_type(x.dtype, jnp.complex64)
return (
x[..., :n_real]
.astype(complex_dtype)
.at[..., 1 : 1 + n_imag]
.add(1j * x[..., n_real:])
)

def _inverse(self, y: jnp.ndarray) -> jnp.ndarray:
(n,) = self.shape
n_real = n // 2 + 1
n_imag = n - n_real
return jnp.concatenate([y.real, y.imag[..., 1 : n_imag + 1]], axis=-1)

def __eq__(self, other) -> bool:
return (
isinstance(other, PackRealFastFourierCoefficientsTransform)
and self.shape == other.shape
)


class RecursiveLinearTransform(Transform):
"""
Apply a linear transformation recursively such that
Expand Down
Loading
Loading