Skip to content

Commit

Permalink
Move cupy array decleration to when needed
Browse files Browse the repository at this point in the history
  • Loading branch information
spxiwh committed Nov 25, 2024
1 parent ea040b2 commit 56bf071
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions pycbc/events/threshold_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,8 @@
from .eventmgr import _BaseThresholdCluster
import pycbc.scheme


# The CUDA version had globally defined storage arrays. We mimic that here
# but I think this should be more elegant (ie. don't define arrays unless
# needed)

val = cp.zeros(4096*256, dtype=cp.complex64)
loc = cp.zeros(4096*256, cp.int32)
val = None
loc = None

# https://stackoverflow.com/questions/77798014/cupy-rawkernel-cuda-error-not-found-named-symbol-not-found-cupy

Expand Down Expand Up @@ -200,6 +195,13 @@ def get_tkernel(slen, window):
return (fn, fn2), nt, nb

def threshold_and_cluster(series, threshold, window):
global val
global loc
if val is None:
val = cp.zeros(4096*256, dtype=cp.complex64)
if loc is None
loc = cp.zeros(4096*256, cp.int32)

outl = loc
outv = val
slen = len(series)
Expand All @@ -220,6 +222,13 @@ class CUDAThresholdCluster(_BaseThresholdCluster):
def __init__(self, series):
self.series = series

global val
global loc
if val is None:
val = cp.zeros(4096*256, dtype=cp.complex64)
if loc is None
loc = cp.zeros(4096*256, cp.int32)

self.outl = loc
self.outv = val
self.slen = len(series)
Expand Down

0 comments on commit 56bf071

Please sign in to comment.