Skip to content

Commit

Permalink
feat: update cpu backend and test.
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Jan 26, 2024
1 parent 422fe9e commit f379a72
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 93 deletions.
55 changes: 30 additions & 25 deletions src/mrinufft/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
:author: Pierre-Antoine Comby
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from functools import partial
import warnings
import numpy as np
from mrinufft._utils import power_method
from mrinufft._utils import power_method, auto_cast, get_array_module

from mrinufft.density import get_density

Expand Down Expand Up @@ -91,7 +91,7 @@ class FourierOperatorBase(ABC):
as required by ModOpt.
"""

interfaces = {}
interfaces: dict[str, tuple] = {}

def __init__(self):
if not self.available:
Expand Down Expand Up @@ -224,6 +224,11 @@ def uses_density(self):
"""Return True if the operator uses density compensation."""
return getattr(self, "density", None) is not None

@property
def ndim(self):
"""Number of dimensions in image space of the operator."""
return len(self._shape)

@property
def shape(self):
"""Shape of the image space of the operator."""
Expand All @@ -238,11 +243,6 @@ def n_coils(self):
"""Number of coils for the operator."""
return self._n_coils

@property
def ndim(self):
"""Number of dimensions in image space of the operator."""
return len(self._shape)

@n_coils.setter
def n_coils(self, n_coils):
if n_coils < 1 or not int(n_coils) == n_coils:
Expand Down Expand Up @@ -285,15 +285,15 @@ def dtype(self):
"""Return floating precision of the operator."""
return self._dtype

@dtype.setter
def dtype(self, dtype):
self._dtype = np.dtype(dtype)

@property
def cpx_dtype(self):
"""Return complex floating precision of the operator."""
return np.dtype(DTYPE_R2C[str(self.dtype)])

@dtype.setter
def dtype(self, dtype):
self._dtype = np.dtype(dtype)

@property
def samples(self):
"""Return the samples used by the operator."""
Expand Down Expand Up @@ -399,20 +399,23 @@ def op(self, data, ksp=None):
this performs for every coil \ell:
..math:: \mathcal{F}\mathcal{S}_\ell x
"""
if data.dtype != self.cpx_dtype:
warnings.warn(
f"Data should be of dtype {self.cpx_dtype} (is {data.dtype}). "
"Casting it for you."
)
data = data.astype(self.cpx_dtype)
# sense
xp = get_array_module(data)
data = auto_cast(data, self.cpx_dtype)

if xp.__name__ == "torch":
data = data.to("cpu").numpy()
if self.uses_sense:
ret = self._op_sense(data, ksp)
# calibrationless or monocoil.
else:
ret = self._op_calibless(data, ksp)
ret /= self.norm_factor
return self._safe_squeeze(ret)

ret = self._safe_squeeze(ret)
if xp.__name__ == "torch":
return xp.from_numpy(ret)
return ret

def _op_sense(self, data, ksp=None):
T, B, C = self.n_trans, self.n_batchs, self.n_coils
Expand Down Expand Up @@ -457,18 +460,20 @@ def adj_op(self, coeffs, img=None):
-------
Array in the same memory space of coeffs. (ie on cpu or gpu Memory).
"""
if coeffs.dtype != self.cpx_dtype:
warnings.warn(
f"coeffs should be of dtype {self.cpx_dtype}. Casting it for you."
)
coeffs = coeffs.astype(self.cpx_dtype)
coeffs = auto_cast(coeffs, self.cpx_dtype)
xp = get_array_module(coeffs)
if xp.__name__ == "torch":
coeffs = coeffs.to("cpu").numpy()
if self.uses_sense:
ret = self._adj_op_sense(coeffs, img)
# calibrationless or monocoil.
else:
ret = self._adj_op_calibless(coeffs, img)
ret /= self.norm_factor
return self._safe_squeeze(ret)
ret = self._safe_squeeze(ret)
if xp.__name__ == "torch":
return xp.from_numpy(ret)
return ret

def _adj_op_sense(self, coeffs, img=None):
T, B, C = self.n_trans, self.n_batchs, self.n_coils
Expand Down
25 changes: 20 additions & 5 deletions src/mrinufft/operators/interfaces/pynufft_cpu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""PyNUFFT CPU Interface."""

from ..base import FourierOperatorCPU

import numpy as np
from mrinufft._utils import proper_trajectory

PYNUFFT_CPU_AVAILABLE = True
try:
Expand All @@ -24,12 +25,12 @@ def __init__(self, samples, shape, osf=2, interpolator_shape=6):

def op(self, coeffs_data, grid_data):
"""Forward Operator."""
coeffs_data = self._nufft_obj.forward(grid_data)
coeffs_data = self._nufft_obj.forward(grid_data.squeeze())
return coeffs_data

def adj_op(self, coeffs_data, grid_data):
"""Adjoint Operator."""
grid_data = self._nufft_obj.backward(coeffs_data)
grid_data = self._nufft_obj.adjoint(coeffs_data.squeeze())
return grid_data


Expand All @@ -45,10 +46,24 @@ def __init__(
shape,
density=False,
n_coils=1,
n_batchs=1,
smaps=None,
osf=2,
**kwargs,
):
super().__init__(samples, shape, density, n_coils, smaps)
super().__init__(
proper_trajectory(samples, normalize="pi"),
shape,
density=density,
n_coils=n_coils,
n_batchs=n_batchs,
n_trans=1,
smaps=smaps,
)

self.raw_op = RawPyNUFFT(self.samples, shape, osf, **kwargs)

self.raw_op = RawPyNUFFT(samples, shape, osf, **kwargs)
# @property
# def norm_factor(self):
# """Normalization factor of the operator."""
# return np.sqrt(2 ** len(self.shape))
2 changes: 2 additions & 0 deletions tests/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from_interface,
CUPY_AVAILABLE,
TORCH_AVAILABLE,
param_array_interface,
)

__all__ = [
Expand All @@ -19,4 +20,5 @@
"from_interface",
"CUPY_AVAILABLE",
"TORCH_AVAILABLE",
"param_array_interface",
]
48 changes: 43 additions & 5 deletions tests/helpers/factories.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Useful factories to create matching data for an operator."""
from functools import wraps
import numpy as np
import pytest

CUPY_AVAILABLE = True
try:
Expand All @@ -12,8 +14,6 @@
import torch
except ImportError:
TORCH_AVAILABLE = False
else:
TORCH_AVAILABLE = torch.cuda.is_available()


def image_from_op(operator):
Expand Down Expand Up @@ -44,7 +44,9 @@ def to_interface(data, interface):
"""Make DATA an array from INTERFACE."""
if interface == "cupy":
return cp.array(data)
elif interface == "torch":
elif interface == "torch-cpu":
return torch.from_numpy(data)
elif interface == "torch-gpu":
return torch.from_numpy(data).to("cuda")
return data

Expand All @@ -53,6 +55,42 @@ def from_interface(data, interface):
"""Get DATA from INTERFACE as a numpy array."""
if interface == "cupy":
return data.get()
elif interface == "torch":
return data.to("cpu").numpy()
elif "torch" in interface:
return data.cpu().numpy()
return data


_param_array_interface = pytest.mark.parametrize(
"array_interface",
[
"numpy",
pytest.param(
"cupy",
marks=pytest.mark.skipif(
not CUPY_AVAILABLE,
reason="cupy not available",
),
),
pytest.param(
"torch-cpu",
marks=pytest.mark.skipif(not TORCH_AVAILABLE, reason="torch not available"),
),
pytest.param(
"torch-gpu",
marks=pytest.mark.skipif(
not (TORCH_AVAILABLE and torch.cuda.is_available()),
reason="torch not available",
),
),
],
)


def param_array_interface(func):
@wraps(func)
def wrapper(operator, array_interface, *args, **kwargs):
if operator.backend != "cufinufft" and array_interface in ["torch-gpu", "cupy"]:
pytest.skip("Uncompatible backend and array")
return func(operator, array_interface, *args, **kwargs)

return _param_array_interface(wrapper)
48 changes: 40 additions & 8 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import numpy as np
import numpy.testing as npt
from pytest_cases import parametrize_with_cases, parametrize, fixture
from helpers import assert_correlate
from helpers import (
assert_correlate,
param_array_interface,
to_interface,
from_interface,
)
from mrinufft import get_operator
from case_trajectories import CasesTrajectories

Expand Down Expand Up @@ -96,17 +101,23 @@ def kspace_data(operator):
return kspace


def test_batch_op(operator, flat_operator, image_data):
@param_array_interface
def test_batch_op(operator, array_interface, flat_operator, image_data):
"""Test the batch type 2 (forward)."""
kspace_batched = operator.op(image_data)

image_data = to_interface(image_data, array_interface)

kspace_batched = from_interface(operator.op(image_data), array_interface)

if operator.uses_sense:
image_flat = image_data.reshape(-1, *operator.shape)
else:
image_flat = image_data.reshape(-1, operator.n_coils, *operator.shape)
kspace_flat = [None] * operator.n_batchs
for i in range(len(kspace_flat)):
kspace_flat[i] = flat_operator.op(image_flat[i])
kspace_flat[i] = from_interface(
flat_operator.op(image_flat[i]), array_interface
)

kspace_flat = np.reshape(
np.concatenate(kspace_flat, axis=0),
Expand All @@ -116,12 +127,22 @@ def test_batch_op(operator, flat_operator, image_data):
npt.assert_array_almost_equal(kspace_batched, kspace_flat)


def test_batch_adj_op(operator, flat_operator, kspace_data):
@param_array_interface
def test_batch_adj_op(
operator,
array_interface,
flat_operator,
kspace_data,
):
"""Test the batch type 1 (adjoint)."""
kspace_data = to_interface(kspace_data, array_interface)

kspace_flat = kspace_data.reshape(-1, operator.n_coils, operator.n_samples)
image_flat = [None] * operator.n_batchs
for i in range(len(image_flat)):
image_flat[i] = flat_operator.adj_op(kspace_flat[i])
image_flat[i] = from_interface(
flat_operator.adj_op(kspace_flat[i]), array_interface
)

if operator.uses_sense:
shape = (operator.n_batchs, 1, *operator.shape)
Expand All @@ -133,14 +154,23 @@ def test_batch_adj_op(operator, flat_operator, kspace_data):
shape,
)

image_batched = operator.adj_op(kspace_data)
image_batched = from_interface(operator.adj_op(kspace_data), array_interface)
# Reduced accuracy for the GPU cases...
npt.assert_allclose(image_batched, image_flat, atol=1e-3, rtol=1e-3)


def test_data_consistency(operator, image_data, kspace_data):
@param_array_interface
def test_data_consistency(
operator,
array_interface,
image_data,
kspace_data,
):
"""Test the data consistency operation."""
# image_data = np.zeros_like(image_data)
image_data = to_interface(image_data)
kspace_data = to_interface(kspace_data)

res = operator.data_consistency(image_data, kspace_data)
tmp = operator.op(image_data)
res2 = operator.adj_op(tmp - kspace_data)
Expand All @@ -149,6 +179,8 @@ def test_data_consistency(operator, image_data, kspace_data):
res = res.reshape(-1, *operator.shape)
res2 = res2.reshape(-1, *operator.shape)

res = from_interface(res, array_interface)
res2 = from_interface(res2, array_interface)
slope_err = 1e-3
# FIXME 2D Sense is not very accurate...
if len(operator.shape) == 2 and operator.uses_sense:
Expand Down
Loading

0 comments on commit f379a72

Please sign in to comment.