Skip to content

Commit

Permalink
Testing chisq optimization - DO NOT MERGE
Browse files Browse the repository at this point in the history
  • Loading branch information
spxiwh committed Nov 28, 2024
1 parent 7850921 commit 62a59e5
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 10 deletions.
15 changes: 15 additions & 0 deletions pycbc/types/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,21 @@ def check_same_len_precision(a, b):
a.precision, b.precision)
raise TypeError(msg)


def get_array_module(arr):
if isinstance(arr, _numpy.ndarray):
return _numpy

try:
import cupy
if isinstance(arr, cupy.ndarray):
return cupy
except ImportError:
pass

raise ValueError(f"Cannot identify type of {type(arr)} {arr}")


class Array(object):
"""Array used to do numeric calculations on a various compute
devices. It is a convience wrapper around numpy, and
Expand Down
47 changes: 42 additions & 5 deletions pycbc/vetoes/chisq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@
#
import numpy, logging, math, pycbc.fft

from pycbc.types import zeros, real_same_precision_as, TimeSeries, complex_same_precision_as
from pycbc.types import zeros, real_same_precision_as, TimeSeries, complex_same_precision_as, get_array_module
from pycbc.filter import sigmasq_series, make_frequency_series, matched_filter_core, get_cutoff_indices
from pycbc.scheme import schemed
import pycbc.pnutils

BACKEND_PREFIX="pycbc.vetoes.chisq_"

_CACHED_BINS = None
_CACHED_BIN_NUM = 128
_CACHED_EDGE_VEC = None
_CACHED_EDGE_ARANGE = None

def power_chisq_bins_from_sigmasq_series(sigmasq_series, num_bins, kmin, kmax):
"""Returns bins of equal power for use with the chisq functions
Expand All @@ -52,11 +57,43 @@ def power_chisq_bins_from_sigmasq_series(sigmasq_series, num_bins, kmin, kmax):
bins: List of ints
A list of the edges of the chisq bins is returned.
"""
global _CACHED_BINS
global _CACHED_BIN_NUM
global _CACHED_EDGE_VEC
global _CACHED_EDGE_ARANGE

sigmasq = sigmasq_series[kmax - 1]
edge_vec = numpy.arange(0, num_bins) * sigmasq / num_bins
bins = numpy.searchsorted(sigmasq_series[kmin:kmax], edge_vec, side='right')
bins += kmin
return numpy.append(bins, kmax)
xp = get_array_module(sigmasq_series.data)

if _CACHED_BINS is None:
_CACHED_BINS = xp.zeros(_CACHED_BIN_NUM, dtype=xp.int64)
_CACHED_EDGE_ARANGE = xp.arange(0, _CACHED_BIN_NUM, dtype=xp.int64)
_CACHED_EDGE_VEC = xp.zeros(_CACHED_BIN_NUM, dtype=xp.float64)
while (num_bins + 1) > _CACHED_BIN_NUM:
_CACHED_BIN_NUM *= 2
_CACHED_BINS = xp.zeros(_CACHED_BIN_NUM, dtype=xp.int64)
_CACHED_EDGE_ARANGE = xp.arange(0, _CACHED_BIN_NUM, dtype=xp.int64)
_CACHED_EDGE_VEC = xp.zeros(_CACHED_BIN_NUM, dtype=xp.float64)
_CACHED_EDGE_VEC[:num_bins] = _CACHED_EDGE_ARANGE[:num_bins] * (sigmasq / num_bins)
# Hack because I want to directly assign output array
from cupy._sorting.search import _searchsorted_kernel
_searchsorted_kernel(
_CACHED_EDGE_VEC[:num_bins],
sigmasq_series.data[kmin:kmax],
kmax - kmin,
True,
True,
_CACHED_BINS[:num_bins]
)
#bins[:] = xp.searchsorted(
# sigmasq_series.data[kmin:kmax],
# edge_vec,
# side='right'
#)
#bins += kmin
#return xp.append(bins, kmax)
_CACHED_BINS[num_bins] = kmax
return _CACHED_BINS[:num_bins+1]


def power_chisq_bins(htilde, num_bins, psd, low_frequency_cutoff=None,
Expand Down
36 changes: 35 additions & 1 deletion pycbc/vetoes/chisq_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@
"accum_diff_sq_kernel"
)

# With thanks to Google Gemini
sum_real_reduction_kernel = cp.ReductionKernel(
'T x',
'T y',
'x * conj(x)',
'a + b',
'y = real(a)',
'0',
'norm_sum'
)


def chisq_accum_bin(chisq, q):
accum_diff_sq_kernel(q.data, chisq.data)

Expand Down Expand Up @@ -235,7 +247,28 @@ def get_pchisq_fn_pow2(np, fuse_correlate=False):
)
return fn, nt

# FIXME: Don't hardcode 512, make it dynamic
_CACHED_BIN_BV = cp.zeros(512, dtype=cp.uint32)
_CACHED_BIN_KMIN = cp.zeros(512, dtype=cp.uint32)
_CACHED_BIN_KMAX = cp.zeros(512, dtype=cp.uint32)
def get_cached_bin_layout(bins):
bin_layout_kern = cp.ElementwiseKernel(
'raw T bins',
'X bv, X kmin, X kmax',
'bv = i; kmin = bins[i]; kmax = bins[i+1];',
'get_bin_layout'
)

bv = _CACHED_BIN_BV[:len(bins)]
kmin = _CACHED_BIN_KMIN[:len(bins)]
kmax = _CACHED_BIN_KMAX[:len(bins)]

bin_layout_kern(bins, bv, kmin, kmax)

return kmin, kmax, bv

# This is the older code, which might give a faster later GPU kernel, but
# is itself a timesink.
bv, kmin, kmax = [], [], []
for i in range(len(bins)-1):
s, e = bins[i], bins[i+1]
Expand All @@ -252,6 +285,7 @@ def get_cached_bin_layout(bins):
bv = cp.array(bv, dtype=cp.uint32)
kmin = cp.array(kmin, dtype=cp.uint32)
kmax = cp.array(kmax, dtype=cp.uint32)
print(len(bv),len(bins)-1, "CHECKING")
return kmin, kmax, bv

def shift_sum_points(num, N, arg_tuple):
Expand Down Expand Up @@ -341,5 +375,5 @@ def shift_sum(corr, points, bins):
elif np == 1:
outp, phase, np = shift_sum_points(1, cargs) # pylint:disable=no-value-for-parameter

return cp.asnumpy((outc.conj() * outc).sum(axis=1).real)
return cp.asnumpy(sum_real_reduction_kernel(outc, axis=1))

2 changes: 1 addition & 1 deletion pycbc/waveform/bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def sigma_cached(self, psd):

kmin = int(self.f_lower / psd.delta_f)
self._sigmasq[key] = self.sigma_scale * \
(curr_sigmasq[self.end_idx-1] - curr_sigmasq[kmin])
float(curr_sigmasq[self.end_idx-1] - curr_sigmasq[kmin])

else:
if not hasattr(self, 'sigma_view'):
Expand Down
6 changes: 3 additions & 3 deletions pycbc/waveform/spa_tmplt.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def spa_tmplt_precondition(length, delta_f, kmin=0):
def spa_tmplt_norm(psd, length, delta_f, f_lower):
amp = spa_tmplt_precondition(length, delta_f)
k_min = int(f_lower / delta_f)
sigma = (amp[k_min:length].numpy() ** 2. / psd[k_min:length].numpy())
norm_vec = numpy.zeros(length)
norm_vec[k_min:length] = sigma.cumsum() * 4. * delta_f
sigma = (amp.data[k_min:length] ** 2. / psd.data[k_min:length])
norm_vec = FrequencySeries(zeros(length), delta_f=delta_f, dtype=float32)
norm_vec.data[k_min:length] = sigma.cumsum() * 4. * delta_f
return norm_vec


Expand Down

0 comments on commit 62a59e5

Please sign in to comment.