From 821d11de3b19c8c06934fea8b565dd45caa29ddc Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Thu, 23 Jan 2020 15:47:17 +0000 Subject: [PATCH] Remove workarounds now that sparse-0.9.1 has numba support --- clifford/__init__.py | 27 ++++++++++----------------- clifford/tools/g3c/__init__.py | 7 +------ setup.py | 2 +- 3 files changed, 12 insertions(+), 24 deletions(-) diff --git a/clifford/__init__.py b/clifford/__init__.py index fee1bd19..a68c7fc7 100644 --- a/clifford/__init__.py +++ b/clifford/__init__.py @@ -196,20 +196,16 @@ def _get_mult_function(mt: sparse.COO): func : function (array_like (n_dims,), array_like (n_dims,)) -> array_like (n_dims,) A function that computes the appropriate multiplication """ - # unpack for numba - dims = mt.shape[1] - k_list, l_list, m_list = mt.coords - mult_table_vals = mt.data - @numba.generated_jit(nopython=True) def mv_mult(value, other_value): # this casting will be done at jit-time - ret_dtype = _get_mult_function_result_type(value, other_value, mult_table_vals.dtype) - mult_table_vals_t = mult_table_vals.astype(ret_dtype) + ret_dtype = _get_mult_function_result_type(value, other_value, mt.dtype) + mt_t = mt.astype(ret_dtype) def mult_inner(value, other_value): - output = np.zeros(dims, dtype=ret_dtype) - for k, l, m, val in zip(k_list, l_list, m_list, mult_table_vals_t): + output = np.zeros(mt_t.shape[1], dtype=ret_dtype) + k_list, l_list, m_list = mt_t.coords[0], mt_t.coords[1], mt_t.coords[2] + for k, l, m, val in zip(k_list, l_list, m_list, mt_t.data): output[l] += value[k] * val * other_value[m] return output @@ -227,19 +223,16 @@ def _get_mult_function_runtime_sparse(mt: sparse.COO): TODO: determine if this actually helps. """ - # unpack for numba - dims = mt.shape[1] - k_list, l_list, m_list = mt.coords - mult_table_vals = mt.data @numba.generated_jit(nopython=True) def mv_mult(value, other_value): # this casting will be done at jit-time - ret_dtype = _get_mult_function_result_type(value, other_value, mult_table_vals.dtype) - mult_table_vals_t = mult_table_vals.astype(ret_dtype) + ret_dtype = _get_mult_function_result_type(value, other_value, mt.dtype) + mt_t = mt.astype(ret_dtype) def mult_inner(value, other_value): - output = np.zeros(dims, dtype=ret_dtype) + output = np.zeros(mt_t.shape[1], dtype=ret_dtype) + k_list, l_list, m_list = mt_t.coords[0], mt_t.coords[1], mt_t.coords[2] for ind, k in enumerate(k_list): v_val = value[k] if v_val != 0.0: @@ -247,7 +240,7 @@ def mult_inner(value, other_value): ov_val = other_value[m] if ov_val != 0.0: l = l_list[ind] - output[l] += v_val * mult_table_vals_t[ind] * ov_val + output[l] += v_val * mt_t.data[ind] * ov_val return output return mult_inner diff --git a/clifford/tools/g3c/__init__.py b/clifford/tools/g3c/__init__.py index 22f0b82b..517ba2fb 100644 --- a/clifford/tools/g3c/__init__.py +++ b/clifford/tools/g3c/__init__.py @@ -499,15 +499,10 @@ def scale_TR_translation(TR, scale): def left_gmt_generator(mt=layout.gmt): # unpack for numba - k_list, l_list, m_list = mt.coords - mult_table_vals = mt.data - gaDims = mt.shape[1] val_get_left_gmt_matrix = cf._numba_val_get_left_gmt_matrix - @numba.njit def get_left_gmt(x_val): - return val_get_left_gmt_matrix( - x_val, k_list, l_list, m_list, mult_table_vals, gaDims) + return val_get_left_gmt_matrix(x_val, mt) return get_left_gmt diff --git a/setup.py b/setup.py index 8b009932..51716832 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ 'scipy', 'numba>=0.45.1', 'h5py', - 'sparse', + 'sparse>=0.9.1', ], package_dir={'clifford':'clifford'},