From 119ee92cf5bca65779987674686462984d33ca16 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Fri, 20 Dec 2024 14:52:51 +0100 Subject: [PATCH] Fix --- src/mrinufft/density/utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/mrinufft/density/utils.py b/src/mrinufft/density/utils.py index 1a774dfc..b11b696f 100644 --- a/src/mrinufft/density/utils.py +++ b/src/mrinufft/density/utils.py @@ -56,8 +56,12 @@ def normalize_weights(weights): def normalize_density(kspace_loc, shape, density, backend, **kwargs): """Normalize the density to ensure that the reconstruction is stable.""" from mrinufft import get_operator - + xp = np + if backend == "cufinufft": + import cupy as cp + xp = cp test_op = get_operator(backend)(samples=kspace_loc, shape=shape, **kwargs) - test_im = np.ones(shape, dtype=test_op.cpx_dtype) + test_im = xp.ones(shape, dtype=test_op.cpx_dtype) test_im_recon = test_op.adj_op(density * test_op.op(test_im)) - density /= np.mean(np.abs(test_im_recon)) + density /= xp.mean(xp.abs(test_im_recon)) + return density