Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TOPUP fieldmap correction workflow #202

Merged
merged 8 commits into from
Jan 17, 2025
Merged
229 changes: 140 additions & 89 deletions PUMI/pipelines/func/deconfound.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
from pathlib import Path
from nipype import Function
from nipype.algorithms import confounds
from nipype.interfaces import afni, fsl, utility
Expand All @@ -10,167 +8,220 @@
from PUMI.plot.carpet_plot import plot_carpet


@QcPipeline(inputspec_fields=['func_1', 'func_2', 'func_corrected'],
@QcPipeline(inputspec_fields=['main', 'fmap', 'func_corrected'],
outputspec_fields=['out_file'])
def fieldmap_correction_qc(wf, volume='middle', **kwargs):
def qc_fieldmap_correction_topup(wf, volume='first', **kwargs):
"""

Quality check image generation for fieldmap correction pipeline.
Generate quality control image for the fieldmap correction consisting of a montage image
comparing a main volume, a fieldmap volume and a volume of the corrected fieldmap.

Parameters:
volume (str): The volume of the functional data to be used for comparison (e.g., 'middle').
Default is 'first'.

Inputs:
func_1 (str): Path to functional image (e.g. LR phase encoded rsfMRI).
func_2 (str): Path to functional image with another phase encoding than func_1 (e.g. RL phase encoded rsfMRI).
func_corrected (str): Path to fieldmap corrected functional image.
main (str): Path to the main sequence functional image (e.g., functional data).
fmap (str): Path to the fieldmap image (e.g., uncorrected fieldmap data).
func_corrected (str): Path to the fieldmap-corrected functional image.

Outputs:
out_file (str): Path to quality check image.
out_file (str): Path to the saved QC montage image comparing the original and corrected images.

Sinking:
- Quality check image.
- Path to QC comparison image (PNG file showing the original and corrected volumes).

"""

def get_cut_cords(func, n_slices=10):
import nibabel as nib
import numpy as np

func_img = nib.load(func)
y_dim = func_img.shape[1] # y-dimension (coronal direction) is the second dimension in the image shape

slices = np.linspace(-y_dim / 2, y_dim / 2, n_slices)
# slices might contain floats but this is not a problem since nilearn will round floats to the
# nearest integer value!
return slices

def create_montage(vol_1, vol_2, vol_corrected, n_slices=10):
def create_montage(vol_main, vol_fmap, vol_corrected, n_slices=3):
from matplotlib import pyplot as plt
from pathlib import Path
from nilearn import plotting
import os

fig, axes = plt.subplots(3, 1, facecolor='black', figsize=(10, 15))
def get_cut_cords(func, n_slices=3):
import nibabel as nib
import numpy as np

func_img = nib.load(func)
y_dim = func_img.shape[1] # y-dimension (coronal direction) is the second dimension in the image shape

plotting.plot_anat(vol_1, display_mode='y', cut_coords=get_cut_cords(vol_1, n_slices=n_slices),
slices = np.linspace(-y_dim / 2, y_dim / 2, n_slices)
# slices might contain floats but this is not a problem since nilearn will round floats to the
# nearest integer value!
return slices

fig, axes = plt.subplots(3, 1, facecolor='black', figsize=(12, 18))
plt.subplots_adjust(hspace=0.4)
plotting.plot_anat(vol_main, display_mode='y', cut_coords=get_cut_cords(vol_main, n_slices=n_slices),
title='Image #1', black_bg=True, axes=axes[0])
plotting.plot_anat(vol_2, display_mode='y', cut_coords=get_cut_cords(vol_2, n_slices=n_slices),
plotting.plot_anat(vol_fmap, display_mode='y', cut_coords=get_cut_cords(vol_fmap, n_slices=n_slices),
title='Image #2', black_bg=True, axes=axes[1])
plotting.plot_anat(vol_corrected, display_mode='y', cut_coords=get_cut_cords(vol_corrected, n_slices=n_slices),
title='Corrected', black_bg=True, axes=axes[2])

path = str(Path(os.getcwd() + '/fieldmap_correction_comparison.png'))
plt.savefig(path)
#path = Path.cwd() / 'fieldmap_correction_comparison.png'
path = os.path.join(os.getcwd(), 'fieldmap_correction_comparison.png')
plt.savefig(path, dpi=300)
plt.close(fig)
return path

vol_1 = pick_volume('vol_1', volume=volume)
wf.connect('inputspec', 'func_1', vol_1, 'in_file')
vol_main = pick_volume('vol_main', volume=volume)
wf.connect('inputspec', 'main', vol_main, 'in_file')

vol_2 = pick_volume('vol_2', volume=volume)
wf.connect('inputspec', 'func_2', vol_2, 'in_file')
vol_fmap = pick_volume('vol_fmap', volume=volume)
wf.connect('inputspec', 'fmap', vol_fmap, 'in_file')

vol_corrected = pick_volume('vol_corrected', volume=volume)
wf.connect('inputspec', 'func_corrected', vol_corrected, 'in_file')

montage = Node(Function(
input_names=['vol_1', 'vol_2', 'vol_corrected'],
input_names=['vol_main', 'vol_fmap', 'vol_corrected'],
output_names=['out_file'],
function=create_montage),
name='montage_node'
)
wf.connect(vol_1, 'out_file', montage, 'vol_1')
wf.connect(vol_2, 'out_file', montage, 'vol_2')
wf.connect(vol_main, 'out_file', montage, 'vol_main')
wf.connect(vol_fmap, 'out_file', montage, 'vol_fmap')
wf.connect(vol_corrected, 'out_file', montage, 'vol_corrected')

wf.connect(montage, 'out_file', 'outputspec', 'out_file')
wf.connect(montage, 'out_file', 'sinker', 'qc_fieldmap_correction')


@FuncPipeline(inputspec_fields=['func_1', 'func_2'],
@FuncPipeline(inputspec_fields=['main', 'main_json', 'fmap', 'fmap_json'],
outputspec_fields=['out_file'])
def fieldmap_correction(wf, encoding_direction=['x-', 'x'], trt=[0.0522, 0.0522], tr=0.72, **kwargs):
def fieldmap_correction_topup(wf, num_volumes=5, **kwargs):
"""

Fieldmap correction pipeline.
Perform fieldmap correction on the functional data using FSL's TOPUP.

Parameters:
encoding_direction (list): List of encoding directions (default is left-right and right-left phase encoding).
trt (list): List of total readout times (default adapted to rsfMRI data of the HCP WU 1200 dataset).
Default is:
1*(10**(-3))*EchoSpacingMS*EpiFactor = 1*(10**(-3))*0.58*90 = 0.0522 (for LR and RL image)
tr (float): Repetition time (default adapted to rsfMRI data of the HCP WU 1200 dataset).
num_volumes (int): Number of volumes to extract from the main functional sequence for averaging.
Default is 5.

Inputs:
func_1 (str): Path to functional image (e.g. LR phase encoded rsfMRI).
func_2 (str): Path to functional image with another phase encoding than func_1 (e.g. RL phase encoded rsfMRI).
main (str): Path to the main functional image (e.g., 4D functional MRI data).
main_json (str): Path to the JSON metadata for the main sequence.
fmap (str): Path to the fieldmap image (e.g., 4D fieldmap data).
fmap_json (str): Path to the JSON metadata for the fieldmap sequence.

Outputs:
out_file (str): 4d distortion corrected image.
out_file (str): Path to the corrected 4D functional image after fieldmap correction.

Sinking:
- 4d distortion corrected image.
- Corrected functional sequence after fieldmap correction.
- QC results for fieldmap correction.

"""

# Extract how many volumes from the main sequence we are told to extract
num_volumes = int(wf.cfg_parser.get('FIELDMAP-CORRECTION', 'num_volumes', fallback=num_volumes))

For more information:
https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/topup/ExampleTopupFollowedByApplytopup
https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/topup/Faq#How_do_I_know_what_phase-encode_vectors_to_put_into_my_--datain_text_file.3F
https://www.humanconnectome.org/storage/app/media/documentation/s1200/HCP_S1200_Release_Appendix_I.pdf
# Extract the first num_volumes volumes from main sequence
extract_main_volumes = Node(fsl.ExtractROI(t_min=0, t_size=num_volumes), name='extract_main_volumes')
wf.connect('inputspec', 'main', extract_main_volumes, 'in_file')

"""
# Compute the mean of extracted main volumes
mean_main = Node(fsl.MeanImage(), name='mean_main')
wf.connect(extract_main_volumes, 'roi_file', mean_main, 'in_file')

# Average all fieldmap volumes
mean_fmap = Node(fsl.MeanImage(), name='mean_fmap')
wf.connect('inputspec', 'fmap', mean_fmap, 'in_file')

# Retrieve encoding direction, total readout time and repetition time
def retrieve_image_params_function(main_json, fmap_json):
import json

with open(main_json, 'r') as f:
main_metadata = json.load(f)

with open(fmap_json, 'r') as f:
fmap_metadata = json.load(f)

for key in ['PhaseEncodingDirection', 'TotalReadoutTime', 'RepetitionTime']:
main_value = main_metadata.get(key, None)
fmap_value = fmap_metadata.get(key, None)

if main_value is None:
raise ValueError(f'JSON of main sequence is missing the key {key}!')

items_to_list_function = lambda item_1, item_2: [item_1, item_2] # helper function we will need later
if fmap_value is None:
raise ValueError(f'JSON of fieldmap sequence is missing the key {key}!')

# We use the first volume of func_1 and the first volume of the func_2 4D-image for the estimation of the field.
first_func1_vol = pick_volume('first_func1_vol', volume='first')
wf.connect('inputspec', 'func_1', first_func1_vol, 'in_file')
main_encoding_direction = main_metadata.get('PhaseEncodingDirection')
main_total_readout_time = main_metadata.get('TotalReadoutTime')
main_repetition_time = main_metadata.get('RepetitionTime')

first_func2_vol = pick_volume('first_func2_vol', volume='first')
wf.connect('inputspec', 'func_2', first_func2_vol, 'in_file')
fmap_encoding_direction = fmap_metadata.get('PhaseEncodingDirection')
fmap_total_readout_time = fmap_metadata.get('TotalReadoutTime')
fmap_repetition_time = fmap_metadata.get('RepetitionTime')

# We need to combine the two 3D images we extracted into one 4D image
# fsl.Merge expects a list as input, so we need to combine our two 3D images first into a list
first_volumes_to_list = Node(Function(
if main_encoding_direction == fmap_encoding_direction:
raise ValueError(f'Encoding direction of main sequence and fieldmap sequence are not allowed to be the same, but found {main_encoding_direction} and {fmap_encoding_direction}!')

if main_total_readout_time != fmap_total_readout_time:
raise ValueError(f'TRT of main sequence IS NOT EQUAL to fieldmap TRT ({main_total_readout_time}) != {fmap_total_readout_time})')

if main_repetition_time != fmap_repetition_time:
raise ValueError(f'TR of main sequence IS NOT EQUAL to fieldmap TR ({main_repetition_time}) != {fmap_repetition_time})')

# In case we have Siemens j-notation instead of y, replace j by y
main_encoding_direction = main_encoding_direction.replace('j', 'y')
fmap_encoding_direction = fmap_encoding_direction.replace('j', 'y')

encoding_direction = [main_encoding_direction, fmap_encoding_direction]
total_readout_time = main_total_readout_time
repetition_time = main_repetition_time

return encoding_direction, total_readout_time, repetition_time

retrieve_image_params = Node(
utility.Function(
input_names=['main_json', 'fmap_json'],
output_names=['encoding_direction', 'total_readout_time', 'repetition_time'],
function=retrieve_image_params_function
),
name='retrieve_image_params'
)
wf.connect('inputspec', 'main_json', retrieve_image_params, 'main_json')
wf.connect('inputspec', 'fmap_json', retrieve_image_params, 'fmap_json')

def combine_items_to_list(item_1, item_2):
return [item_1, item_2]

avg_volumes_to_list = Node(Function(
input_names=['item_1', 'item_2'],
output_names=['output'],
function=items_to_list_function),
name='first_volumes_to_list'
function=combine_items_to_list),
name='avg_volumes_to_list'
)
wf.connect(first_func1_vol, 'out_file', first_volumes_to_list, 'item_1')
wf.connect(first_func2_vol, 'out_file', first_volumes_to_list, 'item_2')
wf.connect(mean_main, 'out_file', avg_volumes_to_list, 'item_1')
wf.connect(mean_fmap, 'out_file', avg_volumes_to_list, 'item_2')

# Now combine 3D images to 4D image along the time axis
merger = Node(fsl.Merge(), name='merger')
merger.inputs.dimension = 't'
merger.inputs.output_type = 'NIFTI_GZ'
merger.inputs.tr = tr
wf.connect(first_volumes_to_list, 'output', merger, 'in_files')
# Combine averaged main and averaged fieldmap into a 4D image
merge_avg_images = Node(fsl.Merge(dimension='t'), name='merge_avg_images')
wf.connect(avg_volumes_to_list, 'output', merge_avg_images, 'in_files')
wf.connect(retrieve_image_params, 'repetition_time', merge_avg_images, 'tr')

# Estimate susceptibility induced distortions
topup = Node(fsl.TOPUP(), name='topup')
topup.inputs.encoding_direction = encoding_direction
topup.inputs.readout_times = trt
wf.connect(merger, 'merged_file', topup, 'in_file')

# The two original 4D files are also needed inside a list
func_files_to_list = Node(Function(
input_names=['item_1', 'item_2'],
output_names=['output'],
function=items_to_list_function),
name='func_files_to_list'
)
wf.connect('inputspec', 'func_1', func_files_to_list, 'item_1')
wf.connect('inputspec', 'func_2', func_files_to_list, 'item_2')
wf.connect(merge_avg_images, 'merged_file', topup, 'in_file')
wf.connect(retrieve_image_params, 'total_readout_time', topup, 'readout_times')
wf.connect(retrieve_image_params, 'encoding_direction', topup, 'encoding_direction')

# Apply result of fsl.TOPUP to our original data
# Result will be one 4D distortion corrected image
apply_topup = Node(fsl.ApplyTOPUP(), name='apply_topup')
wf.connect(func_files_to_list, 'output', apply_topup, 'in_files')
apply_topup = Node(fsl.ApplyTOPUP(method='jac'), name='apply_topup')
wf.connect('inputspec', 'main', apply_topup, 'in_files')
wf.connect(topup, 'out_fieldcoef', apply_topup, 'in_topup_fieldcoef')
wf.connect(topup, 'out_movpar', apply_topup, 'in_topup_movpar')
wf.connect(topup, 'out_enc_file', apply_topup, 'encoding_file')

qc_fieldmap_correction = fieldmap_correction_qc('qc_fieldmap_correction')
wf.connect('inputspec', 'func_1', qc_fieldmap_correction, 'func_1')
wf.connect('inputspec', 'func_2', qc_fieldmap_correction, 'func_2')
qc_fieldmap_correction = qc_fieldmap_correction_topup('qc_fieldmap_correction')
wf.connect('inputspec', 'main', qc_fieldmap_correction, 'main')
wf.connect('inputspec', 'fmap', qc_fieldmap_correction, 'fmap')
wf.connect(topup, 'out_corrected', qc_fieldmap_correction, 'func_corrected')

wf.connect(apply_topup, 'out_corrected', 'outputspec', 'out_file')
Expand Down
3 changes: 3 additions & 0 deletions PUMI/settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ save_mask = 1
# set overwrite_existing to '1' to overwrite existing predictions otherwise set to '0'
overwrite_existing = 1

[FIELDMAP-CORRECTION]
num_volumes = 5

[TEMPLATES]
head = data/standard/MNI152_T1_2mm.nii.gz
#also okay: head = tpl-MNI152Lin/tpl-MNI152Lin_res-02_T1w.nii.gz; source=templateflow
Expand Down
Loading
Loading