Skip to content

Commit

Permalink
Optimize memory usage of pycbc_fit_sngls_split_binned (#4543)
Browse files Browse the repository at this point in the history
* Optimize memory usage of pycbc_fit_sngls_split_binned

* Need add option here

* Comments on PR

* Example doesn't use code being edited!

* Update bin/all_sky_search/pycbc_fit_sngls_split_binned

Co-authored-by: Thomas Dent <[email protected]>

---------

Co-authored-by: Thomas Dent <[email protected]>
  • Loading branch information
spxiwh and tdent authored Nov 1, 2023
1 parent 7d24aa0 commit 7fb6d91
Showing 1 changed file with 135 additions and 43 deletions.
178 changes: 135 additions & 43 deletions bin/all_sky_search/pycbc_fit_sngls_split_binned
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ from matplotlib import pyplot as plt
import numpy as np

from pycbc import events, bin_utils, results
from pycbc.io import HFile, SingleDetTriggers
from pycbc.events import triggers as trigs
from pycbc.events import trigger_fits as trstats
from pycbc.events import stat as pystat
Expand Down Expand Up @@ -82,6 +83,10 @@ parser.add_argument("--gating-veto-windows", nargs='+',
parser.add_argument("--stat-fit-threshold", type=float, required=True,
help="Only fit triggers with statistic value above this "
"threshold. Required")
parser.add_argument("--plot-lower-stat-limit", type=float, required=True,
help="Plot triggers down to this value. Setting this too"
"low will incur huge memory usage in a full search."
"To avoid this, choose 5.5 or larger.")
parser.add_argument("--fit-function",
choices=["exponential", "rayleigh", "power"],
help="Functional form for the maximum likelihood fit")
Expand All @@ -96,6 +101,8 @@ pystat.insert_statistic_option_group(parser,
default_ranking_statistic='single_ranking_only')
args = parser.parse_args()

assert(args.stat_fit_threshold >= args.plot_lower_stat_limit)

pycbc.init_logging(args.verbose)

logging.info('Opening trigger file: %s' % args.trigger_file)
Expand All @@ -122,8 +129,12 @@ for ex_p in extparams:
logging.info('Reading duration from trigger file')
# List comprehension loops over templates; if a template has no triggers, accessing
# the 0th entry of its region reference will return zero due to a quirk of h5py.
params[ex_p] = np.array([trigf[args.ifo + '/template_duration'][ref][0]
for ref in trigf[args.ifo + '/template_duration_template'][:]])
params[ex_p] = np.array(
[
trigf[args.ifo + '/template_duration'][ref][0]
for ref in trigf[args.ifo + '/template_duration_template'][:]
]
)
else:
logging.info("Calculating " + ex_p + " from template parameters")
params[ex_p] = trigs.get_param(ex_p, args, params['mass1'],
Expand Down Expand Up @@ -198,10 +209,65 @@ for i, lower_2, upper_2 in zip(range(args.split_two_nbins),

logging.info('Getting template boundaries from trigger file')
boundaries = trigf[args.ifo + '/template_boundaries'][:]

trigf.close()

# Setup a data mask to remove any triggers with SNR below threshold
# This works as a pre-filter as SNR is always greater than or equal
# to sngl_ranking, except in the psdvar case, where it could increase.
with HFile(args.trigger_file, 'r') as trig_file:
n_triggers_orig = trig_file[f'{args.ifo}/snr'].size
logging.info("Trigger file has %d triggers", n_triggers_orig)
logging.info('Generating trigger mask')
if f'{args.ifo}/psd_var_val' in trig_file:
idx, _, _ = trig_file.select(
lambda snr, psdvar: snr / psdvar ** 0.5 >= args.plot_lower_stat_limit,
f'{args.ifo}/snr',
f'{args.ifo}/psd_var_val',
return_indices=True
)
else:
# psd_var_val may not have been calculated
idx, _ = trig_file.select(
lambda snr: snr >= args.plot_lower_stat_limit,
f'{args.ifo}/snr',
return_indices=True
)
data_mask = np.zeros(n_triggers_orig, dtype=bool)
data_mask[idx] = True


logging.info('Calculating single stat values from trigger file')
trigs = SingleDetTriggers(
args.trigger_file,
None,
None,
None,
None,
args.ifo,
premask=data_mask
)
# This is the direct pointer to the HDF file, used later on
trigf = trigs.trigs_f

stat = trigs.get_ranking(args.sngl_ranking)
time = trigs.end_time


logging.info('Processing template boundaries')
max_boundary_id = np.argmax(boundaries)
sorted_boundary_list = np.sort(boundaries)

logging.info('Processing template boundaries')
# In the two blocks of code that follows we are trying to figure out the index
# ranges in the masked trigger lists corresponding to the "boundaries".
# We will do this by walking over the boundaries in the order they're
# stored in the file, adding in the number of triggers not removed by the
# mask every time.

# First step is to loop over the "boundaries" which gives the start position
# of each block of triggers (corresponding to one template) in the full trigger
# merge file.
# Here we identify the end_idx for the triggers corresponding to each template.
where_idx_end = np.zeros_like(boundaries)
for idx, idx_start in enumerate(boundaries):
if idx == max_boundary_id:
Expand All @@ -210,22 +276,41 @@ for idx, idx_start in enumerate(boundaries):
where_idx_end[idx] = sorted_boundary_list[
np.argmax(sorted_boundary_list == idx_start) + 1]

logging.info('Calculating single stat values from trigger file')
rank_method = pystat.get_statistic_from_opts(args, [args.ifo])
stat = rank_method.get_sngl_ranking(trigf[args.ifo])
# Next we need to map these start/stop indices in the full file, to the start
# stop indices in the masked list of triggers. We do this by figuring out
# how many triggers are in the masked list for each template in the order they
# are stored in the trigger merge file, and keep a running sum.
curr_count = 0
mask_start_idx = np.zeros_like(boundaries)
mask_end_idx = np.zeros_like(boundaries)
for idx_start in sorted_boundary_list:
boundary_idx = np.argmax(boundaries == idx_start)
idx_end = where_idx_end[boundary_idx]
mask_start_idx[boundary_idx] = curr_count
curr_count += np.sum(trigs.mask[idx_start:idx_end])
mask_end_idx[boundary_idx] = curr_count


if args.veto_file:
logging.info('Applying DQ vetoes')
time = trigf[args.ifo + '/end_time'][:]
remove, junk = events.veto.indices_within_segments(time, [args.veto_file],
ifo=args.ifo, segment_name=args.veto_segment_name)
# Set stat to zero for triggers being vetoed: given that the fit threshold is
# >0 these will not be fitted or plotted. Avoids complications from changing
# the number of triggers, ie changes of template boundary.
stat[remove] = np.zeros_like(remove)
time[remove] = np.zeros_like(remove)
logging.info('{} out of {} trigs removed after vetoing with {} from {}'.format(
remove.size, stat.size, args.veto_segment_name, args.veto_file))
remove, junk = events.veto.indices_within_segments(
time,
[args.veto_file],
ifo=args.ifo,
segment_name=args.veto_segment_name
)
# Set stat to zero for triggers being vetoed: given that the fit threshold
# is >0 these will not be fitted or plotted. Avoids complications from
# changing the number of triggers, ie changes of template boundary.
stat[remove] = 0.
time[remove] = 0.
logging.info(
'%d out of %d trigs removed after vetoing with %s from %s',
remove.size,
stat.size,
args.veto_segment_name,
args.veto_file
)

if args.gating_veto_windows:
logging.info('Applying veto to triggers near gates')
Expand All @@ -236,19 +321,24 @@ if args.gating_veto_windows:
raise ValueError("Gating veto window values must be negative before "
"gates and positive after gates.")
if not (gveto_before == 0 and gveto_after == 0):
time = trigf[args.ifo + '/end_time'][:]
autogate_times = np.unique(trigf[args.ifo + '/gating/auto/time'][:])
if args.ifo + '/gating/file' in trigf:
detgate_times = trigf[args.ifo + '/gating/file/time'][:]
else:
detgate_times = []
gate_times = np.concatenate((autogate_times, detgate_times))
gveto_remove = events.veto.indices_within_times(time, gate_times + gveto_before,
gate_times + gveto_after)
stat[gveto_remove] = np.zeros_like(gveto_remove)
time[gveto_remove] = np.zeros_like(gveto_remove)
logging.info('{} out of {} trigs removed after vetoing triggers near gates'.format(
gveto_remove.size, stat.size))
gveto_remove = events.veto.indices_within_times(
time,
gate_times + gveto_before,
gate_times + gveto_after
)
stat[gveto_remove] = 0.
time[gveto_remove] = 0.
logging.info(
'%d out of %d trigs removed after vetoing triggers near gates',
gveto_remove.size,
stat.size
)

for x in range(args.split_one_nbins):
if not args.prune_number:
Expand All @@ -266,42 +356,45 @@ for x in range(args.split_one_nbins):
time_inbin = []
# getting triggers that are in these templates
for idx in id_in_both:
where_idx_start = boundaries[idx]
vals_inbin += list(stat[where_idx_start:where_idx_end[idx]])
time_inbin += list(time[where_idx_start:where_idx_end[idx]])
vals_inbin += list(stat[mask_start_idx[idx]:mask_end_idx[idx]])
time_inbin += list(time[mask_start_idx[idx]:mask_end_idx[idx]])

vals_inbin = np.array(vals_inbin)
time_inbin = np.array(time_inbin)

count_pruned = 0
logging.info('Pruning in split {}-{} {}-{}'.format(
args.split_param_one, x, args.split_param_two, y))
logging.info('Currently have %d triggers', len(vals_inbin))
while count_pruned < args.prune_number:
# Getting loudest statistic value in split
max_val_arg = vals_inbin.argmax()
max_statval = vals_inbin[max_val_arg]

remove = np.nonzero(abs(time_inbin[max_val_arg] - time)
< args.prune_window)[0]
remove = np.nonzero(
abs(time_inbin[max_val_arg] - time) < args.prune_window
)[0]
# Remove from inbin triggers as well in case there
# are more pruning iterations
remove_inbin = np.nonzero(abs(time_inbin[max_val_arg] - time_inbin)
< args.prune_window)[0]
logging.info('Prune {}: removing {} triggers around time {:.2f},'
' {} in this split'.format(count_pruned, remove.size,
time[max_val_arg],
remove_inbin.size))
remove_inbin = np.nonzero(
abs(time_inbin[max_val_arg] - time_inbin) < args.prune_window
)[0]
logging.info(
'Prune %d: removing %d triggers around %.2f, %d in this split',
count_pruned,
remove.size,
time[max_val_arg],
remove_inbin.size
)
# Set pruned triggers' stat values to zero, as above for vetoes
vals_inbin[remove_inbin] = np.zeros_like(remove_inbin)
time_inbin[remove_inbin] = np.zeros_like(remove_inbin)
stat[remove] = np.zeros_like(remove)
time[remove] = np.zeros_like(remove)
vals_inbin[remove_inbin] = 0.
time_inbin[remove_inbin] = 0.
stat[remove] = 0.
time[remove] = 0.
count_pruned += 1

trigf.close()

logging.info('Setting up plotting and fitting limit values')
minplot = max(stat[np.nonzero(stat)].min(), args.stat_fit_threshold - 1)
minplot = max(stat[np.nonzero(stat)].min(), args.plot_lower_stat_limit)
min_fit = max(minplot, args.stat_fit_threshold)
max_fit = 1.05 * stat.max()
if args.plot_max_x:
Expand Down Expand Up @@ -355,8 +448,7 @@ for x in range(args.split_one_nbins):
if len(indices_all_conditions) == 0: continue
vals_inbin = []
for idx in indices_all_conditions:
where_idx_start = boundaries[idx]
vals_inbin += list(stat[where_idx_start:where_idx_end[idx]])
vals_inbin += list(stat[mask_start_idx[idx]:mask_end_idx[idx]])

vals_inbin = np.array(vals_inbin)
vals_above_thresh = vals_inbin[vals_inbin >= args.stat_fit_threshold]
Expand Down

0 comments on commit 7fb6d91

Please sign in to comment.