Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CirculantNormal distribution. #1988

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

tillahoffmann
Copy link
Contributor

@tillahoffmann tillahoffmann commented Feb 26, 2025

This PR adds a CirculantNormal distribution which is a multivariate normal distribution where the covariance matrix has circular boundary conditions. It uses the Fourier transform introduced in #1762 to evaluate the log probability, scaling with n * log(n) instead of the usual n ** 3. The PR

  • Implements the CirculantNormal distribution.
  • Adds a PackRealFastFourierCoefficientsTransform transform. This is required to transform a real vector of size n to a complex vector of size n // 2 + 1 which can be passed to the rfft method. See Add complex constraint and real Fourier transform. #1762 (comment) for a brief initial discussion.
  • Adds a positive_definite_circulant_vector constraint. This is really just required for the test/test_distributions.py::test_distribution_constraints tests to pass.
  • Adds an example notebook, illustrating how the CirculantNormal can be used to accelerate Gaussian process inference.
  • Fix the Jacobian of the Fourier transform (which turns out not to be the identity but depends on the norm parameter).

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@@ -0,0 +1,293 @@
{
Copy link
Contributor

@juanitorduz juanitorduz Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to put the legends outside the plots as in https://stackoverflow.com/questions/4700614/how-to-put-the-legend-outside-the-plot :)


Reply via ReviewNB

assert covariance_row.shape[-1] == n
covariance_rfft = jnp.fft.rfft(covariance_row).real
shape = jnp.broadcast_shapes(loc.shape, covariance_row.shape)
self.covariance_row = jnp.broadcast_to(covariance_row, shape)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm broadcasting to the full shape here instead of just promoting the shapes. That's to guarantee that covariance_row always has shape batch_shape + event_shape. If we only promoted shapes, we might end up with a shape (1, 1, n) if the input covariance_row has shape (n,) but the input loc has shape (a, b, n).

As an aside, this may also be relevant to MultivariateNormal distribution where covariance_matrix may not have the right batch dimensions.

>>> from jax import numpy as jnp
>>> from numpyro.distributions import MultivariateNormal
>>> 
>>> d = MultivariateNormal(jnp.zeros((3, 4, 5)), jnp.eye(5))
>>> d
<numpyro.distributions.continuous.MultivariateNormal object at 0x145cc7d00 with 
batch shape (3, 4) and event shape (5,)>
>>> d.covariance_matrix.shape
(1, 1, 5, 5)  # Expected (3, 4, 5, 5)

@tillahoffmann
Copy link
Contributor Author

tillahoffmann commented Feb 27, 2025

I think the failing tests are due to issues with the data download and not related to the changes in this PR. There are a few further failing tests due to numerics (maybe associated with a new jax release?).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants