Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update density compensation in nufft operator for rotated sampling pattern #8

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
10 changes: 9 additions & 1 deletion src/cli-conf/scenario2-2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defaults:
- handlers:
- activation-block
- sampler:
- stack-of-spiral
- rotated-stack-of-spiral
- reconstructors:
- adjoint
#- sequential
Expand Down Expand Up @@ -35,6 +35,7 @@ handlers:
block_on: 20 # seconds
block_off: 20 #seconds
duration: 360 # seconds
delta_r2s: 1000 # millisecond^-1

sampler:
stack-of-spiral:
Expand All @@ -43,6 +44,13 @@ sampler:
nb_revolutions: 10
constant: true
spiral_name: "galilean"
rotated-stack-of-spiral:
acsz: 1
accelz: 1
nb_revolutions: 10
constant: false
spiral_name: "galilean"
rotate_frame_angle: 0

engine:
n_jobs: 1
Expand Down
2 changes: 2 additions & 0 deletions src/snake/core/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .samplers import (
EPI3dAcquisitionSampler,
StackOfSpiralSampler,
RotatedStackOfSpiralSampler,
NonCartesianAcquisitionSampler,
EVI3dAcquisitionSampler,
LoadTrajectorySampler,
Expand All @@ -15,5 +16,6 @@
"EPI3dAcquisitionSampler",
"EVI3dAcquisitionSampler",
"StackOfSpiralSampler",
"RotatedStackOfSpiralSampler",
"NonCartesianAcquisitionSampler",
]
44 changes: 44 additions & 0 deletions src/snake/core/sampling/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
stack_spiral_factory,
stacked_epi_factory,
evi_factory,
rotate_trajectory,
)
from snake.mrd_utils.utils import ACQ
from snake._meta import batched, EnvConfig
from mrinufft.io import read_trajectory
from collections.abc import Generator


class NonCartesianAcquisitionSampler(BaseSampler):
Expand Down Expand Up @@ -276,6 +278,48 @@ def _single_frame(self, sim_conf: SimConfig) -> NDArray:
)


class RotatedStackOfSpiralSampler(StackOfSpiralSampler):
"""
Spiral 2D Acquisition Handler to generate k-space data.

Parameters
----------
rotate_frame_angle: AngleRotation | int
Angle of rotation of the frame.
frame_index: int
Index of the frame.
**kwargs:
Extra arguments (smaps, n_jobs, backend etc...)
"""

__sampler_name__ = "rotated-stack-of-spiral"
rotate_frame_angle: AngleRotation | int = 0
frame_index: int = 0

def fix_angle_rotation(
self, frame: Generator[np.ndarray, None, None], angle: AngleRotation | float = 0
) -> Generator[np.ndarray, None, None]:
"""Rotate the trajectory by a given angle."""
for traj in frame:
yield from rotate_trajectory((x for x in [traj]), angle)

def get_next_frame(self, sim_conf: SimConfig) -> NDArray:
"""Generate the next rotated frame."""
base_frame = self._single_frame(sim_conf)
if self.constant or self.rotate_frame_angle == 0:
return base_frame
else:
self.frame_index += 1
rotate_frame_angle = np.pi * (self.rotate_frame_angle / 180)
base_frame_gen = (traj[None, ...] for traj in base_frame)
rotated_frame = self.fix_angle_rotation(
base_frame_gen, float(rotate_frame_angle * self.frame_index)
)
return np.concatenate(
[traj.astype(np.float32) for traj in rotated_frame], axis=0
)


class EPI3dAcquisitionSampler(BaseSampler):
"""Sampling pattern for EPI-3D."""

Expand Down
2 changes: 1 addition & 1 deletion src/snake/toolkit/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,5 @@ def cleanup_cuda() -> None:
def make_hydra_cli(fun: callable) -> callable:
"""Create a Hydra CLI for the function."""
return hydra.main(
version_base=None, config_path="../../../cli-conf", config_name="config"
version_base=None, config_path="../../../cli-conf", config_name="scenario2-2d"
)(fun)
15 changes: 12 additions & 3 deletions src/snake/toolkit/reconstructors/pysap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
from numpy.typing import NDArray
from mrinufft.density import get_density

# Local imports
from snake.mrd_utils import (
Expand Down Expand Up @@ -155,9 +156,15 @@ def _reconstruct_nufft(
kwargs["density"] = None
else:
kwargs["density"] = self.density_compensation
method = self.density_compensation
if "stacked" in self.nufft_backend:
kwargs["z_index"] = "auto"

if isinstance(method, str):
method = get_density(method)
if not callable(method):
raise ValueError(f"Unknown density method: {method}")

nufft_operator = get_operator(
self.nufft_backend,
samples=traj,
Expand All @@ -171,9 +178,11 @@ def _reconstruct_nufft(
for i in tqdm(range(data_loader.n_frames)):
traj, data = data_loader.get_kspace_frame(i)
if data_loader.slice_2d:
nufft_operator.samples = traj.reshape(
data_loader.n_shots, -1, traj.shape[-1]
)[0, :, :2]
traj = traj.reshape(data_loader.n_shots, -1, traj.shape[-1])
nufft_operator.samples = traj[0, :, :2]
nufft_operator.density = method(
traj[:, :, :2], shape, backend=self.nufft_backend
)
data = np.reshape(data, (data.shape[0], data_loader.n_shots, -1))
for j in range(data.shape[1]):
final_images[i, :, :, j] = abs(nufft_operator.adj_op(data[:, j]))
Expand Down
Loading