From e257a13b940db9c1405707eed2bfd8d8ed4e8acc Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Fri, 5 Jan 2024 13:18:54 +0100 Subject: [PATCH] docs(density): add example for density. --- examples/example_density.py | 148 +++++++++++++++++++++++++ src/mrinufft/density/geometry_based.py | 8 +- 2 files changed, 152 insertions(+), 4 deletions(-) create mode 100644 examples/example_density.py diff --git a/examples/example_density.py b/examples/example_density.py new file mode 100644 index 000000000..1ca774c22 --- /dev/null +++ b/examples/example_density.py @@ -0,0 +1,148 @@ +# %% +""" +============================= +Density Compensation Routines +============================= + +Examples of differents density compensation methods. + +Density compensation depends on the sampling trajectory,and is apply before the +adjoint operation to act as preconditioner, and should make the lipschitz constant +of the operator roughly equal to 1. + +""" +import brainweb_dl as bwdl + +import matplotlib.pyplot as plt +import numpy as np +from mrinufft import get_density, get_operator +from mrinufft.trajectories import initialize_2D_radial +from mrinufft.trajectories.display import display_2D_trajectory + + +# %% +# Create sample data +# ------------------ + +mri_2D = bwdl.get_mri(4, "T1")[80, ...] + +print(mri_2D.shape) + +traj = initialize_2D_radial(192, 192) + +nufft = get_operator("finufft")(traj, mri_2D.shape, density=False) +kspace = nufft.op(mri_2D) +adjoint = nufft.adj_op(kspace) + +fig, axs = plt.subplots(1, 3, figsize=(15, 5)) +axs[0].imshow(abs(mri_2D)) +display_2D_trajectory(traj, subfigure=axs[1]) +axs[2].imshow(abs(adjoint)) + +# %% +# As you can see, the radial sampling pattern as a strong concentration of sampling point in the center, resulting in a low-frequency biased adjoint reconstruction. + +# %% +# Geometry based methods +# ====================== +# +# Voronoi +# ------- +# +# Voronoi Parcellation attribute a weights to each k-space coordinate, inversely +# proportional to its voronoi cell area. + + +# .. warning:: +# The current implementation of voronoi parcellation is CPU only, and is thus +# **very** slow in 3D ( > 1h). + +# %% +voronoi_weights = get_density("voronoi", traj) + +nufft_voronoi = get_operator("finufft")(traj, shape=mri_2D.shape, density=voronoi_weights) +adjoint_voronoi = nufft_voronoi.adj_op(kspace) +fig, axs = plt.subplots(1, 3, figsize=(15, 5)) +axs[0].imshow(abs(mri_2D)) +axs[0].set_title("Ground Truth") +axs[1].imshow(abs(adjoint)) +axs[1].set_title("no density compensation") +axs[2].imshow(abs(adjoint_voronoi)) +axs[2].set_title("Voronoi density compensation") + + +# %% +# Cell Counting +# ------------- +# +# Cell Counting attributes weights based on the number of trajectory point lying in a same k-space nyquist voxel. +# This can be viewed as an approximation to the voronoi neth + +# .. note:: +# Cell counting is faster than voronoi (especially in 3D), but is less precise. + +# The size of the niquist voxel can be tweak by using the osf parameter. Typically as the NUFFT (and by default in MRI-NUFFT) is performed at an OSF of 2 + + +# %% +cell_count_weights = get_density("cell_count", traj, shape=mri_2D.shape, osf=2.0) + +nufft_cell_count = get_operator("finufft")(traj, shape=mri_2D.shape, density=cell_count_weights, upsampfac=2.0) +adjoint_cell_count = nufft_cell_count.adj_op(kspace) +fig, axs = plt.subplots(1, 3, figsize=(15, 5)) +axs[0].imshow(abs(mri_2D)) +axs[0].set_title("Ground Truth") +axs[1].imshow(abs(adjoint)) +axs[1].set_title("no density compensation") +axs[2].imshow(abs(adjoint_cell_count)) +axs[2].set_title("cell_count density compensation") + +# %% +# Manual Density Estimation +# ------------------------- +# +# For some analytical trajectory it is also possible to determine the density compensation vector directly. +# In radial trajectory for instance, a sample's weight can be determined from its distance to the center. + + +# %% +flat_traj = traj.reshape(-1, 2) +weights = np.sqrt(np.sum(flat_traj ** 2, axis=1)) +nufft = get_operator("finufft")(traj, shape=mri_2D.shape, density=weights) +adjoint_manual = nufft.adj_op(kspace) +fig, axs = plt.subplots(1, 3, figsize=(15, 5)) +axs[0].imshow(abs(mri_2D)) +axs[0].set_title("Ground Truth") +axs[1].imshow(abs(adjoint)) +axs[1].set_title("no density compensation") +axs[2].imshow(abs(adjoint_manual)) +axs[2].set_title("manual density compensation") + +# %% +# Operator-based method +# ===================== +# +# Pipe's Method +# ------------- +# Pipe's method is an iterative scheme, that use the interpolation and spreading kernel operator for computing the density compensation. +# +# .. warning:: +# If this method is widely used in the literature, there exists no convergence guarantees for it. + +# .. note:: +# The Pipe method is currently only implemented for gpuNUFFT. + +# %% +if check_backend("gpunufft"): + flat_traj = traj.reshape(-1, 2) + nufft = get_operator("gpunufft")(traj, shape=mri_2D.shape, density="pipe") + adjoint_manual = nufft.adj_op(kspace) + fig, axs = plt.subplots(1, 3, figsize=(15, 5)) + axs[0].imshow(abs(mri_2D)) + axs[0].set_title("Ground Truth") + axs[1].imshow(abs(adjoint)) + axs[1].set_title("no density compensation") + axs[2].imshow(abs(adjoint_manual)) + axs[2].set_title("manual density compensation") + +# %% diff --git a/src/mrinufft/density/geometry_based.py b/src/mrinufft/density/geometry_based.py index 0ffcfe67e..bbc13af01 100644 --- a/src/mrinufft/density/geometry_based.py +++ b/src/mrinufft/density/geometry_based.py @@ -46,8 +46,8 @@ def _vol2d(points): return abs(area) / 2.0 -@flat_traj @register_density +@flat_traj def voronoi_unique(traj, *args, **kwargs): """Estimate density compensation weight using voronoi parcellation. @@ -87,8 +87,8 @@ def voronoi_unique(traj, *args, **kwargs): return wi -@flat_traj @register_density +@flat_traj def voronoi(traj, *args, **kwargs): """Estimate density compensation weight using voronoi parcellation. @@ -118,11 +118,11 @@ def voronoi(traj, *args, **kwargs): wi[i0] = wi[i0f] / np.sum(i0) else: wi = voronoi_unique(traj) - return normalize_weights(wi) + return 1 / normalize_weights(wi) -@flat_traj @register_density +@flat_traj def cell_count(traj, shape, osf=1.0): """ Compute the number of points in each cell of the grid.