Skip to content

Commit

Permalink
fixup! Add densities, random walk and travelling salesman trajectories
Browse files Browse the repository at this point in the history
  • Loading branch information
Daval-G committed Nov 7, 2024
1 parent 1665906 commit 55b3ba4
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 19 deletions.
16 changes: 8 additions & 8 deletions src/mrinufft/trajectories/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
"""Collection of trajectories and tools used for non-Cartesian MRI."""

from .densities import (
create_chauffert_density,
create_cutoff_decay_density,
create_energy_density,
create_fast_chauffert_density,
create_polynomial_density,
sample_from_density,
)
from .display import display_2D_trajectory, display_3D_trajectory, displayConfig
from .gradients import patch_center_anomaly
from .inits import (
Expand All @@ -16,6 +8,14 @@
initialize_3D_random_walk,
initialize_3D_travelling_salesman,
)
from .sampling_densities import (
create_chauffert_density,
create_cutoff_decay_density,
create_energy_density,
create_fast_chauffert_density,
create_polynomial_density,
sample_from_density,
)
from .tools import (
conify,
duplicate_along_axes,
Expand Down
23 changes: 16 additions & 7 deletions src/mrinufft/trajectories/inits/travelling_salesman.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""Trajectories based on the Travelig Salesman Problem."""
"""Trajectories based on the Travelling Salesman Problem."""

import numpy as np
import numpy.linalg as nl
from scipy.interpolate import CubicSpline
from tqdm.auto import tqdm

from ..densities import sample_from_density
from ..maths import solve_tsp_with_2opt
from ..sampling_densities import sample_from_density
from ..tools import oversample


def _get_approx_cluster_sizes(nb_total, nb_clusters):
# Give a list of cluster sizes close to sqrt(`nb_total`)
cluster_sizes = round(nb_total / nb_clusters) * np.ones(nb_clusters).astype(int)
delta_sum = nb_total - np.sum(cluster_sizes)
cluster_sizes[: int(np.abs(delta_sum))] += np.sign(delta_sum)
Expand Down Expand Up @@ -92,6 +94,8 @@ def _initialize_ND_travelling_salesman(
nb_tsp_points="auto",
sampling="random",
tsp_tol=1e-8,
*,
mask_density=False,
verbose=False,
):
# Handle variable inputs
Expand All @@ -105,7 +109,8 @@ def _initialize_ND_travelling_salesman(
Nd = len(density.shape)

# Select k-space locations
density = density / np.sum(density)
if mask_density:
density = density ** (Nd / (Nd - 1))
locations = sample_from_density(Nc * nb_tsp_points, density, method=sampling)

# Re-organise locations into Nc clusters
Expand Down Expand Up @@ -133,10 +138,8 @@ def _initialize_ND_travelling_salesman(
locations = locations.reshape((Nc, nb_tsp_points, Nd))

# Interpolate shot points up to full length
trajectory = np.zeros((Nc, Ns, Nd))
for i in range(Nc):
cbs = CubicSpline(np.linspace(0, 1, nb_tsp_points), locations[i])
trajectory[i] = cbs(np.linspace(0, 1, Ns))
trajectory = oversample(locations, Ns)

return trajectory


Expand All @@ -150,6 +153,8 @@ def initialize_2D_travelling_salesman(
nb_tsp_points="auto",
sampling="random",
tsp_tol=1e-8,
*,
mask_density=False,
verbose=False,
):
if len(density.shape) != 2:
Expand All @@ -164,6 +169,7 @@ def initialize_2D_travelling_salesman(
nb_tsp_points=nb_tsp_points,
sampling=sampling,
tsp_tol=tsp_tol,
mask_density=mask_density,
verbose=verbose,
)

Expand All @@ -178,6 +184,8 @@ def initialize_3D_travelling_salesman(
nb_tsp_points="auto",
sampling="random",
tsp_tol=1e-8,
*,
mask_density=False,
verbose=False,
):
if len(density.shape) != 3:
Expand All @@ -192,5 +200,6 @@ def initialize_3D_travelling_salesman(
nb_tsp_points=nb_tsp_points,
sampling=sampling,
tsp_tol=tsp_tol,
mask_density=mask_density,
verbose=verbose,
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
def sample_from_density(nb_samples, density, method="random"):
rng = nr.default_rng()

shape = density.shape
density = density / np.sum(density)
shape = np.array(density.shape)
nb_dims = len(shape)
max_nb_samples = np.prod(shape)

Expand All @@ -29,8 +30,9 @@ def sample_from_density(nb_samples, density, method="random"):
replace=False,
)
locations = np.indices(shape).reshape((nb_dims, -1))[:, choices]
locations = locations.T
locations = 2 * KMAX * locations / np.max(locations) - KMAX
locations = locations.T + 0.5
locations = locations / shape[None, :]
locations = 2 * KMAX * locations - KMAX
elif method == "lloyd":
kmeans = (
KMeans(n_clusters=nb_samples)
Expand Down
7 changes: 6 additions & 1 deletion src/mrinufft/trajectories/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Functions to manipulate/modify trajectories."""

import numpy as np
from scipy.interpolate import CubicSpline
from scipy.interpolate import CubicSpline, interp1d

from .maths import Rv, Rx, Ry, Rz
from .utils import KMAX, initialize_tilt
Expand Down Expand Up @@ -429,6 +429,11 @@ def rewind(trajectory, Ns_transitions):
return assembled_trajectory


def oversample(trajectory, new_Ns, kind="cubic"):
f = interp1d(np.linspace(0, 1, trajectory.shape[1]), trajectory, axis=1, kind=kind)
return f(np.linspace(0, 1, new_Ns))


####################
# FUNCTIONAL TOOLS #
####################
Expand Down

0 comments on commit 55b3ba4

Please sign in to comment.