Skip to content

Commit

Permalink
fix: stacked adjoint with cpu or gpu.
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Feb 8, 2024
1 parent e191706 commit 450ebe3
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/mrinufft/operators/stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import scipy as sp

from mrinufft._utils import proper_trajectory, power_method, get_array_module
from mrinufft._utils import proper_trajectory, power_method, get_array_module, auto_cast
from mrinufft.operators.base import FourierOperatorBase, check_backend, get_operator
from mrinufft.operators.interfaces.utils import (
is_cuda_array,
Expand Down Expand Up @@ -373,6 +373,7 @@ def op(self, data, ksp=None):
xp = get_array_module(data)
if xp.__name__ == "torch" and data.is_cpu:
data = data.numpy()
data = auto_cast(data, self.cpx_dtype)

if self.uses_sense and is_cuda_array(data):
op_func = self._op_sense_device
Expand Down Expand Up @@ -530,6 +531,11 @@ def _op_calibless_device(self, data, ksp=None):
def adj_op(self, coeffs, img=None):
"""Adjoint operator."""
# Dispatch to special case.
xp = get_array_module(coeffs)
if xp.__name__ == "torch" and coeffs.is_cpu:
coeffs = coeffs.numpy()
coeffs = auto_cast(coeffs, self.cpx_dtype)

if self.uses_sense and is_cuda_array(coeffs):
adj_op_func = self._adj_op_sense_device
elif self.uses_sense:
Expand All @@ -541,6 +547,11 @@ def adj_op(self, coeffs, img=None):

ret = adj_op_func(coeffs, img)

if xp.__name__ == "torch" and is_cuda_array(ret):
ret = xp.as_tensor(ret, device=coeffs.device)
elif xp.__name__ == "torch":
ret = xp.from_numpy(ret)

return self._safe_squeeze(ret)

def _adj_op_sense_host(self, coeffs, img_d=None):
Expand Down

0 comments on commit 450ebe3

Please sign in to comment.