From c56f1ba33e087ec1a05b48df8955b3676aeb5b20 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 10 Feb 2025 13:43:20 +0100 Subject: [PATCH] refactor: consistent name in functions --- pyproject.toml | 6 +++ src/mrinufft/operators/base.py | 42 ++++++++++--------- .../operators/interfaces/cufinufft.py | 7 ++-- src/mrinufft/operators/interfaces/gpunufft.py | 10 ++--- src/mrinufft/operators/interfaces/tfnufft.py | 4 +- 5 files changed, 39 insertions(+), 30 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 187910627..4a44d4c81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,3 +76,9 @@ live_mode = false [tool.mypy] ignore_missing_imports = true + +[tool.pyright] +reportPossiblyUnboundVariable = false +typeCheckingMode = "basic" +reportOptionalSubscript = false +reportOptionalMemberAccess = false diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 6fbbb73c0..62cbe2d5f 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -10,8 +10,9 @@ from abc import ABC, abstractmethod from functools import partial - +from typing import ClassVar, Callable import numpy as np +from numpy.typing import NDArray from mrinufft._array_compat import with_numpy, with_numpy_cupy, AUTOGRAD_AVAILABLE from mrinufft._utils import auto_cast, power_method @@ -19,10 +20,6 @@ from mrinufft.extras import get_smaps from mrinufft.operators.interfaces.utils import is_cuda_array, is_host_array -if AUTOGRAD_AVAILABLE: - from mrinufft.operators.autodiff import MRINufftAutoGrad - - # Mapping between numpy float and complex types. DTYPE_R2C = {"float32": "complex64", "float64": "complex128"} @@ -122,6 +119,9 @@ class FourierOperatorBase(ABC): _grad_wrt_data = False _grad_wrt_traj = False + backend: ClassVar[str] + available: ClassVar[bool] + def __init__(self): if not self.available: raise RuntimeError(f"'{self.backend}' backend is not available.") @@ -207,21 +207,21 @@ def adj_op(self, coeffs): """ pass - def data_consistency(self, image, obs_data): + def data_consistency(self, image_data, obs_data): """Compute the gradient data consistency. This is the naive implementation using adj_op(op(x)-y). Specific backend can (and should!) implement a more efficient version. """ - return self.adj_op(self.op(image) - obs_data) + return self.adj_op(self.op(image_data) - obs_data) def with_off_resonance_correction(self, B, C, indices): """Return a new operator with Off Resonnance Correction.""" - from ..off_resonance import MRIFourierCorrected + from .off_resonance import MRIFourierCorrected return MRIFourierCorrected(self, B, C, indices) - def compute_smaps(self, method=None): + def compute_smaps(self, method: NDArray | Callable | str | dict | None = None): """Compute the sensitivity maps and set it. Parameters @@ -286,6 +286,8 @@ def make_autograd(self, wrt_data=True, wrt_traj=False): if not self.autograd_available: raise ValueError("Backend does not support auto-differentiation.") + from mrinufft.operators.autodiff import MRINufftAutoGrad + return MRINufftAutoGrad(self, wrt_data=wrt_data, wrt_traj=wrt_traj) def compute_density(self, method=None): @@ -401,9 +403,9 @@ def smaps(self): return self._smaps @smaps.setter - def smaps(self, smaps): - self._check_smaps_shape(smaps) - self._smaps = smaps + def smaps(self, new_smaps): + self._check_smaps_shape(new_smaps) + self._smaps = new_smaps def _check_smaps_shape(self, smaps): """Check the shape of the sensitivity maps.""" @@ -421,13 +423,13 @@ def density(self): return self._density @density.setter - def density(self, density): - if density is None: + def density(self, new_density): + if new_density is None: self._density = None - elif len(density) != self.n_samples: + elif len(new_density) != self.n_samples: raise ValueError("Density and samples should have the same length") else: - self._density = density + self._density = new_density @property def dtype(self): @@ -435,8 +437,8 @@ def dtype(self): return self._dtype @dtype.setter - def dtype(self, dtype): - self._dtype = np.dtype(dtype) + def dtype(self, new_dtype): + self._dtype = np.dtype(new_dtype) @property def cpx_dtype(self): @@ -449,8 +451,8 @@ def samples(self): return self._samples @samples.setter - def samples(self, samples): - self._samples = samples + def samples(self, new_samples): + self._samples = new_samples @property def n_samples(self): diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index 9f3c360f4..db1eaab75 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -26,7 +26,6 @@ except ImportError: CUFINUFFT_AVAILABLE = False - OPTS_FIELD_DECODE = { "gpu_method": {1: "nonuniform pts driven", 2: "shared memory"}, "gpu_sort": {0: "no sort (GM)", 1: "sort (GM-sort)"}, @@ -269,10 +268,12 @@ def smaps(self, new_smaps): self._smaps = new_smaps @FourierOperatorBase.samples.setter - def samples(self, samples): + def samples(self, new_samples): """Update the plans when changing the samples.""" self._samples = np.asfortranarray( - proper_trajectory(samples, normalize="pi").astype(np.float32, copy=False) + proper_trajectory(new_samples, normalize="pi").astype( + np.float32, copy=False + ) ) for typ in [1, 2, "grad"]: if typ == "grad" and not self._grad_wrt_traj: diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index fcaae3e9f..ec55d5c91 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -532,7 +532,7 @@ def smaps(self, new_smaps): self.raw_op.set_smaps(smaps=new_smaps) @FourierOperatorBase.samples.setter - def samples(self, samples): + def samples(self, new_samples): """Set the samples for the Fourier Operator. Parameters @@ -541,7 +541,7 @@ def samples(self, samples): The samples for the Fourier Operator. """ self._samples = proper_trajectory( - samples.astype(np.float32, copy=False), normalize="unit" + new_samples.astype(np.float32, copy=False), normalize="unit" ) # TODO: gpuNUFFT needs to sort the points twice in this case. # It could help to have access to directly dorted arrays from gpuNUFFT. @@ -552,7 +552,7 @@ def samples(self, samples): ) @FourierOperatorBase.density.setter - def density(self, density): + def density(self, new_density): """Set the density for the Fourier Operator. Parameters @@ -560,11 +560,11 @@ def density(self, density): density: np.ndarray The density for the Fourier Operator. """ - self._density = density + self._density = new_density if hasattr(self, "raw_op"): # edge case for init self.raw_op.set_pts( self._samples, - density=density, + density=new_density, ) @classmethod diff --git a/src/mrinufft/operators/interfaces/tfnufft.py b/src/mrinufft/operators/interfaces/tfnufft.py index d82192737..75d322091 100644 --- a/src/mrinufft/operators/interfaces/tfnufft.py +++ b/src/mrinufft/operators/interfaces/tfnufft.py @@ -134,7 +134,7 @@ def norm_factor(self): return np.sqrt(np.prod(self.shape) * 2 ** len(self.shape)) @with_tensorflow - def data_consistency(self, data, obs_data): + def data_consistency(self, image_data, obs_data): """Compute the data consistency. Parameters @@ -149,7 +149,7 @@ def data_consistency(self, data, obs_data): Tensor The data consistency error in image space. """ - return self.adj_op(self.op(data) - obs_data) + return self.adj_op(self.op(image_data) - obs_data) @classmethod def pipe(