Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GP mean prediction as the mean function in another GP #122

Open
vedad opened this issue May 16, 2024 · 0 comments
Open

GP mean prediction as the mean function in another GP #122

vedad opened this issue May 16, 2024 · 0 comments

Comments

@vedad
Copy link

vedad commented May 16, 2024

Hi Dan and celerite team

I have the following versions:

exoplanet.__version__ = '0.6.0'
celerite2.__version__ = '0.3.1'
pymc.__version__ = '5.10.4'

The goal

I want to model in-transit spot occultations with celerite, and the GP should only be conditioned on in-transit data. A separate GP with a different kernel and hyperparameters is used to describe a longer timescale correlated signal. I assume they have to be separate GPs because the data will be different lengths.

Expected result

For a single GP (gp1), I followed the exoplanet examples and created a function that outputs the transit light curve and pass that to the mean keyword in celerite2.pymc.GaussianProcess. For a second GP (gp2) my expectation was that I could create a mean function that is a sum of a light curve model and gp1.predict(). However this doesn't seem to work, it throws AttributeError: '_CeleriteOp' object has no attribute 'rev_op'. If I add an .eval() it'll run but the maximum likelihood model isn't a good fit, I suspect something goes wrong behind the scenes.

Is it possible to use the mean of a GP as a mean function in another GP?

Here's what I'm working with so far:

import numpy as np
import matplotlib.pyplot as plt
from functools import partial

import pymc as pm
import pymc_ext as pmx
import exoplanet as xo
import pytensor.tensor as pt
from celerite2.pymc import GaussianProcess, terms

np.random.seed(123)
period = np.random.uniform(3,10)
t = np.arange(-0.2, 0.2, 2/60/24)

# The light curve calculation requires an orbit
orbit = xo.orbits.KeplerianOrbit(period=period, t0=0, b=0, duration=0.15, ror=0.1)

# Compute a limb-darkened light curve using starry
u = [0.3, 0.2]
light_curve = np.sum(
    xo.LimbDarkLightCurve(u[0], u[1])
    .get_light_curve(orbit=orbit, r=0.1, t=t, texp=2/60/24)
    .eval(),
    axis=-1
)

# Create simulated data
yerr = 3e-4
y = light_curve
M = (t > -0.5*0.15) & (t < 0.5*0.15) # transit mask
y += yerr * np.random.randn(len(y)) # add noise
y += 0.01*t # add linear term
y += 1

# add some spot occultations
locs = [-0.005, 0.03]
widths = [0.008, 0.01]
amps = [0.002, 0.001]

for i in range(len(locs)):
    m = (t > (locs[i]-widths[i])) & (t < (locs[i]+widths[i]))
    y[m] += amps[i] * np.exp(-(t[m]-locs[i])**2/widths[i]**2)


with pm.Model() as model:
    mean = pm.Normal("mean", mu=1, sigma=0.002, initval=1)

    # The time of a reference transit for each planet
    t0 = pm.Normal("t0", mu=0, sigma=0.01, initval=0)

    u = xo.quad_limb_dark("u", initval=[0.3, 0.2])

    log_dur = pm.Normal("log_dur", mu=np.log(0.13), sigma=0.1, initval=np.log(0.13))
    dur = pm.Deterministic("dur", pt.exp(log_dur))

    log_ror = pm.Normal("logr", mu=np.log(0.1), sigma=0.1, initval=np.log(0.1))
    ror = pm.Deterministic("r", pt.exp(log_ror))
    
    b = xo.impact_parameter("b", ror=ror, initval=0.3)

    star = xo.LimbDarkLightCurve(u[0], u[1])

    # Set up a Keplerian orbit for the planets
    orbit = xo.orbits.KeplerianOrbit(period=period, t0=t0, b=b, duration=dur, ror=ror)

    # Compute the model light curve using starry
    def _mean_fn(orbit, mean, r, star, t):
        return pt.sum(star.get_light_curve(
        orbit=orbit, r=r, t=t, texp=2/60/24),
        axis=-1
        ) + mean
    mean_fn = partial(_mean_fn, orbit, mean, ror, star)
    pm.Deterministic("light_curves", mean_fn(t))

    # GP parameters for the linear trend and white noise
    log_sigma = pm.Normal("log_sigma", mu=np.log(0.5*yerr), sigma=0.1)
    sigma = pm.Deterministic("sigma", pt.exp(log_sigma))

    log_rho_gp = pm.Normal("log_rho_gp", mu=7, sigma=0.5, initval=7)
    rho_gp = pm.Deterministic("rho_gp", pt.exp(log_rho_gp))

    log_sigma_gp = pm.Normal("log_sigma_gp", mu=-4, sigma=0.5, initval=-4)
    sigma_gp = pm.Deterministic("sigma_gp", pt.exp(log_sigma_gp))

    kernel = terms.Matern32Term(rho=rho_gp, sigma=sigma_gp)

    gp = GaussianProcess(kernel, t=t, diag=yerr**2 + sigma**2,
                         mean=mean_fn, quiet=True)
    pm.Deterministic("gp_preds", gp.predict(y, include_mean=False))
    
    gp.marginal("obs", observed=y)

    ######################################################################
    # problematic part
    ###################################################################### 
    # GP parameters for spot occultations
    log_sigma_spot = pm.Normal("log_sigma_spot", mu=-10, sigma=5, initval=-10)
    sigma_spot = pm.Deterministic("sigma_spot", pt.exp(log_sigma_spot))

    log_rho_spot = pm.Normal("log_rho_spot", mu=np.log(0.02), sigma=0.5)
    rho_spot = pm.Deterministic("rho_spot", pt.exp(log_rho_spot))

    kernel2 = terms.Matern32Term(rho=rho_spot, sigma=sigma_spot)

    def _mean_fn_spot(gp, orbit, star, mean, y, r, t):

        gp_pred = gp.predict(y, t=t, include_mean=False).eval()
        lc_pred = (pt.sum(star.get_light_curve(
                orbit=orbit, r=r, t=t, texp=2/60/24),
                axis=-1
                ) + mean).eval()
        return pt.as_tensor_variable(lc_pred+gp_pred)

    spot_fn = partial(_mean_fn_spot, gp, orbit, star, mean, y, ror)

    gp2 = GaussianProcess(kernel2, t=t[M], diag=yerr**2 + sigma**2,
                         mean=spot_fn, quiet=False)
    pm.Deterministic("gp_preds_spot", gp2.predict(y[M], include_mean=False))
    
    gp2.marginal("obs_spot", observed=y[M])
    ###################################################################### 

    map_soln = pmx.optimize(start=model.initial_point())


# plot fit
spot_model = np.zeros_like(t)
spot_model[M] = map_soln["gp_preds_spot"]
full_mod = map_soln["light_curves"]+map_soln["gp_preds"]+spot_model

plt.figure()
plt.plot(t, y, ".k", ms=4, label="data")
plt.plot(t, full_mod, lw=1, label="full model")
plt.plot(t, map_soln["light_curves"], lw=1, ls='--', label="transit")
plt.plot(t, map_soln["gp_preds"]+map_soln["mean"], lw=1, ls=':', label="trend")
plt.plot(t, spot_model+map_soln["mean"], lw=1, ls='-.', label="spot")
plt.legend()

plt.xlim(t.min(), t.max())
plt.ylabel("relative flux")
plt.xlabel("time [days]")
plt.legend(fontsize=10)
_ = plt.title("map model")
plt.show()
@vedad vedad changed the title A Gaussian process mean prediction as the mean function in another GP GP mean prediction as the mean function in another GP May 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant