You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am following the mean function tutorial from the tinygp docs to implement an exoplanet transit model using the batman package, but I am running across an issue with the vmapping of the mean model. I am new to jax so it is possible that there is an easy solution to this. It is not immediately clear to me how I can go about implementing a mean function which requires the X_grid to be a numpy array when tinygp requires the mean function to be defined to operate on a single input coordinate. I tried a workaround by taking the val attribute of the X_grid in my transit model calculator, but I still don't get the desired shape of the output after vmapping which I think is because of my function not operating on a single input coordinate but on the whole X_grid array. Do you have any advice on how I can adapt my mean function to operate on a single input coordinate as required by jax? I am copying the code snippet below to illustrate this issue:
importnumpyasnpimportmatplotlib.pyplotaspltimportbatmanimportjaximportjax.numpyasjnpimporttinygpimportjaxoptfromfunctoolsimportpartialfromtinygpimportkernels, GaussianProcessjax.config.update("jax_enable_x64", True)
### Keys of all transit parametersall_transit_params= ['per', 'ecc', 'w', 'limb_dark', 't0', 'a', 'inc', 'u','rp']
### Dictionary of fixed transit parameters fix_transit_params= {
"per":3.52474859,
"ecc":0.000,
"w":90.,
"limb_dark":'quadratic',
"t0":0.,
"a":8.807,
"inc": 86.744,
"u": [0.41739267, 0.15299455],
}
### Dictionary of free transit parameters free_transit_params= {
"rp": 0.12,
"mean_level": 0.
}
### GP hyperparametersparams_gp= {
"log_gp_amp":np.log(0.1),
"log_gp_scale":np.log(3.0),
"log_gp_diag":np.log(0.001),
}
### Merge the free transit params and the GP params free_params_all= {**free_transit_params, **params_gp}
### Function to call batman and calculate exoplanet transit model deftmodel(params,time):
params_bat=batman.TransitParams()
params_bat.__dict__.update(params)
flux_bat=batman.TransitModel(params_bat, time.val ).light_curve(params_bat)
returnflux_bat### Function to calculate the GP mean function defmean_function(params,X):
""" params: Dictionary of all free parameters. X: (X_grid) here time """time=Xmean_level=params["mean_level"]
params_in_func= {}
fortparaminall_transit_params:
iftparaminfix_transit_params.keys():
params_in_func[tparam] =fix_transit_params[tparam]
else:
params_in_func[tparam] =params[tparam]
lc=tmodel(params_in_func,time) +mean_levelreturnlctransit_duration=3./24.### in days X_grid=np.linspace(-transit_duration*2., transit_duration*2., 100)
print("X_grid shape", X_grid.shape)
model=jax.vmap(partial(mean_function, free_params_all))(X_grid)
print("model shape", model.shape)
The output of above is:
X_grid shape (100,)
model shape (100, 100)
The text was updated successfully, but these errors were encountered:
Unfortunately batman is never going to be compatible with JAX, so a function that calls out to batman can't be vmappable. Instead you'll need to use a JAX-compatible library (I like exoplanet_jax, or exoplanet_core.jax, both of which are pretty experimental and undocumented :/) to get all the benefits of tinygp.
I'd be happy to move this conversation to email (I know I've been slow to respond and that'll continue for the next 2 weeks, but eventually!) because this is pretty domain specific and not so suited to this issue tracker.
Thanks! I did suspect that merging a non-JAX library here would probably not be best suited (or even work!). And no worries, happy to continue the discussion via email, please take your time.
I am following the mean function tutorial from the tinygp docs to implement an exoplanet transit model using the batman package, but I am running across an issue with the vmapping of the mean model. I am new to jax so it is possible that there is an easy solution to this. It is not immediately clear to me how I can go about implementing a mean function which requires the X_grid to be a numpy array when tinygp requires the mean function to be defined to operate on a single input coordinate. I tried a workaround by taking the
val
attribute of the X_grid in my transit model calculator, but I still don't get the desired shape of the output after vmapping which I think is because of my function not operating on a single input coordinate but on the whole X_grid array. Do you have any advice on how I can adapt my mean function to operate on a single input coordinate as required by jax? I am copying the code snippet below to illustrate this issue:The output of above is:
The text was updated successfully, but these errors were encountered: