Skip to content

Commit

Permalink
Making non-separable FilterNet faster (AllenInstitute#335)
Browse files Browse the repository at this point in the history
* Enabling numba speedup on non-separable FilterNet

* minor change to reduce characters per line

* Adding numba as a requirement

* making MPI optional.
  • Loading branch information
shixnya authored Nov 1, 2023
1 parent 8b0c5e5 commit bb8dce1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
46 changes: 38 additions & 8 deletions bmtk/simulator/filternet/lgnmodel/cursor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
import numpy as np
import scipy.signal as spsig
from numba import njit, prange

try:
from mpi4py import MPI
mpi_size = MPI.COMM_WORLD.Get_size()
numba_parallel = mpi_size == 1 # if there is only 1 thread, turn on numba parallel
except ImportError:
numba_parallel = True # If there is no MPI, turn on numba parallel


from .utilities import convert_tmin_tmax_framerate_to_trange

Expand Down Expand Up @@ -76,18 +85,39 @@ def apply_dot_product(self, ti_offset):
return self.cache[ti_offset]

except KeyError:
t_inds = self.kernel.t_inds + ti_offset + 1 # Offset by one nhc 14 Apr '17
min_ind, max_ind = 0, self.movie.data.shape[0]
allowed_inds = np.where(np.logical_and(min_ind <= t_inds, t_inds < max_ind))
t_inds = t_inds[allowed_inds]
row_inds = self.kernel.row_inds[allowed_inds]
col_inds = self.kernel.col_inds[allowed_inds]
kernel_vector = self.kernel.kernel[allowed_inds]
result = np.dot(self.movie[t_inds, row_inds, col_inds], kernel_vector)
# This part is rewritten with numba below.
# t_inds = self.kernel.t_inds + ti_offset + 1 # Offset by one nhc 14 Apr '17
# min_ind, max_ind = 0, self.movie.data.shape[0]
# allowed_inds = np.where(np.logical_and(min_ind <= t_inds, t_inds < max_ind))
# t_inds = t_inds[allowed_inds]
# row_inds = self.kernel.row_inds[allowed_inds]
# col_inds = self.kernel.col_inds[allowed_inds]
# kernel_vector = self.kernel.kernel[allowed_inds]
# result = np.dot(self.movie[t_inds, row_inds, col_inds], kernel_vector)
result = fast_dot_product(
self.movie.data,
ti_offset,
self.kernel.t_inds,
self.kernel.row_inds,
self.kernel.col_inds,
self.kernel.kernel,
)
self.cache[ti_offset] = result
return result


# a faster version of the commented out part of the above class method.
# results agree up to a round off error.
@njit(parallel=numba_parallel)
def fast_dot_product(movie_data, ti_offset, kernel_t_inds, kernel_row_inds, kernel_col_inds, kernel_kernel):
t_inds = kernel_t_inds + ti_offset + 1
result = 0.0
for i in prange(len(t_inds)):
if t_inds[i] >= 0 and t_inds[i] < movie_data.shape[0]:
result = result + movie_data[t_inds[i], kernel_row_inds[i], kernel_col_inds[i]] * kernel_kernel[i]
return result


class FilterCursor(KernelCursor):
def __init__(self, spatiotemporal_filter, movie, threshold=0):
# TODO: not sure why this needs to have it's own class and shouldn't be merged into parent?
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def read(*filenames, **kwargs):
'scipy',
'scikit-image', # Only required for filternet, consider making optional
'sympy', # For FilterNet
'numba', # For FilterNet
'pynrrd' # For nrrd reader
],
extras_require={
Expand Down

0 comments on commit bb8dce1

Please sign in to comment.