Skip to content

Commit

Permalink
Added a timeout option for the PyROA module, which will cut off the M…
Browse files Browse the repository at this point in the history
…CMC and continue to the next line if the specified runtime is reached. Also added an option for the solver to use for the DRW fitting when determining the initial conditions for the MCMC.
  • Loading branch information
Zstone19 committed Dec 29, 2024
1 parent d3ec8b2 commit 9c19e22
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 42 deletions.
8 changes: 8 additions & 0 deletions docs/pl_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ Module: DRW Rejection (``run_drw_rej``)
- If :python:`True`, the resulting DRW parameters :math:`(\sigma_{\rm DRW}, $\tau_{\rm DRW})`, will used as input to the JAVELIN module of pyPetal. The DRW parameters in each fit will be fixed to the results obtained in this module.
- :python:`bool`
- :python:`False`
* - ``solver``
- How the module will obtain the initial conditions used in the DRW fitting process. ``minimize`` uses ``scipy.optimize.minimize``, ``diff_evo`` uses ``scipy.optimize.differential_evolution``, and ``none`` randomly selects using the uniform priors.
- :python:`str`
- :python:`"none"`



Expand Down Expand Up @@ -286,6 +290,10 @@ Module: pyROA (``run_pyroa``)
- A function used to get the priors for PyROA. Must have the same arguments as ``pypetal.pyroa.utils.get_priors`` except for the ``delimiter`` argument. If :python:`None`, will use the default priors.
- :python:`function`, :python:`None`
- :python:`None`
* - ``timeout``
- The maximum time to run the PyROA analysis in seconds.
- :python:`int`
- :python:`10800` (3 hours)



Expand Down
2 changes: 2 additions & 0 deletions docs/pl_output.rst
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ The ``MyFit`` object will have the following attributes:
- The error in the driving light curve model fit.
- list of :python:`float`

.. note:: If the PyROA module times out, the output will be :python:`None`.


Module: MICA2
--------------
Expand Down
7 changes: 4 additions & 3 deletions src/pypetal/drw_rej/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def drw_rej_tot(cont_fname, line_fnames, line_names, output_dir,
#Read kwargs

jitter, nsig, nwalker, nburn, nchain, clip, \
reject_data, use_for_javelin = defaults.set_drw_rej(kwargs,
reject_data, use_for_javelin, solver = defaults.set_drw_rej(kwargs,
np.hstack([ [cont_fname], line_fnames ])
)

Expand Down Expand Up @@ -83,7 +83,8 @@ def drw_rej_tot(cont_fname, line_fnames, line_names, output_dir,
'nchain': nchain,
'clip': clip_str,
'reject_data': reject_data,
'use_for_javelin': use_for_javelin
'use_for_javelin': use_for_javelin,
'solver': solver
}

print_subheader('Performing DRW Rejection', 35, print_dict)
Expand All @@ -100,7 +101,7 @@ def drw_rej_tot(cont_fname, line_fnames, line_names, output_dir,


drw_rej_func = partial( drw_flag, nwalkers=nwalker, nburn=nburn, nsamp=nchain,
nsig=nsig, jitter=jitter, plot=plot)
nsig=nsig, jitter=jitter, solver=solver, plot=plot)



Expand Down
34 changes: 30 additions & 4 deletions src/pypetal/drw_rej/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
###################### ASSIST FUNCTIONS ######################
##############################################################

def celerite_fit(x, y, yerr, kernel, nwalkers, nburn, nsamp,
def celerite_fit(x, y, yerr, kernel, nwalkers, nburn, nsamp, mbh_est=None,
solver='minimize', suppress_warnings=True, jitter=True):

"""Fit time-series data to a given Gaussian process kernel using celerite.
Expand Down Expand Up @@ -139,6 +139,31 @@ def log_probability(params):
soln = differential_evolution(neg_ll, bounds=bounds, args=(y.value, gp))
initial = np.array(soln.x)

#Use mass as an initial guess
elif solver == 'mbh':
assert (mbh_est is not None)
tau_est = 107* (mbh_est/1e8)**(.38) #d

c_est = 1/tau_est
loga_est = np.random.uniform(bounds[0][0], bounds[0][1])

initial = [a_est, np.log(c_est)]
if jitter:
logn_est = np.random.uniform(bounds[2][0], bounds[2][1])
initial.append(logn_est)

#Use the priors
elif solver == 'none':
loga_est = np.random.uniform(bounds[0][0], bounds[0][1])
logc_est = np.random.uniform(bounds[1][0], bounds[1][1])
initial = [loga_est, logc_est]

if jitter:
logn_est = np.random.uniform(bounds[2][0], bounds[2][1])
initial.append(logn_est)


initial = np.array(initial)

#Set parameter vector to fit params
gp.set_parameter_vector(initial)
Expand Down Expand Up @@ -301,8 +326,9 @@ def MCMC_fit(x, y, yerr, nwalkers=32, nburn=300, nsamp=1000,
bounds = dict(log_sigma=(smin, smax))
kernel += terms.JitterTerm(log_sigma=sval, bounds=bounds)

mbh_est = None
samples, gp, statuses = celerite_fit(x, y, yerr, kernel, nwalkers,
nburn, nsamp, solver,
nburn, nsamp, mbh_est, solver,
suppress_warnings, jitter)

return samples, gp, statuses
Expand Down Expand Up @@ -690,7 +716,7 @@ def psd_data(x, y, yerr, samples, gp, nsamp=20):
def drw_flag(times, data, error,
target=None, fname=None,
nwalkers=32, nburn=300, nsamp=1000,
nsig=1, jitter=True, clip=True,
nsig=1, jitter=True, clip=True, solver='none',
plot=True):


Expand Down Expand Up @@ -791,7 +817,7 @@ def drw_flag(times, data, error,
#Fit to the DRW model
samples, gp, statuses = MCMC_fit(times, data, error,
nwalkers=nwalkers, nburn=nburn, nsamp=nsamp,
jitter=jitter, clip=clip)
jitter=jitter, solver=solver, clip=clip)


fig, ax = plot_outcome(times, data, error, samples, gp, data_unit,
Expand Down
6 changes: 3 additions & 3 deletions src/pypetal/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ def run_pipeline(output_dir, arg2,


#Get "reject_data"
_, _, _, _, _, _, reject_data, _ = defaults.set_drw_rej(drw_rej_params, fnames)
_, _, _, _, _, _, reject_data, _, _ = defaults.set_drw_rej(drw_rej_params, fnames)

#Get "together_pyroa"
_, _, _, _, _, _, _, _, together_pyroa, _, _ = defaults.set_pyroa(pyroa_params, len(fnames))
_, _, _, _, _, _, _, _, together_pyroa, _, _, _ = defaults.set_pyroa(pyroa_params, len(fnames))

#Get "together_mica2"
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, together_mica2, no_order_mica2 = defaults.set_mica2(mica2_params)
Expand Down Expand Up @@ -462,7 +462,7 @@ def run_weighting(output_dir, line_names,
together_jav = False

#Get "together" for pyroa
_, _, _, _, _, _, _, _, together_pyroa, _, _ = defaults.set_pyroa( pyroa_params, len(line_names) )
_, _, _, _, _, _, _, _, together_pyroa, _, _, _ = defaults.set_pyroa( pyroa_params, len(line_names) )

#Get "together" for mica2
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, together_mica2, _ = defaults.set_mica2(mica2_params)
Expand Down
15 changes: 12 additions & 3 deletions src/pypetal/pyroa/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def pyroa_tot(cont_fname, line_fnames, line_names, output_dir,

nchain, nburn, init_tau, subtract_mean, div_mean, \
add_var, delay_dist, psi_types, together, \
objname, prior_func = defaults.set_pyroa( kwargs, len(line_names) )
objname, prior_func, timeout = defaults.set_pyroa( kwargs, len(line_names) )

if verbose:

Expand All @@ -38,7 +38,8 @@ def pyroa_tot(cont_fname, line_fnames, line_names, output_dir,
'delay_dist': delay_dist,
'psi_types': psi_types,
'together': together,
'objname': objname
'objname': objname,
'timeout': timeout
}
print_subheader('Running PyROA', 35, print_dict)

Expand All @@ -58,11 +59,15 @@ def pyroa_tot(cont_fname, line_fnames, line_names, output_dir,
together=together, subtract_mean=subtract_mean,
div_mean=div_mean, add_var=add_var,
delay_dist=delay_dist, psi_types=psi_types,
objname=objname, prior_func=prior_func, verbose=verbose)
objname=objname, prior_func=prior_func, timeout=timeout,
verbose=verbose)

lc_fnames = [ lc_dir + objname + '_' + x + '.dat' for x in line_names ]

if together:
if res is None:
return res

pyroa_trace_plot( res.samples, line_names, add_var=add_var,
delay_dist=delay_dist, nburn=nburn,
fname = output_dir + 'pyroa/trace_plot.pdf',
Expand Down Expand Up @@ -97,6 +102,10 @@ def pyroa_tot(cont_fname, line_fnames, line_names, output_dir,
else:

for i, res_i in enumerate(res):
if res_i is None:
continue


names_i = [ line_names[0], line_names[i+1] ]
fnames_i = [ lc_fnames[0], lc_fnames[i+1] ]

Expand Down
87 changes: 62 additions & 25 deletions src/pypetal/pyroa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import numpy as np
import PyROA
from astropy.table import Table
from pypetal.utils.petalio import print_error

import signal

##############################################################
####################### SILENCE OUTPUT #######################
Expand Down Expand Up @@ -451,7 +454,8 @@ def run_pyroa(fnames, lc_dir, line_dir, line_names,
init_tau=None, init_delta=10, sig_level=100,
together=True, subtract_mean=True, div_mean=False,
add_var=False, delay_dist=False, psi_types='Gaussian',
objname=None, prior_func=None, verbose=True):
objname=None, prior_func=None, timeout=60*60*3,
verbose=True):


"""Run PyROA for a number of input light curves.
Expand Down Expand Up @@ -530,6 +534,9 @@ def run_pyroa(fnames, lc_dir, line_dir, line_names,
"""

def handler(signum, frame):
raise Exception("Timed out")


if objname is None:
objname = 'pyroa'
Expand Down Expand Up @@ -600,28 +607,44 @@ def run_pyroa(fnames, lc_dir, line_dir, line_names,
kwargs = {'add_var':add_var[i], 'init_tau':[init_tau[i]], 'init_delta':init_delta, 'sig_level':sig_level,
'delay_dist':delay_dist[i], 'psi_types':[psi_types[i]], 'Nsamples':nchain, 'Nburnin':nburn}

if verbose:
proc = mp.get_context('fork').Process(target=PyROA.Fit, args=args, kwargs=kwargs)

proc.start()
while proc.is_alive():
proc.is_alive()
proc.terminate()

else:
with suppress_stdout_stderr():
try:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)

if verbose:
proc = mp.get_context('fork').Process(target=PyROA.Fit, args=args, kwargs=kwargs)

proc.start()
while proc.is_alive():
proc.is_alive()
proc.terminate()

else:
with suppress_stdout_stderr():
proc = mp.get_context('fork').Process(target=PyROA.Fit, args=args, kwargs=kwargs)

proc.start()
while proc.is_alive():
proc.is_alive()
proc.terminate()


move_output_files(cwd, line_dir[i])
fit = MyFit(line_dir[i])
move_output_files(cwd, line_dir[i])
fit = MyFit(line_dir[i])

fit_arr.append(fit)
fit_arr.append(fit)


signal.alarm(0)

except Exception as e:
proc.terminate()

print_error('PyROA timed out for line: {}'.format(line_names[i+1]))
print_error('Skipping and continuing to next line')

fit_arr.append(None)
continue

return fit_arr

Expand All @@ -634,24 +657,38 @@ def run_pyroa(fnames, lc_dir, line_dir, line_names,
kwargs = {'add_var':add_var, 'init_tau':init_tau, 'init_delta':init_delta, 'sig_level':sig_level,
'delay_dist':delay_dist, 'psi_types':psi_types, 'Nsamples':nchain, 'Nburnin':nburn}

if verbose:
proc = mp.get_context('fork').Process(target=PyROA.Fit, args=args, kwargs=kwargs)

proc.start()
while proc.is_alive():
proc.is_alive()
proc.terminate()
try:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)

else:
with suppress_stdout_stderr():
if verbose:
proc = mp.get_context('fork').Process(target=PyROA.Fit, args=args, kwargs=kwargs)

proc.start()
while proc.is_alive():
proc.is_alive()
proc.terminate()

move_output_files(cwd, line_dir)
fit = MyFit(line_dir)
else:
with suppress_stdout_stderr():
proc = mp.get_context('fork').Process(target=PyROA.Fit, args=args, kwargs=kwargs)

proc.start()
while proc.is_alive():
proc.is_alive()
proc.terminate()

move_output_files(cwd, line_dir)
fit = MyFit(line_dir)

signal.alarm(0)

except Exception as e:
proc.terminate()

print_error('PyROA timed out'.format(line_names[i+1]))

fit = None


return fit
12 changes: 8 additions & 4 deletions src/pypetal/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def set_drw_rej(input_args, fnames):
'nchain': 1000,
'clip': np.full( len(fnames), True),
'reject_data': np.hstack([ [True], np.full( len(fnames)-1, False) ]),
'use_for_javelin': False
'use_for_javelin': False,
'solver': 'none'
}

params = { **default_kwargs, **input_args }
Expand All @@ -133,6 +134,7 @@ def set_drw_rej(input_args, fnames):
clip = params['clip']
reject_data = params['reject_data']
use_for_javelin = params['use_for_javelin']
solver = params['solver']

if isinstance(clip, bool):
clip = np.full( len(fnames), clip)
Expand All @@ -145,7 +147,7 @@ def set_drw_rej(input_args, fnames):


return jitter, nsig, nwalker, nburn, nchain, clip, \
reject_data, use_for_javelin
reject_data, use_for_javelin, solver


def set_detrend(input_args):
Expand Down Expand Up @@ -233,7 +235,8 @@ def set_pyroa(input_args, nlc):
'psi_types': 'Gaussian',
'together': True,
'objname': None,
'prior_func': None
'prior_func': None,
'timeout': 60*60*3
}

params = { **default_kwargs, **input_args }
Expand All @@ -249,6 +252,7 @@ def set_pyroa(input_args, nlc):
together = params['together']
objname = params['objname']
prior_func = params['prior_func']
timeout = params['timeout']

if init_tau is None:
init_tau = [10.] * (nlc-1)
Expand All @@ -267,7 +271,7 @@ def set_pyroa(input_args, nlc):


return nchain, nburn, init_tau, subtract_mean, div_mean, \
add_var, delay_dist, psi_types, together, objname, prior_func
add_var, delay_dist, psi_types, together, objname, prior_func, timeout


def set_mica2(input_args):
Expand Down

0 comments on commit 9c19e22

Please sign in to comment.