-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CirculantNormal
distribution.
#1988
base: master
Are you sure you want to change the base?
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@@ -0,0 +1,293 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to put the legends outside the plots as in https://stackoverflow.com/questions/4700614/how-to-put-the-legend-outside-the-plot :)
Reply via ReviewNB
assert covariance_row.shape[-1] == n | ||
covariance_rfft = jnp.fft.rfft(covariance_row).real | ||
shape = jnp.broadcast_shapes(loc.shape, covariance_row.shape) | ||
self.covariance_row = jnp.broadcast_to(covariance_row, shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm broadcasting to the full shape here instead of just promoting the shapes. That's to guarantee that covariance_row
always has shape batch_shape + event_shape
. If we only promoted shapes, we might end up with a shape (1, 1, n)
if the input covariance_row
has shape (n,)
but the input loc
has shape (a, b, n)
.
As an aside, this may also be relevant to MultivariateNormal
distribution where covariance_matrix
may not have the right batch dimensions.
>>> from jax import numpy as jnp
>>> from numpyro.distributions import MultivariateNormal
>>>
>>> d = MultivariateNormal(jnp.zeros((3, 4, 5)), jnp.eye(5))
>>> d
<numpyro.distributions.continuous.MultivariateNormal object at 0x145cc7d00 with
batch shape (3, 4) and event shape (5,)>
>>> d.covariance_matrix.shape
(1, 1, 5, 5) # Expected (3, 4, 5, 5)
I think the failing tests are due to issues with the data download and not related to the changes in this PR. There are a few further failing tests due to numerics (maybe associated with a new jax release?). |
This PR adds a
CirculantNormal
distribution which is a multivariate normal distribution where the covariance matrix has circular boundary conditions. It uses the Fourier transform introduced in #1762 to evaluate the log probability, scaling withn * log(n)
instead of the usualn ** 3
. The PRCirculantNormal
distribution.PackRealFastFourierCoefficientsTransform
transform. This is required to transform a real vector of sizen
to a complex vector of sizen // 2 + 1
which can be passed to therfft
method. See Add complex constraint and real Fourier transform. #1762 (comment) for a brief initial discussion.positive_definite_circulant_vector
constraint. This is really just required for thetest/test_distributions.py::test_distribution_constraints
tests to pass.CirculantNormal
can be used to accelerate Gaussian process inference.norm
parameter).