From d526791ea10e55180520a33099c088c2faa41829 Mon Sep 17 00:00:00 2001 From: Guillaume DAVAL-FREROT Date: Wed, 18 Dec 2024 21:05:48 +0100 Subject: [PATCH 1/6] Clean documentation and add type hints --- src/mrinufft/trajectories/display.py | 109 ++++----- src/mrinufft/trajectories/gradients.py | 24 +- .../trajectories/inits/random_walk.py | 59 +++-- .../trajectories/inits/travelling_salesman.py | 76 +++--- src/mrinufft/trajectories/maths/fibonacci.py | 12 +- src/mrinufft/trajectories/maths/primes.py | 6 +- src/mrinufft/trajectories/maths/rotations.py | 16 +- src/mrinufft/trajectories/maths/tsp_solver.py | 19 +- src/mrinufft/trajectories/sampling.py | 2 +- src/mrinufft/trajectories/tools.py | 228 +++++++++++------- src/mrinufft/trajectories/trajectory2D.py | 80 +++--- src/mrinufft/trajectories/trajectory3D.py | 175 ++++++++------ src/mrinufft/trajectories/utils.py | 84 ++++--- 13 files changed, 516 insertions(+), 374 deletions(-) diff --git a/src/mrinufft/trajectories/display.py b/src/mrinufft/trajectories/display.py index 50f20344b..973568765 100644 --- a/src/mrinufft/trajectories/display.py +++ b/src/mrinufft/trajectories/display.py @@ -1,6 +1,7 @@ """Display functions for trajectories.""" import itertools +from typing import Any import matplotlib as mpl import matplotlib.pyplot as plt @@ -124,7 +125,7 @@ def get_colorlist(cls): ############## -def _setup_2D_ticks(figsize, fig=None): +def _setup_2D_ticks(figsize: float, fig: plt.Figure | None = None) -> plt.Axes: """Add ticks to 2D plot.""" if fig is None: fig = plt.figure(figsize=(figsize, figsize)) @@ -139,7 +140,7 @@ def _setup_2D_ticks(figsize, fig=None): return ax -def _setup_3D_ticks(figsize, fig=None): +def _setup_3D_ticks(figsize: float, fig: plt.Figure | None = None) -> plt.Axes: """Add ticks to 3D plot.""" if fig is None: fig = plt.figure(figsize=(figsize, figsize)) @@ -163,21 +164,21 @@ def _setup_3D_ticks(figsize, fig=None): def display_2D_trajectory( - trajectory, - figsize=5, - one_shot=False, - subfigure=None, - show_constraints=False, - gmax=DEFAULT_GMAX, - smax=DEFAULT_SMAX, - constraints_order=None, - **constraints_kwargs, -): + trajectory: np.ndarray, + figsize: float = 5, + one_shot: bool | int = False, + subfigure: plt.Figure | plt.Axes | None = None, + show_constraints: bool = False, + gmax: float = DEFAULT_GMAX, + smax: float = DEFAULT_SMAX, + constraints_order: int | str | None = None, + **constraints_kwargs: Any, +) -> plt.Axes: """Display 2D trajectories. Parameters ---------- - trajectory : array_like + trajectory : np.ndarray Trajectory to display. figsize : float, optional Size of the figure. @@ -278,23 +279,23 @@ def display_2D_trajectory( def display_3D_trajectory( - trajectory, - nb_repetitions=None, - figsize=5, - per_plane=True, - one_shot=False, - subfigure=None, - show_constraints=False, - gmax=DEFAULT_GMAX, - smax=DEFAULT_SMAX, - constraints_order=None, - **constraints_kwargs, -): + trajectory: np.ndarray, + nb_repetitions: int | None = None, + figsize: float = 5, + per_plane: bool = True, + one_shot: bool | int = False, + subfigure: plt.Figure | plt.Axes | None = None, + show_constraints: bool = False, + gmax: float = DEFAULT_GMAX, + smax: float = DEFAULT_SMAX, + constraints_order: int | str | None = None, + **constraints_kwargs: dict, +) -> plt.Axes: """Display 3D trajectories. Parameters ---------- - trajectory : array_like + trajectory : np.ndarray Trajectory to display. nb_repetitions : int Number of repetitions (planes, cones, shells, etc). @@ -417,22 +418,22 @@ def display_3D_trajectory( def display_gradients_simply( - trajectory, - shot_ids=(0,), - figsize=5, - fill_area=True, - show_signal=True, - uni_signal="gray", - uni_gradient=None, - subfigure=None, -): + trajectory: np.ndarray, + shot_ids: tuple[int, ...] = (0,), + figsize: float = 5, + fill_area: bool = True, + show_signal: bool = True, + uni_signal: str | None = "gray", + uni_gradient: str | None = None, + subfigure: plt.Figure | None = None, +) -> tuple[plt.Axes]: """Display gradients based on trajectory of any dimension. Parameters ---------- - trajectory : array_like + trajectory : np.ndarray Trajectory to display. - shot_ids : list of int + shot_ids : tuple[int, ...], optional Indices of the shots to display. The default is `[0]`. figsize : float, optional @@ -455,7 +456,7 @@ def display_gradients_simply( unique color given as argument or just by the default color cycle when `None`. The default is `None`. - subfigure: plt.Figure or plt.SubFigure, optional + subfigure: plt.Figure, optional The figure where the trajectory should be displayed. The default is `None`. @@ -531,26 +532,26 @@ def display_gradients_simply( def display_gradients( - trajectory, - shot_ids=(0,), - figsize=5, - fill_area=True, - show_signal=True, - uni_signal="gray", - uni_gradient=None, - subfigure=None, - show_constraints=False, - gmax=DEFAULT_GMAX, - smax=DEFAULT_SMAX, - constraints_order=None, - raster_time=DEFAULT_RASTER_TIME, - **constraints_kwargs, -): + trajectory: np.ndarray, + shot_ids: tuple[int, ...] = (0,), + figsize: float = 5, + fill_area: bool = True, + show_signal: bool = True, + uni_signal: str | None = "gray", + uni_gradient: str | None = None, + subfigure: plt.Figure | plt.Axes | None = None, + show_constraints: bool = False, + gmax: float = DEFAULT_GMAX, + smax: float = DEFAULT_SMAX, + constraints_order: int | str | None = None, + raster_time: float = DEFAULT_RASTER_TIME, + **constraints_kwargs: Any, +) -> tuple[plt.Axes]: """Display gradients based on trajectory of any dimension. Parameters ---------- - trajectory : array_like + trajectory : np.ndarray Trajectory to display. shot_ids : list of int Indices of the shots to display. diff --git a/src/mrinufft/trajectories/gradients.py b/src/mrinufft/trajectories/gradients.py index 39ac934c0..932c22ee5 100644 --- a/src/mrinufft/trajectories/gradients.py +++ b/src/mrinufft/trajectories/gradients.py @@ -1,24 +1,26 @@ """Functions to improve/modify gradients.""" +from typing import Callable + import numpy as np import numpy.linalg as nl from scipy.interpolate import CubicSpline def patch_center_anomaly( - shot_or_params, - update_shot=None, - update_parameters=None, - in_out=False, - learning_rate=1e-1, -): + shot_or_params: np.ndarray | list, + update_shot: Callable[..., np.ndarray] | None = None, + update_parameters: Callable[..., list] | None = None, + in_out: bool = False, + learning_rate: float = 1e-1, +) -> tuple[np.ndarray, list]: """Re-position samples to avoid center anomalies. Some trajectories behave slightly differently from expected when approaching definition bounds, most often the k-space center as for spirals in some cases. - This function enforces non-strictly increasing monoticity of + This function enforces non-strictly increasing monotonicity of sample distances from the center, effectively reducing slew rates and smoothing gradient transitions locally. @@ -41,7 +43,7 @@ def patch_center_anomaly( If None, cubic spline parameterization is used instead, by default None in_out : bool, optional - Whether the shot is going in-and-out or start from the center, + Whether the shot is going in-and-out or starts from the center, by default False learning_rate : float, optional Learning rate used in the iterative optimization process, @@ -49,7 +51,7 @@ def patch_center_anomaly( Returns ------- - array_like + np.ndarray N-D trajectory based on ``shot_or_params`` if a shot or update_shot otherwise. list @@ -70,7 +72,7 @@ def patch_center_anomaly( if update_shot is None or update_parameters is None: - def _default_update_parameters(shot, *parameters): + def _default_update_parameters(shot: np.ndarray, *parameters: list) -> list: return parameters update_parameters = _default_update_parameters @@ -114,5 +116,5 @@ def _default_update_parameters(shot, *parameters): single_shot = cbs(x_axis).T parameters = update_parameters(single_shot, *parameters) - single_shot = single_shot = update_shot(*parameters) + single_shot = update_shot(*parameters) return single_shot, parameters diff --git a/src/mrinufft/trajectories/inits/random_walk.py b/src/mrinufft/trajectories/inits/random_walk.py index da68feb1f..5f2745068 100644 --- a/src/mrinufft/trajectories/inits/random_walk.py +++ b/src/mrinufft/trajectories/inits/random_walk.py @@ -1,16 +1,18 @@ """Trajectories based on random walks.""" +from typing import Any + import numpy as np from ..sampling import sample_from_density from ..utils import KMAX -def _get_adjacent_neighbors_offsets(shape): +def _get_adjacent_neighbors_offsets(shape: tuple[int, ...]) -> np.ndarray: return np.concatenate([np.eye(len(shape)), -np.eye(len(shape))], axis=0).astype(int) -def _get_neighbors_offsets(shape): +def _get_neighbors_offsets(shape: tuple[int, ...]) -> np.ndarray: nb_dims = len(shape) neighbors = (np.indices([3] * nb_dims) - 1).reshape((nb_dims, -1)).T nb_half = neighbors.shape[0] // 2 @@ -20,8 +22,14 @@ def _get_neighbors_offsets(shape): def _initialize_ND_random_walk( - Nc, Ns, density, *, diagonals=True, pseudo_random=True, **sampling_kwargs -): + Nc: int, + Ns: int, + density: np.ndarray, + *, + diagonals: bool = True, + pseudo_random: bool = True, + **sampling_kwargs: Any, +) -> np.ndarray: density = density / np.sum(density) flat_density = np.copy(density.flatten()) shape = np.array(density.shape) @@ -41,7 +49,6 @@ def _initialize_ND_random_walk( locations = sample_from_density(Nc, density, **sampling_kwargs) choices = np.around((locations + KMAX) * (np.array(density.shape) - 1)).astype(int) choices = np.ravel_multi_index(choices.T, density.shape) - # choices = np.random.choice(np.arange(len(flat_density)), size=Nc, p=flat_density) routes = [choices] # Walk @@ -71,7 +78,7 @@ def _initialize_ND_random_walk( choices = neighbors[np.arange(Nc), indices] routes.append(choices) - # Update density to account for already drawed positions + # Update density to account for already drawn positions if pseudo_random: flat_density[choices] = ( mask[choices] * flat_density[choices] / (mask[choices] + 1) @@ -88,8 +95,14 @@ def _initialize_ND_random_walk( def initialize_2D_random_walk( - Nc, Ns, density, *, diagonals=True, pseudo_random=True, **sampling_kwargs -): + Nc: int, + Ns: int, + density: np.ndarray, + *, + diagonals: bool = True, + pseudo_random: bool = True, + **sampling_kwargs: Any, +) -> np.ndarray: """Initialize a 2D random walk trajectory. This is an adaptation of the proposition from [Cha+14]_. @@ -107,31 +120,31 @@ def initialize_2D_random_walk( Number of shots Ns : int Number of samples per shot - density : array_like + density : np.ndarray Sampling density used to determine the walk probabilities, normalized automatically by its sum during the call for convenience. diagonals : bool, optional - Whether to draw the next walk step from the diagional neighbors + Whether to draw the next walk step from the diagonal neighbors on top of the adjacent ones. Default to ``True``. pseudo_random : bool, optional Whether to adapt the density dynamically to reduce areas already covered. The density is still statistically followed for undersampled acquisitions. Default to ``True``. - **sampling_kwargs + **sampling_kwargs : Any Sampling parameters in ``mrinufft.trajectories.sampling.sample_from_density`` used for the shot starting positions. Returns ------- - array_like + np.ndarray 2D random walk trajectory References ---------- .. [Cha+14] Chauffert, Nicolas, Philippe Ciuciu, Jonas Kahn, and Pierre Weiss. - "Variable density sampling with continuous trajectories." + "Variable density sampling with continuous trajectories" SIAM Journal on Imaging Sciences 7, no. 4 (2014): 1962-1992. """ if len(density.shape) != 2: @@ -147,8 +160,14 @@ def initialize_2D_random_walk( def initialize_3D_random_walk( - Nc, Ns, density, *, diagonals=True, pseudo_random=True, **sampling_kwargs -): + Nc: int, + Ns: int, + density: np.ndarray, + *, + diagonals: bool = True, + pseudo_random: bool = True, + **sampling_kwargs: Any, +) -> np.ndarray: """Initialize a 3D random walk trajectory. This is an adaptation of the proposition from [Cha+14]_. @@ -166,31 +185,31 @@ def initialize_3D_random_walk( Number of shots Ns : int Number of samples per shot - density : array_like + density : np.ndarray Sampling density used to determine the walk probabilities, normalized automatically by its sum during the call for convenience. diagonals : bool, optional - Whether to draw the next walk step from the diagional neighbors + Whether to draw the next walk step from the diagonal neighbors on top of the adjacent ones. Default to ``True``. pseudo_random : bool, optional Whether to adapt the density dynamically to reduce areas already covered. The density is still statistically followed for undersampled acquisitions. Default to ``True``. - **sampling_kwargs + **sampling_kwargs : Any Sampling parameters in ``mrinufft.trajectories.sampling.sample_from_density`` used for the shot starting positions. Returns ------- - array_like + np.ndarray 3D random walk trajectory References ---------- .. [Cha+14] Chauffert, Nicolas, Philippe Ciuciu, Jonas Kahn, and Pierre Weiss. - "Variable density sampling with continuous trajectories." + "Variable density sampling with continuous trajectories" SIAM Journal on Imaging Sciences 7, no. 4 (2014): 1962-1992. """ if len(density.shape) != 3: diff --git a/src/mrinufft/trajectories/inits/travelling_salesman.py b/src/mrinufft/trajectories/inits/travelling_salesman.py index 6b4f7764b..adfff6ca0 100644 --- a/src/mrinufft/trajectories/inits/travelling_salesman.py +++ b/src/mrinufft/trajectories/inits/travelling_salesman.py @@ -1,5 +1,7 @@ """Trajectories based on the Travelling Salesman Problem.""" +from typing import Any + import numpy as np import numpy.linalg as nl from scipy.interpolate import CubicSpline @@ -10,7 +12,7 @@ from ..tools import oversample -def _get_approx_cluster_sizes(nb_total, nb_clusters): +def _get_approx_cluster_sizes(nb_total: int, nb_clusters: int) -> np.ndarray: # Give a list of cluster sizes close to sqrt(`nb_total`) cluster_sizes = round(nb_total / nb_clusters) * np.ones(nb_clusters).astype(int) delta_sum = nb_total - np.sum(cluster_sizes) @@ -18,7 +20,7 @@ def _get_approx_cluster_sizes(nb_total, nb_clusters): return cluster_sizes -def _sort_by_coordinate(array, coord): +def _sort_by_coordinate(array: np.ndarray, coord: str) -> np.ndarray: # Sort a list of N-D locations by a Cartesian/spherical coordinate if array.shape[-1] < 3 and coord.lower() in ["z", "theta"]: raise ValueError( @@ -47,8 +49,12 @@ def _sort_by_coordinate(array, coord): def _cluster_by_coordinate( - locations, nb_clusters, cluster_by, second_cluster_by=None, sort_by=None -): + locations: np.ndarray, + nb_clusters: int, + cluster_by: str, + second_cluster_by: str | None = None, + sort_by: str | None = None, +) -> np.ndarray: # Cluster approximately a list of N-D locations by Cartesian/spherical coordinates # Gather dimension variables nb_dims = locations.shape[-1] @@ -87,17 +93,17 @@ def _cluster_by_coordinate( def _initialize_ND_travelling_salesman( - Nc, - Ns, - density, - first_cluster_by=None, - second_cluster_by=None, - sort_by=None, - tsp_tol=1e-8, + Nc: int, + Ns: int, + density: np.ndarray, + first_cluster_by: str | None = None, + second_cluster_by: str | None = None, + sort_by: str | None = None, + tsp_tol: float = 1e-8, *, - verbose=False, - **sampling_kwargs, -): + verbose: bool = False, + **sampling_kwargs: Any, +) -> np.ndarray: # Check arguments validity if Nc * Ns > np.prod(density.shape): raise ValueError("`density` array not large enough to pick `Nc` * `Ns` points.") @@ -134,17 +140,17 @@ def _initialize_ND_travelling_salesman( def initialize_2D_travelling_salesman( - Nc, - Ns, - density, - first_cluster_by=None, - second_cluster_by=None, - sort_by=None, - tsp_tol=1e-8, + Nc: int, + Ns: int, + density: np.ndarray, + first_cluster_by: str | None = None, + second_cluster_by: str | None = None, + sort_by: str | None = None, + tsp_tol: float = 1e-8, *, - verbose=False, - **sampling_kwargs, -): + verbose: bool = False, + **sampling_kwargs: Any, +) -> np.ndarray: """ Initialize a 2D trajectory using a Travelling Salesman Problem (TSP)-based path. @@ -192,7 +198,7 @@ def initialize_2D_travelling_salesman( ---------- .. [Cha+14] Chauffert, Nicolas, Philippe Ciuciu, Jonas Kahn, and Pierre Weiss. - "Variable density sampling with continuous trajectories." + "Variable density sampling with continuous trajectories" SIAM Journal on Imaging Sciences 7, no. 4 (2014): 1962-1992. """ if len(density.shape) != 2: @@ -211,17 +217,17 @@ def initialize_2D_travelling_salesman( def initialize_3D_travelling_salesman( - Nc, - Ns, - density, - first_cluster_by=None, - second_cluster_by=None, - sort_by=None, - tsp_tol=1e-8, + Nc: int, + Ns: int, + density: np.ndarray, + first_cluster_by: str | None = None, + second_cluster_by: str | None = None, + sort_by: str | None = None, + tsp_tol: float = 1e-8, *, - verbose=False, - **sampling_kwargs, -): + verbose: bool = False, + **sampling_kwargs: Any, +) -> np.ndarray: """ Initialize a 3D trajectory using a Travelling Salesman Problem (TSP)-based path. diff --git a/src/mrinufft/trajectories/maths/fibonacci.py b/src/mrinufft/trajectories/maths/fibonacci.py index d7cc20a88..69e621210 100644 --- a/src/mrinufft/trajectories/maths/fibonacci.py +++ b/src/mrinufft/trajectories/maths/fibonacci.py @@ -3,7 +3,7 @@ import numpy as np -def is_from_fibonacci_sequence(n): +def is_from_fibonacci_sequence(n: int) -> bool: """Check if an integer belongs to the Fibonacci sequence. An integer belongs to the Fibonacci sequence if either @@ -21,14 +21,14 @@ def is_from_fibonacci_sequence(n): Whether or not ``n`` belongs to the Fibonacci sequence. """ - def _is_perfect_square(n): + def _is_perfect_square(n: int) -> bool: r = int(np.sqrt(n)) return r * r == n return _is_perfect_square(5 * n**2 + 4) or _is_perfect_square(5 * n**2 - 4) -def get_closest_fibonacci_number(x): +def get_closest_fibonacci_number(x: float) -> int: """Provide the closest Fibonacci number. Parameters @@ -52,7 +52,7 @@ def get_closest_fibonacci_number(x): return xf -def generate_fibonacci_lattice(nb_points, epsilon=0.25): +def generate_fibonacci_lattice(nb_points: int, epsilon: float = 0.25) -> np.ndarray: """Generate 2D Cartesian coordinates using the Fibonacci lattice. Place 2D points over a 1x1 square following the Fibonacci lattice. @@ -78,7 +78,7 @@ def generate_fibonacci_lattice(nb_points, epsilon=0.25): return fibonacci_square -def generate_fibonacci_circle(nb_points, epsilon=0.25): +def generate_fibonacci_circle(nb_points: int, epsilon: float = 0.25) -> np.ndarray: """Generate 2D Cartesian coordinates shaped as Fibonacci spirals. Place 2D points structured as Fibonacci spirals by distorting @@ -106,7 +106,7 @@ def generate_fibonacci_circle(nb_points, epsilon=0.25): return fibonacci_circle -def generate_fibonacci_sphere(nb_points, epsilon=0.25): +def generate_fibonacci_sphere(nb_points: int, epsilon: float = 0.25) -> np.ndarray: """Generate 3D Cartesian coordinates as a Fibonacci sphere. Place 3D points almost evenly over a sphere surface of radius diff --git a/src/mrinufft/trajectories/maths/primes.py b/src/mrinufft/trajectories/maths/primes.py index 3a9df50cb..3ed36fb5d 100644 --- a/src/mrinufft/trajectories/maths/primes.py +++ b/src/mrinufft/trajectories/maths/primes.py @@ -3,7 +3,9 @@ import numpy as np -def compute_coprime_factors(Nc, length, start=1, update=1): +def compute_coprime_factors( + Nc: int, length: int, start: int = 1, update: int = 1 +) -> list[int]: """Compute a list of coprime factors of Nc. Parameters @@ -19,7 +21,7 @@ def compute_coprime_factors(Nc, length, start=1, update=1): Returns ------- - list + list[int] List of coprime factors of Nc. """ count = start diff --git a/src/mrinufft/trajectories/maths/rotations.py b/src/mrinufft/trajectories/maths/rotations.py index 568780786..5bd165701 100644 --- a/src/mrinufft/trajectories/maths/rotations.py +++ b/src/mrinufft/trajectories/maths/rotations.py @@ -4,7 +4,7 @@ import numpy.linalg as nl -def R2D(theta): +def R2D(theta: float) -> np.ndarray: """Initialize 2D rotation matrix. Parameters @@ -20,7 +20,7 @@ def R2D(theta): return np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) -def Rx(theta): +def Rx(theta: float) -> np.ndarray: """Initialize 3D rotation matrix around x axis. Parameters @@ -42,7 +42,7 @@ def Rx(theta): ) -def Ry(theta): +def Ry(theta: float) -> np.ndarray: """Initialize 3D rotation matrix around y axis. Parameters @@ -64,7 +64,7 @@ def Ry(theta): ) -def Rz(theta): +def Rz(theta: float) -> np.ndarray: """Initialize 3D rotation matrix around z axis. Parameters @@ -86,7 +86,9 @@ def Rz(theta): ) -def Rv(v1, v2, normalize=True, eps=1e-8): +def Rv( + v1: np.ndarray, v2: np.ndarray, eps: float = 1e-8, *, normalize: bool = True +) -> np.ndarray: """Initialize 3D rotation matrix from two vectors. Initialize a 3D rotation matrix from two vectors using Rodrigues's rotation @@ -101,6 +103,8 @@ def Rv(v1, v2, normalize=True, eps=1e-8): Source vector. v2 : np.ndarray Target vector. + eps : float, optional + Tolerance to consider two vectors as colinear. The default is 1e-8. normalize : bool, optional Normalize the vectors. The default is True. @@ -122,7 +126,7 @@ def Rv(v1, v2, normalize=True, eps=1e-8): return np.identity(3) + cross_matrix + cross_matrix @ cross_matrix / (1 + cos_theta) -def Ra(vector, theta): +def Ra(vector: np.ndarray, theta: float) -> np.ndarray: """Initialize 3D rotation matrix around an arbitrary vector. Initialize a 3D rotation matrix to rotate around `vector` by an angle `theta`. diff --git a/src/mrinufft/trajectories/maths/tsp_solver.py b/src/mrinufft/trajectories/maths/tsp_solver.py index b4bcf859d..7bad67662 100644 --- a/src/mrinufft/trajectories/maths/tsp_solver.py +++ b/src/mrinufft/trajectories/maths/tsp_solver.py @@ -3,27 +3,28 @@ import numpy as np -def solve_tsp_with_2opt(locations, improvement_threshold=1e-8): +def solve_tsp_with_2opt( + locations: np.ndarray, improvement_threshold: float = 1e-8 +) -> np.ndarray: """Solve the TSP problem using a 2-opt approach. A sub-optimal solution to the Travelling Salesman Problem (TSP) is provided using the 2-opt approach in O(n²) where chunks of an initially random route are reversed, and selected if the - total distance is reduced. As a result the route solution + total distance is reduced. As a result, the route solution does not cross its own path in 2D. Parameters ---------- - locations : array_like - An array of N points with shape (N, D) with D - the space dimension. - improvement_threshold: float - Threshold used as progress criterion to stop the optimization - process. + locations : np.ndarray + An array of N points with shape (N, D) where D is the space dimension. + improvement_threshold : float, optional + Threshold used as progress criterion to stop the optimization process. + The default is 1e-8. Returns ------- - array_like + np.ndarray The new positions order of shape (N,). """ route = np.arange(locations.shape[0]) diff --git a/src/mrinufft/trajectories/sampling.py b/src/mrinufft/trajectories/sampling.py index 14ed18fc0..d0cceb795 100644 --- a/src/mrinufft/trajectories/sampling.py +++ b/src/mrinufft/trajectories/sampling.py @@ -50,7 +50,7 @@ def sample_from_density( ---------- .. [Cha+14] Chauffert, Nicolas, Philippe Ciuciu, Jonas Kahn, and Pierre Weiss. - "Variable density sampling with continuous trajectories." + "Variable density sampling with continuous trajectories" SIAM Journal on Imaging Sciences 7, no. 4 (2014): 1962-1992. """ try: diff --git a/src/mrinufft/trajectories/tools.py b/src/mrinufft/trajectories/tools.py index cc099c07a..fd4f041c0 100644 --- a/src/mrinufft/trajectories/tools.py +++ b/src/mrinufft/trajectories/tools.py @@ -1,5 +1,7 @@ """Functions to manipulate/modify trajectories.""" +from typing import Any, Callable + import numpy as np from scipy.interpolate import CubicSpline, interp1d @@ -11,12 +13,18 @@ ################ -def stack(trajectory, nb_stacks, z_tilt=None, hard_bounded=True): +def stack( + trajectory: np.ndarray, + nb_stacks: int, + z_tilt: str | None = None, + *, + hard_bounded: bool = True, +) -> np.ndarray: """Stack 2D or 3D trajectories over the :math:`k_z`-axis. Parameters ---------- - trajectory : array_like + trajectory : np.ndarray Trajectory in 2D or 3D to stack. nb_stacks : int Number of stacks repeating the provided trajectory. @@ -27,7 +35,7 @@ def stack(trajectory, nb_stacks, z_tilt=None, hard_bounded=True): Returns ------- - array_like + np.ndarray Stacked trajectory. """ # Check dimensionality and initialize output @@ -55,12 +63,18 @@ def stack(trajectory, nb_stacks, z_tilt=None, hard_bounded=True): return new_trajectory.reshape(nb_stacks * Nc, Ns, 3) -def rotate(trajectory, nb_rotations, x_tilt=None, y_tilt=None, z_tilt=None): +def rotate( + trajectory: np.ndarray, + nb_rotations: int, + x_tilt: str | None = None, + y_tilt: str | None = None, + z_tilt: str | None = None, +) -> np.ndarray: """Rotate 2D or 3D trajectories over the different axes. Parameters ---------- - trajectory : array_like + trajectory : np.ndarray Trajectory in 2D or 3D to rotate. nb_rotations : int Number of rotations repeating the provided trajectory. @@ -73,7 +87,7 @@ def rotate(trajectory, nb_rotations, x_tilt=None, y_tilt=None, z_tilt=None): Returns ------- - array_like + np.ndarray Rotated trajectory. """ # Check dimensionality and initialize output @@ -96,18 +110,18 @@ def rotate(trajectory, nb_rotations, x_tilt=None, y_tilt=None, z_tilt=None): def precess( - trajectory, - nb_rotations, - tilt="golden", - half_sphere=False, - partition="axial", - axis=None, -): + trajectory: np.ndarray, + nb_rotations: int, + tilt: str = "golden", + half_sphere: bool = False, + partition: str = "axial", + axis: int | np.ndarray | None = None, +) -> np.ndarray: """Rotate trajectories as a precession around the :math:`k_z`-axis. Parameters ---------- - trajectory : array_like + trajectory : np.ndarray Trajectory in 2D or 3D to rotate. nb_rotations : int Number of rotations repeating the provided trajectory while precessing. @@ -122,7 +136,7 @@ def precess( Partition type between an "axial" or "polar" split of the :math:`k_z`-axis, designating whether the axis should be fragmented by radius or angle respectively, by default "axial". - axis : int, array_like, optional + axis : int, np.ndarray, optional Axis selected for alignment reference when rotating the trajectory around the :math:`k_z`-axis, generally corresponding to the shot direction for single shot ``trajectory`` inputs. It can either @@ -132,7 +146,7 @@ def precess( Returns ------- - array_like + np.ndarray Precessed trajectory. """ # Check for partition option error @@ -175,18 +189,18 @@ def precess( def conify( - trajectory, - nb_cones, - z_tilt=None, - in_out=False, - max_angle=np.pi / 2, - borderless=True, -): + trajectory: np.ndarray, + nb_cones: int, + z_tilt: str | None = None, + in_out: bool = False, + max_angle: float = np.pi / 2, + borderless: bool = True, +) -> np.ndarray: """Distort 2D or 3D trajectories into cones along the :math:`k_z`-axis. Parameters ---------- - trajectory : array_like + trajectory : np.ndarray Trajectory to conify. nb_cones : int Number of cones repeating the provided trajectory. @@ -203,7 +217,7 @@ def conify( Returns ------- - array_like + np.ndarray Conified trajectory. """ # Check dimensionality and initialize output @@ -253,7 +267,13 @@ def conify( return new_trajectory -def epify(trajectory, Ns_transitions, nb_trains, reverse_odd_shots=False): +def epify( + trajectory: np.ndarray, + Ns_transitions: int, + nb_trains: int, + *, + reverse_odd_shots: bool = False, +) -> np.ndarray: """Create multi-readout shots from trajectory composed of single-readouts. Assemble multiple single-readout shots together by adding transition @@ -261,7 +281,7 @@ def epify(trajectory, Ns_transitions, nb_trains, reverse_odd_shots=False): Parameters ---------- - trajectory : array_like + trajectory : np.ndarray Trajectory to change by prolonging and merging the shots. Ns_transitions : int Number of samples/steps between the merged readouts. @@ -274,7 +294,7 @@ def epify(trajectory, Ns_transitions, nb_trains, reverse_odd_shots=False): Returns ------- - array_like + np.ndarray Trajectory with fewer but longer multi-readout shots. """ Nc, Ns, Nd = trajectory.shape @@ -307,7 +327,9 @@ def epify(trajectory, Ns_transitions, nb_trains, reverse_odd_shots=False): return assembled_trajectory -def unepify(trajectory, Ns_readouts, Ns_transitions): +def unepify( + trajectory: np.ndarray, Ns_readouts: int, Ns_transitions: int +) -> np.ndarray: """Recover single-readout shots from multi-readout trajectory. Reformat an EPI-like trajectory with multiple readouts and transitions @@ -319,7 +341,7 @@ def unepify(trajectory, Ns_readouts, Ns_transitions): Parameters ---------- - trajectory : array_like + trajectory : np.ndarray Trajectory to reduce by discarding transitions between readouts. Ns_readouts : int Number of samples within a single readout. @@ -328,7 +350,7 @@ def unepify(trajectory, Ns_readouts, Ns_transitions): Returns ------- - array_like + np.ndarray Trajectory with more but shorter single shots. """ Nc, Ns, Nd = trajectory.shape @@ -349,7 +371,7 @@ def unepify(trajectory, Ns_readouts, Ns_transitions): return trajectory -def prewind(trajectory, Ns_transitions): +def prewind(trajectory: np.ndarray, Ns_transitions: int) -> np.ndarray: """Add pre-winding/positioning to the trajectory. The trajectory is extended to start before the readout @@ -358,7 +380,7 @@ def prewind(trajectory, Ns_transitions): Parameters ---------- - trajectory : array_like + trajectory : np.ndarray Trajectory to extend with rewind gradients. Ns_transitions : int Number of pre-winding/positioning steps used to leave the @@ -366,7 +388,7 @@ def prewind(trajectory, Ns_transitions): Returns ------- - array_like + np.ndarray Extended trajectory with pre-winding/positioning. """ Nc, Ns, Nd = trajectory.shape @@ -389,7 +411,7 @@ def prewind(trajectory, Ns_transitions): return assembled_trajectory -def rewind(trajectory, Ns_transitions): +def rewind(trajectory: np.ndarray, Ns_transitions: int) -> np.ndarray: """Add rewinding to the trajectory. The trajectory is extended to come back to the k-space center @@ -397,14 +419,14 @@ def rewind(trajectory, Ns_transitions): Parameters ---------- - trajectory : array_like + trajectory : np.ndarray Trajectory to extend with rewind gradients. Ns_transitions : int Number of rewinding steps used to come back to the k-space center. Returns ------- - array_like + np.ndarray Extended trajectory with rewinding. """ Nc, Ns, Nd = trajectory.shape @@ -429,7 +451,7 @@ def rewind(trajectory, Ns_transitions): return assembled_trajectory -def oversample(trajectory, new_Ns, kind="cubic"): +def oversample(trajectory: np.ndarray, new_Ns: int, kind: str = "cubic") -> np.ndarray: """ Resample a trajectory to increase the number of samples using interpolation. @@ -475,32 +497,36 @@ def oversample(trajectory, new_Ns, kind="cubic"): def stack_spherically( - trajectory_func, Nc, nb_stacks, z_tilt=None, hard_bounded=True, **traj_kwargs -): + trajectory_func: Callable[..., np.ndarray], + Nc: int, + nb_stacks: int, + z_tilt: str | None = None, + hard_bounded: bool = True, + **traj_kwargs: Any, +) -> np.ndarray: """Stack 2D or 3D trajectories over the :math:`k_z`-axis to make a sphere. Parameters ---------- - trajectory_func : function + trajectory_func : Callable[..., np.ndarray] Trajectory function that should return an array-like with the usual (Nc, Ns, Nd) size. Nc : int - Number of shots to use for the whole spherically - stacked trajectory. + Number of shots to use for the whole spherically stacked trajectory. nb_stacks : int Number of stacks of trajectories. - z_tilt : str, optional + z_tilt : str | None, optional Tilt of the stacks, by default `None`. hard_bounded : bool, optional - Whether the stacks should be strictly within the limits of the k-space, - by default `True`. - **kwargs - Trajectory initialization parameters for the function provided - with `trajectory_func`. + Whether the stacks should be strictly within the limits + of the k-space, by default `True`. + **traj_kwargs : Any + Trajectory initialization parameters for the function + provided with `trajectory_func`. Returns ------- - array_like + np.ndarray Stacked trajectory. """ # Handle argument errors @@ -558,41 +584,39 @@ def stack_spherically( def shellify( - trajectory_func, - Nc, - nb_shells, - z_tilt="golden", - hemisphere_mode="symmetric", - **traj_kwargs, -): + trajectory_func: Callable[..., np.ndarray], + Nc: int, + nb_shells: int, + z_tilt: str | float = "golden", + hemisphere_mode: str = "symmetric", + **traj_kwargs: Any, +) -> np.ndarray: """Stack 2D or 3D trajectories over the :math:`k_z`-axis to make a sphere. Parameters ---------- - trajectory_func : function - Trajectory function that should return an array-like - with the usual (Nc, Ns, Nd) size. + trajectory_func : Callable[..., np.ndarray] + Trajectory function that should return an array-like with the usual + (Nc, Ns, Nd) size. Nc : int - Number of shots to use for the whole spherically - stacked trajectory. + Number of shots to use for the whole spherically stacked trajectory. nb_shells : int - Number of shells of distorded trajectories. - z_tilt : str, float, optional + Number of shells of distorted trajectories. + z_tilt : str | float, optional Tilt of the shells, by default "golden". hemisphere_mode : str, optional - Define how the lower hemisphere should be oriented - relatively to the upper one, with "symmetric" providing - a :math:`k_x-k_y` planar symmetry by changing the polar angle, - and with "reversed" promoting continuity (for example - in spirals) by reversing the azimuth angle. + Define how the lower hemisphere should be oriented relatively to the + upper one, with "symmetric" providing a :math:`k_x-k_y` planar symmetry + by changing the polar angle, and with "reversed" promoting continuity + (for example in spirals) by reversing the azimuth angle. The default is "symmetric". - **kwargs - Trajectory initialization parameters for the function provided - with `trajectory_func`. + **traj_kwargs : Any + Trajectory initialization parameters for the function + provided with `trajectory_func`. Returns ------- - array_like + np.ndarray Concentric shell trajectory. """ # Handle argument errors @@ -658,7 +682,9 @@ def shellify( ######### -def duplicate_along_axes(trajectory, axes=(0, 1, 2)): +def duplicate_along_axes( + trajectory: np.ndarray, axes: tuple[int, ...] = (0, 1, 2) +) -> np.ndarray: """ Duplicate a trajectory along the specified axes. @@ -668,14 +694,14 @@ def duplicate_along_axes(trajectory, axes=(0, 1, 2)): Parameters ---------- - trajectory : array_like + trajectory : np.ndarray Trajectory to duplicate. - axes : tuple, optional + axes : tuple[int, ...], optional Axes along which to duplicate the trajectory, by default (0, 1, 2) Returns ------- - array_like + np.ndarray Duplicated trajectory along the specified axes. """ # Copy input trajectory along other axes @@ -694,8 +720,21 @@ def duplicate_along_axes(trajectory, axes=(0, 1, 2)): return new_trajectory -def _radialize_center_out(trajectory, nb_samples): - """Radialize a trajectory from the center to the outside.""" +def _radialize_center_out(trajectory: np.ndarray, nb_samples: int) -> np.ndarray: + """Radialize a trajectory from the center to the outside. + + Parameters + ---------- + trajectory : np.ndarray + Trajectory to radialize. + nb_samples : int + Number of samples to radialize from the center. + + Returns + ------- + np.ndarray + Radialized trajectory. + """ Nc, Ns = trajectory.shape[:2] new_trajectory = np.copy(trajectory) for i in range(Nc): @@ -706,8 +745,21 @@ def _radialize_center_out(trajectory, nb_samples): return new_trajectory -def _radialize_in_out(trajectory, nb_samples): - """Radialize a trajectory from the inside to the outside.""" +def _radialize_in_out(trajectory: np.ndarray, nb_samples: int) -> np.ndarray: + """Radialize a trajectory from the inside to the outside. + + Parameters + ---------- + trajectory : np.ndarray + Trajectory to radialize. + nb_samples : int + Number of samples to radialize from the inside out. + + Returns + ------- + np.ndarray + Radialized trajectory. + """ Nc, Ns = trajectory.shape[:2] new_trajectory = np.copy(trajectory) first, half, second = (Ns - nb_samples) // 2, Ns // 2, (Ns + nb_samples) // 2 @@ -723,20 +775,26 @@ def _radialize_in_out(trajectory, nb_samples): return new_trajectory -def radialize_center(trajectory, nb_samples, in_out=False): +def radialize_center( + trajectory: np.ndarray, nb_samples: int, in_out: bool = False +) -> np.ndarray: """Radialize a trajectory. Parameters ---------- - trajectory : array_like + trajectory : np.ndarray Trajectory to radialize. nb_samples : int Number of samples to keep. in_out : bool, optional Whether the radialization is from the inside to the outside, by default False + + Returns + ------- + np.ndarray + Radialized trajectory. """ # Make nb_samples into straight lines around the center if in_out: return _radialize_in_out(trajectory, nb_samples) - else: - return _radialize_center_out(trajectory, nb_samples) + return _radialize_center_out(trajectory, nb_samples) diff --git a/src/mrinufft/trajectories/trajectory2D.py b/src/mrinufft/trajectories/trajectory2D.py index e930e90b3..e23f12bf8 100644 --- a/src/mrinufft/trajectories/trajectory2D.py +++ b/src/mrinufft/trajectories/trajectory2D.py @@ -14,7 +14,9 @@ ##################### -def initialize_2D_radial(Nc, Ns, tilt="uniform", in_out=False): +def initialize_2D_radial( + Nc: int, Ns: int, tilt: str | float = "uniform", in_out: bool = False +) -> np.ndarray: """Initialize a 2D radial trajectory. Parameters @@ -30,7 +32,7 @@ def initialize_2D_radial(Nc, Ns, tilt="uniform", in_out=False): Returns ------- - array_like + np.ndarray 2D radial trajectory """ # Initialize a first shot @@ -47,14 +49,14 @@ def initialize_2D_radial(Nc, Ns, tilt="uniform", in_out=False): def initialize_2D_spiral( - Nc, - Ns, - tilt="uniform", - in_out=False, - nb_revolutions=1, - spiral="archimedes", - patch_center=True, -): + Nc: int, + Ns: int, + tilt: str | float = "uniform", + in_out: bool = False, + nb_revolutions: int = 1, + spiral: str | float = "archimedes", + patch_center: bool = True, +) -> np.ndarray: """Initialize a 2D algebraic spiral trajectory. A generalized function that generates algebraic spirals defined @@ -82,7 +84,7 @@ def initialize_2D_spiral( Returns ------- - array_like + np.ndarray 2D spiral trajectory Raises @@ -144,7 +146,9 @@ def _update_parameters(single_shot, angles, radius, spiral_power): return trajectory -def initialize_2D_fibonacci_spiral(Nc, Ns, spiral_reduction=1, patch_center=True): +def initialize_2D_fibonacci_spiral( + Nc: int, Ns: int, spiral_reduction: float = 1, patch_center: bool = True +) -> np.ndarray: """Initialize a 2D Fibonacci spiral trajectory. A non-algebraic spiral trajectory based on the Fibonacci sequence, @@ -168,7 +172,7 @@ def initialize_2D_fibonacci_spiral(Nc, Ns, spiral_reduction=1, patch_center=True Returns ------- - array_like + np.ndarray 2D Fibonacci spiral trajectory References @@ -213,7 +217,14 @@ def initialize_2D_fibonacci_spiral(Nc, Ns, spiral_reduction=1, patch_center=True return trajectory -def initialize_2D_cones(Nc, Ns, tilt="uniform", in_out=False, nb_zigzags=5, width=1): +def initialize_2D_cones( + Nc: int, + Ns: int, + tilt: str = "uniform", + in_out: bool = False, + nb_zigzags: float = 5, + width: float = 1, +) -> np.ndarray: """Initialize a 2D cone trajectory. Parameters @@ -233,7 +244,7 @@ def initialize_2D_cones(Nc, Ns, tilt="uniform", in_out=False, nb_zigzags=5, widt Returns ------- - array_like + np.ndarray 2D cone trajectory """ @@ -253,8 +264,13 @@ def initialize_2D_cones(Nc, Ns, tilt="uniform", in_out=False, nb_zigzags=5, widt def initialize_2D_sinusoide( - Nc, Ns, tilt="uniform", in_out=False, nb_zigzags=5, width=1 -): + Nc: int, + Ns: int, + tilt: str | float = "uniform", + in_out: bool = False, + nb_zigzags: float = 5, + width: float = 1, +) -> np.ndarray: """Initialize a 2D sinusoide trajectory. Parameters @@ -274,7 +290,7 @@ def initialize_2D_sinusoide( Returns ------- - array_like + np.ndarray 2D sinusoide trajectory """ @@ -293,7 +309,7 @@ def initialize_2D_sinusoide( return trajectory -def initialize_2D_propeller(Nc, Ns, nb_strips): +def initialize_2D_propeller(Nc: int, Ns: int, nb_strips: int) -> np.ndarray: """Initialize a 2D PROPELLER trajectory, as proposed in [Pip99]_. The PROPELLER trajectory is generally used along a specific @@ -341,7 +357,7 @@ def initialize_2D_propeller(Nc, Ns, nb_strips): return KMAX * trajectory -def initialize_2D_rings(Nc, Ns, nb_rings): +def initialize_2D_rings(Nc: int, Ns: int, nb_rings: int) -> np.ndarray: """Initialize a 2D ring trajectory, as proposed in [HHN08]_. Parameters @@ -355,7 +371,7 @@ def initialize_2D_rings(Nc, Ns, nb_rings): Returns ------- - array_like + np.ndarray 2D ring trajectory References @@ -387,7 +403,9 @@ def initialize_2D_rings(Nc, Ns, nb_rings): return KMAX * np.array(trajectory) -def initialize_2D_rosette(Nc, Ns, in_out=False, coprime_index=0): +def initialize_2D_rosette( + Nc: int, Ns: int, in_out: bool = False, coprime_index: int = 0 +) -> np.ndarray: """Initialize a 2D rosette trajectory. Parameters @@ -403,7 +421,7 @@ def initialize_2D_rosette(Nc, Ns, in_out=False, coprime_index=0): Returns ------- - array_like + np.ndarray 2D rosette trajectory """ @@ -428,7 +446,9 @@ def initialize_2D_rosette(Nc, Ns, in_out=False, coprime_index=0): return trajectory -def initialize_2D_polar_lissajous(Nc, Ns, in_out=False, nb_segments=1, coprime_index=0): +def initialize_2D_polar_lissajous( + Nc: int, Ns: int, in_out: bool = False, nb_segments: int = 1, coprime_index: int = 0 +) -> np.ndarray: """Initialize a 2D polar Lissajous trajectory. Parameters @@ -446,7 +466,7 @@ def initialize_2D_polar_lissajous(Nc, Ns, in_out=False, nb_segments=1, coprime_i Returns ------- - array_like + np.ndarray 2D polar Lissajous trajectory """ # Adapt the parameters to subcases @@ -482,7 +502,7 @@ def initialize_2D_polar_lissajous(Nc, Ns, in_out=False, nb_segments=1, coprime_i ######################### -def initialize_2D_lissajous(Nc, Ns, density=1): +def initialize_2D_lissajous(Nc: int, Ns: int, density: float = 1) -> np.ndarray: """Initialize a 2D Lissajous trajectory. Parameters @@ -496,7 +516,7 @@ def initialize_2D_lissajous(Nc, Ns, density=1): Returns ------- - array_like + np.ndarray 2D Lissajous trajectory """ # Define the whole curve in Cartesian coordinates @@ -512,7 +532,9 @@ def initialize_2D_lissajous(Nc, Ns, density=1): return trajectory -def initialize_2D_waves(Nc, Ns, nb_zigzags=5, width=1): +def initialize_2D_waves( + Nc: int, Ns: int, nb_zigzags: float = 5, width: float = 1 +) -> np.ndarray: """Initialize a 2D waves trajectory. Parameters @@ -528,7 +550,7 @@ def initialize_2D_waves(Nc, Ns, nb_zigzags=5, width=1): Returns ------- - array_like + np.ndarray 2D waves trajectory """ # Initialize a first shot diff --git a/src/mrinufft/trajectories/trajectory3D.py b/src/mrinufft/trajectories/trajectory3D.py index cdcf4a0ad..a6cbd5e5a 100644 --- a/src/mrinufft/trajectories/trajectory3D.py +++ b/src/mrinufft/trajectories/trajectory3D.py @@ -24,7 +24,9 @@ ############## -def initialize_3D_phyllotaxis_radial(Nc, Ns, nb_interleaves=1, in_out=False): +def initialize_3D_phyllotaxis_radial( + Nc: int, Ns: int, nb_interleaves: int = 1, in_out: bool = False +) -> np.ndarray: """Initialize 3D radial trajectories with phyllotactic structure. The radial shots are oriented according to a Fibonacci sphere @@ -53,7 +55,7 @@ def initialize_3D_phyllotaxis_radial(Nc, Ns, nb_interleaves=1, in_out=False): Returns ------- - array_like + np.ndarray 3D phyllotaxis radial trajectory References @@ -71,7 +73,9 @@ def initialize_3D_phyllotaxis_radial(Nc, Ns, nb_interleaves=1, in_out=False): return trajectory -def initialize_3D_golden_means_radial(Nc, Ns, in_out=False): +def initialize_3D_golden_means_radial( + Nc: int, Ns: int, in_out: bool = False +) -> np.ndarray: """Initialize 3D radial trajectories with golden means-based structure. The radial shots are oriented using multidimensional golden means, @@ -95,7 +99,7 @@ def initialize_3D_golden_means_radial(Nc, Ns, in_out=False): Returns ------- - array_like + np.ndarray 3D golden means radial trajectory References @@ -123,7 +127,9 @@ def initialize_3D_golden_means_radial(Nc, Ns, in_out=False): return KMAX * trajectory -def initialize_3D_wong_radial(Nc, Ns, nb_interleaves=1, in_out=False): +def initialize_3D_wong_radial( + Nc: int, Ns: int, nb_interleaves: int = 1, in_out: bool = False +) -> np.ndarray: """Initialize 3D radial trajectories with a spiral structure. The radial shots are oriented according to an archimedean spiral @@ -149,7 +155,7 @@ def initialize_3D_wong_radial(Nc, Ns, nb_interleaves=1, in_out=False): Returns ------- - array_like + np.ndarray 3D Wong radial trajectory References @@ -181,7 +187,9 @@ def initialize_3D_wong_radial(Nc, Ns, nb_interleaves=1, in_out=False): return trajectory -def initialize_3D_park_radial(Nc, Ns, nb_interleaves=1, in_out=False): +def initialize_3D_park_radial( + Nc: int, Ns: int, nb_interleaves: int = 1, in_out: bool = False +) -> np.ndarray: """Initialize 3D radial trajectories with a spiral structure. The radial shots are oriented according to an archimedean spiral @@ -208,7 +216,7 @@ def initialize_3D_park_radial(Nc, Ns, nb_interleaves=1, in_out=False): Returns ------- - array_like + np.ndarray 3D Park radial trajectory References @@ -233,8 +241,14 @@ def initialize_3D_park_radial(Nc, Ns, nb_interleaves=1, in_out=False): def initialize_3D_cones( - Nc, Ns, tilt="golden", in_out=False, nb_zigzags=5, spiral="archimedes", width=1 -): + Nc: int, + Ns: int, + tilt: str | float = "golden", + in_out: bool = False, + nb_zigzags: float = 5, + spiral: str | float = "archimedes", + width: float = 1, +) -> np.ndarray: """Initialize 3D trajectories with cones. Initialize a trajectory consisting of 3D cones duplicated @@ -264,7 +278,7 @@ def initialize_3D_cones( Returns ------- - array_like + np.ndarray 3D cones trajectory References @@ -308,15 +322,15 @@ def initialize_3D_cones( def initialize_3D_floret( - Nc, - Ns, - in_out=False, - nb_revolutions=1, - spiral="fermat", - cone_tilt="golden", - max_angle=np.pi / 2, - axes=(2,), -): + Nc: int, + Ns: int, + in_out: bool = False, + nb_revolutions: float = 1, + spiral: str | float = "fermat", + cone_tilt: str | float = "golden", + max_angle: float = np.pi / 2, + axes: tuple[int, ...] = (2,), +) -> np.ndarray: """Initialize 3D trajectories with FLORET. This implementation is based on the work from [Pip+11]_. @@ -346,7 +360,7 @@ def initialize_3D_floret( Returns ------- - array_like + np.ndarray 3D FLORET trajectory References @@ -387,14 +401,14 @@ def initialize_3D_floret( def initialize_3D_wave_caipi( - Nc, - Ns, - nb_revolutions=5, - width=1, - packing="triangular", - shape="square", - spacing=(1, 1), -): + Nc: int, + Ns: int, + nb_revolutions: float = 5, + width: float = 1, + packing: str = "triangular", + shape: str | float = "square", + spacing: tuple[int, int] = (1, 1), +) -> np.ndarray: """Initialize 3D trajectories with Wave-CAIPI. This implementation is based on the work from [Bil+15]_. @@ -428,7 +442,7 @@ def initialize_3D_wave_caipi( Returns ------- - array_like + np.ndarray 3D wave-CAIPI trajectory References @@ -506,14 +520,14 @@ def initialize_3D_wave_caipi( def initialize_3D_seiffert_spiral( - Nc, - Ns, - curve_index=0.2, - nb_revolutions=1, - axis_tilt="golden", - spiral_tilt="golden", - in_out=False, -): + Nc: int, + Ns: int, + curve_index: float = 0.2, + nb_revolutions: float = 1, + axis_tilt: str | float = "golden", + spiral_tilt: str | float = "golden", + in_out: bool = False, +) -> np.ndarray: """Initialize 3D trajectories with modulated Seiffert spirals. Initially introduced in [SMR18]_, but also proposed later as "Yarnball" @@ -543,7 +557,7 @@ def initialize_3D_seiffert_spiral( Returns ------- - array_like + np.ndarray 3D Seiffert spiral trajectory References @@ -611,8 +625,13 @@ def initialize_3D_seiffert_spiral( def initialize_3D_helical_shells( - Nc, Ns, nb_shells, spiral_reduction=1, shell_tilt="intergaps", shot_tilt="uniform" -): + Nc: int, + Ns: int, + nb_shells: int, + spiral_reduction: float = 1, + shell_tilt: str = "intergaps", + shot_tilt: str = "uniform", +) -> np.ndarray: """Initialize 3D trajectories with helical shells. The implementation follows the proposition from [YRB06]_ @@ -635,7 +654,7 @@ def initialize_3D_helical_shells( Returns ------- - array_like + np.ndarray 3D helical shell trajectory References @@ -694,8 +713,12 @@ def initialize_3D_helical_shells( def initialize_3D_annular_shells( - Nc, Ns, nb_shells, shell_tilt=np.pi, ring_tilt=np.pi / 2 -): + Nc: int, + Ns: int, + nb_shells: int, + shell_tilt: float = np.pi, + ring_tilt: float = np.pi / 2, +) -> np.ndarray: """Initialize 3D trajectories with annular shells. An exclusive trajectory inspired from the work proposed in [HM11]_. @@ -715,7 +738,7 @@ def initialize_3D_annular_shells( Returns ------- - array_like + np.ndarray 3D annular shell trajectory References @@ -807,14 +830,14 @@ def initialize_3D_annular_shells( def initialize_3D_seiffert_shells( - Nc, - Ns, - nb_shells, - curve_index=0.5, - nb_revolutions=1, - shell_tilt="uniform", - shot_tilt="uniform", -): + Nc: int, + Ns: int, + nb_shells: int, + curve_index: float = 0.5, + nb_revolutions: float = 1, + shell_tilt: str = "uniform", + shot_tilt: str = "uniform", +) -> np.ndarray: """Initialize 3D trajectories with Seiffert shells. The implementation is based on work from [Er00]_ and [Br09]_, @@ -841,7 +864,7 @@ def initialize_3D_seiffert_shells( Returns ------- - array_like + np.ndarray 3D Seiffert shell trajectory References @@ -905,15 +928,15 @@ def initialize_3D_seiffert_shells( def initialize_3D_turbine( - Nc, - Ns_readouts, - Ns_transitions, - nb_blades, - blade_tilt="uniform", - nb_trains="auto", - skip_factor=1, - in_out=True, -): + Nc: int, + Ns_readouts: int, + Ns_transitions: int, + nb_blades: int, + blade_tilt: str = "uniform", + nb_trains: int | str = "auto", + skip_factor: int = 1, + in_out: bool = True, +) -> np.ndarray: """Initialize 3D TURBINE trajectory. This is an implementation of the TURBINE (Trajectory Using Radially @@ -951,7 +974,7 @@ def initialize_3D_turbine( Returns ------- - array_like + np.ndarray 3D TURBINE trajectory References @@ -1013,17 +1036,17 @@ def initialize_3D_turbine( def initialize_3D_repi( - Nc, - Ns_readouts, - Ns_transitions, - nb_blades, - nb_blade_revolutions=0, - blade_tilt="uniform", - nb_spiral_revolutions=0, - spiral="archimedes", - nb_trains="auto", - in_out=True, -): + Nc: int, + Ns_readouts: int, + Ns_transitions: int, + nb_blades: int, + nb_blade_revolutions: float = 0, + blade_tilt: str = "uniform", + nb_spiral_revolutions: float = 0, + spiral: str = "archimedes", + nb_trains: int | str = "auto", + in_out: bool = True, +) -> np.ndarray: """Initialize 3D REPI trajectory. This is an implementation of the REPI (Radial Echo Planar Imaging) @@ -1067,7 +1090,7 @@ def initialize_3D_repi( Returns ------- - array_like + np.ndarray 3D REPI trajectory References diff --git a/src/mrinufft/trajectories/utils.py b/src/mrinufft/trajectories/utils.py index e2c9ff4fb..a9023d2bc 100644 --- a/src/mrinufft/trajectories/utils.py +++ b/src/mrinufft/trajectories/utils.py @@ -150,10 +150,10 @@ class Packings(str, Enum, metaclass=CaseInsensitiveEnumMeta): def normalize_trajectory( - trajectory, - norm_factor=KMAX, - resolution=DEFAULT_RESOLUTION, -): + trajectory: np.ndarray, + norm_factor: float = KMAX, + resolution: float | np.ndarray = DEFAULT_RESOLUTION, +) -> np.ndarray: """Normalize an un-normalized/natural trajectory for NUFFT use. Parameters @@ -176,10 +176,10 @@ def normalize_trajectory( def unnormalize_trajectory( - trajectory, - norm_factor=KMAX, - resolution=DEFAULT_RESOLUTION, -): + trajectory: np.ndarray, + norm_factor: float = KMAX, + resolution: float | np.ndarray = DEFAULT_RESOLUTION, +) -> np.ndarray: """Un-normalize a NUFFT-normalized trajectory. Parameters @@ -202,13 +202,13 @@ def unnormalize_trajectory( def convert_trajectory_to_gradients( - trajectory, - norm_factor=KMAX, - resolution=DEFAULT_RESOLUTION, - raster_time=DEFAULT_RASTER_TIME, - gamma=Gammas.HYDROGEN, - get_final_positions=False, -): + trajectory: np.ndarray, + norm_factor: float = KMAX, + resolution: float | np.ndarray = DEFAULT_RESOLUTION, + raster_time: float = DEFAULT_RASTER_TIME, + gamma: float = Gammas.HYDROGEN, + get_final_positions: bool = False, +) -> tuple[np.ndarray, ...]: """Derive a normalized trajectory over time to provide gradients. Parameters @@ -249,13 +249,13 @@ def convert_trajectory_to_gradients( def convert_gradients_to_trajectory( - gradients, - initial_positions=None, - norm_factor=KMAX, - resolution=DEFAULT_RESOLUTION, - raster_time=DEFAULT_RASTER_TIME, - gamma=Gammas.HYDROGEN, -): + gradients: np.ndarray, + initial_positions: np.ndarray | None = None, + norm_factor: float = KMAX, + resolution: float | np.ndarray = DEFAULT_RESOLUTION, + raster_time: float = DEFAULT_RASTER_TIME, + gamma: float = Gammas.HYDROGEN, +) -> np.ndarray: """Integrate gradients over time to provide a normalized trajectory. Parameters @@ -299,9 +299,9 @@ def convert_gradients_to_trajectory( def convert_gradients_to_slew_rates( - gradients, - raster_time=DEFAULT_RASTER_TIME, -): + gradients: np.ndarray, + raster_time: float = DEFAULT_RASTER_TIME, +) -> tuple[np.ndarray, np.ndarray]: """Derive the gradients over time to provide slew rates. Parameters @@ -327,10 +327,10 @@ def convert_gradients_to_slew_rates( def convert_slew_rates_to_gradients( - slewrates, - initial_gradients=None, - raster_time=DEFAULT_RASTER_TIME, -): + slewrates: np.ndarray, + initial_gradients: np.ndarray | None = None, + raster_time: float = DEFAULT_RASTER_TIME, +) -> np.ndarray: """Integrate slew rates over time to provide gradients. Parameters @@ -362,12 +362,12 @@ def convert_slew_rates_to_gradients( def compute_gradients_and_slew_rates( - trajectory, - norm_factor=KMAX, - resolution=DEFAULT_RESOLUTION, - raster_time=DEFAULT_RASTER_TIME, - gamma=Gammas.HYDROGEN, -): + trajectory: np.ndarray, + norm_factor: float = KMAX, + resolution: float | np.ndarray = DEFAULT_RESOLUTION, + raster_time: float = DEFAULT_RASTER_TIME, + gamma: float = Gammas.HYDROGEN, +) -> tuple[np.ndarray, np.ndarray]: """Compute the gradients and slew rates from a normalized trajectory. Parameters @@ -411,8 +411,12 @@ def compute_gradients_and_slew_rates( def check_hardware_constraints( - gradients, slewrates, gmax=DEFAULT_GMAX, smax=DEFAULT_SMAX, order=None -): + gradients: np.ndarray, + slewrates: np.ndarray, + gmax: float = DEFAULT_GMAX, + smax: float = DEFAULT_SMAX, + order: int | str | None = None, +) -> tuple[bool, float, float]: """Check if a trajectory satisfies the gradient hardware constraints. Parameters @@ -450,7 +454,7 @@ def check_hardware_constraints( ########### -def initialize_tilt(tilt, nb_partitions=1): +def initialize_tilt(tilt: str | float, nb_partitions: int = 1) -> float: r"""Initialize the tilt angle. Parameters @@ -493,7 +497,7 @@ def initialize_tilt(tilt, nb_partitions=1): raise NotImplementedError(f"Unknown tilt name: {tilt}") -def initialize_algebraic_spiral(spiral): +def initialize_algebraic_spiral(spiral: str | float) -> float: """Initialize the algebraic spiral type. Parameters @@ -511,7 +515,7 @@ def initialize_algebraic_spiral(spiral): return Spirals[spiral] -def initialize_shape_norm(shape): +def initialize_shape_norm(shape: str | float) -> float: """Initialize the norm for a given shape. Parameters From a0031a3267ef103b493f597cb41802175b20b381 Mon Sep 17 00:00:00 2001 From: Guillaume DAVAL-FREROT Date: Thu, 19 Dec 2024 23:15:00 +0100 Subject: [PATCH 2/6] Apply ruff ANN rules to trajectories --- src/mrinufft/trajectories/display.py | 22 +++---- .../trajectories/inits/random_walk.py | 8 +-- .../trajectories/inits/travelling_salesman.py | 46 ++++++------- src/mrinufft/trajectories/sampling.py | 52 +++++++++++---- src/mrinufft/trajectories/tools.py | 52 ++++++++------- src/mrinufft/trajectories/trajectory2D.py | 33 ++++++---- src/mrinufft/trajectories/trajectory3D.py | 65 ++++++++++--------- src/mrinufft/trajectories/utils.py | 17 ++--- 8 files changed, 166 insertions(+), 129 deletions(-) diff --git a/src/mrinufft/trajectories/display.py b/src/mrinufft/trajectories/display.py index 973568765..f900c3f9b 100644 --- a/src/mrinufft/trajectories/display.py +++ b/src/mrinufft/trajectories/display.py @@ -47,7 +47,7 @@ class displayConfig: """Font size for most labels and texts, by default ``18``.""" small_fontsize: int = 14 """Font size for smaller texts, by default ``14``.""" - nb_colors = 10 + nb_colors: int = 10 """Number of colors to use in the color cycle, by default ``10``.""" palette: str = "tab10" """Name of the color palette to use, by default ``"tab10"``. @@ -59,33 +59,33 @@ class displayConfig: slewrate_point_color: str = "b" """Matplotlib color for slew rate constraint points, by default ``"b"`` (blue).""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: # noqa ANN401 """Update the display configuration.""" self.update(**kwargs) - def update(self, **kwargs): + def update(self, **kwargs: Any) -> None: # noqa ANN401 """Update the display configuration.""" self._old_values = {} for key, value in kwargs.items(): self._old_values[key] = getattr(displayConfig, key) setattr(displayConfig, key, value) - def reset(self): + def reset(self) -> None: """Restore the display configuration.""" for key, value in self._old_values.items(): setattr(displayConfig, key, value) delattr(self, "_old_values") - def __enter__(self): + def __enter__(self) -> "displayConfig": """Enter the context manager.""" return self - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: # noqa ANN401 """Exit the context manager.""" self.reset() @classmethod - def get_colorlist(cls): + def get_colorlist(cls) -> list[str | np.ndarray]: """Extract a list of colors from a matplotlib palette. If the palette is continuous, the colors will be sampled from it. @@ -172,7 +172,7 @@ def display_2D_trajectory( gmax: float = DEFAULT_GMAX, smax: float = DEFAULT_SMAX, constraints_order: int | str | None = None, - **constraints_kwargs: Any, + **constraints_kwargs: float | np.ndarray, ) -> plt.Axes: """Display 2D trajectories. @@ -205,7 +205,7 @@ def display_2D_trajectory( typically 2 or `np.inf`, following the `numpy.linalg.norm` conventions on parameter `ord`. The default is None. - **kwargs + **constraints_kwargs Acquisition parameters used to check on hardware constraints, following the parameter convention from `mrinufft.trajectories.utils.compute_gradients_and_slew_rates`. @@ -545,7 +545,7 @@ def display_gradients( smax: float = DEFAULT_SMAX, constraints_order: int | str | None = None, raster_time: float = DEFAULT_RASTER_TIME, - **constraints_kwargs: Any, + **constraints_kwargs: float | np.ndarray, ) -> tuple[plt.Axes]: """Display gradients based on trajectory of any dimension. @@ -598,7 +598,7 @@ def display_gradients( Amount of time between the acquisition of two consecutive samples in ms. The default is `DEFAULT_RASTER_TIME`. - **kwargs + **constraints_kwargs Acquisition parameters used to check on hardware constraints, following the parameter convention from `mrinufft.trajectories.utils.compute_gradients_and_slew_rates`. diff --git a/src/mrinufft/trajectories/inits/random_walk.py b/src/mrinufft/trajectories/inits/random_walk.py index 5f2745068..0ad947bbd 100644 --- a/src/mrinufft/trajectories/inits/random_walk.py +++ b/src/mrinufft/trajectories/inits/random_walk.py @@ -1,6 +1,6 @@ """Trajectories based on random walks.""" -from typing import Any +from typing import Literal import numpy as np @@ -28,7 +28,7 @@ def _initialize_ND_random_walk( *, diagonals: bool = True, pseudo_random: bool = True, - **sampling_kwargs: Any, + **sampling_kwargs: Literal | bool, ) -> np.ndarray: density = density / np.sum(density) flat_density = np.copy(density.flatten()) @@ -101,7 +101,7 @@ def initialize_2D_random_walk( *, diagonals: bool = True, pseudo_random: bool = True, - **sampling_kwargs: Any, + **sampling_kwargs: Literal | bool, ) -> np.ndarray: """Initialize a 2D random walk trajectory. @@ -166,7 +166,7 @@ def initialize_3D_random_walk( *, diagonals: bool = True, pseudo_random: bool = True, - **sampling_kwargs: Any, + **sampling_kwargs: Literal | bool, ) -> np.ndarray: """Initialize a 3D random walk trajectory. diff --git a/src/mrinufft/trajectories/inits/travelling_salesman.py b/src/mrinufft/trajectories/inits/travelling_salesman.py index adfff6ca0..411534b77 100644 --- a/src/mrinufft/trajectories/inits/travelling_salesman.py +++ b/src/mrinufft/trajectories/inits/travelling_salesman.py @@ -1,6 +1,6 @@ """Trajectories based on the Travelling Salesman Problem.""" -from typing import Any +from typing import Literal import numpy as np import numpy.linalg as nl @@ -20,7 +20,7 @@ def _get_approx_cluster_sizes(nb_total: int, nb_clusters: int) -> np.ndarray: return cluster_sizes -def _sort_by_coordinate(array: np.ndarray, coord: str) -> np.ndarray: +def _sort_by_coordinate(array: np.ndarray, coord: Literal) -> np.ndarray: # Sort a list of N-D locations by a Cartesian/spherical coordinate if array.shape[-1] < 3 and coord.lower() in ["z", "theta"]: raise ValueError( @@ -51,9 +51,9 @@ def _sort_by_coordinate(array: np.ndarray, coord: str) -> np.ndarray: def _cluster_by_coordinate( locations: np.ndarray, nb_clusters: int, - cluster_by: str, - second_cluster_by: str | None = None, - sort_by: str | None = None, + cluster_by: Literal, + second_cluster_by: Literal | None = None, + sort_by: Literal | None = None, ) -> np.ndarray: # Cluster approximately a list of N-D locations by Cartesian/spherical coordinates # Gather dimension variables @@ -96,13 +96,13 @@ def _initialize_ND_travelling_salesman( Nc: int, Ns: int, density: np.ndarray, - first_cluster_by: str | None = None, - second_cluster_by: str | None = None, - sort_by: str | None = None, + first_cluster_by: Literal | None = None, + second_cluster_by: Literal | None = None, + sort_by: Literal | None = None, tsp_tol: float = 1e-8, *, verbose: bool = False, - **sampling_kwargs: Any, + **sampling_kwargs: Literal | bool, ) -> np.ndarray: # Check arguments validity if Nc * Ns > np.prod(density.shape): @@ -143,13 +143,13 @@ def initialize_2D_travelling_salesman( Nc: int, Ns: int, density: np.ndarray, - first_cluster_by: str | None = None, - second_cluster_by: str | None = None, - sort_by: str | None = None, + first_cluster_by: Literal | None = None, + second_cluster_by: Literal | None = None, + sort_by: Literal | None = None, tsp_tol: float = 1e-8, *, verbose: bool = False, - **sampling_kwargs: Any, + **sampling_kwargs: Literal | bool, ) -> np.ndarray: """ Initialize a 2D trajectory using a Travelling Salesman Problem (TSP)-based path. @@ -168,12 +168,12 @@ def initialize_2D_travelling_salesman( The number of points per cluster. density : np.ndarray A 2-dimensional density array from which points are sampled. - first_cluster_by : str, optional + first_cluster_by : Literal, optional The coordinate used to cluster points initially, by default ``None``. - second_cluster_by : str, optional + second_cluster_by : Literal, optional A secondary coordinate used for clustering within primary clusters, by default ``None``. - sort_by : str, optional + sort_by : Literal, optional The coordinate by which to order points within each cluster, by default ``None``. tsp_tol : float, optional @@ -220,13 +220,13 @@ def initialize_3D_travelling_salesman( Nc: int, Ns: int, density: np.ndarray, - first_cluster_by: str | None = None, - second_cluster_by: str | None = None, - sort_by: str | None = None, + first_cluster_by: Literal | None = None, + second_cluster_by: Literal | None = None, + sort_by: Literal | None = None, tsp_tol: float = 1e-8, *, verbose: bool = False, - **sampling_kwargs: Any, + **sampling_kwargs: Literal | bool, ) -> np.ndarray: """ Initialize a 3D trajectory using a Travelling Salesman Problem (TSP)-based path. @@ -247,12 +247,12 @@ def initialize_3D_travelling_salesman( The number of points per cluster. density : np.ndarray A 3-dimensional density array from which points are sampled. - first_cluster_by : str, optional + first_cluster_by : Literal, optional The coordinate used to cluster points initially, by default ``None``. - second_cluster_by : str, optional + second_cluster_by : Literal, optional A secondary coordinate used for clustering within primary clusters, by default ``None``. - sort_by : str, optional + sort_by : Literal, optional The coordinate by which to order points within each cluster, by default ``None``. tsp_tol : float, optional diff --git a/src/mrinufft/trajectories/sampling.py b/src/mrinufft/trajectories/sampling.py index d0cceb795..eff2be0f3 100644 --- a/src/mrinufft/trajectories/sampling.py +++ b/src/mrinufft/trajectories/sampling.py @@ -1,17 +1,24 @@ """Sampling densities and methods.""" +from typing import TYPE_CHECKING, Literal + +if TYPE_CHECKING: + import pywt as pw + import numpy as np import numpy.fft as nf -import numpy.linalg as nl -import numpy.random as nr from tqdm.auto import tqdm from .utils import KMAX def sample_from_density( - nb_samples, density, method="random", *, dim_compensation="auto" -): + nb_samples: int, + density: np.ndarray, + method: Literal = "random", + *, + dim_compensation: Literal | bool = "auto", +) -> np.ndarray: """ Sample points based on a given density distribution. @@ -81,7 +88,7 @@ def sample_from_density( density = density / np.sum(density) # Sample using specified method - rng = nr.default_rng() + rng = np.random.default_rng() if method == "random": choices = rng.choice( np.arange(max_nb_samples), @@ -110,7 +117,12 @@ def sample_from_density( return locations -def create_cutoff_decay_density(shape, cutoff, decay, resolution=None): +def create_cutoff_decay_density( + shape: tuple[int, ...], + cutoff: float, + decay: float, + resolution: np.ndarray | None = None, +) -> np.ndarray: """ Create a density with central plateau and polynomial decay. @@ -120,7 +132,7 @@ def create_cutoff_decay_density(shape, cutoff, decay, resolution=None): Parameters ---------- - shape : tuple of int + shape : tuple[int, ...] The shape of the density grid, analog to the field-of-view as opposed to ``resolution`` below. cutoff : float @@ -156,7 +168,7 @@ def create_cutoff_decay_density(shape, cutoff, decay, resolution=None): for i in range(nb_dims): differences[i] = differences[i] + 0.5 - shape[i] / 2 differences[i] = differences[i] / shape[i] / resolution[i] - distances = nl.norm(differences, axis=0) + distances = np.linalg.norm(differences, axis=0) cutoff = cutoff * np.max(differences) if cutoff else np.min(differences) density = np.ones(shape) @@ -167,7 +179,9 @@ def create_cutoff_decay_density(shape, cutoff, decay, resolution=None): return density -def create_polynomial_density(shape, decay, resolution=None): +def create_polynomial_density( + shape: tuple[int, ...], decay: float, resolution: np.ndarray | None = None +) -> np.ndarray: """ Create a density with polynomial decay from the center. @@ -191,7 +205,7 @@ def create_polynomial_density(shape, decay, resolution=None): ) -def create_energy_density(dataset): +def create_energy_density(dataset: np.ndarray) -> np.ndarray: """ Create a density based on energy in the Fourier spectrum. @@ -221,7 +235,13 @@ def create_energy_density(dataset): return density -def create_chauffert_density(shape, wavelet_basis, nb_wavelet_scales, verbose=False): +def create_chauffert_density( + shape: tuple[int, ...], + wavelet_basis: Literal | pw.Wavelet, + nb_wavelet_scales: int, + *, + verbose: bool = False, +) -> np.ndarray: """Create a density based on Chauffert's method. This is a reproduction of the proposition from [CCW13]_. @@ -231,7 +251,7 @@ def create_chauffert_density(shape, wavelet_basis, nb_wavelet_scales, verbose=Fa Parameters ---------- - shape : tuple of int + shape : tuple[int, ...] The shape of the density grid. wavelet_basis : str, pywt.Wavelet The wavelet basis to use for wavelet decomposition, either @@ -290,7 +310,11 @@ def create_chauffert_density(shape, wavelet_basis, nb_wavelet_scales, verbose=Fa return nf.ifftshift(density) -def create_fast_chauffert_density(shape, wavelet_basis, nb_wavelet_scales): +def create_fast_chauffert_density( + shape: tuple[int, ...], + wavelet_basis: Literal | pw.Wavelet, + nb_wavelet_scales: int, +) -> np.ndarray: """Create a density based on an approximated Chauffert method. This implementation is based on this @@ -306,7 +330,7 @@ def create_fast_chauffert_density(shape, wavelet_basis, nb_wavelet_scales): Parameters ---------- - shape : tuple of int + shape : tuple[int, ...] The shape of the density grid. wavelet_basis : str, pywt.Wavelet The wavelet basis to use for wavelet decomposition, either diff --git a/src/mrinufft/trajectories/tools.py b/src/mrinufft/trajectories/tools.py index fd4f041c0..334e245f2 100644 --- a/src/mrinufft/trajectories/tools.py +++ b/src/mrinufft/trajectories/tools.py @@ -1,6 +1,6 @@ """Functions to manipulate/modify trajectories.""" -from typing import Any, Callable +from typing import Any, Callable, Literal import numpy as np from scipy.interpolate import CubicSpline, interp1d @@ -16,7 +16,7 @@ def stack( trajectory: np.ndarray, nb_stacks: int, - z_tilt: str | None = None, + z_tilt: Literal | float | None = None, *, hard_bounded: bool = True, ) -> np.ndarray: @@ -28,7 +28,7 @@ def stack( Trajectory in 2D or 3D to stack. nb_stacks : int Number of stacks repeating the provided trajectory. - z_tilt : str, optional + z_tilt : Literal, float, optional Tilt of the stacks, by default `None`. hard_bounded : bool, optional Whether the stacks should be strictly within the limits of the k-space. @@ -66,9 +66,9 @@ def stack( def rotate( trajectory: np.ndarray, nb_rotations: int, - x_tilt: str | None = None, - y_tilt: str | None = None, - z_tilt: str | None = None, + x_tilt: Literal | float | None = None, + y_tilt: Literal | float | None = None, + z_tilt: Literal | float | None = None, ) -> np.ndarray: """Rotate 2D or 3D trajectories over the different axes. @@ -78,11 +78,11 @@ def rotate( Trajectory in 2D or 3D to rotate. nb_rotations : int Number of rotations repeating the provided trajectory. - x_tilt : str, optional + x_tilt : Literal, optional Tilt of the trajectory over the :math:`k_x`-axis, by default `None`. - y_tilt : str, optional + y_tilt : Literal, optional Tilt of the trajectory over the :math:`k_y`-axis, by default `None`. - z_tilt : str, optional + z_tilt : Literal, optional Tilt of the trajectory over the :math:`k_z`-axis, by default `None`. Returns @@ -112,9 +112,9 @@ def rotate( def precess( trajectory: np.ndarray, nb_rotations: int, - tilt: str = "golden", + tilt: Literal | float = "golden", half_sphere: bool = False, - partition: str = "axial", + partition: Literal = "axial", axis: int | np.ndarray | None = None, ) -> np.ndarray: """Rotate trajectories as a precession around the :math:`k_z`-axis. @@ -125,14 +125,14 @@ def precess( Trajectory in 2D or 3D to rotate. nb_rotations : int Number of rotations repeating the provided trajectory while precessing. - tilt : str, optional + tilt : Literal, float, optional Angle tilt between consecutive rotations around the :math:`k_z`-axis, by default "golden". half_sphere : bool, optional Whether the precession should be limited to the upper half of the k-space sphere. It is typically used for in-out trajectories or planes. - partition : str, optional + partition : Literal, optional Partition type between an "axial" or "polar" split of the :math:`k_z`-axis, designating whether the axis should be fragmented by radius or angle respectively, by default "axial". @@ -191,7 +191,7 @@ def precess( def conify( trajectory: np.ndarray, nb_cones: int, - z_tilt: str | None = None, + z_tilt: Literal | float | None = None, in_out: bool = False, max_angle: float = np.pi / 2, borderless: bool = True, @@ -204,7 +204,7 @@ def conify( Trajectory to conify. nb_cones : int Number of cones repeating the provided trajectory. - z_tilt : str, optional + z_tilt : Literal, float, optional Tilt of the trajectory over the :math:`k_z`-axis, by default `None`. in_out : bool, optional Whether to account for the in-out nature of some trajectories @@ -451,7 +451,9 @@ def rewind(trajectory: np.ndarray, Ns_transitions: int) -> np.ndarray: return assembled_trajectory -def oversample(trajectory: np.ndarray, new_Ns: int, kind: str = "cubic") -> np.ndarray: +def oversample( + trajectory: np.ndarray, new_Ns: int, kind: Literal = "cubic" +) -> np.ndarray: """ Resample a trajectory to increase the number of samples using interpolation. @@ -462,7 +464,7 @@ def oversample(trajectory: np.ndarray, new_Ns: int, kind: str = "cubic") -> np.n is applied along the second axis. new_Ns : int The desired number of samples in the resampled trajectory. - kind : str, optional + kind : Literal, optional The type of interpolation to use, such as 'linear', 'quadratic', or 'cubic', by default "cubic". @@ -500,9 +502,9 @@ def stack_spherically( trajectory_func: Callable[..., np.ndarray], Nc: int, nb_stacks: int, - z_tilt: str | None = None, + z_tilt: Literal | float | None = None, hard_bounded: bool = True, - **traj_kwargs: Any, + **traj_kwargs: Any, # noqa ANN401 ) -> np.ndarray: """Stack 2D or 3D trajectories over the :math:`k_z`-axis to make a sphere. @@ -515,7 +517,7 @@ def stack_spherically( Number of shots to use for the whole spherically stacked trajectory. nb_stacks : int Number of stacks of trajectories. - z_tilt : str | None, optional + z_tilt : Literal, float, optional Tilt of the stacks, by default `None`. hard_bounded : bool, optional Whether the stacks should be strictly within the limits @@ -587,9 +589,9 @@ def shellify( trajectory_func: Callable[..., np.ndarray], Nc: int, nb_shells: int, - z_tilt: str | float = "golden", - hemisphere_mode: str = "symmetric", - **traj_kwargs: Any, + z_tilt: Literal | float = "golden", + hemisphere_mode: Literal = "symmetric", + **traj_kwargs: Any, # noqa ANN401 ) -> np.ndarray: """Stack 2D or 3D trajectories over the :math:`k_z`-axis to make a sphere. @@ -602,9 +604,9 @@ def shellify( Number of shots to use for the whole spherically stacked trajectory. nb_shells : int Number of shells of distorted trajectories. - z_tilt : str | float, optional + z_tilt : Literal, float, optional Tilt of the shells, by default "golden". - hemisphere_mode : str, optional + hemisphere_mode : Literal, optional Define how the lower hemisphere should be oriented relatively to the upper one, with "symmetric" providing a :math:`k_x-k_y` planar symmetry by changing the polar angle, and with "reversed" promoting continuity diff --git a/src/mrinufft/trajectories/trajectory2D.py b/src/mrinufft/trajectories/trajectory2D.py index e23f12bf8..e9d9c492c 100644 --- a/src/mrinufft/trajectories/trajectory2D.py +++ b/src/mrinufft/trajectories/trajectory2D.py @@ -1,5 +1,7 @@ """Functions to initialize 2D trajectories.""" +from typing import Any, Literal + import numpy as np import numpy.linalg as nl from scipy.interpolate import CubicSpline @@ -15,7 +17,7 @@ def initialize_2D_radial( - Nc: int, Ns: int, tilt: str | float = "uniform", in_out: bool = False + Nc: int, Ns: int, tilt: Literal | float = "uniform", in_out: bool = False ) -> np.ndarray: """Initialize a 2D radial trajectory. @@ -25,7 +27,7 @@ def initialize_2D_radial( Number of shots Ns : int Number of samples per shot - tilt : str, float, optional + tilt : Literal, float, optional Tilt of the shots, by default "uniform" in_out : bool, optional Whether to start from the center or not, by default False @@ -51,10 +53,10 @@ def initialize_2D_radial( def initialize_2D_spiral( Nc: int, Ns: int, - tilt: str | float = "uniform", + tilt: Literal | float = "uniform", in_out: bool = False, nb_revolutions: int = 1, - spiral: str | float = "archimedes", + spiral: Literal | float = "archimedes", patch_center: bool = True, ) -> np.ndarray: """Initialize a 2D algebraic spiral trajectory. @@ -70,13 +72,13 @@ def initialize_2D_spiral( Number of shots Ns : int Number of samples per shot - tilt : str, float, optional + tilt : Literal, float, optional Tilt of the shots, by default "uniform" in_out : bool, optional Whether to start from the center or not, by default False nb_revolutions : int, optional Number of revolutions, by default 1 - spiral : str, float, optional + spiral : Literal, float, optional Spiral type or algebraic power, by default "archimedes" patch_center : bool, optional Whether the spiral anomaly at the center should be patched @@ -111,11 +113,18 @@ def initialize_2D_spiral( # Algebraic spirals with power coefficients superior to 1 # have a non-monotonic gradient norm when varying the angle # over [0, +inf) - def _update_shot(angles, radius, *args): + def _update_shot( + angles: np.ndarray, radius: np.ndarray, *args: Any # noqa ANN401 + ) -> np.ndarray: shot = np.sign(angles) * np.abs(radius) * np.exp(1j * np.abs(angles)) return np.stack([shot.real, shot.imag], axis=-1) - def _update_parameters(single_shot, angles, radius, spiral_power): + def _update_parameters( + single_shot: np.ndarray, + angles: np.ndarray, + radius: np.ndarray, + spiral_power: float, + ) -> tuple[np.ndarray, np.ndarray, float]: radius = nl.norm(single_shot, axis=-1) angles = np.sign(angles) * np.abs(radius) ** (1 / spiral_power) return angles, radius, spiral_power @@ -220,7 +229,7 @@ def initialize_2D_fibonacci_spiral( def initialize_2D_cones( Nc: int, Ns: int, - tilt: str = "uniform", + tilt: Literal = "uniform", in_out: bool = False, nb_zigzags: float = 5, width: float = 1, @@ -233,7 +242,7 @@ def initialize_2D_cones( Number of shots Ns : int Number of samples per shot - tilt : str, optional + tilt : Literal, optional Tilt of the shots, by default "uniform" in_out : bool, optional Whether to start from the center or not, by default False @@ -266,7 +275,7 @@ def initialize_2D_cones( def initialize_2D_sinusoide( Nc: int, Ns: int, - tilt: str | float = "uniform", + tilt: Literal | float = "uniform", in_out: bool = False, nb_zigzags: float = 5, width: float = 1, @@ -279,7 +288,7 @@ def initialize_2D_sinusoide( Number of shots Ns : int Number of samples per shot - tilt : str, float, optional + tilt : Literal, float, optional Tilt of the shots, by default "uniform" in_out : bool, optional Whether to start from the center or not, by default False diff --git a/src/mrinufft/trajectories/trajectory3D.py b/src/mrinufft/trajectories/trajectory3D.py index a6cbd5e5a..725131043 100644 --- a/src/mrinufft/trajectories/trajectory3D.py +++ b/src/mrinufft/trajectories/trajectory3D.py @@ -1,6 +1,7 @@ """Functions to initialize 3D trajectories.""" from functools import partial +from typing import Literal import numpy as np import numpy.linalg as nl @@ -243,10 +244,10 @@ def initialize_3D_park_radial( def initialize_3D_cones( Nc: int, Ns: int, - tilt: str | float = "golden", + tilt: Literal | float = "golden", in_out: bool = False, nb_zigzags: float = 5, - spiral: str | float = "archimedes", + spiral: Literal | float = "archimedes", width: float = 1, ) -> np.ndarray: """Initialize 3D trajectories with cones. @@ -264,14 +265,14 @@ def initialize_3D_cones( Number of shots Ns : int Number of samples per shot - tilt : str, float, optional + tilt : Literal, float, optional Tilt of the cones, by default "golden" in_out : bool, optional Whether the curves are going in-and-out or start from the center, by default False nb_zigzags : float, optional Number of zigzags of the cones, by default 5 - spiral : str, float, optional + spiral : Literal, float, optional Spiral type, by default "archimedes" width : float, optional Cone width normalized such that `width=1` avoids cone overlaps, by default 1 @@ -326,8 +327,8 @@ def initialize_3D_floret( Ns: int, in_out: bool = False, nb_revolutions: float = 1, - spiral: str | float = "fermat", - cone_tilt: str | float = "golden", + spiral: Literal | float = "fermat", + cone_tilt: Literal | float = "golden", max_angle: float = np.pi / 2, axes: tuple[int, ...] = (2,), ) -> np.ndarray: @@ -348,9 +349,9 @@ def initialize_3D_floret( Whether to start from the center or not, by default False nb_revolutions : float, optional Number of revolutions of the spirals, by default 1 - spiral : str, float, optional + spiral : Literal, float, optional Spiral type, by default "fermat" - cone_tilt : str, float, optional + cone_tilt : Literal, float, optional Tilt of the cones around the :math:`k_z`-axis, by default "golden" max_angle : float, optional Maximum polar angle starting from the :math:`k_x-k_y` plane, @@ -405,8 +406,8 @@ def initialize_3D_wave_caipi( Ns: int, nb_revolutions: float = 5, width: float = 1, - packing: str = "triangular", - shape: str | float = "square", + packing: Literal = "triangular", + shape: Literal | float = "square", spacing: tuple[int, int] = (1, 1), ) -> np.ndarray: """Initialize 3D trajectories with Wave-CAIPI. @@ -425,11 +426,11 @@ def initialize_3D_wave_caipi( Diameter of the helices normalized such that `width=1` densely covers the k-space without overlap for square packing, by default 1. - packing : str, optional + packing : Literal, optional Packing method used to position the helices: "triangular"/"hexagonal", "square", "circular" or "random"/"uniform", by default "triangular". - shape : str or float, optional + shape : Literal or float, optional Shape over the 2D :math:`k_x-k_y` plane to pack with shots, either defined as `str` ("circle", "square", "diamond") or as `float` through p-norms following the conventions @@ -524,8 +525,8 @@ def initialize_3D_seiffert_spiral( Ns: int, curve_index: float = 0.2, nb_revolutions: float = 1, - axis_tilt: str | float = "golden", - spiral_tilt: str | float = "golden", + axis_tilt: Literal | float = "golden", + spiral_tilt: Literal | float = "golden", in_out: bool = False, ) -> np.ndarray: """Initialize 3D trajectories with modulated Seiffert spirals. @@ -546,9 +547,9 @@ def initialize_3D_seiffert_spiral( nb_revolutions : float Number of revolutions, i.e. times the polar angle of the curves passes through 0, by default 1 - axis_tilt : str, float, optional + axis_tilt : Literal, float, optional Angle between shots over a precession around the z-axis, by default "golden" - spiral_tilt : str, float, optional + spiral_tilt : Literal, float, optional Angle of the spiral within its own axis defined from center to its outermost point, by default "golden" in_out : bool @@ -629,8 +630,8 @@ def initialize_3D_helical_shells( Ns: int, nb_shells: int, spiral_reduction: float = 1, - shell_tilt: str = "intergaps", - shot_tilt: str = "uniform", + shell_tilt: Literal = "intergaps", + shot_tilt: Literal = "uniform", ) -> np.ndarray: """Initialize 3D trajectories with helical shells. @@ -647,9 +648,9 @@ def initialize_3D_helical_shells( Number of concentric shells/spheres spiral_reduction : float, optional Factor used to reduce the automatic spiral length, by default 1 - shell_tilt : str, float, optional + shell_tilt : Literal, float, optional Angle between consecutive shells along z-axis, by default "intergaps" - shot_tilt : str, float, optional + shot_tilt : Literal, float, optional Angle between shots over a shell surface along z-axis, by default "uniform" Returns @@ -731,9 +732,9 @@ def initialize_3D_annular_shells( Number of samples per shot nb_shells : int Number of concentric shells/spheres - shell_tilt : str, float, optional + shell_tilt : Literal, float, optional Angle between consecutive shells along z-axis, by default pi - ring_tilt : str, float, optional + ring_tilt : Literal, float, optional Angle controlling approximately the ring halves rotation, by default pi / 2 Returns @@ -835,8 +836,8 @@ def initialize_3D_seiffert_shells( nb_shells: int, curve_index: float = 0.5, nb_revolutions: float = 1, - shell_tilt: str = "uniform", - shot_tilt: str = "uniform", + shell_tilt: Literal = "uniform", + shot_tilt: Literal = "uniform", ) -> np.ndarray: """Initialize 3D trajectories with Seiffert shells. @@ -857,9 +858,9 @@ def initialize_3D_seiffert_shells( nb_revolutions : float Number of revolutions, i.e. times the curve passes through the upper-half of the z-axis, by default 1 - shell_tilt : str, float, optional + shell_tilt : Literal, float, optional Angle between consecutive shells along z-axis, by default "uniform" - shot_tilt : str, float, optional + shot_tilt : Literal, float, optional Angle between shots over a shell surface along z-axis, by default "uniform" Returns @@ -932,7 +933,7 @@ def initialize_3D_turbine( Ns_readouts: int, Ns_transitions: int, nb_blades: int, - blade_tilt: str = "uniform", + blade_tilt: Literal = "uniform", nb_trains: int | str = "auto", skip_factor: int = 1, in_out: bool = True, @@ -959,7 +960,7 @@ def initialize_3D_turbine( Number of samples per transition between two readouts nb_blades : int Number of line stacks over the :math:`k_z`-axis axis - blade_tilt : str, float, optional + blade_tilt : Literal, float, optional Tilt between individual blades, by default "uniform" nb_trains : int, str, optional Number of resulting shots, or readout trains, such that each of @@ -1041,9 +1042,9 @@ def initialize_3D_repi( Ns_transitions: int, nb_blades: int, nb_blade_revolutions: float = 0, - blade_tilt: str = "uniform", + blade_tilt: Literal = "uniform", nb_spiral_revolutions: float = 0, - spiral: str = "archimedes", + spiral: Literal = "archimedes", nb_trains: int | str = "auto", in_out: bool = True, ) -> np.ndarray: @@ -1075,11 +1076,11 @@ def initialize_3D_repi( nb_blade_revolutions : float Number of revolutions over lines/spirals within a blade over the kz axis. - blade_tilt : str, float, optional + blade_tilt : Literal, float, optional Tilt between individual blades, by default "uniform" nb_spiral_revolutions : float, optional Number of revolutions of the spirals over the readouts, by default 0 - spiral : str, float, optional + spiral : Literal, float, optional Spiral type, by default "archimedes" nb_trains : int, str Number of trains dividing the readouts, such that each diff --git a/src/mrinufft/trajectories/utils.py b/src/mrinufft/trajectories/utils.py index a9023d2bc..880bb96a3 100644 --- a/src/mrinufft/trajectories/utils.py +++ b/src/mrinufft/trajectories/utils.py @@ -2,6 +2,7 @@ from enum import Enum, EnumMeta from numbers import Real +from typing import Literal import numpy as np @@ -26,11 +27,11 @@ class CaseInsensitiveEnumMeta(EnumMeta): """A case-insensitive EnumMeta.""" - def __getitem__(self, name): + def __getitem__(self, name: str) -> Enum: """Allow ``MyEnum['Member'] == MyEnum['MEMBER']`` .""" return super().__getitem__(name.upper()) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: # noqa ANN401 """Allow ``MyEnum.Member == MyEnum.MEMBER`` .""" return super().__getattr__(name.upper()) @@ -454,12 +455,12 @@ def check_hardware_constraints( ########### -def initialize_tilt(tilt: str | float, nb_partitions: int = 1) -> float: +def initialize_tilt(tilt: Literal | float, nb_partitions: int = 1) -> float: r"""Initialize the tilt angle. Parameters ---------- - tilt : str or float + tilt : Literal or float Tilt angle in rad or name of the tilt. nb_partitions : int, optional Number of partitions. The default is 1. @@ -497,12 +498,12 @@ def initialize_tilt(tilt: str | float, nb_partitions: int = 1) -> float: raise NotImplementedError(f"Unknown tilt name: {tilt}") -def initialize_algebraic_spiral(spiral: str | float) -> float: +def initialize_algebraic_spiral(spiral: Literal | float) -> float: """Initialize the algebraic spiral type. Parameters ---------- - spiral : str or float + spiral : Literal or float Spiral type or spiral power value. Returns @@ -515,12 +516,12 @@ def initialize_algebraic_spiral(spiral: str | float) -> float: return Spirals[spiral] -def initialize_shape_norm(shape: str | float) -> float: +def initialize_shape_norm(shape: Literal | float) -> float: """Initialize the norm for a given shape. Parameters ---------- - shape : str or float + shape : Literal or float Shape name or p-norm value. Returns From d5b3e095efb490d6cdf1e5643f00dc5b68c3a831 Mon Sep 17 00:00:00 2001 From: Guillaume DAVAL-FREROT Date: Thu, 19 Dec 2024 23:23:13 +0100 Subject: [PATCH 3/6] Replace np.ndarray type hints with np.typing.NDArray --- src/mrinufft/trajectories/display.py | 8 +++--- src/mrinufft/trajectories/gradients.py | 10 ++++--- .../trajectories/inits/random_walk.py | 6 ++--- .../trajectories/inits/travelling_salesman.py | 10 +++---- src/mrinufft/trajectories/maths/rotations.py | 8 ++++-- src/mrinufft/trajectories/maths/tsp_solver.py | 2 +- src/mrinufft/trajectories/sampling.py | 8 +++--- src/mrinufft/trajectories/tools.py | 26 +++++++++---------- src/mrinufft/trajectories/trajectory2D.py | 8 +++--- src/mrinufft/trajectories/utils.py | 22 ++++++++-------- 10 files changed, 57 insertions(+), 51 deletions(-) diff --git a/src/mrinufft/trajectories/display.py b/src/mrinufft/trajectories/display.py index f900c3f9b..b6ca11876 100644 --- a/src/mrinufft/trajectories/display.py +++ b/src/mrinufft/trajectories/display.py @@ -164,7 +164,7 @@ def _setup_3D_ticks(figsize: float, fig: plt.Figure | None = None) -> plt.Axes: def display_2D_trajectory( - trajectory: np.ndarray, + trajectory: np.typing.NDArray, figsize: float = 5, one_shot: bool | int = False, subfigure: plt.Figure | plt.Axes | None = None, @@ -279,7 +279,7 @@ def display_2D_trajectory( def display_3D_trajectory( - trajectory: np.ndarray, + trajectory: np.typing.NDArray, nb_repetitions: int | None = None, figsize: float = 5, per_plane: bool = True, @@ -418,7 +418,7 @@ def display_3D_trajectory( def display_gradients_simply( - trajectory: np.ndarray, + trajectory: np.typing.NDArray, shot_ids: tuple[int, ...] = (0,), figsize: float = 5, fill_area: bool = True, @@ -532,7 +532,7 @@ def display_gradients_simply( def display_gradients( - trajectory: np.ndarray, + trajectory: np.typing.NDArray, shot_ids: tuple[int, ...] = (0,), figsize: float = 5, fill_area: bool = True, diff --git a/src/mrinufft/trajectories/gradients.py b/src/mrinufft/trajectories/gradients.py index 932c22ee5..ad24bc187 100644 --- a/src/mrinufft/trajectories/gradients.py +++ b/src/mrinufft/trajectories/gradients.py @@ -8,12 +8,12 @@ def patch_center_anomaly( - shot_or_params: np.ndarray | list, - update_shot: Callable[..., np.ndarray] | None = None, + shot_or_params: np.typing.NDArray | list, + update_shot: Callable[..., np.typing.NDArray] | None = None, update_parameters: Callable[..., list] | None = None, in_out: bool = False, learning_rate: float = 1e-1, -) -> tuple[np.ndarray, list]: +) -> tuple[np.typing.NDArray, list]: """Re-position samples to avoid center anomalies. Some trajectories behave slightly differently from expected when @@ -72,7 +72,9 @@ def patch_center_anomaly( if update_shot is None or update_parameters is None: - def _default_update_parameters(shot: np.ndarray, *parameters: list) -> list: + def _default_update_parameters( + shot: np.typing.NDArray, *parameters: list + ) -> list: return parameters update_parameters = _default_update_parameters diff --git a/src/mrinufft/trajectories/inits/random_walk.py b/src/mrinufft/trajectories/inits/random_walk.py index 0ad947bbd..296e55f26 100644 --- a/src/mrinufft/trajectories/inits/random_walk.py +++ b/src/mrinufft/trajectories/inits/random_walk.py @@ -24,7 +24,7 @@ def _get_neighbors_offsets(shape: tuple[int, ...]) -> np.ndarray: def _initialize_ND_random_walk( Nc: int, Ns: int, - density: np.ndarray, + density: np.typing.NDArray, *, diagonals: bool = True, pseudo_random: bool = True, @@ -97,7 +97,7 @@ def _initialize_ND_random_walk( def initialize_2D_random_walk( Nc: int, Ns: int, - density: np.ndarray, + density: np.typing.NDArray, *, diagonals: bool = True, pseudo_random: bool = True, @@ -162,7 +162,7 @@ def initialize_2D_random_walk( def initialize_3D_random_walk( Nc: int, Ns: int, - density: np.ndarray, + density: np.typing.NDArray, *, diagonals: bool = True, pseudo_random: bool = True, diff --git a/src/mrinufft/trajectories/inits/travelling_salesman.py b/src/mrinufft/trajectories/inits/travelling_salesman.py index 411534b77..c0db6564b 100644 --- a/src/mrinufft/trajectories/inits/travelling_salesman.py +++ b/src/mrinufft/trajectories/inits/travelling_salesman.py @@ -20,7 +20,7 @@ def _get_approx_cluster_sizes(nb_total: int, nb_clusters: int) -> np.ndarray: return cluster_sizes -def _sort_by_coordinate(array: np.ndarray, coord: Literal) -> np.ndarray: +def _sort_by_coordinate(array: np.typing.NDArray, coord: Literal) -> np.ndarray: # Sort a list of N-D locations by a Cartesian/spherical coordinate if array.shape[-1] < 3 and coord.lower() in ["z", "theta"]: raise ValueError( @@ -49,7 +49,7 @@ def _sort_by_coordinate(array: np.ndarray, coord: Literal) -> np.ndarray: def _cluster_by_coordinate( - locations: np.ndarray, + locations: np.typing.NDArray, nb_clusters: int, cluster_by: Literal, second_cluster_by: Literal | None = None, @@ -95,7 +95,7 @@ def _cluster_by_coordinate( def _initialize_ND_travelling_salesman( Nc: int, Ns: int, - density: np.ndarray, + density: np.typing.NDArray, first_cluster_by: Literal | None = None, second_cluster_by: Literal | None = None, sort_by: Literal | None = None, @@ -142,7 +142,7 @@ def _initialize_ND_travelling_salesman( def initialize_2D_travelling_salesman( Nc: int, Ns: int, - density: np.ndarray, + density: np.typing.NDArray, first_cluster_by: Literal | None = None, second_cluster_by: Literal | None = None, sort_by: Literal | None = None, @@ -219,7 +219,7 @@ def initialize_2D_travelling_salesman( def initialize_3D_travelling_salesman( Nc: int, Ns: int, - density: np.ndarray, + density: np.typing.NDArray, first_cluster_by: Literal | None = None, second_cluster_by: Literal | None = None, sort_by: Literal | None = None, diff --git a/src/mrinufft/trajectories/maths/rotations.py b/src/mrinufft/trajectories/maths/rotations.py index 5bd165701..59e2558d1 100644 --- a/src/mrinufft/trajectories/maths/rotations.py +++ b/src/mrinufft/trajectories/maths/rotations.py @@ -87,7 +87,11 @@ def Rz(theta: float) -> np.ndarray: def Rv( - v1: np.ndarray, v2: np.ndarray, eps: float = 1e-8, *, normalize: bool = True + v1: np.typing.NDArray, + v2: np.typing.NDArray, + eps: float = 1e-8, + *, + normalize: bool = True, ) -> np.ndarray: """Initialize 3D rotation matrix from two vectors. @@ -126,7 +130,7 @@ def Rv( return np.identity(3) + cross_matrix + cross_matrix @ cross_matrix / (1 + cos_theta) -def Ra(vector: np.ndarray, theta: float) -> np.ndarray: +def Ra(vector: np.typing.NDArray, theta: float) -> np.ndarray: """Initialize 3D rotation matrix around an arbitrary vector. Initialize a 3D rotation matrix to rotate around `vector` by an angle `theta`. diff --git a/src/mrinufft/trajectories/maths/tsp_solver.py b/src/mrinufft/trajectories/maths/tsp_solver.py index 7bad67662..92a6709e1 100644 --- a/src/mrinufft/trajectories/maths/tsp_solver.py +++ b/src/mrinufft/trajectories/maths/tsp_solver.py @@ -4,7 +4,7 @@ def solve_tsp_with_2opt( - locations: np.ndarray, improvement_threshold: float = 1e-8 + locations: np.typing.NDArray, improvement_threshold: float = 1e-8 ) -> np.ndarray: """Solve the TSP problem using a 2-opt approach. diff --git a/src/mrinufft/trajectories/sampling.py b/src/mrinufft/trajectories/sampling.py index eff2be0f3..6d011bf4b 100644 --- a/src/mrinufft/trajectories/sampling.py +++ b/src/mrinufft/trajectories/sampling.py @@ -14,7 +14,7 @@ def sample_from_density( nb_samples: int, - density: np.ndarray, + density: np.typing.NDArray, method: Literal = "random", *, dim_compensation: Literal | bool = "auto", @@ -121,7 +121,7 @@ def create_cutoff_decay_density( shape: tuple[int, ...], cutoff: float, decay: float, - resolution: np.ndarray | None = None, + resolution: np.typing.NDArray | None = None, ) -> np.ndarray: """ Create a density with central plateau and polynomial decay. @@ -180,7 +180,7 @@ def create_cutoff_decay_density( def create_polynomial_density( - shape: tuple[int, ...], decay: float, resolution: np.ndarray | None = None + shape: tuple[int, ...], decay: float, resolution: np.typing.NDArray | None = None ) -> np.ndarray: """ Create a density with polynomial decay from the center. @@ -205,7 +205,7 @@ def create_polynomial_density( ) -def create_energy_density(dataset: np.ndarray) -> np.ndarray: +def create_energy_density(dataset: np.typing.NDArray) -> np.ndarray: """ Create a density based on energy in the Fourier spectrum. diff --git a/src/mrinufft/trajectories/tools.py b/src/mrinufft/trajectories/tools.py index 334e245f2..25d173937 100644 --- a/src/mrinufft/trajectories/tools.py +++ b/src/mrinufft/trajectories/tools.py @@ -14,7 +14,7 @@ def stack( - trajectory: np.ndarray, + trajectory: np.typing.NDArray, nb_stacks: int, z_tilt: Literal | float | None = None, *, @@ -64,7 +64,7 @@ def stack( def rotate( - trajectory: np.ndarray, + trajectory: np.typing.NDArray, nb_rotations: int, x_tilt: Literal | float | None = None, y_tilt: Literal | float | None = None, @@ -110,7 +110,7 @@ def rotate( def precess( - trajectory: np.ndarray, + trajectory: np.typing.NDArray, nb_rotations: int, tilt: Literal | float = "golden", half_sphere: bool = False, @@ -189,7 +189,7 @@ def precess( def conify( - trajectory: np.ndarray, + trajectory: np.typing.NDArray, nb_cones: int, z_tilt: Literal | float | None = None, in_out: bool = False, @@ -268,7 +268,7 @@ def conify( def epify( - trajectory: np.ndarray, + trajectory: np.typing.NDArray, Ns_transitions: int, nb_trains: int, *, @@ -328,7 +328,7 @@ def epify( def unepify( - trajectory: np.ndarray, Ns_readouts: int, Ns_transitions: int + trajectory: np.typing.NDArray, Ns_readouts: int, Ns_transitions: int ) -> np.ndarray: """Recover single-readout shots from multi-readout trajectory. @@ -371,7 +371,7 @@ def unepify( return trajectory -def prewind(trajectory: np.ndarray, Ns_transitions: int) -> np.ndarray: +def prewind(trajectory: np.typing.NDArray, Ns_transitions: int) -> np.ndarray: """Add pre-winding/positioning to the trajectory. The trajectory is extended to start before the readout @@ -411,7 +411,7 @@ def prewind(trajectory: np.ndarray, Ns_transitions: int) -> np.ndarray: return assembled_trajectory -def rewind(trajectory: np.ndarray, Ns_transitions: int) -> np.ndarray: +def rewind(trajectory: np.typing.NDArray, Ns_transitions: int) -> np.ndarray: """Add rewinding to the trajectory. The trajectory is extended to come back to the k-space center @@ -452,7 +452,7 @@ def rewind(trajectory: np.ndarray, Ns_transitions: int) -> np.ndarray: def oversample( - trajectory: np.ndarray, new_Ns: int, kind: Literal = "cubic" + trajectory: np.typing.NDArray, new_Ns: int, kind: Literal = "cubic" ) -> np.ndarray: """ Resample a trajectory to increase the number of samples using interpolation. @@ -685,7 +685,7 @@ def shellify( def duplicate_along_axes( - trajectory: np.ndarray, axes: tuple[int, ...] = (0, 1, 2) + trajectory: np.typing.NDArray, axes: tuple[int, ...] = (0, 1, 2) ) -> np.ndarray: """ Duplicate a trajectory along the specified axes. @@ -722,7 +722,7 @@ def duplicate_along_axes( return new_trajectory -def _radialize_center_out(trajectory: np.ndarray, nb_samples: int) -> np.ndarray: +def _radialize_center_out(trajectory: np.typing.NDArray, nb_samples: int) -> np.ndarray: """Radialize a trajectory from the center to the outside. Parameters @@ -747,7 +747,7 @@ def _radialize_center_out(trajectory: np.ndarray, nb_samples: int) -> np.ndarray return new_trajectory -def _radialize_in_out(trajectory: np.ndarray, nb_samples: int) -> np.ndarray: +def _radialize_in_out(trajectory: np.typing.NDArray, nb_samples: int) -> np.ndarray: """Radialize a trajectory from the inside to the outside. Parameters @@ -778,7 +778,7 @@ def _radialize_in_out(trajectory: np.ndarray, nb_samples: int) -> np.ndarray: def radialize_center( - trajectory: np.ndarray, nb_samples: int, in_out: bool = False + trajectory: np.typing.NDArray, nb_samples: int, in_out: bool = False ) -> np.ndarray: """Radialize a trajectory. diff --git a/src/mrinufft/trajectories/trajectory2D.py b/src/mrinufft/trajectories/trajectory2D.py index e9d9c492c..2e1e5a89a 100644 --- a/src/mrinufft/trajectories/trajectory2D.py +++ b/src/mrinufft/trajectories/trajectory2D.py @@ -114,15 +114,15 @@ def initialize_2D_spiral( # have a non-monotonic gradient norm when varying the angle # over [0, +inf) def _update_shot( - angles: np.ndarray, radius: np.ndarray, *args: Any # noqa ANN401 + angles: np.typing.NDArray, radius: np.typing.NDArray, *args: Any # noqa ANN401 ) -> np.ndarray: shot = np.sign(angles) * np.abs(radius) * np.exp(1j * np.abs(angles)) return np.stack([shot.real, shot.imag], axis=-1) def _update_parameters( - single_shot: np.ndarray, - angles: np.ndarray, - radius: np.ndarray, + single_shot: np.typing.NDArray, + angles: np.typing.NDArray, + radius: np.typing.NDArray, spiral_power: float, ) -> tuple[np.ndarray, np.ndarray, float]: radius = nl.norm(single_shot, axis=-1) diff --git a/src/mrinufft/trajectories/utils.py b/src/mrinufft/trajectories/utils.py index 880bb96a3..81cd2e73e 100644 --- a/src/mrinufft/trajectories/utils.py +++ b/src/mrinufft/trajectories/utils.py @@ -151,7 +151,7 @@ class Packings(str, Enum, metaclass=CaseInsensitiveEnumMeta): def normalize_trajectory( - trajectory: np.ndarray, + trajectory: np.typing.NDArray, norm_factor: float = KMAX, resolution: float | np.ndarray = DEFAULT_RESOLUTION, ) -> np.ndarray: @@ -177,7 +177,7 @@ def normalize_trajectory( def unnormalize_trajectory( - trajectory: np.ndarray, + trajectory: np.typing.NDArray, norm_factor: float = KMAX, resolution: float | np.ndarray = DEFAULT_RESOLUTION, ) -> np.ndarray: @@ -203,7 +203,7 @@ def unnormalize_trajectory( def convert_trajectory_to_gradients( - trajectory: np.ndarray, + trajectory: np.typing.NDArray, norm_factor: float = KMAX, resolution: float | np.ndarray = DEFAULT_RESOLUTION, raster_time: float = DEFAULT_RASTER_TIME, @@ -250,8 +250,8 @@ def convert_trajectory_to_gradients( def convert_gradients_to_trajectory( - gradients: np.ndarray, - initial_positions: np.ndarray | None = None, + gradients: np.typing.NDArray, + initial_positions: np.typing.NDArray | None = None, norm_factor: float = KMAX, resolution: float | np.ndarray = DEFAULT_RESOLUTION, raster_time: float = DEFAULT_RASTER_TIME, @@ -300,7 +300,7 @@ def convert_gradients_to_trajectory( def convert_gradients_to_slew_rates( - gradients: np.ndarray, + gradients: np.typing.NDArray, raster_time: float = DEFAULT_RASTER_TIME, ) -> tuple[np.ndarray, np.ndarray]: """Derive the gradients over time to provide slew rates. @@ -328,8 +328,8 @@ def convert_gradients_to_slew_rates( def convert_slew_rates_to_gradients( - slewrates: np.ndarray, - initial_gradients: np.ndarray | None = None, + slewrates: np.typing.NDArray, + initial_gradients: np.typing.NDArray | None = None, raster_time: float = DEFAULT_RASTER_TIME, ) -> np.ndarray: """Integrate slew rates over time to provide gradients. @@ -363,7 +363,7 @@ def convert_slew_rates_to_gradients( def compute_gradients_and_slew_rates( - trajectory: np.ndarray, + trajectory: np.typing.NDArray, norm_factor: float = KMAX, resolution: float | np.ndarray = DEFAULT_RESOLUTION, raster_time: float = DEFAULT_RASTER_TIME, @@ -412,8 +412,8 @@ def compute_gradients_and_slew_rates( def check_hardware_constraints( - gradients: np.ndarray, - slewrates: np.ndarray, + gradients: np.typing.NDArray, + slewrates: np.typing.NDArray, gmax: float = DEFAULT_GMAX, smax: float = DEFAULT_SMAX, order: int | str | None = None, From e4314acf0454585e0fad397ecde29eecbf1836a1 Mon Sep 17 00:00:00 2001 From: Guillaume DAVAL-FREROT Date: Thu, 19 Dec 2024 23:24:41 +0100 Subject: [PATCH 4/6] Fix missing import --- src/mrinufft/trajectories/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrinufft/trajectories/utils.py b/src/mrinufft/trajectories/utils.py index 81cd2e73e..0bb4cce2c 100644 --- a/src/mrinufft/trajectories/utils.py +++ b/src/mrinufft/trajectories/utils.py @@ -2,7 +2,7 @@ from enum import Enum, EnumMeta from numbers import Real -from typing import Literal +from typing import Literal, Any import numpy as np From 914fc3fcead8f9fd6b39499c4eafdd15dafe677a Mon Sep 17 00:00:00 2001 From: Guillaume DAVAL-FREROT Date: Wed, 25 Dec 2024 21:43:29 +0100 Subject: [PATCH 5/6] Fix Literal use, change types to NDArray and reduce mypy errors --- src/mrinufft/trajectories/display.py | 29 +-- src/mrinufft/trajectories/gradients.py | 21 +- .../trajectories/inits/random_walk.py | 56 ++--- .../trajectories/inits/travelling_salesman.py | 75 +++---- src/mrinufft/trajectories/maths/primes.py | 2 +- src/mrinufft/trajectories/maths/rotations.py | 35 ++-- src/mrinufft/trajectories/maths/tsp_solver.py | 9 +- src/mrinufft/trajectories/sampling.py | 62 +++--- src/mrinufft/trajectories/tools.py | 196 ++++++++---------- src/mrinufft/trajectories/trajectory2D.py | 75 +++---- src/mrinufft/trajectories/trajectory3D.py | 167 +++++++-------- src/mrinufft/trajectories/utils.py | 111 +++++----- 12 files changed, 419 insertions(+), 419 deletions(-) diff --git a/src/mrinufft/trajectories/display.py b/src/mrinufft/trajectories/display.py index b6ca11876..42e07cbe1 100644 --- a/src/mrinufft/trajectories/display.py +++ b/src/mrinufft/trajectories/display.py @@ -1,5 +1,7 @@ """Display functions for trajectories.""" +from __future__ import annotations + import itertools from typing import Any @@ -7,6 +9,7 @@ import matplotlib.pyplot as plt import matplotlib.ticker as mticker import numpy as np +from numpy.typing import NDArray from .utils import ( DEFAULT_GMAX, @@ -76,7 +79,7 @@ def reset(self) -> None: setattr(displayConfig, key, value) delattr(self, "_old_values") - def __enter__(self) -> "displayConfig": + def __enter__(self) -> displayConfig: """Enter the context manager.""" return self @@ -85,7 +88,7 @@ def __exit__(self, *args: Any) -> None: # noqa ANN401 self.reset() @classmethod - def get_colorlist(cls) -> list[str | np.ndarray]: + def get_colorlist(cls) -> list[str | NDArray]: """Extract a list of colors from a matplotlib palette. If the palette is continuous, the colors will be sampled from it. @@ -164,7 +167,7 @@ def _setup_3D_ticks(figsize: float, fig: plt.Figure | None = None) -> plt.Axes: def display_2D_trajectory( - trajectory: np.typing.NDArray, + trajectory: NDArray, figsize: float = 5, one_shot: bool | int = False, subfigure: plt.Figure | plt.Axes | None = None, @@ -172,13 +175,13 @@ def display_2D_trajectory( gmax: float = DEFAULT_GMAX, smax: float = DEFAULT_SMAX, constraints_order: int | str | None = None, - **constraints_kwargs: float | np.ndarray, + **constraints_kwargs: Any, # noqa ANN401 ) -> plt.Axes: """Display 2D trajectories. Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory to display. figsize : float, optional Size of the figure. @@ -279,7 +282,7 @@ def display_2D_trajectory( def display_3D_trajectory( - trajectory: np.typing.NDArray, + trajectory: NDArray, nb_repetitions: int | None = None, figsize: float = 5, per_plane: bool = True, @@ -289,13 +292,13 @@ def display_3D_trajectory( gmax: float = DEFAULT_GMAX, smax: float = DEFAULT_SMAX, constraints_order: int | str | None = None, - **constraints_kwargs: dict, + **constraints_kwargs: Any, # noqa ANN401 ) -> plt.Axes: """Display 3D trajectories. Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory to display. nb_repetitions : int Number of repetitions (planes, cones, shells, etc). @@ -418,7 +421,7 @@ def display_3D_trajectory( def display_gradients_simply( - trajectory: np.typing.NDArray, + trajectory: NDArray, shot_ids: tuple[int, ...] = (0,), figsize: float = 5, fill_area: bool = True, @@ -431,7 +434,7 @@ def display_gradients_simply( Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory to display. shot_ids : tuple[int, ...], optional Indices of the shots to display. @@ -532,7 +535,7 @@ def display_gradients_simply( def display_gradients( - trajectory: np.typing.NDArray, + trajectory: NDArray, shot_ids: tuple[int, ...] = (0,), figsize: float = 5, fill_area: bool = True, @@ -545,13 +548,13 @@ def display_gradients( smax: float = DEFAULT_SMAX, constraints_order: int | str | None = None, raster_time: float = DEFAULT_RASTER_TIME, - **constraints_kwargs: float | np.ndarray, + **constraints_kwargs: Any, # noqa ANN401 ) -> tuple[plt.Axes]: """Display gradients based on trajectory of any dimension. Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory to display. shot_ids : list of int Indices of the shots to display. diff --git a/src/mrinufft/trajectories/gradients.py b/src/mrinufft/trajectories/gradients.py index ad24bc187..ac205cab1 100644 --- a/src/mrinufft/trajectories/gradients.py +++ b/src/mrinufft/trajectories/gradients.py @@ -4,16 +4,17 @@ import numpy as np import numpy.linalg as nl +from numpy.typing import NDArray from scipy.interpolate import CubicSpline def patch_center_anomaly( - shot_or_params: np.typing.NDArray | list, - update_shot: Callable[..., np.typing.NDArray] | None = None, - update_parameters: Callable[..., list] | None = None, + shot_or_params: NDArray | tuple, + update_shot: Callable[..., NDArray] | None = None, + update_parameters: Callable[..., tuple] | None = None, in_out: bool = False, learning_rate: float = 1e-1, -) -> tuple[np.typing.NDArray, list]: +) -> tuple[NDArray, tuple]: """Re-position samples to avoid center anomalies. Some trajectories behave slightly differently from expected when @@ -33,11 +34,11 @@ def patch_center_anomaly( shot_or_params : np.array, list Either a single shot of shape (Ns, Nd), or a list of arbitrary arguments used by ``update_shot`` to initialize a single shot. - update_shot : function, optional + update_shot : Callable[..., NDArray], optional Function used to initialize a single shot based on parameters provided by ``update_parameters``. If None, cubic splines are used as an approximation instead, by default None - update_parameters : function, optional + update_parameters : Callable[..., tuple], optional Function used to update shot parameters when provided in ``shot_or_params`` from an updated shot and parameters. If None, cubic spline parameterization is used instead, @@ -51,9 +52,9 @@ def patch_center_anomaly( Returns ------- - np.ndarray + NDArray N-D trajectory based on ``shot_or_params`` if a shot or - update_shot otherwise. + ``update_shot`` otherwise. list Updated parameters either in the ``shot_or_params`` format if params, or cubic spline parameterization as an array of @@ -72,9 +73,7 @@ def patch_center_anomaly( if update_shot is None or update_parameters is None: - def _default_update_parameters( - shot: np.typing.NDArray, *parameters: list - ) -> list: + def _default_update_parameters(shot: NDArray, *parameters: list) -> list: return parameters update_parameters = _default_update_parameters diff --git a/src/mrinufft/trajectories/inits/random_walk.py b/src/mrinufft/trajectories/inits/random_walk.py index 296e55f26..889fd20dd 100644 --- a/src/mrinufft/trajectories/inits/random_walk.py +++ b/src/mrinufft/trajectories/inits/random_walk.py @@ -1,19 +1,19 @@ """Trajectories based on random walks.""" -from typing import Literal +from typing import Any, Literal import numpy as np +from numpy.typing import NDArray from ..sampling import sample_from_density from ..utils import KMAX -def _get_adjacent_neighbors_offsets(shape: tuple[int, ...]) -> np.ndarray: - return np.concatenate([np.eye(len(shape)), -np.eye(len(shape))], axis=0).astype(int) +def _get_adjacent_neighbors_offsets(nb_dims: int) -> NDArray: + return np.concatenate([np.eye(nb_dims), -np.eye(nb_dims)], axis=0).astype(int) -def _get_neighbors_offsets(shape: tuple[int, ...]) -> np.ndarray: - nb_dims = len(shape) +def _get_neighbors_offsets(nb_dims: int) -> NDArray: neighbors = (np.indices([3] * nb_dims) - 1).reshape((nb_dims, -1)).T nb_half = neighbors.shape[0] // 2 # Remove full zero entry @@ -24,22 +24,23 @@ def _get_neighbors_offsets(shape: tuple[int, ...]) -> np.ndarray: def _initialize_ND_random_walk( Nc: int, Ns: int, - density: np.typing.NDArray, + density: NDArray, *, diagonals: bool = True, pseudo_random: bool = True, - **sampling_kwargs: Literal | bool, -) -> np.ndarray: + **sampling_kwargs: Any, # noqa ANN401 +) -> NDArray: density = density / np.sum(density) flat_density = np.copy(density.flatten()) - shape = np.array(density.shape) + shape = density.shape + nb_dims = len(shape) mask = np.ones_like(flat_density) # Prepare neighbor offsets once offsets = ( - _get_neighbors_offsets(shape) + _get_neighbors_offsets(nb_dims) if diagonals - else _get_adjacent_neighbors_offsets(shape) + else _get_adjacent_neighbors_offsets(nb_dims) ) # Make all random draws at once for performance @@ -47,8 +48,8 @@ def _initialize_ND_random_walk( # Initialize shot starting points locations = sample_from_density(Nc, density, **sampling_kwargs) - choices = np.around((locations + KMAX) * (np.array(density.shape) - 1)).astype(int) - choices = np.ravel_multi_index(choices.T, density.shape) + choices = np.around((locations + KMAX) * (np.array(shape) - 1)).astype(int) + choices = np.ravel_multi_index(choices.T, shape) routes = [choices] # Walk @@ -59,7 +60,7 @@ def _initialize_ND_random_walk( # Find out-of-bound neighbors and ignore them invalids = (neighbors < 0).any(axis=0) | ( - neighbors >= shape[:, None, None] + neighbors >= np.array(shape)[:, None, None] ).any(axis=0) neighbors[:, invalids] = 0 invalids = invalids.T @@ -84,25 +85,24 @@ def _initialize_ND_random_walk( mask[choices] * flat_density[choices] / (mask[choices] + 1) ) mask[choices] += 1 - routes = np.array(routes).T # Create trajectory from routes locations = np.indices(shape) - locations = locations.reshape((len(shape), -1)) - trajectory = np.array([locations[:, r].T for r in routes]) - trajectory = 2 * KMAX * trajectory / (shape - 1) - KMAX + locations = locations.reshape((nb_dims, -1)) + trajectory = np.array([locations[:, r].T for r in np.array(routes).T]) + trajectory = 2 * KMAX * trajectory / (np.array(shape) - 1) - KMAX return trajectory def initialize_2D_random_walk( Nc: int, Ns: int, - density: np.typing.NDArray, + density: NDArray, *, diagonals: bool = True, pseudo_random: bool = True, - **sampling_kwargs: Literal | bool, -) -> np.ndarray: + **sampling_kwargs: Any, # noqa ANN401 +) -> NDArray: """Initialize a 2D random walk trajectory. This is an adaptation of the proposition from [Cha+14]_. @@ -120,7 +120,7 @@ def initialize_2D_random_walk( Number of shots Ns : int Number of samples per shot - density : np.ndarray + density : NDArray Sampling density used to determine the walk probabilities, normalized automatically by its sum during the call for convenience. diagonals : bool, optional @@ -137,7 +137,7 @@ def initialize_2D_random_walk( Returns ------- - np.ndarray + NDArray 2D random walk trajectory References @@ -162,12 +162,12 @@ def initialize_2D_random_walk( def initialize_3D_random_walk( Nc: int, Ns: int, - density: np.typing.NDArray, + density: NDArray, *, diagonals: bool = True, pseudo_random: bool = True, - **sampling_kwargs: Literal | bool, -) -> np.ndarray: + **sampling_kwargs: Any, # noqa ANN401 +) -> NDArray: """Initialize a 3D random walk trajectory. This is an adaptation of the proposition from [Cha+14]_. @@ -185,7 +185,7 @@ def initialize_3D_random_walk( Number of shots Ns : int Number of samples per shot - density : np.ndarray + density : NDArray Sampling density used to determine the walk probabilities, normalized automatically by its sum during the call for convenience. diagonals : bool, optional @@ -202,7 +202,7 @@ def initialize_3D_random_walk( Returns ------- - np.ndarray + NDArray 3D random walk trajectory References diff --git a/src/mrinufft/trajectories/inits/travelling_salesman.py b/src/mrinufft/trajectories/inits/travelling_salesman.py index c0db6564b..063527f14 100644 --- a/src/mrinufft/trajectories/inits/travelling_salesman.py +++ b/src/mrinufft/trajectories/inits/travelling_salesman.py @@ -1,9 +1,10 @@ """Trajectories based on the Travelling Salesman Problem.""" -from typing import Literal +from typing import Any, Literal import numpy as np import numpy.linalg as nl +from numpy.typing import NDArray from scipy.interpolate import CubicSpline from tqdm.auto import tqdm @@ -12,7 +13,7 @@ from ..tools import oversample -def _get_approx_cluster_sizes(nb_total: int, nb_clusters: int) -> np.ndarray: +def _get_approx_cluster_sizes(nb_total: int, nb_clusters: int) -> NDArray: # Give a list of cluster sizes close to sqrt(`nb_total`) cluster_sizes = round(nb_total / nb_clusters) * np.ones(nb_clusters).astype(int) delta_sum = nb_total - np.sum(cluster_sizes) @@ -20,7 +21,9 @@ def _get_approx_cluster_sizes(nb_total: int, nb_clusters: int) -> np.ndarray: return cluster_sizes -def _sort_by_coordinate(array: np.typing.NDArray, coord: Literal) -> np.ndarray: +def _sort_by_coordinate( + array: NDArray, coord: Literal["x", "y", "z", "r", "phi", "theta"] +) -> NDArray: # Sort a list of N-D locations by a Cartesian/spherical coordinate if array.shape[-1] < 3 and coord.lower() in ["z", "theta"]: raise ValueError( @@ -49,12 +52,12 @@ def _sort_by_coordinate(array: np.typing.NDArray, coord: Literal) -> np.ndarray: def _cluster_by_coordinate( - locations: np.typing.NDArray, + locations: NDArray, nb_clusters: int, - cluster_by: Literal, - second_cluster_by: Literal | None = None, - sort_by: Literal | None = None, -) -> np.ndarray: + cluster_by: Literal["x", "y", "z", "r", "phi", "theta"], + second_cluster_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, + sort_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, +) -> NDArray: # Cluster approximately a list of N-D locations by Cartesian/spherical coordinates # Gather dimension variables nb_dims = locations.shape[-1] @@ -95,15 +98,15 @@ def _cluster_by_coordinate( def _initialize_ND_travelling_salesman( Nc: int, Ns: int, - density: np.typing.NDArray, - first_cluster_by: Literal | None = None, - second_cluster_by: Literal | None = None, - sort_by: Literal | None = None, + density: NDArray, + first_cluster_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, + second_cluster_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, + sort_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, tsp_tol: float = 1e-8, *, verbose: bool = False, - **sampling_kwargs: Literal | bool, -) -> np.ndarray: + **sampling_kwargs: Any, # noqa ANN401 +) -> NDArray: # Check arguments validity if Nc * Ns > np.prod(density.shape): raise ValueError("`density` array not large enough to pick `Nc` * `Ns` points.") @@ -142,15 +145,15 @@ def _initialize_ND_travelling_salesman( def initialize_2D_travelling_salesman( Nc: int, Ns: int, - density: np.typing.NDArray, - first_cluster_by: Literal | None = None, - second_cluster_by: Literal | None = None, - sort_by: Literal | None = None, + density: NDArray, + first_cluster_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, + second_cluster_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, + sort_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, tsp_tol: float = 1e-8, *, verbose: bool = False, - **sampling_kwargs: Literal | bool, -) -> np.ndarray: + **sampling_kwargs: Any, # noqa ANN401 +) -> NDArray: """ Initialize a 2D trajectory using a Travelling Salesman Problem (TSP)-based path. @@ -166,14 +169,14 @@ def initialize_2D_travelling_salesman( The number of clusters (or shots) to divide the trajectory into. Ns : int The number of points per cluster. - density : np.ndarray + density : NDArray A 2-dimensional density array from which points are sampled. - first_cluster_by : Literal, optional + first_cluster_by : {"x", "y", "z", "r", "phi", "theta"}, optional The coordinate used to cluster points initially, by default ``None``. - second_cluster_by : Literal, optional + second_cluster_by : {"x", "y", "z", "r", "phi", "theta"}, optional A secondary coordinate used for clustering within primary clusters, by default ``None``. - sort_by : Literal, optional + sort_by : {"x", "y", "z", "r", "phi", "theta"}, optional The coordinate by which to order points within each cluster, by default ``None``. tsp_tol : float, optional @@ -186,7 +189,7 @@ def initialize_2D_travelling_salesman( Returns ------- - np.ndarray + NDArray A 2D array representing the TSP-ordered trajectory. Raises @@ -219,15 +222,15 @@ def initialize_2D_travelling_salesman( def initialize_3D_travelling_salesman( Nc: int, Ns: int, - density: np.typing.NDArray, - first_cluster_by: Literal | None = None, - second_cluster_by: Literal | None = None, - sort_by: Literal | None = None, + density: NDArray, + first_cluster_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, + second_cluster_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, + sort_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, tsp_tol: float = 1e-8, *, verbose: bool = False, - **sampling_kwargs: Literal | bool, -) -> np.ndarray: + **sampling_kwargs: Any, # noqa ANN401 +) -> NDArray: """ Initialize a 3D trajectory using a Travelling Salesman Problem (TSP)-based path. @@ -245,14 +248,14 @@ def initialize_3D_travelling_salesman( The number of clusters (or shots) to divide the trajectory into. Ns : int The number of points per cluster. - density : np.ndarray + density : NDArray A 3-dimensional density array from which points are sampled. - first_cluster_by : Literal, optional + first_cluster_by : {"x", "y", "z", "r", "phi", "theta"}, optional The coordinate used to cluster points initially, by default ``None``. - second_cluster_by : Literal, optional + second_cluster_by : {"x", "y", "z", "r", "phi", "theta"}, optional A secondary coordinate used for clustering within primary clusters, by default ``None``. - sort_by : Literal, optional + sort_by : {"x", "y", "z", "r", "phi", "theta"}, optional The coordinate by which to order points within each cluster, by default ``None``. tsp_tol : float, optional @@ -265,7 +268,7 @@ def initialize_3D_travelling_salesman( Returns ------- - np.ndarray + NDArray A 3D array representing the TSP-ordered trajectory. Raises diff --git a/src/mrinufft/trajectories/maths/primes.py b/src/mrinufft/trajectories/maths/primes.py index 3ed36fb5d..9db01c483 100644 --- a/src/mrinufft/trajectories/maths/primes.py +++ b/src/mrinufft/trajectories/maths/primes.py @@ -25,7 +25,7 @@ def compute_coprime_factors( List of coprime factors of Nc. """ count = start - coprimes = [] + coprimes: list[int] = [] while len(coprimes) < length: # Check greatest common divider (gcd) if np.gcd(Nc, count) == 1: diff --git a/src/mrinufft/trajectories/maths/rotations.py b/src/mrinufft/trajectories/maths/rotations.py index 59e2558d1..f91cd37b8 100644 --- a/src/mrinufft/trajectories/maths/rotations.py +++ b/src/mrinufft/trajectories/maths/rotations.py @@ -2,9 +2,10 @@ import numpy as np import numpy.linalg as nl +from numpy.typing import NDArray -def R2D(theta: float) -> np.ndarray: +def R2D(theta: float) -> NDArray: """Initialize 2D rotation matrix. Parameters @@ -14,13 +15,13 @@ def R2D(theta: float) -> np.ndarray: Returns ------- - np.ndarray + NDArray 2D rotation matrix. """ return np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) -def Rx(theta: float) -> np.ndarray: +def Rx(theta: float) -> NDArray: """Initialize 3D rotation matrix around x axis. Parameters @@ -30,7 +31,7 @@ def Rx(theta: float) -> np.ndarray: Returns ------- - np.ndarray + NDArray 3D rotation matrix. """ return np.array( @@ -42,7 +43,7 @@ def Rx(theta: float) -> np.ndarray: ) -def Ry(theta: float) -> np.ndarray: +def Ry(theta: float) -> NDArray: """Initialize 3D rotation matrix around y axis. Parameters @@ -52,7 +53,7 @@ def Ry(theta: float) -> np.ndarray: Returns ------- - np.ndarray + NDArray 3D rotation matrix. """ return np.array( @@ -64,7 +65,7 @@ def Ry(theta: float) -> np.ndarray: ) -def Rz(theta: float) -> np.ndarray: +def Rz(theta: float) -> NDArray: """Initialize 3D rotation matrix around z axis. Parameters @@ -74,7 +75,7 @@ def Rz(theta: float) -> np.ndarray: Returns ------- - np.ndarray + NDArray 3D rotation matrix. """ return np.array( @@ -87,12 +88,12 @@ def Rz(theta: float) -> np.ndarray: def Rv( - v1: np.typing.NDArray, - v2: np.typing.NDArray, + v1: NDArray, + v2: NDArray, eps: float = 1e-8, *, normalize: bool = True, -) -> np.ndarray: +) -> NDArray: """Initialize 3D rotation matrix from two vectors. Initialize a 3D rotation matrix from two vectors using Rodrigues's rotation @@ -103,9 +104,9 @@ def Rv( Parameters ---------- - v1 : np.ndarray + v1 : NDArray Source vector. - v2 : np.ndarray + v2 : NDArray Target vector. eps : float, optional Tolerance to consider two vectors as colinear. The default is 1e-8. @@ -114,7 +115,7 @@ def Rv( Returns ------- - np.ndarray + NDArray 3D rotation matrix. """ # Check for colinearity, not handled by Rodrigues' coefficients @@ -130,7 +131,7 @@ def Rv( return np.identity(3) + cross_matrix + cross_matrix @ cross_matrix / (1 + cos_theta) -def Ra(vector: np.typing.NDArray, theta: float) -> np.ndarray: +def Ra(vector: NDArray, theta: float) -> NDArray: """Initialize 3D rotation matrix around an arbitrary vector. Initialize a 3D rotation matrix to rotate around `vector` by an angle `theta`. @@ -138,14 +139,14 @@ def Ra(vector: np.typing.NDArray, theta: float) -> np.ndarray: Parameters ---------- - vector : np.ndarray + vector : NDArray Vector defining the rotation axis, automatically normalized. theta : float Angle in radians defining the rotation around `vector`. Returns ------- - np.ndarray + NDArray 3D rotation matrix. """ cos_t = np.cos(theta) diff --git a/src/mrinufft/trajectories/maths/tsp_solver.py b/src/mrinufft/trajectories/maths/tsp_solver.py index 92a6709e1..0b40a53aa 100644 --- a/src/mrinufft/trajectories/maths/tsp_solver.py +++ b/src/mrinufft/trajectories/maths/tsp_solver.py @@ -1,11 +1,12 @@ """Solver for the Travelling Salesman Problem.""" import numpy as np +from numpy.typing import NDArray def solve_tsp_with_2opt( - locations: np.typing.NDArray, improvement_threshold: float = 1e-8 -) -> np.ndarray: + locations: NDArray, improvement_threshold: float = 1e-8 +) -> NDArray: """Solve the TSP problem using a 2-opt approach. A sub-optimal solution to the Travelling Salesman Problem (TSP) @@ -16,7 +17,7 @@ def solve_tsp_with_2opt( Parameters ---------- - locations : np.ndarray + locations : NDArray An array of N points with shape (N, D) where D is the space dimension. improvement_threshold : float, optional Threshold used as progress criterion to stop the optimization process. @@ -24,7 +25,7 @@ def solve_tsp_with_2opt( Returns ------- - np.ndarray + NDArray The new positions order of shape (N,). """ route = np.arange(locations.shape[0]) diff --git a/src/mrinufft/trajectories/sampling.py b/src/mrinufft/trajectories/sampling.py index 6d011bf4b..6b4c643ad 100644 --- a/src/mrinufft/trajectories/sampling.py +++ b/src/mrinufft/trajectories/sampling.py @@ -1,5 +1,7 @@ """Sampling densities and methods.""" +from __future__ import annotations + from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: @@ -7,6 +9,7 @@ import numpy as np import numpy.fft as nf +from numpy.typing import NDArray from tqdm.auto import tqdm from .utils import KMAX @@ -14,11 +17,11 @@ def sample_from_density( nb_samples: int, - density: np.typing.NDArray, - method: Literal = "random", + density: NDArray, + method: Literal["random", "lloyd"] = "random", *, - dim_compensation: Literal | bool = "auto", -) -> np.ndarray: + dim_compensation: Literal["auto"] | bool = "auto", +) -> NDArray: """ Sample points based on a given density distribution. @@ -26,14 +29,14 @@ def sample_from_density( ---------- nb_samples : int The number of samples to draw. - density : np.ndarray + density : NDArray An array representing the density distribution from which samples are drawn, normalized automatically by its sum during the call for convenience. - method : str, optional + method : Literal["random", "lloyd"], optional The sampling method to use, either 'random' for random sampling over the discrete grid defined by the density or 'lloyd' for Lloyd's algorithm over a continuous space, by default "random". - dim_compensation : str, bool, optional + dim_compensation : Literal["auto"], bool, optional Whether to apply a specific dimensionality compensation introduced in [Cha+14]_. An exponent ``N/(N-1)`` with ``N`` the number of dimensions in ``density`` is applied to fix the observed @@ -44,7 +47,7 @@ def sample_from_density( Returns ------- - np.ndarray + NDArray An array of range-normalized sampled locations. Raises @@ -70,7 +73,7 @@ def sample_from_density( ) from err # Define dimension variables - shape = np.array(density.shape) + shape = density.shape nb_dims = len(shape) max_nb_samples = np.prod(shape) density = density / np.sum(density) @@ -98,7 +101,7 @@ def sample_from_density( ) locations = np.indices(shape).reshape((nb_dims, -1))[:, choices] locations = locations.T + 0.5 - locations = locations / shape[None, :] + locations = locations / np.array(shape)[None, :] locations = 2 * KMAX * locations - KMAX elif method == "lloyd": kmeans = ( @@ -107,10 +110,10 @@ def sample_from_density( else BisectingKMeans(n_clusters=nb_samples) ) kmeans.fit( - np.indices(density.shape).reshape((nb_dims, -1)).T, + np.indices(shape).reshape((nb_dims, -1)).T, sample_weight=density.flatten(), ) - locations = kmeans.cluster_centers_ - np.array(density.shape) / 2 + locations = kmeans.cluster_centers_ - np.array(shape) / 2 locations = KMAX * locations / np.max(np.abs(locations)) else: raise ValueError(f"Unknown sampling method {method}.") @@ -121,8 +124,8 @@ def create_cutoff_decay_density( shape: tuple[int, ...], cutoff: float, decay: float, - resolution: np.typing.NDArray | None = None, -) -> np.ndarray: + resolution: NDArray | None = None, +) -> NDArray: """ Create a density with central plateau and polynomial decay. @@ -140,13 +143,13 @@ def create_cutoff_decay_density( and 1 within which density remains uniform and beyond which it decays. decay : float The polynomial decay in density beyond the cutoff ratio. - resolution : np.ndarray, optional + resolution : NDArray, optional Resolution scaling factors for each dimension of the density grid, by default ``None``. Returns ------- - np.ndarray + NDArray A density array with values decaying based on the specified cutoff ratio and decay rate. @@ -158,7 +161,6 @@ def create_cutoff_decay_density( magnetic resonance imaging." IEEE Transactions on Medical Imaging 41, no. 8 (2022): 2105-2117. """ - shape = np.array(shape) nb_dims = len(shape) if not resolution: @@ -180,8 +182,8 @@ def create_cutoff_decay_density( def create_polynomial_density( - shape: tuple[int, ...], decay: float, resolution: np.typing.NDArray | None = None -) -> np.ndarray: + shape: tuple[int, ...], decay: float, resolution: NDArray | None = None +) -> NDArray: """ Create a density with polynomial decay from the center. @@ -191,13 +193,13 @@ def create_polynomial_density( The shape of the density grid. decay : float The exponent that controls the rate of decay for density. - resolution : np.ndarray, optional + resolution : NDArray, optional Resolution scaling factors for each dimension of the density grid, by default None. Returns ------- - np.ndarray + NDArray A density array with polynomial decay. """ return create_cutoff_decay_density( @@ -205,7 +207,7 @@ def create_polynomial_density( ) -def create_energy_density(dataset: np.typing.NDArray) -> np.ndarray: +def create_energy_density(dataset: NDArray) -> NDArray: """ Create a density based on energy in the Fourier spectrum. @@ -214,7 +216,7 @@ def create_energy_density(dataset: np.typing.NDArray) -> np.ndarray: Parameters ---------- - dataset : np.ndarray + dataset : NDArray The dataset from which to calculate the density based on its Fourier transform, with an expected shape (nb_volumes, dim_1, ..., dim_N). @@ -222,7 +224,7 @@ def create_energy_density(dataset: np.typing.NDArray) -> np.ndarray: Returns ------- - np.ndarray + NDArray A density array derived from the mean energy in the Fourier domain of the input dataset. """ @@ -237,11 +239,11 @@ def create_energy_density(dataset: np.typing.NDArray) -> np.ndarray: def create_chauffert_density( shape: tuple[int, ...], - wavelet_basis: Literal | pw.Wavelet, + wavelet_basis: str | pw.Wavelet, nb_wavelet_scales: int, *, verbose: bool = False, -) -> np.ndarray: +) -> NDArray: """Create a density based on Chauffert's method. This is a reproduction of the proposition from [CCW13]_. @@ -264,7 +266,7 @@ def create_chauffert_density( Returns ------- - np.ndarray + NDArray A density array created based on wavelet transform coefficients. See Also @@ -312,9 +314,9 @@ def create_chauffert_density( def create_fast_chauffert_density( shape: tuple[int, ...], - wavelet_basis: Literal | pw.Wavelet, + wavelet_basis: str | pw.Wavelet, nb_wavelet_scales: int, -) -> np.ndarray: +) -> NDArray: """Create a density based on an approximated Chauffert method. This implementation is based on this @@ -341,7 +343,7 @@ def create_fast_chauffert_density( Returns ------- - np.ndarray + NDArray A density array created using a faster approximation based on 1D projections of the wavelet transform. diff --git a/src/mrinufft/trajectories/tools.py b/src/mrinufft/trajectories/tools.py index 25d173937..02ca23ae0 100644 --- a/src/mrinufft/trajectories/tools.py +++ b/src/mrinufft/trajectories/tools.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Literal import numpy as np +from numpy.typing import NDArray from scipy.interpolate import CubicSpline, interp1d from .maths import Rv, Rx, Ry, Rz @@ -14,28 +15,28 @@ def stack( - trajectory: np.typing.NDArray, + trajectory: NDArray, nb_stacks: int, - z_tilt: Literal | float | None = None, + z_tilt: str | float | None = None, *, hard_bounded: bool = True, -) -> np.ndarray: +) -> NDArray: """Stack 2D or 3D trajectories over the :math:`k_z`-axis. Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory in 2D or 3D to stack. nb_stacks : int Number of stacks repeating the provided trajectory. - z_tilt : Literal, float, optional + z_tilt : str | float, optional Tilt of the stacks, by default `None`. hard_bounded : bool, optional Whether the stacks should be strictly within the limits of the k-space. Returns ------- - np.ndarray + NDArray Stacked trajectory. """ # Check dimensionality and initialize output @@ -64,30 +65,30 @@ def stack( def rotate( - trajectory: np.typing.NDArray, + trajectory: NDArray, nb_rotations: int, - x_tilt: Literal | float | None = None, - y_tilt: Literal | float | None = None, - z_tilt: Literal | float | None = None, -) -> np.ndarray: + x_tilt: str | float | None = None, + y_tilt: str | float | None = None, + z_tilt: str | float | None = None, +) -> NDArray: """Rotate 2D or 3D trajectories over the different axes. Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory in 2D or 3D to rotate. nb_rotations : int Number of rotations repeating the provided trajectory. - x_tilt : Literal, optional + x_tilt : str | float, optional Tilt of the trajectory over the :math:`k_x`-axis, by default `None`. - y_tilt : Literal, optional + y_tilt : str | float, optional Tilt of the trajectory over the :math:`k_y`-axis, by default `None`. - z_tilt : Literal, optional + z_tilt : str | float, optional Tilt of the trajectory over the :math:`k_z`-axis, by default `None`. Returns ------- - np.ndarray + NDArray Rotated trajectory. """ # Check dimensionality and initialize output @@ -110,33 +111,33 @@ def rotate( def precess( - trajectory: np.typing.NDArray, + trajectory: NDArray, nb_rotations: int, - tilt: Literal | float = "golden", + tilt: str | float = "golden", half_sphere: bool = False, - partition: Literal = "axial", - axis: int | np.ndarray | None = None, -) -> np.ndarray: + partition: Literal["axial", "polar"] = "axial", + axis: int | NDArray | None = None, +) -> NDArray: """Rotate trajectories as a precession around the :math:`k_z`-axis. Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory in 2D or 3D to rotate. nb_rotations : int Number of rotations repeating the provided trajectory while precessing. - tilt : Literal, float, optional + tilt : str | float, optional Angle tilt between consecutive rotations around the :math:`k_z`-axis, by default "golden". half_sphere : bool, optional Whether the precession should be limited to the upper half of the k-space sphere. It is typically used for in-out trajectories or planes. - partition : Literal, optional + partition : Literal["axial", "polar"], optional Partition type between an "axial" or "polar" split of the :math:`k_z`-axis, designating whether the axis should be fragmented by radius or angle respectively, by default "axial". - axis : int, np.ndarray, optional + axis : int, NDArray, optional Axis selected for alignment reference when rotating the trajectory around the :math:`k_z`-axis, generally corresponding to the shot direction for single shot ``trajectory`` inputs. It can either @@ -146,7 +147,7 @@ def precess( Returns ------- - np.ndarray + NDArray Precessed trajectory. """ # Check for partition option error @@ -189,22 +190,22 @@ def precess( def conify( - trajectory: np.typing.NDArray, + trajectory: NDArray, nb_cones: int, - z_tilt: Literal | float | None = None, + z_tilt: str | float | None = None, in_out: bool = False, max_angle: float = np.pi / 2, borderless: bool = True, -) -> np.ndarray: +) -> NDArray: """Distort 2D or 3D trajectories into cones along the :math:`k_z`-axis. Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory to conify. nb_cones : int Number of cones repeating the provided trajectory. - z_tilt : Literal, float, optional + z_tilt : str | float, optional Tilt of the trajectory over the :math:`k_z`-axis, by default `None`. in_out : bool, optional Whether to account for the in-out nature of some trajectories @@ -217,7 +218,7 @@ def conify( Returns ------- - np.ndarray + NDArray Conified trajectory. """ # Check dimensionality and initialize output @@ -268,12 +269,12 @@ def conify( def epify( - trajectory: np.typing.NDArray, + trajectory: NDArray, Ns_transitions: int, nb_trains: int, *, reverse_odd_shots: bool = False, -) -> np.ndarray: +) -> NDArray: """Create multi-readout shots from trajectory composed of single-readouts. Assemble multiple single-readout shots together by adding transition @@ -281,7 +282,7 @@ def epify( Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory to change by prolonging and merging the shots. Ns_transitions : int Number of samples/steps between the merged readouts. @@ -294,7 +295,7 @@ def epify( Returns ------- - np.ndarray + NDArray Trajectory with fewer but longer multi-readout shots. """ Nc, Ns, Nd = trajectory.shape @@ -322,14 +323,10 @@ def epify( for i_c in range(nb_trains): spline = CubicSpline(source_sample_ids, np.concatenate(trajectory[i_c], axis=0)) assembled_trajectory.append(spline(target_sample_ids)) - assembled_trajectory = np.array(assembled_trajectory) + return np.array(assembled_trajectory) - return assembled_trajectory - -def unepify( - trajectory: np.typing.NDArray, Ns_readouts: int, Ns_transitions: int -) -> np.ndarray: +def unepify(trajectory: NDArray, Ns_readouts: int, Ns_transitions: int) -> NDArray: """Recover single-readout shots from multi-readout trajectory. Reformat an EPI-like trajectory with multiple readouts and transitions @@ -341,7 +338,7 @@ def unepify( Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory to reduce by discarding transitions between readouts. Ns_readouts : int Number of samples within a single readout. @@ -350,7 +347,7 @@ def unepify( Returns ------- - np.ndarray + NDArray Trajectory with more but shorter single shots. """ Nc, Ns, Nd = trajectory.shape @@ -371,7 +368,7 @@ def unepify( return trajectory -def prewind(trajectory: np.typing.NDArray, Ns_transitions: int) -> np.ndarray: +def prewind(trajectory: NDArray, Ns_transitions: int) -> NDArray: """Add pre-winding/positioning to the trajectory. The trajectory is extended to start before the readout @@ -380,7 +377,7 @@ def prewind(trajectory: np.typing.NDArray, Ns_transitions: int) -> np.ndarray: Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory to extend with rewind gradients. Ns_transitions : int Number of pre-winding/positioning steps used to leave the @@ -388,7 +385,7 @@ def prewind(trajectory: np.typing.NDArray, Ns_transitions: int) -> np.ndarray: Returns ------- - np.ndarray + NDArray Extended trajectory with pre-winding/positioning. """ Nc, Ns, Nd = trajectory.shape @@ -406,12 +403,10 @@ def prewind(trajectory: np.typing.NDArray, Ns_transitions: int) -> np.ndarray: np.concatenate([np.zeros((2, Nd)), trajectory[i_c]], axis=0), ) assembled_trajectory.append(spline(target_sample_ids)) - assembled_trajectory = np.array(assembled_trajectory) - - return assembled_trajectory + return np.array(assembled_trajectory) -def rewind(trajectory: np.typing.NDArray, Ns_transitions: int) -> np.ndarray: +def rewind(trajectory: NDArray, Ns_transitions: int) -> NDArray: """Add rewinding to the trajectory. The trajectory is extended to come back to the k-space center @@ -419,14 +414,14 @@ def rewind(trajectory: np.typing.NDArray, Ns_transitions: int) -> np.ndarray: Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory to extend with rewind gradients. Ns_transitions : int Number of rewinding steps used to come back to the k-space center. Returns ------- - np.ndarray + NDArray Extended trajectory with rewinding. """ Nc, Ns, Nd = trajectory.shape @@ -446,20 +441,20 @@ def rewind(trajectory: np.typing.NDArray, Ns_transitions: int) -> np.ndarray: np.concatenate([trajectory[i_c], np.zeros((2, Nd))], axis=0), ) assembled_trajectory.append(spline(target_sample_ids)) - assembled_trajectory = np.array(assembled_trajectory) - - return assembled_trajectory + return np.array(assembled_trajectory) def oversample( - trajectory: np.typing.NDArray, new_Ns: int, kind: Literal = "cubic" -) -> np.ndarray: + trajectory: NDArray, + new_Ns: int, + kind: Literal["linear", "quadratic", "cubic"] = "cubic", +) -> NDArray: """ Resample a trajectory to increase the number of samples using interpolation. Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray The original trajectory array, where interpolation is applied along the second axis. new_Ns : int @@ -470,7 +465,7 @@ def oversample( Returns ------- - np.ndarray + NDArray The resampled trajectory array with ``new_Ns`` points along the second axis. Notes @@ -499,25 +494,25 @@ def oversample( def stack_spherically( - trajectory_func: Callable[..., np.ndarray], + trajectory_func: Callable[..., NDArray], Nc: int, nb_stacks: int, - z_tilt: Literal | float | None = None, + z_tilt: str | float | None = None, hard_bounded: bool = True, **traj_kwargs: Any, # noqa ANN401 -) -> np.ndarray: +) -> NDArray: """Stack 2D or 3D trajectories over the :math:`k_z`-axis to make a sphere. Parameters ---------- - trajectory_func : Callable[..., np.ndarray] + trajectory_func : Callable[..., NDArray] Trajectory function that should return an array-like with the usual (Nc, Ns, Nd) size. Nc : int Number of shots to use for the whole spherically stacked trajectory. nb_stacks : int Number of stacks of trajectories. - z_tilt : Literal, float, optional + z_tilt : str | float, optional Tilt of the stacks, by default `None`. hard_bounded : bool, optional Whether the stacks should be strictly within the limits @@ -528,7 +523,7 @@ def stack_spherically( Returns ------- - np.ndarray + NDArray Stacked trajectory. """ # Handle argument errors @@ -575,38 +570,33 @@ def stack_spherically( # Concatenate or handle varying Ns value Ns_values = np.array([stk.shape[1] for stk in new_trajectory]) if (Ns_values == Ns_values[0]).all(): - new_trajectory = np.concatenate(new_trajectory, axis=0) - new_trajectory = new_trajectory.reshape(Nc, Ns_values[0], 3) - else: - new_trajectory = np.concatenate( - [stk.reshape((-1, 3)) for stk in new_trajectory], axis=0 - ) - - return new_trajectory + output = np.concatenate(new_trajectory, axis=0) + return output.reshape(Nc, Ns_values[0], 3) + return np.concatenate([stk.reshape((-1, 3)) for stk in new_trajectory], axis=0) def shellify( - trajectory_func: Callable[..., np.ndarray], + trajectory_func: Callable[..., NDArray], Nc: int, nb_shells: int, - z_tilt: Literal | float = "golden", - hemisphere_mode: Literal = "symmetric", + z_tilt: str | float = "golden", + hemisphere_mode: Literal["symmetric", "reversed"] = "symmetric", **traj_kwargs: Any, # noqa ANN401 -) -> np.ndarray: +) -> NDArray: """Stack 2D or 3D trajectories over the :math:`k_z`-axis to make a sphere. Parameters ---------- - trajectory_func : Callable[..., np.ndarray] + trajectory_func : Callable[..., NDArray] Trajectory function that should return an array-like with the usual (Nc, Ns, Nd) size. Nc : int Number of shots to use for the whole spherically stacked trajectory. nb_shells : int Number of shells of distorted trajectories. - z_tilt : Literal, float, optional + z_tilt : str | float, optional Tilt of the shells, by default "golden". - hemisphere_mode : Literal, optional + hemisphere_mode : Literal["symmetric", "reversed"], optional Define how the lower hemisphere should be oriented relatively to the upper one, with "symmetric" providing a :math:`k_x-k_y` planar symmetry by changing the polar angle, and with "reversed" promoting continuity @@ -618,7 +608,7 @@ def shellify( Returns ------- - np.ndarray + NDArray Concentric shell trajectory. """ # Handle argument errors @@ -669,14 +659,9 @@ def shellify( # Concatenate or handle varying Ns value Ns_values = np.array([hem.shape[1] for hem in new_trajectory]) if (Ns_values == Ns_values[0]).all(): - new_trajectory = np.concatenate(new_trajectory, axis=0) - new_trajectory = new_trajectory.reshape(Nc, Ns_values[0], 3) - else: - new_trajectory = np.concatenate( - [hem.reshape((-1, 3)) for hem in new_trajectory], axis=0 - ) - - return new_trajectory + output = np.concatenate(new_trajectory, axis=0) + return output.reshape(Nc, Ns_values[0], 3) + return np.concatenate([hem.reshape((-1, 3)) for hem in new_trajectory], axis=0) ######### @@ -685,8 +670,8 @@ def shellify( def duplicate_along_axes( - trajectory: np.typing.NDArray, axes: tuple[int, ...] = (0, 1, 2) -) -> np.ndarray: + trajectory: NDArray, axes: tuple[int, ...] = (0, 1, 2) +) -> NDArray: """ Duplicate a trajectory along the specified axes. @@ -696,14 +681,14 @@ def duplicate_along_axes( Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory to duplicate. axes : tuple[int, ...], optional Axes along which to duplicate the trajectory, by default (0, 1, 2) Returns ------- - np.ndarray + NDArray Duplicated trajectory along the specified axes. """ # Copy input trajectory along other axes @@ -718,23 +703,22 @@ def duplicate_along_axes( dp_trajectory = np.copy(trajectory) dp_trajectory[..., [2, 0]] = dp_trajectory[..., [0, 2]] new_trajectory.append(dp_trajectory) - new_trajectory = np.concatenate(new_trajectory, axis=0) - return new_trajectory + return np.concatenate(new_trajectory, axis=0) -def _radialize_center_out(trajectory: np.typing.NDArray, nb_samples: int) -> np.ndarray: +def _radialize_center_out(trajectory: NDArray, nb_samples: int) -> NDArray: """Radialize a trajectory from the center to the outside. Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory to radialize. nb_samples : int Number of samples to radialize from the center. Returns ------- - np.ndarray + NDArray Radialized trajectory. """ Nc, Ns = trajectory.shape[:2] @@ -747,19 +731,19 @@ def _radialize_center_out(trajectory: np.typing.NDArray, nb_samples: int) -> np. return new_trajectory -def _radialize_in_out(trajectory: np.typing.NDArray, nb_samples: int) -> np.ndarray: +def _radialize_in_out(trajectory: NDArray, nb_samples: int) -> NDArray: """Radialize a trajectory from the inside to the outside. Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory to radialize. nb_samples : int Number of samples to radialize from the inside out. Returns ------- - np.ndarray + NDArray Radialized trajectory. """ Nc, Ns = trajectory.shape[:2] @@ -778,13 +762,13 @@ def _radialize_in_out(trajectory: np.typing.NDArray, nb_samples: int) -> np.ndar def radialize_center( - trajectory: np.typing.NDArray, nb_samples: int, in_out: bool = False -) -> np.ndarray: + trajectory: NDArray, nb_samples: int, in_out: bool = False +) -> NDArray: """Radialize a trajectory. Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Trajectory to radialize. nb_samples : int Number of samples to keep. @@ -793,7 +777,7 @@ def radialize_center( Returns ------- - np.ndarray + NDArray Radialized trajectory. """ # Make nb_samples into straight lines around the center diff --git a/src/mrinufft/trajectories/trajectory2D.py b/src/mrinufft/trajectories/trajectory2D.py index 2e1e5a89a..a9029e25c 100644 --- a/src/mrinufft/trajectories/trajectory2D.py +++ b/src/mrinufft/trajectories/trajectory2D.py @@ -4,6 +4,7 @@ import numpy as np import numpy.linalg as nl +from numpy.typing import NDArray from scipy.interpolate import CubicSpline from .gradients import patch_center_anomaly @@ -17,8 +18,8 @@ def initialize_2D_radial( - Nc: int, Ns: int, tilt: Literal | float = "uniform", in_out: bool = False -) -> np.ndarray: + Nc: int, Ns: int, tilt: str | float = "uniform", in_out: bool = False +) -> NDArray: """Initialize a 2D radial trajectory. Parameters @@ -27,14 +28,14 @@ def initialize_2D_radial( Number of shots Ns : int Number of samples per shot - tilt : Literal, float, optional + tilt : str | float, optional Tilt of the shots, by default "uniform" in_out : bool, optional Whether to start from the center or not, by default False Returns ------- - np.ndarray + NDArray 2D radial trajectory """ # Initialize a first shot @@ -53,12 +54,12 @@ def initialize_2D_radial( def initialize_2D_spiral( Nc: int, Ns: int, - tilt: Literal | float = "uniform", + tilt: str | float = "uniform", in_out: bool = False, - nb_revolutions: int = 1, - spiral: Literal | float = "archimedes", + nb_revolutions: float = 1.0, + spiral: str | float = "archimedes", patch_center: bool = True, -) -> np.ndarray: +) -> NDArray: """Initialize a 2D algebraic spiral trajectory. A generalized function that generates algebraic spirals defined @@ -76,7 +77,7 @@ def initialize_2D_spiral( Tilt of the shots, by default "uniform" in_out : bool, optional Whether to start from the center or not, by default False - nb_revolutions : int, optional + nb_revolutions : float, optional Number of revolutions, by default 1 spiral : Literal, float, optional Spiral type or algebraic power, by default "archimedes" @@ -86,7 +87,7 @@ def initialize_2D_spiral( Returns ------- - np.ndarray + NDArray 2D spiral trajectory Raises @@ -114,17 +115,17 @@ def initialize_2D_spiral( # have a non-monotonic gradient norm when varying the angle # over [0, +inf) def _update_shot( - angles: np.typing.NDArray, radius: np.typing.NDArray, *args: Any # noqa ANN401 - ) -> np.ndarray: + angles: NDArray, radius: NDArray, *args: Any # noqa ANN401 + ) -> NDArray: shot = np.sign(angles) * np.abs(radius) * np.exp(1j * np.abs(angles)) return np.stack([shot.real, shot.imag], axis=-1) def _update_parameters( - single_shot: np.typing.NDArray, - angles: np.typing.NDArray, - radius: np.typing.NDArray, + single_shot: NDArray, + angles: NDArray, + radius: NDArray, spiral_power: float, - ) -> tuple[np.ndarray, np.ndarray, float]: + ) -> tuple[NDArray, NDArray, float]: radius = nl.norm(single_shot, axis=-1) angles = np.sign(angles) * np.abs(radius) ** (1 / spiral_power) return angles, radius, spiral_power @@ -157,7 +158,7 @@ def _update_parameters( def initialize_2D_fibonacci_spiral( Nc: int, Ns: int, spiral_reduction: float = 1, patch_center: bool = True -) -> np.ndarray: +) -> NDArray: """Initialize a 2D Fibonacci spiral trajectory. A non-algebraic spiral trajectory based on the Fibonacci sequence, @@ -181,7 +182,7 @@ def initialize_2D_fibonacci_spiral( Returns ------- - np.ndarray + NDArray 2D Fibonacci spiral trajectory References @@ -229,11 +230,11 @@ def initialize_2D_fibonacci_spiral( def initialize_2D_cones( Nc: int, Ns: int, - tilt: Literal = "uniform", + tilt: str | float = "uniform", in_out: bool = False, nb_zigzags: float = 5, width: float = 1, -) -> np.ndarray: +) -> NDArray: """Initialize a 2D cone trajectory. Parameters @@ -242,7 +243,7 @@ def initialize_2D_cones( Number of shots Ns : int Number of samples per shot - tilt : Literal, optional + tilt : str | float, optional Tilt of the shots, by default "uniform" in_out : bool, optional Whether to start from the center or not, by default False @@ -253,7 +254,7 @@ def initialize_2D_cones( Returns ------- - np.ndarray + NDArray 2D cone trajectory """ @@ -275,11 +276,11 @@ def initialize_2D_cones( def initialize_2D_sinusoide( Nc: int, Ns: int, - tilt: Literal | float = "uniform", + tilt: str | float = "uniform", in_out: bool = False, nb_zigzags: float = 5, width: float = 1, -) -> np.ndarray: +) -> NDArray: """Initialize a 2D sinusoide trajectory. Parameters @@ -288,7 +289,7 @@ def initialize_2D_sinusoide( Number of shots Ns : int Number of samples per shot - tilt : Literal, float, optional + tilt : str | float, optional Tilt of the shots, by default "uniform" in_out : bool, optional Whether to start from the center or not, by default False @@ -299,7 +300,7 @@ def initialize_2D_sinusoide( Returns ------- - np.ndarray + NDArray 2D sinusoide trajectory """ @@ -318,7 +319,7 @@ def initialize_2D_sinusoide( return trajectory -def initialize_2D_propeller(Nc: int, Ns: int, nb_strips: int) -> np.ndarray: +def initialize_2D_propeller(Nc: int, Ns: int, nb_strips: int) -> NDArray: """Initialize a 2D PROPELLER trajectory, as proposed in [Pip99]_. The PROPELLER trajectory is generally used along a specific @@ -366,7 +367,7 @@ def initialize_2D_propeller(Nc: int, Ns: int, nb_strips: int) -> np.ndarray: return KMAX * trajectory -def initialize_2D_rings(Nc: int, Ns: int, nb_rings: int) -> np.ndarray: +def initialize_2D_rings(Nc: int, Ns: int, nb_rings: int) -> NDArray: """Initialize a 2D ring trajectory, as proposed in [HHN08]_. Parameters @@ -380,7 +381,7 @@ def initialize_2D_rings(Nc: int, Ns: int, nb_rings: int) -> np.ndarray: Returns ------- - np.ndarray + NDArray 2D ring trajectory References @@ -414,7 +415,7 @@ def initialize_2D_rings(Nc: int, Ns: int, nb_rings: int) -> np.ndarray: def initialize_2D_rosette( Nc: int, Ns: int, in_out: bool = False, coprime_index: int = 0 -) -> np.ndarray: +) -> NDArray: """Initialize a 2D rosette trajectory. Parameters @@ -430,7 +431,7 @@ def initialize_2D_rosette( Returns ------- - np.ndarray + NDArray 2D rosette trajectory """ @@ -457,7 +458,7 @@ def initialize_2D_rosette( def initialize_2D_polar_lissajous( Nc: int, Ns: int, in_out: bool = False, nb_segments: int = 1, coprime_index: int = 0 -) -> np.ndarray: +) -> NDArray: """Initialize a 2D polar Lissajous trajectory. Parameters @@ -475,7 +476,7 @@ def initialize_2D_polar_lissajous( Returns ------- - np.ndarray + NDArray 2D polar Lissajous trajectory """ # Adapt the parameters to subcases @@ -511,7 +512,7 @@ def initialize_2D_polar_lissajous( ######################### -def initialize_2D_lissajous(Nc: int, Ns: int, density: float = 1) -> np.ndarray: +def initialize_2D_lissajous(Nc: int, Ns: int, density: float = 1) -> NDArray: """Initialize a 2D Lissajous trajectory. Parameters @@ -525,7 +526,7 @@ def initialize_2D_lissajous(Nc: int, Ns: int, density: float = 1) -> np.ndarray: Returns ------- - np.ndarray + NDArray 2D Lissajous trajectory """ # Define the whole curve in Cartesian coordinates @@ -543,7 +544,7 @@ def initialize_2D_lissajous(Nc: int, Ns: int, density: float = 1) -> np.ndarray: def initialize_2D_waves( Nc: int, Ns: int, nb_zigzags: float = 5, width: float = 1 -) -> np.ndarray: +) -> NDArray: """Initialize a 2D waves trajectory. Parameters @@ -559,7 +560,7 @@ def initialize_2D_waves( Returns ------- - np.ndarray + NDArray 2D waves trajectory """ # Initialize a first shot diff --git a/src/mrinufft/trajectories/trajectory3D.py b/src/mrinufft/trajectories/trajectory3D.py index 725131043..5ec9f2bca 100644 --- a/src/mrinufft/trajectories/trajectory3D.py +++ b/src/mrinufft/trajectories/trajectory3D.py @@ -5,6 +5,7 @@ import numpy as np import numpy.linalg as nl +from numpy.typing import NDArray from scipy.special import ellipj, ellipk from .maths import ( @@ -18,7 +19,7 @@ ) from .tools import conify, duplicate_along_axes, epify, precess, stack from .trajectory2D import initialize_2D_radial, initialize_2D_spiral -from .utils import KMAX, Packings, initialize_shape_norm, initialize_tilt +from .utils import KMAX, Packings, Spirals, initialize_shape_norm, initialize_tilt ############## # 3D RADIALS # @@ -27,7 +28,7 @@ def initialize_3D_phyllotaxis_radial( Nc: int, Ns: int, nb_interleaves: int = 1, in_out: bool = False -) -> np.ndarray: +) -> NDArray: """Initialize 3D radial trajectories with phyllotactic structure. The radial shots are oriented according to a Fibonacci sphere @@ -56,7 +57,7 @@ def initialize_3D_phyllotaxis_radial( Returns ------- - np.ndarray + NDArray 3D phyllotaxis radial trajectory References @@ -76,7 +77,7 @@ def initialize_3D_phyllotaxis_radial( def initialize_3D_golden_means_radial( Nc: int, Ns: int, in_out: bool = False -) -> np.ndarray: +) -> NDArray: """Initialize 3D radial trajectories with golden means-based structure. The radial shots are oriented using multidimensional golden means, @@ -100,7 +101,7 @@ def initialize_3D_golden_means_radial( Returns ------- - np.ndarray + NDArray 3D golden means radial trajectory References @@ -130,7 +131,7 @@ def initialize_3D_golden_means_radial( def initialize_3D_wong_radial( Nc: int, Ns: int, nb_interleaves: int = 1, in_out: bool = False -) -> np.ndarray: +) -> NDArray: """Initialize 3D radial trajectories with a spiral structure. The radial shots are oriented according to an archimedean spiral @@ -156,7 +157,7 @@ def initialize_3D_wong_radial( Returns ------- - np.ndarray + NDArray 3D Wong radial trajectory References @@ -190,7 +191,7 @@ def initialize_3D_wong_radial( def initialize_3D_park_radial( Nc: int, Ns: int, nb_interleaves: int = 1, in_out: bool = False -) -> np.ndarray: +) -> NDArray: """Initialize 3D radial trajectories with a spiral structure. The radial shots are oriented according to an archimedean spiral @@ -217,7 +218,7 @@ def initialize_3D_park_radial( Returns ------- - np.ndarray + NDArray 3D Park radial trajectory References @@ -244,12 +245,12 @@ def initialize_3D_park_radial( def initialize_3D_cones( Nc: int, Ns: int, - tilt: Literal | float = "golden", + tilt: str | float = "golden", in_out: bool = False, nb_zigzags: float = 5, - spiral: Literal | float = "archimedes", + spiral: str | float = "archimedes", width: float = 1, -) -> np.ndarray: +) -> NDArray: """Initialize 3D trajectories with cones. Initialize a trajectory consisting of 3D cones duplicated @@ -265,21 +266,21 @@ def initialize_3D_cones( Number of shots Ns : int Number of samples per shot - tilt : Literal, float, optional + tilt : str, float, optional Tilt of the cones, by default "golden" in_out : bool, optional Whether the curves are going in-and-out or start from the center, by default False nb_zigzags : float, optional Number of zigzags of the cones, by default 5 - spiral : Literal, float, optional + spiral : str, float, optional Spiral type, by default "archimedes" width : float, optional Cone width normalized such that `width=1` avoids cone overlaps, by default 1 Returns ------- - np.ndarray + NDArray 3D cones trajectory References @@ -289,7 +290,7 @@ def initialize_3D_cones( Journal of mathematical chemistry 6, no. 1 (1991): 325-349. """ # Initialize first spiral - spiral = initialize_2D_spiral( + single_spiral = initialize_2D_spiral( Nc=1, Ns=Ns, spiral=spiral, @@ -305,8 +306,8 @@ def initialize_3D_cones( # Initialize first cone ## Create three cones for proper partitioning, but only one is needed - cone = conify( - spiral, + cones = conify( + single_spiral, nb_cones=3, z_tilt=None, in_out=in_out, @@ -316,7 +317,12 @@ def initialize_3D_cones( # Apply precession to the first cone trajectory = precess( - cone, tilt=tilt, nb_rotations=Nc, half_sphere=in_out, partition="axial", axis=2 + cones, + tilt=tilt, + nb_rotations=Nc, + half_sphere=in_out, + partition="axial", + axis=2, ) return trajectory @@ -327,11 +333,11 @@ def initialize_3D_floret( Ns: int, in_out: bool = False, nb_revolutions: float = 1, - spiral: Literal | float = "fermat", - cone_tilt: Literal | float = "golden", + spiral: str | float = "fermat", + cone_tilt: str | float = "golden", max_angle: float = np.pi / 2, axes: tuple[int, ...] = (2,), -) -> np.ndarray: +) -> NDArray: """Initialize 3D trajectories with FLORET. This implementation is based on the work from [Pip+11]_. @@ -349,9 +355,9 @@ def initialize_3D_floret( Whether to start from the center or not, by default False nb_revolutions : float, optional Number of revolutions of the spirals, by default 1 - spiral : Literal, float, optional + spiral : str, float, optional Spiral type, by default "fermat" - cone_tilt : Literal, float, optional + cone_tilt : str, float, optional Tilt of the cones around the :math:`k_z`-axis, by default "golden" max_angle : float, optional Maximum polar angle starting from the :math:`k_x-k_y` plane, @@ -361,7 +367,7 @@ def initialize_3D_floret( Returns ------- - np.ndarray + NDArray 3D FLORET trajectory References @@ -378,7 +384,7 @@ def initialize_3D_floret( raise ValueError("Nc should be divisible by len(axes).") # Initialize first spiral - spiral = initialize_2D_spiral( + single_spiral = initialize_2D_spiral( Nc=1, Ns=Ns, spiral=spiral, @@ -387,8 +393,8 @@ def initialize_3D_floret( ) # Initialize first cone - cone = conify( - spiral, + cones = conify( + single_spiral, nb_cones=Nc_per_axis, z_tilt=cone_tilt, in_out=in_out, @@ -396,8 +402,8 @@ def initialize_3D_floret( ) # Duplicate cone along axes - axes = [2 - ax for ax in axes] # Default axis is kz, not kx - trajectory = duplicate_along_axes(cone, axes=axes) + axes = tuple(2 - ax for ax in axes) # Default axis is kz, not kx + trajectory = duplicate_along_axes(cones, axes=axes) return trajectory @@ -406,10 +412,10 @@ def initialize_3D_wave_caipi( Ns: int, nb_revolutions: float = 5, width: float = 1, - packing: Literal = "triangular", - shape: Literal | float = "square", + packing: str = "triangular", + shape: str | float = "square", spacing: tuple[int, int] = (1, 1), -) -> np.ndarray: +) -> NDArray: """Initialize 3D trajectories with Wave-CAIPI. This implementation is based on the work from [Bil+15]_. @@ -426,24 +432,24 @@ def initialize_3D_wave_caipi( Diameter of the helices normalized such that `width=1` densely covers the k-space without overlap for square packing, by default 1. - packing : Literal, optional + packing : str, optional Packing method used to position the helices: "triangular"/"hexagonal", "square", "circular" or "random"/"uniform", by default "triangular". - shape : Literal or float, optional + shape : str | float, optional Shape over the 2D :math:`k_x-k_y` plane to pack with shots, either defined as `str` ("circle", "square", "diamond") or as `float` through p-norms following the conventions of the `ord` parameter from `numpy.linalg.norm`, by default "circle". - spacing : tuple(int, int) + spacing : tuple[int, int] Spacing between helices over the 2D :math:`k_x-k_y` plane normalized similarly to `width` to correspond to helix diameters, by default (1, 1). Returns ------- - np.ndarray + NDArray 3D wave-CAIPI trajectory References @@ -454,7 +460,6 @@ def initialize_3D_wave_caipi( Magnetic resonance in medicine 73, no. 6 (2015): 2152-2162. """ trajectory = np.zeros((Nc, Ns, 3)) - spacing = np.array(spacing) # Initialize first shot angles = nb_revolutions * 2 * np.pi * np.arange(0, Ns) / Ns @@ -463,11 +468,11 @@ def initialize_3D_wave_caipi( trajectory[0, :, 2] = np.linspace(-1, 1, Ns) # Choose the helix positions according to packing - packing = Packings[packing] + packing_enum = Packings[packing] side = 2 * int(np.ceil(np.sqrt(Nc))) * np.max(spacing) - if packing == Packings.RANDOM: + if packing_enum == Packings.RANDOM: positions = 2 * side * (np.random.random((side * side, 2)) - 0.5) - elif packing == Packings.CIRCLE: + elif packing_enum == Packings.CIRCLE: positions = [[0, 0]] counter = 0 while len(positions) < side**2: @@ -481,21 +486,21 @@ def initialize_3D_wave_caipi( positions = np.concatenate( [positions, np.array([circle.real, circle.imag]).T], axis=0 ) - elif packing in [Packings.SQUARE, Packings.TRIANGLE, Packings.HEXAGONAL]: + elif packing_enum in [Packings.SQUARE, Packings.TRIANGLE, Packings.HEXAGONAL]: # Square packing or initialize hexagonal/triangular packing px, py = np.meshgrid( np.arange(-side + 1, side, 2), np.arange(-side + 1, side, 2) ) positions = np.stack([px.flatten(), py.flatten()], axis=-1).astype(float) - if packing in [Packings.HEXAGON, Packings.TRIANGLE]: + if packing_enum in [Packings.HEXAGON, Packings.TRIANGLE]: # Hexagonal/triangular packing based on square packing positions[::2, 1] += 1 / 2 positions[1::2, 1] -= 1 / 2 ratio = nl.norm(np.diff(positions[:2], axis=-1)) positions[:, 0] /= ratio / 2 - if packing == Packings.FIBONACCI: + if packing_enum == Packings.FIBONACCI: # Estimate helix width based on the k-space 2D surface # and an optimal circle packing positions = np.sqrt( @@ -505,7 +510,7 @@ def initialize_3D_wave_caipi( # Remove points by distance to fit both shape and Nc main_order = initialize_shape_norm(shape) tie_order = 2 if (main_order != 2) else np.inf # breaking ties - positions = np.array(positions) * spacing + positions = np.array(positions) * np.array(spacing) positions = sorted(positions, key=partial(nl.norm, ord=tie_order)) positions = sorted(positions, key=partial(nl.norm, ord=main_order)) positions = positions[:Nc] @@ -525,10 +530,10 @@ def initialize_3D_seiffert_spiral( Ns: int, curve_index: float = 0.2, nb_revolutions: float = 1, - axis_tilt: Literal | float = "golden", - spiral_tilt: Literal | float = "golden", + axis_tilt: str | float = "golden", + spiral_tilt: str | float = "golden", in_out: bool = False, -) -> np.ndarray: +) -> NDArray: """Initialize 3D trajectories with modulated Seiffert spirals. Initially introduced in [SMR18]_, but also proposed later as "Yarnball" @@ -547,9 +552,9 @@ def initialize_3D_seiffert_spiral( nb_revolutions : float Number of revolutions, i.e. times the polar angle of the curves passes through 0, by default 1 - axis_tilt : Literal, float, optional + axis_tilt : str, float, optional Angle between shots over a precession around the z-axis, by default "golden" - spiral_tilt : Literal, float, optional + spiral_tilt : str, float, optional Angle of the spiral within its own axis defined from center to its outermost point, by default "golden" in_out : bool @@ -558,7 +563,7 @@ def initialize_3D_seiffert_spiral( Returns ------- - np.ndarray + NDArray 3D Seiffert spiral trajectory References @@ -630,9 +635,9 @@ def initialize_3D_helical_shells( Ns: int, nb_shells: int, spiral_reduction: float = 1, - shell_tilt: Literal = "intergaps", - shot_tilt: Literal = "uniform", -) -> np.ndarray: + shell_tilt: str = "intergaps", + shot_tilt: str = "uniform", +) -> NDArray: """Initialize 3D trajectories with helical shells. The implementation follows the proposition from [YRB06]_ @@ -648,14 +653,14 @@ def initialize_3D_helical_shells( Number of concentric shells/spheres spiral_reduction : float, optional Factor used to reduce the automatic spiral length, by default 1 - shell_tilt : Literal, float, optional + shell_tilt : str, float, optional Angle between consecutive shells along z-axis, by default "intergaps" - shot_tilt : Literal, float, optional + shot_tilt : str, float, optional Angle between shots over a shell surface along z-axis, by default "uniform" Returns ------- - np.ndarray + NDArray 3D helical shell trajectory References @@ -719,7 +724,7 @@ def initialize_3D_annular_shells( nb_shells: int, shell_tilt: float = np.pi, ring_tilt: float = np.pi / 2, -) -> np.ndarray: +) -> NDArray: """Initialize 3D trajectories with annular shells. An exclusive trajectory inspired from the work proposed in [HM11]_. @@ -732,14 +737,14 @@ def initialize_3D_annular_shells( Number of samples per shot nb_shells : int Number of concentric shells/spheres - shell_tilt : Literal, float, optional + shell_tilt : str, float, optional Angle between consecutive shells along z-axis, by default pi - ring_tilt : Literal, float, optional + ring_tilt : str, float, optional Angle controlling approximately the ring halves rotation, by default pi / 2 Returns ------- - np.ndarray + NDArray 3D annular shell trajectory References @@ -836,9 +841,9 @@ def initialize_3D_seiffert_shells( nb_shells: int, curve_index: float = 0.5, nb_revolutions: float = 1, - shell_tilt: Literal = "uniform", - shot_tilt: Literal = "uniform", -) -> np.ndarray: + shell_tilt: str = "uniform", + shot_tilt: str = "uniform", +) -> NDArray: """Initialize 3D trajectories with Seiffert shells. The implementation is based on work from [Er00]_ and [Br09]_, @@ -858,14 +863,14 @@ def initialize_3D_seiffert_shells( nb_revolutions : float Number of revolutions, i.e. times the curve passes through the upper-half of the z-axis, by default 1 - shell_tilt : Literal, float, optional + shell_tilt : str, float, optional Angle between consecutive shells along z-axis, by default "uniform" - shot_tilt : Literal, float, optional + shot_tilt : str, float, optional Angle between shots over a shell surface along z-axis, by default "uniform" Returns ------- - np.ndarray + NDArray 3D Seiffert shell trajectory References @@ -933,11 +938,11 @@ def initialize_3D_turbine( Ns_readouts: int, Ns_transitions: int, nb_blades: int, - blade_tilt: Literal = "uniform", - nb_trains: int | str = "auto", + blade_tilt: str = "uniform", + nb_trains: int | Literal["auto"] = "auto", skip_factor: int = 1, in_out: bool = True, -) -> np.ndarray: +) -> NDArray: """Initialize 3D TURBINE trajectory. This is an implementation of the TURBINE (Trajectory Using Radially @@ -960,9 +965,9 @@ def initialize_3D_turbine( Number of samples per transition between two readouts nb_blades : int Number of line stacks over the :math:`k_z`-axis axis - blade_tilt : Literal, float, optional + blade_tilt : str, float, optional Tilt between individual blades, by default "uniform" - nb_trains : int, str, optional + nb_trains : int, Literal["auto"], optional Number of resulting shots, or readout trains, such that each of them will be composed of :math:`n` readouts with ``Nc = n * nb_trains``. If ``"auto"`` then ``nb_trains`` is set @@ -975,7 +980,7 @@ def initialize_3D_turbine( Returns ------- - np.ndarray + NDArray 3D TURBINE trajectory References @@ -1042,12 +1047,12 @@ def initialize_3D_repi( Ns_transitions: int, nb_blades: int, nb_blade_revolutions: float = 0, - blade_tilt: Literal = "uniform", + blade_tilt: str = "uniform", nb_spiral_revolutions: float = 0, - spiral: Literal = "archimedes", - nb_trains: int | str = "auto", + spiral: str = "archimedes", + nb_trains: int | Literal["auto"] = "auto", in_out: bool = True, -) -> np.ndarray: +) -> NDArray: """Initialize 3D REPI trajectory. This is an implementation of the REPI (Radial Echo Planar Imaging) @@ -1076,13 +1081,13 @@ def initialize_3D_repi( nb_blade_revolutions : float Number of revolutions over lines/spirals within a blade over the kz axis. - blade_tilt : Literal, float, optional + blade_tilt : str, float, optional Tilt between individual blades, by default "uniform" nb_spiral_revolutions : float, optional Number of revolutions of the spirals over the readouts, by default 0 - spiral : Literal, float, optional + spiral : str, float, optional Spiral type, by default "archimedes" - nb_trains : int, str + nb_trains : int, Literal["auto"], optional Number of trains dividing the readouts, such that each shot will be composed of `n` readouts with `Nc = n * nb_trains`. If "auto" then `nb_trains` is set to `nb_blades`. @@ -1091,7 +1096,7 @@ def initialize_3D_repi( Returns ------- - np.ndarray + NDArray 3D REPI trajectory References diff --git a/src/mrinufft/trajectories/utils.py b/src/mrinufft/trajectories/utils.py index 0bb4cce2c..c1124e116 100644 --- a/src/mrinufft/trajectories/utils.py +++ b/src/mrinufft/trajectories/utils.py @@ -2,9 +2,10 @@ from enum import Enum, EnumMeta from numbers import Real -from typing import Literal, Any +from typing import Any, Literal import numpy as np +from numpy.typing import NDArray ############# # CONSTANTS # @@ -122,7 +123,7 @@ class Tilts(str, Enum): class Packings(str, Enum, metaclass=CaseInsensitiveEnumMeta): """Enumerate available packing method for shots. - It is mostly use for wave-CAIPI trajectory + It is mostly used for wave-CAIPI trajectory See Also -------- @@ -151,15 +152,15 @@ class Packings(str, Enum, metaclass=CaseInsensitiveEnumMeta): def normalize_trajectory( - trajectory: np.typing.NDArray, + trajectory: NDArray, norm_factor: float = KMAX, - resolution: float | np.ndarray = DEFAULT_RESOLUTION, -) -> np.ndarray: + resolution: float | NDArray = DEFAULT_RESOLUTION, +) -> NDArray: """Normalize an un-normalized/natural trajectory for NUFFT use. Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Un-normalized trajectory consisting of k-space coordinates in 2D or 3D. norm_factor : float, optional Trajectory normalization factor, by default KMAX. @@ -170,22 +171,22 @@ def normalize_trajectory( Returns ------- - trajectory : np.ndarray + trajectory : NDArray Normalized trajectory corresponding to `trajectory` input. """ return trajectory * norm_factor * (2 * resolution) def unnormalize_trajectory( - trajectory: np.typing.NDArray, + trajectory: NDArray, norm_factor: float = KMAX, - resolution: float | np.ndarray = DEFAULT_RESOLUTION, -) -> np.ndarray: + resolution: float | NDArray = DEFAULT_RESOLUTION, +) -> NDArray: """Un-normalize a NUFFT-normalized trajectory. Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Normalized trajectory consisting of k-space coordinates in 2D or 3D. norm_factor : float, optional Trajectory normalization factor, by default KMAX. @@ -196,25 +197,25 @@ def unnormalize_trajectory( Returns ------- - trajectory : np.ndarray + trajectory : NDArray Un-normalized trajectory corresponding to `trajectory` input. """ return trajectory / norm_factor / (2 * resolution) def convert_trajectory_to_gradients( - trajectory: np.typing.NDArray, + trajectory: NDArray, norm_factor: float = KMAX, - resolution: float | np.ndarray = DEFAULT_RESOLUTION, + resolution: float | NDArray = DEFAULT_RESOLUTION, raster_time: float = DEFAULT_RASTER_TIME, gamma: float = Gammas.HYDROGEN, get_final_positions: bool = False, -) -> tuple[np.ndarray, ...]: +) -> tuple[NDArray, ...]: """Derive a normalized trajectory over time to provide gradients. Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Normalized trajectory consisting of k-space coordinates in 2D or 3D. norm_factor : float, optional Trajectory normalization factor, by default KMAX. @@ -235,7 +236,7 @@ def convert_trajectory_to_gradients( Returns ------- - gradients : np.ndarray + gradients : NDArray Gradients corresponding to `trajectory`. """ # Un-normalize the trajectory from NUFFT usage @@ -250,20 +251,20 @@ def convert_trajectory_to_gradients( def convert_gradients_to_trajectory( - gradients: np.typing.NDArray, - initial_positions: np.typing.NDArray | None = None, + gradients: NDArray, + initial_positions: NDArray | None = None, norm_factor: float = KMAX, - resolution: float | np.ndarray = DEFAULT_RESOLUTION, + resolution: float | NDArray = DEFAULT_RESOLUTION, raster_time: float = DEFAULT_RASTER_TIME, gamma: float = Gammas.HYDROGEN, -) -> np.ndarray: +) -> NDArray: """Integrate gradients over time to provide a normalized trajectory. Parameters ---------- - gradients : np.ndarray + gradients : NDArray Gradients over 2 or 3 directions. - initial_positions: np.ndarray, optional + initial_positions: NDArray, optional Positions in k-space at the beginning of the readout window. The default is `None`. norm_factor : float, optional @@ -282,7 +283,7 @@ def convert_gradients_to_trajectory( Returns ------- - trajectory : np.ndarray + trajectory : NDArray Normalized trajectory corresponding to `gradients`. """ # Handle no initial positions @@ -300,14 +301,14 @@ def convert_gradients_to_trajectory( def convert_gradients_to_slew_rates( - gradients: np.typing.NDArray, + gradients: NDArray, raster_time: float = DEFAULT_RASTER_TIME, -) -> tuple[np.ndarray, np.ndarray]: +) -> tuple[NDArray, NDArray]: """Derive the gradients over time to provide slew rates. Parameters ---------- - gradients : np.ndarray + gradients : NDArray Gradients over 2 or 3 directions. raster_time : float, optional Amount of time between the acquisition of two @@ -316,9 +317,9 @@ def convert_gradients_to_slew_rates( Returns ------- - slewrates : np.ndarray + slewrates : NDArray Slew rates corresponding to `gradients`. - initial_gradients : np.ndarray + initial_gradients : NDArray Gradients at the beginning of the readout window. """ # Compute slew rates and starting gradients @@ -328,17 +329,17 @@ def convert_gradients_to_slew_rates( def convert_slew_rates_to_gradients( - slewrates: np.typing.NDArray, - initial_gradients: np.typing.NDArray | None = None, + slewrates: NDArray, + initial_gradients: NDArray | None = None, raster_time: float = DEFAULT_RASTER_TIME, -) -> np.ndarray: +) -> NDArray: """Integrate slew rates over time to provide gradients. Parameters ---------- - slewrates : np.ndarray + slewrates : NDArray Slew rates over 2 or 3 directions. - initial_gradients: np.ndarray, optional + initial_gradients: NDArray, optional Gradients at the beginning of the readout window. The default is `None`. raster_time : float, optional @@ -348,7 +349,7 @@ def convert_slew_rates_to_gradients( Returns ------- - gradients : np.ndarray + gradients : NDArray Gradients corresponding to `slewrates`. """ # Handle no initial gradients @@ -363,17 +364,17 @@ def convert_slew_rates_to_gradients( def compute_gradients_and_slew_rates( - trajectory: np.typing.NDArray, + trajectory: NDArray, norm_factor: float = KMAX, - resolution: float | np.ndarray = DEFAULT_RESOLUTION, + resolution: float | NDArray = DEFAULT_RESOLUTION, raster_time: float = DEFAULT_RASTER_TIME, gamma: float = Gammas.HYDROGEN, -) -> tuple[np.ndarray, np.ndarray]: +) -> tuple[NDArray, NDArray]: """Compute the gradients and slew rates from a normalized trajectory. Parameters ---------- - trajectory : np.ndarray + trajectory : NDArray Normalized trajectory consisting of k-space coordinates in 2D or 3D. norm_factor : float, optional Trajectory normalization factor, by default KMAX. @@ -391,9 +392,9 @@ def compute_gradients_and_slew_rates( Returns ------- - gradients : np.ndarray + gradients : NDArray Gradients corresponding to `trajectory`. - slewrates : np.ndarray + slewrates : NDArray Slew rates corresponding to `trajectory` gradients. """ # Convert normalized trajectory to gradients @@ -412,8 +413,8 @@ def compute_gradients_and_slew_rates( def check_hardware_constraints( - gradients: np.typing.NDArray, - slewrates: np.typing.NDArray, + gradients: NDArray, + slewrates: NDArray, gmax: float = DEFAULT_GMAX, smax: float = DEFAULT_SMAX, order: int | str | None = None, @@ -422,9 +423,9 @@ def check_hardware_constraints( Parameters ---------- - gradients : np.ndarray + gradients : NDArray Gradients to check - slewrates: np.ndarray + slewrates: NDArray Slewrates to check gmax : float, optional Maximum gradient amplitude in T/m. The default is DEFAULT_GMAX. @@ -455,12 +456,12 @@ def check_hardware_constraints( ########### -def initialize_tilt(tilt: Literal | float, nb_partitions: int = 1) -> float: +def initialize_tilt(tilt: str | float | None, nb_partitions: int = 1) -> float: r"""Initialize the tilt angle. Parameters ---------- - tilt : Literal or float + tilt : str | float | None Tilt angle in rad or name of the tilt. nb_partitions : int, optional Number of partitions. The default is 1. @@ -498,12 +499,12 @@ def initialize_tilt(tilt: Literal | float, nb_partitions: int = 1) -> float: raise NotImplementedError(f"Unknown tilt name: {tilt}") -def initialize_algebraic_spiral(spiral: Literal | float) -> float: +def initialize_algebraic_spiral(spiral: str | float) -> float: """Initialize the algebraic spiral type. Parameters ---------- - spiral : Literal or float + spiral : str | float Spiral type or spiral power value. Returns @@ -512,16 +513,16 @@ def initialize_algebraic_spiral(spiral: Literal | float) -> float: Spiral power value. """ if isinstance(spiral, Real): - return spiral - return Spirals[spiral] + return float(spiral) + return Spirals[str(spiral)] -def initialize_shape_norm(shape: Literal | float) -> float: +def initialize_shape_norm(shape: str | float) -> float: """Initialize the norm for a given shape. Parameters ---------- - shape : Literal or float + shape : str | float Shape name or p-norm value. Returns @@ -530,5 +531,5 @@ def initialize_shape_norm(shape: Literal | float) -> float: Shape p-norm value. """ if isinstance(shape, Real): - return shape - return NormShapes[shape] + return float(shape) + return NormShapes[str(shape)] From ff493e95e267dcf0638803168177df56e67cded6 Mon Sep 17 00:00:00 2001 From: Guillaume DAVAL-FREROT Date: Fri, 27 Dec 2024 17:24:54 +0100 Subject: [PATCH 6/6] Add type alias for clustering coordinates --- .../trajectories/inits/travelling_salesman.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/mrinufft/trajectories/inits/travelling_salesman.py b/src/mrinufft/trajectories/inits/travelling_salesman.py index 063527f14..9e3c35762 100644 --- a/src/mrinufft/trajectories/inits/travelling_salesman.py +++ b/src/mrinufft/trajectories/inits/travelling_salesman.py @@ -1,6 +1,6 @@ """Trajectories based on the Travelling Salesman Problem.""" -from typing import Any, Literal +from typing import Any, Literal, TypeAlias import numpy as np import numpy.linalg as nl @@ -12,6 +12,8 @@ from ..sampling import sample_from_density from ..tools import oversample +Coordinate: TypeAlias = Literal["x", "y", "z", "r", "phi", "theta"] + def _get_approx_cluster_sizes(nb_total: int, nb_clusters: int) -> NDArray: # Give a list of cluster sizes close to sqrt(`nb_total`) @@ -21,9 +23,7 @@ def _get_approx_cluster_sizes(nb_total: int, nb_clusters: int) -> NDArray: return cluster_sizes -def _sort_by_coordinate( - array: NDArray, coord: Literal["x", "y", "z", "r", "phi", "theta"] -) -> NDArray: +def _sort_by_coordinate(array: NDArray, coord: Coordinate) -> NDArray: # Sort a list of N-D locations by a Cartesian/spherical coordinate if array.shape[-1] < 3 and coord.lower() in ["z", "theta"]: raise ValueError( @@ -54,9 +54,9 @@ def _sort_by_coordinate( def _cluster_by_coordinate( locations: NDArray, nb_clusters: int, - cluster_by: Literal["x", "y", "z", "r", "phi", "theta"], - second_cluster_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, - sort_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, + cluster_by: Coordinate, + second_cluster_by: Coordinate | None = None, + sort_by: Coordinate | None = None, ) -> NDArray: # Cluster approximately a list of N-D locations by Cartesian/spherical coordinates # Gather dimension variables @@ -99,9 +99,9 @@ def _initialize_ND_travelling_salesman( Nc: int, Ns: int, density: NDArray, - first_cluster_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, - second_cluster_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, - sort_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, + first_cluster_by: Coordinate | None = None, + second_cluster_by: Coordinate | None = None, + sort_by: Coordinate | None = None, tsp_tol: float = 1e-8, *, verbose: bool = False, @@ -146,9 +146,9 @@ def initialize_2D_travelling_salesman( Nc: int, Ns: int, density: NDArray, - first_cluster_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, - second_cluster_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, - sort_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, + first_cluster_by: Coordinate | None = None, + second_cluster_by: Coordinate | None = None, + sort_by: Coordinate | None = None, tsp_tol: float = 1e-8, *, verbose: bool = False, @@ -223,9 +223,9 @@ def initialize_3D_travelling_salesman( Nc: int, Ns: int, density: NDArray, - first_cluster_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, - second_cluster_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, - sort_by: Literal["x", "y", "z", "r", "phi", "theta"] | None = None, + first_cluster_by: Coordinate | None = None, + second_cluster_by: Coordinate | None = None, + sort_by: Coordinate | None = None, tsp_tol: float = 1e-8, *, verbose: bool = False,