Skip to content

Commit

Permalink
Sync and speed-up trajectory separation via slice fft
Browse files Browse the repository at this point in the history
  • Loading branch information
mcencini committed Jan 22, 2024
1 parent e4917d0 commit cf93f8d
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 19 deletions.
6 changes: 3 additions & 3 deletions src/deepmr/io/generic/mrd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def read_mrd(filepath, external=False):

# sort
data, traj, dcf, ordering = _sort_data(data, trajdcf, acquisitions, mrdhead)

# get constrats info
TI = mrd._get_inversion_times(mrdhead)
TE = mrd._get_echo_times(mrdhead)
Expand Down Expand Up @@ -196,12 +196,12 @@ def _sort_data(data, trajdcf, acquisitions, mrdhead):
if data is not None:
datatmp = np.zeros([ncoils] + list(shape), dtype=np.complex64)
_data_sorting(datatmp, data, icontrast, iz, iview)
data = np.ascontiguousarray(datatmp.squeeze())
data = datatmp

if trajdcf is not None:
trajdcftmp = np.zeros(list(shape) + [ndims], dtype=np.float32)
_trajdcf_sorting(trajdcftmp, trajdcf, icontrast, iz, iview)
trajdcf = np.ascontiguousarray(trajdcftmp.squeeze())
trajdcf = trajdcftmp
traj, dcf = trajdcf[..., :-1], trajdcf[..., -1]
else:
# actual sorting
Expand Down
36 changes: 21 additions & 15 deletions src/deepmr/io/kspace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def read_rawdata(filepath, acqheader=None, device="cpu", verbose=0):
data, head = _gehc.read_gehc_rawdata(filepath, acqheader)
done = True
except Exception:
raise
pass

# siemens
# if not(done):
Expand All @@ -76,11 +76,11 @@ def read_rawdata(filepath, acqheader=None, device="cpu", verbose=0):
raise RuntimeError(f"File (={filepath}) not recognized!")
if verbose == 2:
t1 = time.time()
print(f"done! Elapsed time: {round(t1-t0, 2)} s...")
print(f"done! Elapsed time: {round(t1-t0, 2)} s")

# transpose
data = data.transpose(2, 0, 1, 3, 4) # (slice, coil, contrast, view, sample)

# select actual readout
if verbose == 2:
nsamples = data.shape[-1]
Expand All @@ -89,7 +89,7 @@ def read_rawdata(filepath, acqheader=None, device="cpu", verbose=0):
data = _select_readout(data, head)
if verbose == 2:
t1 = time.time()
print(f"done! Selected {data.shape[-1]} out of {nsamples} samples. Elapsed time: {round(t1-t0, 2)} s...")
print(f"done! Selected {data.shape[-1]} out of {nsamples} samples. Elapsed time: {round(t1-t0, 2)} s")

# center fov
if verbose == 2:
Expand All @@ -98,14 +98,14 @@ def read_rawdata(filepath, acqheader=None, device="cpu", verbose=0):
ndim = head.traj.shape[-1]
shift = head._shift[:ndim]
if ndim == 2:
print(f"Shifting FoV by (dx={shift[0]}, dy={shift[1]}) mm...", end="\t")
print(f"Shifting FoV by (dx={shift[0]}, dy={shift[1]}) mm", end="\t")
if ndim == 3:
print(f"Shifting FoV by (dx={shift[0]}, dy={shift[1]}, dz={shift[2]}) mm...", end="\t")
print(f"Shifting FoV by (dx={shift[0]}, dy={shift[1]}, dz={shift[2]}) mm", end="\t")
data = _fov_centering(data, head)
if verbose == 2:
if head.traj is not None:
t1 = time.time()
print(f"done! Elapsed time: {round(t1-t0, 2)} s...")
print(f"done! Elapsed time: {round(t1-t0, 2)} s")

# remove oversampling for Cartesian
if "mode" in head.user:
Expand All @@ -114,10 +114,10 @@ def read_rawdata(filepath, acqheader=None, device="cpu", verbose=0):
t0 = time.time()
ns1 = data.shape[0]
ns2 = head.shape[0]
print(f"Removing oversample along readout ({round(ns1/ns2, 2)})...", end="\t")
print(f"Removing oversampling along readout ({round(ns1/ns2, 2)})...", end="\t")
data, head = _remove_oversampling(data, head)
t1 = time.time()
print(f"done! Elapsed time: {round(t1-t0, 2)} s...")
print(f"done! Elapsed time: {round(t1-t0, 2)} s")

# transpose readout in slice direction for 3D Cartesian
if "mode" in head.user:
Expand All @@ -132,31 +132,35 @@ def read_rawdata(filepath, acqheader=None, device="cpu", verbose=0):
data = _fft(data, 0)
if verbose == 2:
t1 = time.time()
print(f"done! Elapsed time: {round(t1-t0, 4)} s...")
print(f"done! Elapsed time: {round(t1-t0, 4)} s")

# set-up transposition
if "mode" in head.user:
if head.user["mode"] == "2Dcart":
head.transpose = [1, 0, 2, 3]
if verbose == 2:
print("Acquisition mode: 2D Cartesian")
print(f"K-space shape: (nslices={data.shape[0]}, nchannels={data.shape[1]}, ncontrasts={data.shape[2]}, ny={data.shape[3]}, nx={data.shape[4]})")
print(f"Expected image shape: (nslices={data.shape[0]}, nchannels={data.shape[1]}, ncontrasts={data.shape[2]}, ny={head.shape[1]}, nx={head.shape[2]})")
elif head.user["mode"] == "2Dnoncart":
head.transpose = [1, 0, 2, 3]
if verbose == 2:
print("Acquisition mode: 2D Non-Cartesian")
print(f"K-space shape: (nslices={data.shape[0]}, nchannels={data.shape[1]}, ncontrasts={data.shape[2]}, nviews={data.shape[3]}, nsamples={data.shape[4]})")
print(f"Expected image shape: (nslices={data.shape[0]}, nchannels={data.shape[0]}, ncontrasts={data.shape[1]}, ny={head.shape[1]}, nx={head.shape[2]})")
elif head.user["mode"] == "3Dnoncart":
data = data[0]
head.transpose = [1, 0, 2, 3]
if verbose == 2:
print("Acquisition mode: 3D Non-Cartesian")
print(f"K-space shape: (nchannels={data.shape[0]}, ncontrasts={data.shape[1]}, nviews={data.shape[2]}, nsamples={data.shape[3]})")
print(f"Expected image shape: (nchannels={data.shape[0]}, ncontrasts={data.shape[1]}, nz={head.shape[0]}, ny={head.shape[1]}, nx={head.shape[2]})")
elif head.user["mode"] == "3Dcart":
head.transpose = [1, 2, 3, 0]
if verbose == 2:
print("Acquisition mode: 3D Cartesian")
print(f"K-space shape: (nx={data.shape[0]}, nchannels={data.shape[1]}, ncontrasts={data.shape[2]}, nz={data.shape[3]}, ny={data.shape[4]})")
print(f"Expected image shape: (nx={data.shape[0]}, nchannels={data.shape[1]}, ncontrasts={data.shape[2]}, nz={head.shape[3]}, ny={head.shape[4]})")
print(f"Expected image shape: (nx={head.shape[2]}, nchannels={data.shape[1]}, ncontrasts={data.shape[2]}, nz={head.shape[0]}, ny={head.shape[1]})")

# remove unused trajectory for cartesian
if head.user["mode"][2:] == "cart":
Expand Down Expand Up @@ -205,9 +209,9 @@ def read_rawdata(filepath, acqheader=None, device="cpu", verbose=0):

tend = time.time()
if verbose == 1:
print(f"done! Elapsed time: {round(tend-tstart, 2)} s...")
print(f"done! Elapsed time: {round(tend-tstart, 2)} s")
elif verbose == 2:
print(f"Total elapsed time: {round(tend-tstart, 2)} s...")
print(f"Total elapsed time: {round(tend-tstart, 2)} s")

return data, head

Expand All @@ -222,7 +226,7 @@ def _select_readout(data, head):

def _fov_centering(data, head):

if head.traj is not None:
if head.traj is not None and np.allclose(head._shift, 0.0) is False:

# ndimensions
ndim = head.traj.shape[-1]
Expand Down Expand Up @@ -251,4 +255,6 @@ def _remove_oversampling(data, head):
return data, head

def _fft(data, axis):
return np.fft.fftshift(np.fft.fft(np.fft.fftshift(data, axes=axis), axis=axis), axes=axis)
tmp = torch.as_tensor(data)
tmp = torch.fft.fftshift(torch.fft.fft(torch.fft.fftshift(tmp, dim=axis), dim=axis), dim=axis)
return tmp.numpy()
2 changes: 1 addition & 1 deletion src/deepmr/io/kspace/mrd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def read_mrd_rawdata(filepath):
head : deepmr.Header
Metadata for image reconstruction.
"""
data, head = mrd.read_mrd(filepath, external=True)
data, head = mrd.read_mrd(filepath)

return data, head

Expand Down

0 comments on commit cf93f8d

Please sign in to comment.