Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues sampling with PyMCv5 (gp.marginal fails when combined modelling GP+transit) #124

Open
hposborn opened this issue May 28, 2024 · 4 comments

Comments

@hposborn
Copy link

hposborn commented May 28, 2024

I am having issues using celerite2 with PYMC. In the past (with PyMC3) I always used gp.marginal(observed=y-extra_model) in order to sample models which included both GP and other variables (i.e. a transit model) and this had no issue. For whatever reason that is no longer the case with PyMCv5 and I get TypeError: Variables that depend on other nodes cannot be used for observed data..
I thought an easy alternative would be to initialise with gp.compute(), generate a predicted GP curve with gp.predict(), and then model everything the "classical" way in PyMC using pm.Normal(mu=gp_pred+extra_model, sigma=y_err, observed=y). But this gives completely different, and horrendously overfitted, results from using gp.marginal() for the same model. (see below)

So I would love some advice on how to model combined celerite + additional functions:
a) Is there any way to sample using gp.marginal() where the observed data can depend on other PyMC parameters? For example, maybe the mean function could be more than a single value but to have N_t values and we can put the transit model in that way?
b) How should sampling within PyMC be done if using gp.marginal() with y-extra_model is not possible? Should we be using gp.predict() for this purpose at all, or is there just a step I'm missing which is causing the drastic overfitting?

Some code as a MWE:

import pymc as pm
import pymc_ext as pmx
import celerite2.pymc
import arviz as az

import exoplanet as xo

import numpy as np
import matplotlib.pyplot as plt

#Initialising some sinusoidal terms to act as something for GP to remove:
sin_amps=np.exp(np.random.normal(-3,0.2,5))
sin_t0s=np.random.normal(0,15,5)
sin_pers=np.exp(np.random.normal(2,0.5,5))

#Initialising transit parameters:
i_Rs=0.8;i_Ms=0.76
i_us=np.array([0.1,0.3])
i_t0=3.197652;i_P=12.59219 #days
i_b=0.393
i_rpl=3.1309 #Rearth
i_rprs=i_rpl/109.2*i_Rs

#Creating fake data by doing LimbDarkLightCurve
t=np.arange(0,50,1/50)
flux_err=np.tile(0.15,2500)
pure_flux = 1000*xo.LimbDarkLightCurve(i_us).get_light_curve(orbit=xo.orbits.KeplerianOrbit(r_star=i_Rs,m_star=i_Ms,period=i_P,t0=i_t0,b=i_b), r=i_rprs*i_Rs, t=t).eval()[:,0] + \
                      np.sum(sin_amps[:,None]*np.sin(2*np.pi*(t[None,:]-sin_t0s[:,None])/sin_pers[:,None]),axis=0)
flux=pure_flux+np.random.normal(0.0,np.nanmedian(flux_err),2500)

#Plotting to check:
plt.plot(t,flux,'.')
plt.plot(t,pure_flux,'--',alpha=0.7)

MWE_true_variation

The anticipated behaviour, using gp.marginal() (no transit):

with pm.Model() as model:
    logjit =pm.Normal('logjit', mu=np.log(np.std(flux)), sigma=1)

    #Initialising GP:
    sigma=pm.InverseGamma("sigma", **pmx.utils.estimate_inverse_gamma_parameters(lower=np.nanmedian(flux_err),
                                                              upper=np.ptp(flux),target=0.01))
    w0=pm.InverseGamma("w0", **pmx.utils.estimate_inverse_gamma_parameters(lower=(2*np.pi)/10,
                                                              upper=(2*np.pi)/0.2,target=0.01))
    kernel = celerite2.pymc.terms.SHOTerm(sigma=sigma, w0=w0, Q=1/np.sqrt(2))

    mean_flux = pm.Normal("mean", mu=0.0, sigma=0.5*np.nanstd(flux))
    gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux)
    gp.compute(t, yerr=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), quiet=True)
    loglik=gp.marginal("loglik", observed=flux)# - light_curve)

    gp_pred=pm.Deterministic("gp_pred",gp.predict(flux,return_var=False))
    wmarg_init_soln=pm.find_MAP()
    wmarg_trace=pm.sample(start=wmarg_init_soln)

plt.plot(t,flux,'.',alpha=0.6)
plt.plot(t,np.nanmedian(wmarg_trace.posterior['gp_pred'],axis=(0,1)),'-')
plt.savefig("MWE_fit_wmarg.png")

MWE_fit_wmarg

The arviz summary:

 	mean 	sd 	hdi_3% 	hdi_97% 	mcse_mean 	mcse_sd 	ess_bulk 	ess_tail 	r_hat
logjit 	-3.097 	0.194 	-3.436 	-2.759 	0.004 	0.003 	2858.0 	2106.0 	1.0
mean 	-0.003 	0.024 	-0.048 	0.045 	0.001 	0.000 	2031.0 	1516.0 	1.0
sigma 	0.119 	0.017 	0.089 	0.150 	0.000 	0.000 	1796.0 	1514.0 	1.0
w0 	1.599 	0.428 	0.916 	2.432 	0.011 	0.008 	1431.0 	1696.0 	1.0

So a GP-only model works fine.

The behaviour when including an additional non-celerite mean function (with exoplanet transit):

with pm.Model() as model:
    Rs=pm.Normal("Rs",mu=0.8,sigma=0.02)
    Ms=pm.Normal("Ms",mu=0.78,sigma=0.02)
    P=pm.Normal("P",mu=12.6,sigma=0.01)
    t0=pm.Normal("t0",mu=3.21,sigma=0.04)
    log_rprs=pm.Normal("log_rprs",mu=-4,sigma=3)
    rprs=pm.Deterministic("rprs",pm.math.exp(log_rprs))
    rpl=pm.Deterministic("rpl",rprs*Rs*109.2)
    b=xo.distributions.ImpactParameter("b",ror=rprs)
    orb = xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b)
    u_stars = xo.distributions.QuadLimbDark("u_star", testval=np.array([0.3, 0.2]))
    light_curve = 1000*xo.LimbDarkLightCurve(u_stars).get_light_curve(orbit=xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b), r=rprs*Rs, t=t)

    logjit =pm.Normal('logjit', mu=np.log(np.std(flux)), sigma=1)

    #Initialising GP:
    sigma=pm.InverseGamma("sigma", **pmx.utils.estimate_inverse_gamma_parameters(lower=np.nanmedian(flux_err),
                                                              upper=np.ptp(flux),target=0.01))
    w0=pm.InverseGamma("w0", **pmx.utils.estimate_inverse_gamma_parameters(lower=(2*np.pi)/10,
                                                              upper=(2*np.pi)/0.2,target=0.01))
    kernel = celerite2.pymc.terms.SHOTerm(sigma=sigma, w0=w0, Q=1/np.sqrt(2))

    mean_flux = pm.Normal("mean", mu=0.0, sigma=0.5*np.nanstd(flux))
    gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux)
    gp.compute(t, yerr=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), quiet=True)
    loglik=gp.marginal("loglik", observed=flux - light_curve)
    pm.find_MAP()

Output:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[25], line 26
     24 gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux)
     25 gp.compute(t, yerr=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), quiet=True)
---> 26 loglik=gp.marginal("loglik", observed=flux - light_curve)
     27 pm.find_MAP()

File ~/miniconda3/envs/chx/lib/python3.9/site-packages/celerite2/pymc/celerite2.py:96, in GaussianProcess.marginal(self, name, **kwargs)
     93 from celerite2.pymc.distribution import CeleriteNormal
     95 self._add_citations_to_pymc_model(**kwargs)
---> 96 return CeleriteNormal(
     97     name,
     98     self._mean_value,
     99     self._norm,
    100     self._t,
    101     self._c,
    102     self._U,
    103     self._W,
    104     self._d,
    105     **kwargs
    106 )

File ~/miniconda3/envs/chx/lib/python3.9/site-packages/pymc/distributions/distribution.py:413, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
    409         kwargs["shape"] = tuple(observed.shape)
    411 rv_out = cls.dist(*args, **kwargs)
--> 413 rv_out = model.register_rv(
    414     rv_out,
    415     name,
    416     observed,
    417     total_size,
    418     dims=dims,
    419     transform=transform,
    420     initval=initval,
    421 )
    423 # add in pretty-printing support
    424 rv_out.str_repr = types.MethodType(str_for_dist, rv_out)

File ~/miniconda3/envs/chx/lib/python3.9/site-packages/pymc/model/core.py:1265, in Model.register_rv(self, rv_var, name, observed, total_size, dims, transform, initval)
   1252 else:
   1253     if (
   1254         isinstance(observed, Variable)
   1255         and not isinstance(observed, GenTensorVariable)
   (...)
   1263         and not is_minibatch(observed)
   1264     ):
-> 1265         raise TypeError(
   1266             "Variables that depend on other nodes cannot be used for observed data."
   1267             f"The data variable was: {observed}"
   1268         )
   1270     # `rv_var` is potentially changed by `make_obs_var`,
   1271     # for example into a new graph for imputation of missing data.
   1272     rv_var = self.make_obs_var(rv_var, observed, dims, transform, total_size)

TypeError: Variables that depend on other nodes cannot be used for observed data.The data variable was: Sub.0

I have verified that the same error occurs across different computers (both my M2 Mac and linux server).

The behaviour when sampling with the output of gp.predict():

with pm.Model() as model:
    logjit =pm.Normal('logjit', mu=np.log(np.std(flux)), sigma=1)

    #Initialising GP:
    sigma=pm.InverseGamma("sigma", **pmx.utils.estimate_inverse_gamma_parameters(lower=np.nanmedian(flux_err),
                                                              upper=np.ptp(flux),target=0.01))
    w0=pm.InverseGamma("w0", **pmx.utils.estimate_inverse_gamma_parameters(lower=(2*np.pi)/10,
                                                              upper=(2*np.pi)/0.2,target=0.01))
    kernel = celerite2.pymc.SHOTerm(sigma=sigma, w0=w0, Q=1/np.sqrt(2))

    mean_flux = pm.Normal("mean", mu=0.0, sigma=0.5*np.nanstd(flux))
    gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux)
    gp.compute(t, yerr=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), quiet=True)
    gp_pred=pm.Deterministic("gp_pred",gp.predict(flux, return_var=False))
    loglik=pm.Normal("loglik", mu=gp_pred, sigma=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), observed=flux)# - light_curve)

    nomarg_init_soln=pm.find_MAP()
    nomarg_trace=pm.sample(start=nomarg_init_soln)

plt.plot(t,flux,'.',alpha=0.6)
plt.plot(t,np.nanmedian(nomarg_trace.posterior['gp_pred'],axis=(0,1)),'-')
plt.savefig("MWE_fit_nomarg.png")

MWE_fit_nomarg

The arviz summary:

 	mean 	sd 	hdi_3% 	hdi_97% 	mcse_mean 	mcse_sd 	ess_bulk 	ess_tail 	r_hat
logjit 	-5.297 	0.391 	-6.014 	-4.597 	0.009 	0.006 	2263.0 	1961.0 	1.00
mean 	-0.000 	0.084 	-0.150 	0.164 	0.002 	0.001 	2608.0 	2583.0 	1.00
sigma 	1.559 	0.874 	0.652 	2.833 	0.030 	0.021 	1389.0 	833.0 	1.00
w0 	163.862 	205.731 	21.684 	393.793 	7.451 	5.271 	992.0 	849.0 	1.01

This is clearly extremely over-fitted for some reason...

@hposborn hposborn changed the title Issues sampling with PyMCv5 (gp.marginal fails with additional parameter inputs) Issues sampling with PyMCv5 (gp.marginal fails when combined modelling GP+transit) May 28, 2024
@hposborn
Copy link
Author

Ok, it looks like using pm.Potential(gp.log_likelihood(y-extra_model)) is the way to go:

with pm.Model() as model:
    Rs=pm.Normal("Rs",mu=0.8,sigma=0.02)
    Ms=pm.Normal("Ms",mu=0.78,sigma=0.02)
    P=pm.Normal("P",mu=12.6,sigma=0.01)
    t0=pm.Normal("t0",mu=3.21,sigma=0.04)
    log_rprs=pm.Normal("log_rprs",mu=-4,sigma=3,initval=-2)
    rprs=pm.Deterministic("rprs",pm.math.exp(log_rprs))
    rpl=pm.Deterministic("rpl",rprs*Rs*109.2)
    b=xo.distributions.ImpactParameter("b",ror=rprs,initval=0.4)
    orb = xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b)
    u_stars = xo.distributions.QuadLimbDark("u_star", testval=np.array([0.3, 0.2]))
    lightcurve=pm.Deterministic('lightcurve',1000*xo.LimbDarkLightCurve(u_stars).get_light_curve(orbit=xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b), r=rprs*Rs, t=t))
    logjit =pm.Normal('logjit', mu=np.log(np.std(flux)), sigma=1)

    #Initialising GP:
    sigma=pm.InverseGamma("sigma", **pmx.utils.estimate_inverse_gamma_parameters(lower=np.nanmedian(flux_err),
                                                              upper=np.ptp(flux),target=0.01))
    w0=pm.InverseGamma("w0", **pmx.utils.estimate_inverse_gamma_parameters(lower=(2*np.pi)/10,
                                                              upper=(2*np.pi)/0.2,target=0.01))
    kernel = celerite2.pymc.terms.SHOTerm(sigma=sigma, w0=w0, Q=1/np.sqrt(2))

    mean_flux = pm.Normal("mean_flux", mu=0.0, sigma=0.5*np.nanstd(flux))
    gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux, t=t, diag=flux_err** 2 + pm.math.exp(logjit)**2)
    loglik=pm.Potential("loglik", gp.log_likelihood(flux-pm.math.sum(lightcurve,axis=1)))
    gp_pred=pm.Deterministic("gp_pred", gp.predict(flux-pm.math.sum(lightcurve,axis=1), return_var=False))
    
    #nomarg_trans_init_soln=pm.find_MAP()
    nomarg_trans_trace=pm.sample()

plt.plot(t,flux,'.',alpha=0.6)
plt.plot(t,np.nanmedian(nomarg_trans_trace.posterior['gp_pred'].values+nomarg_trans_trace.posterior['lightcurve'].values[:,:,:,0],axis=(0,1)),'-')
plt.savefig("MWE_fit_nomarg_trans.png")

MWE_fit_nomarg_trans

So that seems to fix it! Though I am apprehensive about this as having a bit of a blackbox likelihood function - sometimes that doesn't play ball with some arviz functions like WAIC, so any advice on directly calling pm.Normal or gp.marginal would still be useful imho.

@TylerFair
Copy link

I'm also coming across this issue, thanks @hposborn for the current fix! I imagine the only workaround to this is if marginal calls something like pm.Potential(pm.logp(pm.MvNorm(<gp.marginal's vars>), observed=y-transit_model))... It seems like pymc5+> are keeping observed as observed onwards.

@hposborn
Copy link
Author

hposborn commented Oct 8, 2024

My "fix" doesn't really work - I end up with compilation errors when running large models that way. So I'd love to know the official way of sampling with pymcv5 and both a GP and additional models...

@TylerFair
Copy link

Hi @hposborn , I believe the solution to this is to indeed model the classical way of lc+GP = observed, but this is accessed through setting the GP mean as your light curve. Using your code:

    lightcurve=pm.Deterministic('lightcurve',1000*xo.LimbDarkLightCurve(u_stars).get_light_curve(orbit=xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b), r=rprs*Rs, t=t))
    logjit =pm.Normal('logjit', mu=np.log(np.std(flux)), sigma=1)

    #Initialising GP:
    sigma=pm.InverseGamma("sigma", **pmx.utils.estimate_inverse_gamma_parameters(lower=np.nanmedian(flux_err),
                                                              upper=np.ptp(flux),target=0.01))
    w0=pm.InverseGamma("w0", **pmx.utils.estimate_inverse_gamma_parameters(lower=(2*np.pi)/10,
                                                              upper=(2*np.pi)/0.2,target=0.01))
    kernel = celerite2.pymc.terms.SHOTerm(sigma=sigma, w0=w0, Q=1/np.sqrt(2))

    mean_flux = pm.Normal("mean_flux", mu=0.0, sigma=0.5*np.nanstd(flux))
    gp = celerite2.pymc.GaussianProcess(kernel, mean=(lightcurve+mean_flux))
    gp.compute(t, diag=flux_err** 2 + pm.math.exp(logjit)**2, quiet=True)

    pm.Deterministic(f'gp_pred', gp.predict(flux))

     gp.marginal(f'obs', observed=flux)    

(Obviously making sure flux and lightcurve+mean_flux are centered around the same value)

Let me know if this works! I believe this should be the solution moving from pymc>=5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants