diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 2095587bfc..fe9f796aba 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -367,6 +367,8 @@ def f_pow2(x_tm2, x_tm1): state_val = np.array([1.0, 2.0]) numba_mode = get_mode("NUMBA").including("scan_save_mem") + # multi-output Elemwise not supported in NUMBA + numba_mode = numba_mode.excluding("fusion") py_mode = Mode("py").including("scan_save_mem") out_fg = FunctionGraph([init_x, n_steps], [output]) @@ -406,6 +408,8 @@ def inner_fct(seq, state_old, state_current): g_outs = grad(out.sum(), [seq, init_x]) numba_mode = get_mode("NUMBA").including("scan_save_mem") + # multi-output Elemwise not supported in NUMBA + numba_mode = numba_mode.excluding("fusion") py_mode = Mode("py").including("scan_save_mem") out_fg = FunctionGraph([seq, init_x], g_outs)