From 48bc2ee8f3da02dc077d1904c5ada978fbe500a0 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Wed, 5 Feb 2025 14:38:40 +0100 Subject: [PATCH] fix: is_cuda_tensor guarded by TORCH_AVAILABLE. --- src/mrinufft/_array_compat.py | 6 +++--- src/mrinufft/operators/interfaces/utils/gpu_utils.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mrinufft/_array_compat.py b/src/mrinufft/_array_compat.py index 07ce88bc8..4347c8cdd 100644 --- a/src/mrinufft/_array_compat.py +++ b/src/mrinufft/_array_compat.py @@ -246,9 +246,9 @@ def _to_numpy_cupy(args, kwargs, leading_argument): This avoid transfers between different devices (e.g., CPU->GPU, GPU->CPU or different GPUs). """ - if is_cuda_array(leading_argument) and CUPY_AVAILABLE: - return _to_cupy(*args, **kwargs) - elif is_cuda_tensor(leading_argument) and CUPY_AVAILABLE: + if ( + is_cuda_array(leading_argument) or is_cuda_tensor(leading_argument) + ) and CUPY_AVAILABLE: return _to_cupy(*args, **kwargs) else: return _to_numpy(*args, **kwargs) diff --git a/src/mrinufft/operators/interfaces/utils/gpu_utils.py b/src/mrinufft/operators/interfaces/utils/gpu_utils.py index dcc12c91b..c19e65883 100644 --- a/src/mrinufft/operators/interfaces/utils/gpu_utils.py +++ b/src/mrinufft/operators/interfaces/utils/gpu_utils.py @@ -36,7 +36,7 @@ def is_cuda_array(var): def is_cuda_tensor(var): """Check if var is a CUDA tensor.""" - return isinstance(var, torch.Tensor) and var.is_cuda + return TORCH_AVAILABLE and isinstance(var, torch.Tensor) and var.is_cuda def is_host_array(var):