Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing a simple exoplanet transit model mean function #94

Closed
vatsalpanwar opened this issue May 18, 2022 · 3 comments
Closed

Implementing a simple exoplanet transit model mean function #94

vatsalpanwar opened this issue May 18, 2022 · 3 comments

Comments

@vatsalpanwar
Copy link

I am following the mean function tutorial from the tinygp docs to implement an exoplanet transit model using the batman package, but I am running across an issue with the vmapping of the mean model. I am new to jax so it is possible that there is an easy solution to this. It is not immediately clear to me how I can go about implementing a mean function which requires the X_grid to be a numpy array when tinygp requires the mean function to be defined to operate on a single input coordinate. I tried a workaround by taking the val attribute of the X_grid in my transit model calculator, but I still don't get the desired shape of the output after vmapping which I think is because of my function not operating on a single input coordinate but on the whole X_grid array. Do you have any advice on how I can adapt my mean function to operate on a single input coordinate as required by jax? I am copying the code snippet below to illustrate this issue:

import numpy as np 
import matplotlib.pyplot as plt
import batman 

import jax
import jax.numpy as jnp

import tinygp
import jaxopt

from functools import partial
from tinygp import kernels, GaussianProcess

jax.config.update("jax_enable_x64", True)

### Keys of all transit parameters
all_transit_params = ['per', 'ecc', 'w', 'limb_dark', 't0', 'a', 'inc', 'u','rp']

### Dictionary of fixed transit parameters 
fix_transit_params = {
               "per":3.52474859,
             "ecc":0.000,
             "w":90.,
             "limb_dark":'quadratic',
             "t0":0.,
         "a":8.807,
        "inc": 86.744, 
    "u": [0.41739267, 0.15299455],
           }

### Dictionary of free transit parameters 
free_transit_params = {
    "rp": 0.12,
    "mean_level": 0.
}

### GP hyperparameters
params_gp = {
    "log_gp_amp":np.log(0.1),
    "log_gp_scale":np.log(3.0),
    "log_gp_diag":np.log(0.001),

}

### Merge the free transit params and the GP params 
free_params_all = {**free_transit_params, **params_gp}

### Function to call batman and calculate exoplanet transit model 
def tmodel(params,time):
    params_bat = batman.TransitParams()
    params_bat.__dict__.update(params)
    flux_bat = batman.TransitModel(params_bat, time.val ).light_curve(params_bat)
    return flux_bat
    

### Function to calculate the GP mean function 
def mean_function(params,X):
    """
    params: Dictionary of all free parameters.
    X: (X_grid) here time
    """
    time = X
    mean_level = params["mean_level"]
    
    params_in_func = {}
    
    for tparam in all_transit_params:
        if tparam in fix_transit_params.keys():
            params_in_func[tparam] = fix_transit_params[tparam]
        else:
            params_in_func[tparam] = params[tparam]
    
    lc = tmodel(params_in_func,time) + mean_level
        
    return lc

transit_duration = 3./24. ### in days 
X_grid = np.linspace(-transit_duration*2., transit_duration*2., 100)
print("X_grid shape", X_grid.shape)

model = jax.vmap(partial(mean_function, free_params_all))(X_grid)
print("model shape", model.shape) 

The output of above is:

X_grid shape (100,)
model shape (100, 100)
@dfm
Copy link
Owner

dfm commented May 18, 2022

Unfortunately batman is never going to be compatible with JAX, so a function that calls out to batman can't be vmappable. Instead you'll need to use a JAX-compatible library (I like exoplanet_jax, or exoplanet_core.jax, both of which are pretty experimental and undocumented :/) to get all the benefits of tinygp.

I'd be happy to move this conversation to email (I know I've been slow to respond and that'll continue for the next 2 weeks, but eventually!) because this is pretty domain specific and not so suited to this issue tracker.

@dfm dfm closed this as completed May 18, 2022
@vatsalpanwar
Copy link
Author

Thanks! I did suspect that merging a non-JAX library here would probably not be best suited (or even work!). And no worries, happy to continue the discussion via email, please take your time.

@dfm
Copy link
Owner

dfm commented May 18, 2022

I've opened a more general issue over in #96 because it would be good to demonstrate how to do something like this in general in the docs!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants