diff --git a/src/mrinufft/density/utils.py b/src/mrinufft/density/utils.py index 1a774dfc..b11b696f 100644 --- a/src/mrinufft/density/utils.py +++ b/src/mrinufft/density/utils.py @@ -56,8 +56,12 @@ def normalize_weights(weights): def normalize_density(kspace_loc, shape, density, backend, **kwargs): """Normalize the density to ensure that the reconstruction is stable.""" from mrinufft import get_operator - + xp = np + if backend == "cufinufft": + import cupy as cp + xp = cp test_op = get_operator(backend)(samples=kspace_loc, shape=shape, **kwargs) - test_im = np.ones(shape, dtype=test_op.cpx_dtype) + test_im = xp.ones(shape, dtype=test_op.cpx_dtype) test_im_recon = test_op.adj_op(density * test_op.op(test_im)) - density /= np.mean(np.abs(test_im_recon)) + density /= xp.mean(xp.abs(test_im_recon)) + return density