diff --git a/bmtk/simulator/filternet/lgnmodel/cursor.py b/bmtk/simulator/filternet/lgnmodel/cursor.py index 06649873d..ee5aa08d3 100644 --- a/bmtk/simulator/filternet/lgnmodel/cursor.py +++ b/bmtk/simulator/filternet/lgnmodel/cursor.py @@ -1,5 +1,14 @@ import numpy as np import scipy.signal as spsig +from numba import njit, prange + +try: + from mpi4py import MPI + mpi_size = MPI.COMM_WORLD.Get_size() + numba_parallel = mpi_size == 1 # if there is only 1 thread, turn on numba parallel +except ImportError: + numba_parallel = True # If there is no MPI, turn on numba parallel + from .utilities import convert_tmin_tmax_framerate_to_trange @@ -76,18 +85,39 @@ def apply_dot_product(self, ti_offset): return self.cache[ti_offset] except KeyError: - t_inds = self.kernel.t_inds + ti_offset + 1 # Offset by one nhc 14 Apr '17 - min_ind, max_ind = 0, self.movie.data.shape[0] - allowed_inds = np.where(np.logical_and(min_ind <= t_inds, t_inds < max_ind)) - t_inds = t_inds[allowed_inds] - row_inds = self.kernel.row_inds[allowed_inds] - col_inds = self.kernel.col_inds[allowed_inds] - kernel_vector = self.kernel.kernel[allowed_inds] - result = np.dot(self.movie[t_inds, row_inds, col_inds], kernel_vector) + # This part is rewritten with numba below. + # t_inds = self.kernel.t_inds + ti_offset + 1 # Offset by one nhc 14 Apr '17 + # min_ind, max_ind = 0, self.movie.data.shape[0] + # allowed_inds = np.where(np.logical_and(min_ind <= t_inds, t_inds < max_ind)) + # t_inds = t_inds[allowed_inds] + # row_inds = self.kernel.row_inds[allowed_inds] + # col_inds = self.kernel.col_inds[allowed_inds] + # kernel_vector = self.kernel.kernel[allowed_inds] + # result = np.dot(self.movie[t_inds, row_inds, col_inds], kernel_vector) + result = fast_dot_product( + self.movie.data, + ti_offset, + self.kernel.t_inds, + self.kernel.row_inds, + self.kernel.col_inds, + self.kernel.kernel, + ) self.cache[ti_offset] = result return result +# a faster version of the commented out part of the above class method. +# results agree up to a round off error. +@njit(parallel=numba_parallel) +def fast_dot_product(movie_data, ti_offset, kernel_t_inds, kernel_row_inds, kernel_col_inds, kernel_kernel): + t_inds = kernel_t_inds + ti_offset + 1 + result = 0.0 + for i in prange(len(t_inds)): + if t_inds[i] >= 0 and t_inds[i] < movie_data.shape[0]: + result = result + movie_data[t_inds[i], kernel_row_inds[i], kernel_col_inds[i]] * kernel_kernel[i] + return result + + class FilterCursor(KernelCursor): def __init__(self, spatiotemporal_filter, movie, threshold=0): # TODO: not sure why this needs to have it's own class and shouldn't be merged into parent? diff --git a/setup.py b/setup.py index c906a2b39..313c10a45 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ def read(*filenames, **kwargs): 'scipy', 'scikit-image', # Only required for filternet, consider making optional 'sympy', # For FilterNet + 'numba', # For FilterNet 'pynrrd' # For nrrd reader ], extras_require={