Skip to content

Commit

Permalink
Fixes updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chaithyagr committed Apr 11, 2024
1 parent ced9962 commit cbbb79f
Showing 1 changed file with 28 additions and 58 deletions.
86 changes: 28 additions & 58 deletions src/mrinufft/extras/smaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,8 @@
register_smaps = MethodRegister("sensitivity_maps")


@flat_traj
def _get_centeral_index(kspace_loc, threshold):
r"""
Extract the index of the k-space center.
Parameters
----------
kspace_loc: numpy.ndarray
The samples location in the k-sapec domain (between [-0.5, 0.5[)
threshold: tuple or float
The threshold used to extract the k_space center (between (0, 1])
Returns
-------
The index of the k-space center.
"""
xp = get_array_module(kspace_loc)
radius = xp.linalg.norm(kspace_loc, axis=-1)

if isinstance(threshold, float):
threshold = (threshold,) * kspace_loc.shape[-1]
condition = xp.logical_and.reduce(tuple(
xp.abs(kspace_loc[:, i]) <= threshold[i] for i in range(len(threshold))
))
index = xp.linspace(0, kspace_loc.shape[0] - 1, kspace_loc.shape[0], dtype=xp.int64)
index = xp.extract(condition, index)
return index

def extract_k_space_center_and_locations(
kspace_data, kspace_loc, threshold=None, window_fun=None,
def extract_kspace_center(
kspace_data, kspace_loc, threshold=None, window_fun="ellipse",
):
r"""
Extract k-space center and corresponding sampling locations.
Expand All @@ -48,7 +20,8 @@ def extract_k_space_center_and_locations(
The samples location in the k-sapec domain (between [-0.5, 0.5[)
threshold: tuple or float
The threshold used to extract the k_space center (between (0, 1])
window_fun: "Hann", "Hanning", "Hamming", or a callable, default None.
window_fun: "hann" / "hanning", "hamming", "ellipse", "rect", or a callable,
default "ellipse".
The window function to apply to the selected data. It is computed with
the center locations selected. Only works with circular mask.
If window_fun is a callable, it takes as input the array (n_samples x n_dims)
Expand Down Expand Up @@ -78,50 +51,47 @@ def extract_k_space_center_and_locations(
"""
xp = get_array_module(kspace_data)
radius = xp.linalg.norm(center_locations, axis=1)
data_ordered = xp.copy(kspace_data)
if isinstance(threshold, float):
threshold = (threshold,) * kspace_loc.shape[1]
condition = xp.logical_and.reduce(tuple(
xp.abs(kspace_loc[:, i]) <= threshold[i] for i in range(len(threshold))
))
index = xp.linspace(0, kspace_loc.shape[0] - 1, kspace_loc.shape[0], dtype=xp.int64)
index = xp.extract(condition, index)
center_locations = kspace_loc[index, :]
data_thresholded = data_ordered[:, index]
if window_fun is not None:

if window_fun == "rect":
data_ordered = xp.copy(kspace_data)
index = xp.linspace(0, kspace_loc.shape[0] - 1, kspace_loc.shape[0], dtype=xp.int64)
condition = xp.logical_and.reduce(tuple(
xp.abs(kspace_loc[:, i]) <= threshold[i] for i in range(len(threshold))
))
index = xp.extract(condition, index)
center_locations = kspace_loc[index, :]
data_thresholded = data_ordered[:, index]
else:
if callable(window_fun):
window = window_fun(center_locations)
else:
if window_fun == "Hann" or window_fun == "Hanning":
a_0 = 0.5
elif window_fun == "Hamming":
a_0 = 0.53836
if window_fun in ["hann", "hanning", "hamming"]:
radius = xp.linalg.norm(kspace_loc, axis=1)
a_0 = 0.5 if window_fun in ["hann", "hanning"] else 0.53836
window = a_0 + (1 - a_0) * xp.cos(xp.pi * radius / threshold)
elif window_fun == "ellipse":
window = xp.sum(kspace_loc**2/ xp.asarray(threshold)**2, axis=1) <= 1
else:
raise ValueError("Unsupported window function.")

window = a_0 + (1 - a_0) * xp.cos(xp.pi * radius / threshold)
data_thresholded = window * data_thresholded

if density_comp is not None:
density_comp = density_comp[index]
return data_thresholded, center_locations, density_comp
else:
return data_thresholded, center_locations
# Return k-space locations just for consistency
return data_thresholded, kspace_loc


@register_smaps
@with_numpy_cupy
@flat_traj
def low_frequency(traj, kspace_data, shape, backend, theshold, *args, **kwargs):
def low_frequency(traj, kspace_data, shape, backend, threshold, *args, **kwargs):
xp = get_array_module(kspace_data)
k_space, samples, dc = extract_k_space_center_and_locations(
k_space, traj = extract_kspace_center(
kspace_data=kspace_data,
kspace_loc=traj,
threshold=threshold,
img_shape=traj_params['img_size'],
img_shape=shape,
**kwargs,
)
smaps_adj_op = get_operator('gpunufft')(
smaps_adj_op = get_operator(backend)(
samples,
shape,
density=dc,
Expand Down

0 comments on commit cbbb79f

Please sign in to comment.