Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient of CARMA log probability wrt kernel parameters produces NaNs #228

Open
davecwright3 opened this issue Dec 4, 2024 · 1 comment

Comments

@davecwright3
Copy link

davecwright3 commented Dec 4, 2024

I originally found this issue when trying to use a CARMA kernel in NumPyro HMC. I receive only NaNs when evaluating the gradient of the log probability wrt the kernel parameters.

Things I've tried to resolve/narrow down the issue

  • adding increasingly larger diag values to the GP
  • many different values for the CARMA parameters
  • double vs single precision
  • different CARMA(p, q) models other than (1,0)---they still produce NaNs
  • other quasiseparable kernels (these work as they should)

Below is a minimal reproducible example that doesn't involve Numpyro.

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from tinygp import GaussianProcess, kernels


# CARMA(1,0)
def build_gp_drw(params, x):
    kernel = kernels.quasisep.CARMA(params["alpha"], params["beta"])
    gp = GaussianProcess(kernel, x)
    return gp


x = jnp.linspace(1, 100)
y = jnp.sin(x) + 1e-2*jax.random.normal(jax.random.key(5), x.shape)

params = {"alpha": jnp.array([0.01]), "beta": jnp.array([0.1])}
drw_gp = build_gp_drw(params, x)


@jax.jit
def loss(params):
    gp = build_gp_drw(params, x)
    return -gp.log_probability(y)
>>> loss(params)
Array(194.17045899, dtype=float64)
>>> jax.grad(loss)(params)
{'alpha': Array([nan], dtype=float64), 'beta': Array([nan], dtype=float64)}
@davecwright3
Copy link
Author

davecwright3 commented Dec 4, 2024

I did test this with the celerite2.jax DRW kernel (RealTerm), and it works as expected.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant