diff --git a/PUMI/pipelines/func/deconfound.py b/PUMI/pipelines/func/deconfound.py index 9e8cb5f..dbf97af 100644 --- a/PUMI/pipelines/func/deconfound.py +++ b/PUMI/pipelines/func/deconfound.py @@ -9,6 +9,87 @@ from PUMI.plot.carpet_plot import plot_carpet +@QcPipeline(inputspec_fields=['main', 'fmap', 'func_corrected'], + outputspec_fields=['out_file']) +def qc_fieldmap_correction_topup(wf, volume='first', **kwargs): + """ + + Generate quality control image for the fieldmap correction consisting of a montage image + comparing a main volume, a fieldmap volume and a volume of the corrected fieldmap. + + Parameters: + volume (str): The volume of the functional data to be used for comparison (e.g., 'middle'). + Default is 'first'. + + Inputs: + main (str): Path to the main sequence functional image (e.g., functional data). + fmap (str): Path to the fieldmap image (e.g., uncorrected fieldmap data). + func_corrected (str): Path to the fieldmap-corrected functional image. + + Outputs: + out_file (str): Path to the saved QC montage image comparing the original and corrected images. + + Sinking: + - Path to QC comparison image (PNG file showing the original and corrected volumes). + + """ + + def create_montage(vol_main, vol_fmap, vol_corrected, n_slices=3): + from matplotlib import pyplot as plt + from pathlib import Path + from nilearn import plotting + import os + + def get_cut_cords(func, n_slices=3): + import nibabel as nib + import numpy as np + + func_img = nib.load(func) + y_dim = func_img.shape[1] # y-dimension (coronal direction) is the second dimension in the image shape + + slices = np.linspace(-y_dim / 2, y_dim / 2, n_slices) + # slices might contain floats but this is not a problem since nilearn will round floats to the + # nearest integer value! + return slices + + fig, axes = plt.subplots(3, 1, facecolor='black', figsize=(12, 18)) + plt.subplots_adjust(hspace=0.4) + plotting.plot_anat(vol_main, display_mode='y', cut_coords=get_cut_cords(vol_main, n_slices=n_slices), + title='Image #1', black_bg=True, axes=axes[0]) + plotting.plot_anat(vol_fmap, display_mode='y', cut_coords=get_cut_cords(vol_fmap, n_slices=n_slices), + title='Image #2', black_bg=True, axes=axes[1]) + plotting.plot_anat(vol_corrected, display_mode='y', cut_coords=get_cut_cords(vol_corrected, n_slices=n_slices), + title='Corrected', black_bg=True, axes=axes[2]) + + #path = Path.cwd() / 'fieldmap_correction_comparison.png' + path = os.path.join(os.getcwd(), 'fieldmap_correction_comparison.png') + plt.savefig(path, dpi=300) + plt.close(fig) + return path + + vol_main = pick_volume('vol_main', volume=volume) + wf.connect('inputspec', 'main', vol_main, 'in_file') + + vol_fmap = pick_volume('vol_fmap', volume=volume) + wf.connect('inputspec', 'fmap', vol_fmap, 'in_file') + + vol_corrected = pick_volume('vol_corrected', volume=volume) + wf.connect('inputspec', 'func_corrected', vol_corrected, 'in_file') + + montage = Node(Function( + input_names=['vol_main', 'vol_fmap', 'vol_corrected'], + output_names=['out_file'], + function=create_montage), + name='montage_node' + ) + wf.connect(vol_main, 'out_file', montage, 'vol_main') + wf.connect(vol_fmap, 'out_file', montage, 'vol_fmap') + wf.connect(vol_corrected, 'out_file', montage, 'vol_corrected') + + wf.connect(montage, 'out_file', 'outputspec', 'out_file') + wf.connect(montage, 'out_file', 'sinker', 'qc_fieldmap_correction') + + @QcPipeline(inputspec_fields=['background', 'overlay'], outputspec_fields=['out_file']) def qc_fieldmap_correction_fugue(wf, overlay_volume='middle', **kwargs): @@ -57,7 +138,146 @@ def create_fieldmap_plot(overlay, background): # output wf.connect(plot, 'out_file', 'outputspec', 'out_file') + +@FuncPipeline(inputspec_fields=['main', 'main_json', 'fmap', 'fmap_json'], + outputspec_fields=['out_file']) +def fieldmap_correction_topup(wf, num_volumes=5, **kwargs): + """ + + Perform fieldmap correction on the functional data using FSL's TOPUP. + + Parameters: + num_volumes (int): Number of volumes to extract from the main functional sequence for averaging. + Default is 5. + + Inputs: + main (str): Path to the main functional image (e.g., 4D functional MRI data). + main_json (str): Path to the JSON metadata for the main sequence. + fmap (str): Path to the fieldmap image (e.g., 4D fieldmap data). + fmap_json (str): Path to the JSON metadata for the fieldmap sequence. + + Outputs: + out_file (str): Path to the corrected 4D functional image after fieldmap correction. + + Sinking: + - Corrected functional sequence after fieldmap correction. + - QC results for fieldmap correction. + + """ + + # Extract how many volumes from the main sequence we are told to extract + num_volumes = int(wf.cfg_parser.get('FIELDMAP-CORRECTION', 'num_volumes', fallback=num_volumes)) + + # Extract the first num_volumes volumes from main sequence + extract_main_volumes = Node(fsl.ExtractROI(t_min=0, t_size=num_volumes), name='extract_main_volumes') + wf.connect('inputspec', 'main', extract_main_volumes, 'in_file') + + # Compute the mean of extracted main volumes + mean_main = Node(fsl.MeanImage(), name='mean_main') + wf.connect(extract_main_volumes, 'roi_file', mean_main, 'in_file') + + # Average all fieldmap volumes + mean_fmap = Node(fsl.MeanImage(), name='mean_fmap') + wf.connect('inputspec', 'fmap', mean_fmap, 'in_file') + + # Retrieve encoding direction, total readout time and repetition time + def retrieve_image_params_function(main_json, fmap_json): + import json + + with open(main_json, 'r') as f: + main_metadata = json.load(f) + + with open(fmap_json, 'r') as f: + fmap_metadata = json.load(f) + + for key in ['PhaseEncodingDirection', 'TotalReadoutTime', 'RepetitionTime']: + main_value = main_metadata.get(key, None) + fmap_value = fmap_metadata.get(key, None) + + if main_value is None: + raise ValueError(f'JSON of main sequence is missing the key {key}!') + + if fmap_value is None: + raise ValueError(f'JSON of fieldmap sequence is missing the key {key}!') + + main_encoding_direction = main_metadata.get('PhaseEncodingDirection') + main_total_readout_time = main_metadata.get('TotalReadoutTime') + main_repetition_time = main_metadata.get('RepetitionTime') + fmap_encoding_direction = fmap_metadata.get('PhaseEncodingDirection') + fmap_total_readout_time = fmap_metadata.get('TotalReadoutTime') + fmap_repetition_time = fmap_metadata.get('RepetitionTime') + + if main_encoding_direction == fmap_encoding_direction: + raise ValueError(f'Encoding direction of main sequence and fieldmap sequence are not allowed to be the same, but found {main_encoding_direction} and {fmap_encoding_direction}!') + + if main_total_readout_time != fmap_total_readout_time: + raise ValueError(f'TRT of main sequence IS NOT EQUAL to fieldmap TRT ({main_total_readout_time}) != {fmap_total_readout_time})') + + if main_repetition_time != fmap_repetition_time: + raise ValueError(f'TR of main sequence IS NOT EQUAL to fieldmap TR ({main_repetition_time}) != {fmap_repetition_time})') + + # In case we have Siemens j-notation instead of y, replace j by y + main_encoding_direction = main_encoding_direction.replace('j', 'y') + fmap_encoding_direction = fmap_encoding_direction.replace('j', 'y') + + encoding_direction = [main_encoding_direction, fmap_encoding_direction] + total_readout_time = main_total_readout_time + repetition_time = main_repetition_time + + return encoding_direction, total_readout_time, repetition_time + + retrieve_image_params = Node( + utility.Function( + input_names=['main_json', 'fmap_json'], + output_names=['encoding_direction', 'total_readout_time', 'repetition_time'], + function=retrieve_image_params_function + ), + name='retrieve_image_params' + ) + wf.connect('inputspec', 'main_json', retrieve_image_params, 'main_json') + wf.connect('inputspec', 'fmap_json', retrieve_image_params, 'fmap_json') + + def combine_items_to_list(item_1, item_2): + return [item_1, item_2] + + avg_volumes_to_list = Node(Function( + input_names=['item_1', 'item_2'], + output_names=['output'], + function=combine_items_to_list), + name='avg_volumes_to_list' + ) + wf.connect(mean_main, 'out_file', avg_volumes_to_list, 'item_1') + wf.connect(mean_fmap, 'out_file', avg_volumes_to_list, 'item_2') + + # Combine averaged main and averaged fieldmap into a 4D image + merge_avg_images = Node(fsl.Merge(dimension='t'), name='merge_avg_images') + wf.connect(avg_volumes_to_list, 'output', merge_avg_images, 'in_files') + wf.connect(retrieve_image_params, 'repetition_time', merge_avg_images, 'tr') + + # Estimate susceptibility induced distortions + topup = Node(fsl.TOPUP(), name='topup') + wf.connect(merge_avg_images, 'merged_file', topup, 'in_file') + wf.connect(retrieve_image_params, 'total_readout_time', topup, 'readout_times') + wf.connect(retrieve_image_params, 'encoding_direction', topup, 'encoding_direction') + + # Apply result of fsl.TOPUP to our original data + # Result will be one 4D distortion corrected image + apply_topup = Node(fsl.ApplyTOPUP(method='jac'), name='apply_topup') + wf.connect('inputspec', 'main', apply_topup, 'in_files') + wf.connect(topup, 'out_fieldcoef', apply_topup, 'in_topup_fieldcoef') + wf.connect(topup, 'out_movpar', apply_topup, 'in_topup_movpar') + wf.connect(topup, 'out_enc_file', apply_topup, 'encoding_file') + + qc_fieldmap_correction = qc_fieldmap_correction_topup('qc_fieldmap_correction') + wf.connect('inputspec', 'main', qc_fieldmap_correction, 'main') + wf.connect('inputspec', 'fmap', qc_fieldmap_correction, 'fmap') + wf.connect(topup, 'out_corrected', qc_fieldmap_correction, 'func_corrected') + + wf.connect(apply_topup, 'out_corrected', 'outputspec', 'out_file') + wf.connect(apply_topup, 'out_corrected', 'sinker', 'out_file') + + @FuncPipeline(inputspec_fields=['main_img', 'main_json', 'anat_img', 'phasediff_img', 'phasediff_json', 'magnitude_img'], outputspec_fields=['out_file']) @@ -142,7 +362,7 @@ def get_fieldmap_parameters(main_json, phasediff_json): wf.connect('inputspec', 'anat_img', qc, 'background') wf.connect(fugue, 'unwarped_file', 'outputspec', 'out_file') - + @FuncPipeline(inputspec_fields=['in_file'], outputspec_fields=['out_file']) diff --git a/PUMI/settings.ini b/PUMI/settings.ini index 5f67e9a..c10b005 100644 --- a/PUMI/settings.ini +++ b/PUMI/settings.ini @@ -24,6 +24,9 @@ save_mask = 1 # set overwrite_existing to '1' to overwrite existing predictions otherwise set to '0' overwrite_existing = 1 +[FIELDMAP-CORRECTION] +num_volumes = 5 + [TEMPLATES] head = data/standard/MNI152_T1_2mm.nii.gz #also okay: head = tpl-MNI152Lin/tpl-MNI152Lin_res-02_T1w.nii.gz; source=templateflow diff --git a/pipelines/rcpl-unittests.py b/pipelines/rcpl-unittests.py new file mode 100644 index 0000000..c96cfcb --- /dev/null +++ b/pipelines/rcpl-unittests.py @@ -0,0 +1,464 @@ +#!/usr/bin/env python3 + +from nipype.interfaces.fsl import Reorient2Std +from nipype.interfaces import afni +from PUMI.engine import BidsPipeline, NestedNode as Node, FuncPipeline, GroupPipeline, BidsApp, \ + create_dataset_description +from PUMI.pipelines.anat.anat_proc import anat_proc +from PUMI.pipelines.func.compcor import anat_noise_roi, compcor +from PUMI.pipelines.anat.func_to_anat import func2anat +from nipype.interfaces import utility +from PUMI.pipelines.func.func_proc import func_proc_despike_afni +from PUMI.pipelines.func.timeseries_extractor import pick_atlas, extract_timeseries_nativespace +from PUMI.utils import mist_modules, mist_labels, get_reference +from PUMI.pipelines.func.func2standard import func2standard +from PUMI.pipelines.multimodal.image_manipulation import pick_volume +import traits +import os + + +def relabel_mist_atlas(atlas_file, modules, labels): + """ + Relabel MIST atlas + * Beware : currently works only with labelmap!! + Parameters: + atlas_file(str): Path to the atlas file + modules ([str]): List containing the modules in MIST + labels ([str]): List containing the labels in MIST + Returns: + relabel_file (str): Path to relabeld atlas file + reordered_modules ([str]): list containing reordered module names + reordered_labels ([str]): list containing reordered label names + new_labels (str): Path to .tsv-file with the new labels + """ + + import os + import numpy as np + import pandas as pd + import nibabel as nib + + df = pd.DataFrame({'modules': modules, 'labels': labels}) + df.index += 1 # indexing from 1 + + reordered = df.sort_values(by='modules') + + # relabel labelmap + img = nib.load(atlas_file) + if len(img.shape) != 3: + raise Exception("relabeling does not work for probability maps!") + + lut = reordered.reset_index().sort_values(by="index").index.values + 1 + lut = np.array([0] + lut.tolist()) + # maybe this is a bit complicated, but believe me it does what it should + + data = img.get_fdata() + newdata = lut[np.array(data, dtype=np.int32)] # apply lookup table to swap labels + + img = nib.Nifti1Image(newdata.astype(np.float64), img.affine) + nib.save(img, 'relabeled_atlas.nii.gz') + + out = reordered.reset_index() + out.index = out.index + 1 + relabel_file = os.path.join(os.getcwd(), 'relabeled_atlas.nii.gz') + reordered_modules = reordered['modules'].values.tolist() + reordered_labels = reordered['labels'].values.tolist() + + newlabels_file = os.path.join(os.getcwd(), 'newlabels.tsv') + out.to_csv(newlabels_file, sep='\t') + return relabel_file, reordered_modules, reordered_labels, newlabels_file + +@GroupPipeline(inputspec_fields=['labelmap', 'modules', 'labels'], + outputspec_fields=['relabeled_atlas', 'reordered_labels', 'reordered_modules']) +def mist_atlas(wf, reorder=True, **kwargs): + + resample_atlas = Node( + interface=afni.Resample( + outputtype='NIFTI_GZ', + master=get_reference(wf, 'brain'), + ), + name='resample_atlas' + ) + + if reorder: + # reorder if modules is given (like for MIST atlases) + relabel_atls = Node( + interface=utility.Function( + input_names=['atlas_file', 'modules', 'labels'], + output_names=['relabelled_atlas_file', 'reordered_modules', 'reordered_labels', 'newlabels_file'], + function=relabel_mist_atlas + ), + name='relabel_atls' + ) + wf.connect('inputspec', 'labelmap', relabel_atls, 'atlas_file') + wf.connect('inputspec', 'modules', relabel_atls, 'modules') + wf.connect('inputspec', 'labels', relabel_atls, 'labels') + + wf.connect(relabel_atls, 'relabelled_atlas_file', resample_atlas, 'in_file') + else: + wf.connect('inputspec', 'labelmap', resample_atlas, 'in_file') + + # Sinking + wf.connect(resample_atlas, 'out_file', 'sinker', 'atlas') + if reorder: + wf.connect(relabel_atls, 'newlabels_file', 'sinker', 'reordered_labels') + else: + wf.connect('inputspec', 'labels', 'sinker', 'atlas_labels') + + # Output + wf.connect(resample_atlas, 'out_file', 'outputspec', 'relabeled_atlas') + if reorder: + wf.connect(relabel_atls, 'reordered_labels', 'outputspec', 'reordered_labels') + wf.connect(relabel_atls, 'reordered_modules', 'outputspec', 'reordered_modules') + else: + wf.connect('inputspec', 'labels', 'outputspec', 'reordered_labels') + wf.connect('inputspec', 'modules', 'outputspec', 'reordered_modules') + + +@FuncPipeline(inputspec_fields=['ts_files', 'fd_files', 'scrub_threshold'], + outputspec_fields=['features', 'out_file']) +def calculate_connectivity(wf, **kwargs): + + def calc_connectivity(ts_files, fd_files, scrub_threshold): + import os + import pandas as pd + import numpy as np + from PUMI.PAINTeR import load_timeseries, connectivity_matrix + + if not isinstance(ts_files, (list, np.ndarray)): # in this case we assume we have a string or path-like object + ts_files = [ts_files] + if not isinstance(fd_files, (list, np.ndarray)): # in this case we assume we have a string or path-like object + fd_files = [fd_files] + FD = [] + mean_FD = [] + median_FD = [] + max_FD = [] + perc_scrubbed = [] + for f in fd_files: + fd = pd.read_csv(f, sep="\t").values.flatten() + fd = np.insert(fd, 0, 0) + FD.append(fd.ravel()) + mean_FD.append(fd.mean()) + median_FD.append(np.median(fd)) + max_FD.append(fd.max()) + perc_scrubbed.append(100 - 100 * len(fd) / len(fd[fd <= scrub_threshold])) + + df = pd.DataFrame() + df['ts_file'] = ts_files + df['fd_file'] = fd_files + df['meanFD'] = mean_FD + df['medianFD'] = median_FD + df['maxFD'] = max_FD + df['perc_scrubbed'] = perc_scrubbed + + ts, labels = load_timeseries(ts_files, df, scrubbing=True, scrub_threshold=scrub_threshold) + features, cm = connectivity_matrix(np.array(ts)) + + path = os.path.abspath('motion.csv') + df.to_csv(path) + return features, path + + connectivity_wf = Node( + utility.Function( + input_names=['ts_files', 'fd_files', 'scrub_threshold'], + output_names=['features', 'out_file'], + function=calc_connectivity + ), + name="connectivity_wf" + ) + wf.connect('inputspec', 'ts_files', connectivity_wf, 'ts_files') + wf.connect('inputspec', 'fd_files', connectivity_wf, 'fd_files') + + if isinstance(wf.get_node('inputspec').inputs.scrub_threshold, traits.trait_base._Undefined): + connectivity_wf.inputs.scrub_threshold = .15 + else: + wf.connect('inputspec', 'scrub_threshold', connectivity_wf, 'scrub_threshold') + + wf.connect(connectivity_wf, 'features', 'outputspec', 'features') + wf.connect(connectivity_wf, 'out_file', 'outputspec', 'out_file') + + wf.connect(connectivity_wf, 'out_file', 'sinker', 'connectivity') + + +@FuncPipeline(inputspec_fields=['X', 'in_file'], + outputspec_fields=['score', 'out_file']) +def predict_pain_sensitivity_rpn(wf, **kwargs): + """ + + Perform pain sensitivity prediction using the RPN signature + (Resting-state Pain susceptibility Network signature). + Further information regarding the signature: https://spisakt.github.io/RPN-signature/ + + Inputs: + X (array-like): Input data for pain sensitivity prediction + in_file (str): Path to the bold file that was used to create X + + Outputs: + predicted (float): Predicted pain sensitivity score + out_file (str): Absolute path to the output CSV file containing the prediction result + + Sinking: + CSV file containing the prediction result + + """ + + def predict(X, in_file): + from PUMI.utils import rpn_model + import pandas as pd + import PUMI + import os + import importlib + + with importlib.resources.path('resources', 'model_rpn.json') as file: + model_json = file + + model = rpn_model(file=model_json) + predicted = model.predict(X) + + path = os.path.abspath('rpn-prediction.csv') + df = pd.DataFrame() + df['in_file'] = [in_file] + df['RPN'] = predicted + df.to_csv(path, index=False) + return predicted, path + + predict_wf = Node( + utility.Function( + input_names=['X', 'in_file'], + output_names=['score', 'out_file'], + function=predict + ), + name="predict_wf" + ) + wf.connect('inputspec', 'X', predict_wf, 'X') + wf.connect('inputspec', 'in_file', predict_wf, 'in_file') + + wf.connect(predict_wf, 'score', 'outputspec', 'score') + wf.connect(predict_wf, 'out_file', 'outputspec', 'out_file') + wf.connect(predict_wf, 'out_file', 'sinker', 'rpn') + + +@FuncPipeline(inputspec_fields=['X', 'in_file'], + outputspec_fields=['score', 'out_file']) +def predict_pain_sensitivity_rcpl(wf, model_path=None, **kwargs): + """ + + Perform pain sensitivity prediction using RCPL signature + (Resting-state functional Connectivity signature of Pain-related Learning). + Further information regarding the signature: https://github.com/kincsesbalint/paintone_rsn + + Parameters: + model_path (str, optional): Path to the pre-trained model relative to PUMI's data_in folder. + If set to None, PUMI's build in RCPL model is used. + + Inputs: + X (array-like): Input data for pain sensitivity prediction + in_file (str): Path to the bold file that was used to create X + + Outputs: + predicted (float): Predicted pain sensitivity score + out_file (str): Absolute path to the output CSV file containing the prediction result + + Sinking: + CSV file containing the prediction result + + """ + + def predict(X, in_file, model_path): + import pandas as pd + import os + import PUMI + import joblib + import importlib + + if model_path is None: + with importlib.resources.path('resources', 'rcpl_model.sav') as file: + model = joblib.load(file) + else: + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model file not found: {model_path}") + + data_in_folder = os.path.join(os.path.dirname(os.path.abspath(PUMI.__file__)), '..', 'data_in') + model_path = os.path.join(data_in_folder, model_path) + model = joblib.load(model_path) + + predicted = model.predict(X) + + path = os.path.abspath('rcpl-prediction.csv') + df = pd.DataFrame() + df['in_file'] = [in_file] + df['RCPL'] = predicted + df.to_csv(path, index=False) + + return predicted, path + + predict_wf = Node( + utility.Function( + input_names=['X', 'in_file', 'model_path'], + output_names=['score', 'out_file'], + function=predict + ), + name="predict_wf" + ) + wf.connect('inputspec', 'X', predict_wf, 'X') + wf.connect('inputspec', 'in_file', predict_wf, 'in_file') + predict_wf.inputs.model_path = model_path + + wf.connect(predict_wf, 'score', 'outputspec', 'score') + wf.connect(predict_wf, 'out_file', 'outputspec', 'out_file') + wf.connect(predict_wf, 'out_file', 'sinker', 'rcpl') + + +@FuncPipeline(inputspec_fields=['rpn_out_file', 'rcpl_out_file'], + outputspec_fields=['out_file']) +def collect_pain_predictions(wf, **kwargs): + """ + + Merge the out_file's of pain sensitivity predictions generated using the RCPL and RPN methods into one file + + Inputs: + rpn_out_file (str): Path to the out_file generated by the RPN method + rcpl_out_file (str): Path to the out_file generated by the RCPL method + + Outputs: + out_file (str): Absolute path to the output CSV file containing the RPN and RCPL predictions. + + Sinking: + CSV file containing RPN and RCPL predictions + + """ + + def merge_predictions(rpn_out_file, rcpl_out_file): + import pandas as pd + import os + + df_rpn = pd.read_csv(rpn_out_file) + df_rcpl = pd.read_csv(rcpl_out_file) + + # Check if in_file columns are the same + if df_rpn['in_file'].iloc[0] != df_rcpl['in_file'].iloc[0]: + raise ValueError("The 'in_file' columns in the two CSV files are not the same!") + + merged_df = pd.DataFrame() + merged_df['in_file'] = df_rpn['in_file'] + merged_df['RPN'] = df_rpn['RPN'] + merged_df['RCPL'] = df_rcpl['RCPL'] + + path = os.path.abspath('pain-sensitivity-predictions.csv') + merged_df.to_csv(path, index=False) + + return path + + merge_predictions_wf = Node( + utility.Function( + input_names=['rpn_out_file', 'rcpl_out_file'], + output_names=['out_file'], + function=merge_predictions + ), + name="merge_predictions_wf" + ) + wf.connect('inputspec', 'rpn_out_file', merge_predictions_wf, 'rpn_out_file') + wf.connect('inputspec', 'rcpl_out_file', merge_predictions_wf, 'rcpl_out_file') + + wf.connect(merge_predictions_wf, 'out_file', 'outputspec', 'out_file') + wf.connect(merge_predictions_wf, 'out_file', 'sinker', 'pain_predictions') + + +@BidsPipeline(output_query={ + 'T1w': dict( + datatype='anat', + suffix="T1w", + extension=['nii', 'nii.gz'] + ), + 'bold': dict( + datatype='func', + suffix="bold", + extension=['nii', 'nii.gz'] + ) +}) +def rcpl(wf, bbr=True, **kwargs): + + print('* bbr:', bbr) + + reorient_struct_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_struct_wf") + wf.connect('inputspec', 'T1w', reorient_struct_wf, 'in_file') + + reorient_func_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_func_wf") + wf.connect('inputspec', 'bold', reorient_func_wf, 'in_file') + + anatomical_preprocessing_wf = anat_proc(name='anatomical_preprocessing_wf', bet_tool='deepbet') + wf.connect(reorient_struct_wf, 'out_file', anatomical_preprocessing_wf, 'in_file') + + func2anat_wf = func2anat(name='func2anat_wf', bbr=bbr) + wf.connect(reorient_func_wf, 'out_file', func2anat_wf, 'func') + wf.connect(anatomical_preprocessing_wf, 'brain', func2anat_wf, 'head') + wf.connect(anatomical_preprocessing_wf, 'probmap_wm', func2anat_wf, 'anat_wm_segmentation') + wf.connect(anatomical_preprocessing_wf, 'probmap_csf', func2anat_wf, 'anat_csf_segmentation') + wf.connect(anatomical_preprocessing_wf, 'probmap_gm', func2anat_wf, 'anat_gm_segmentation') + wf.connect(anatomical_preprocessing_wf, 'probmap_ventricle', func2anat_wf, 'anat_ventricle_segmentation') + + compcor_roi_wf = anat_noise_roi('compcor_roi_wf') + wf.connect(func2anat_wf, 'wm_mask_in_funcspace', compcor_roi_wf, 'wm_mask') + wf.connect(func2anat_wf, 'ventricle_mask_in_funcspace', compcor_roi_wf, 'ventricle_mask') + + func_proc_wf = func_proc_despike_afni('func_proc_wf', bet_tool='deepbet', deepbet_n_dilate=2) + wf.connect(reorient_func_wf, 'out_file', func_proc_wf, 'func') + wf.connect(compcor_roi_wf, 'out_file', func_proc_wf, 'cc_noise_roi') + + pick_atlas_wf = mist_atlas('pick_atlas_wf') + mist_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data_in/atlas/MIST")) + pick_atlas_wf.get_node('inputspec').inputs.labelmap = os.path.join(mist_dir, 'Parcellations/MIST_122.nii.gz') + pick_atlas_wf.get_node('inputspec').inputs.modules = mist_modules(mist_directory=mist_dir, resolution="122") + pick_atlas_wf.get_node('inputspec').inputs.labels = mist_labels(mist_directory=mist_dir, resolution="122") + + extract_timeseries = extract_timeseries_nativespace('extract_timeseries') + wf.connect(pick_atlas_wf, 'relabeled_atlas', extract_timeseries, 'atlas') + wf.connect(pick_atlas_wf, 'reordered_labels', extract_timeseries, 'labels') + wf.connect(pick_atlas_wf, 'reordered_modules', extract_timeseries, 'modules') + wf.connect(anatomical_preprocessing_wf, 'brain', extract_timeseries, 'anat') + wf.connect(func2anat_wf, 'anat_to_func_linear_xfm', extract_timeseries, 'inv_linear_reg_mtrx') + wf.connect(anatomical_preprocessing_wf, 'mni2anat_warpfield', extract_timeseries, 'inv_nonlinear_reg_mtrx') + wf.connect(func2anat_wf, 'gm_mask_in_funcspace', extract_timeseries, 'gm_mask') + wf.connect(func_proc_wf, 'func_preprocessed', extract_timeseries, 'func') + wf.connect(func_proc_wf, 'FD', extract_timeseries, 'confounds') + + func2std = func2standard('func2std') + wf.connect(anatomical_preprocessing_wf, 'brain', func2std, 'anat') + wf.connect(func2anat_wf, 'func_to_anat_linear_xfm', func2std, 'linear_reg_mtrx') + wf.connect(anatomical_preprocessing_wf, 'anat2mni_warpfield', func2std, 'nonlinear_reg_mtrx') + wf.connect(anatomical_preprocessing_wf, 'std_template', func2std, 'reference_brain') + wf.connect(func_proc_wf, 'func_preprocessed', func2std, 'func') + wf.connect(func_proc_wf, 'mc_ref_vol', func2std, 'bbr2ants_source_file') + + calculate_connectivity_wf = calculate_connectivity('calculate_connectivity_wf') + wf.connect(extract_timeseries, 'timeseries', calculate_connectivity_wf, 'ts_files') + wf.connect(func_proc_wf, 'FD', calculate_connectivity_wf, 'fd_files') + + predict_pain_sensitivity_rpn_wf = predict_pain_sensitivity_rpn('predict_pain_sensitivity_rpn_wf') + wf.connect(calculate_connectivity_wf, 'features', predict_pain_sensitivity_rpn_wf, 'X') + wf.connect('inputspec', 'bold', predict_pain_sensitivity_rpn_wf, 'in_file') + + predict_pain_sensitivity_rcpl_wf = predict_pain_sensitivity_rcpl('predict_pain_sensitivity_rcpl_wf') + wf.connect(calculate_connectivity_wf, 'features', predict_pain_sensitivity_rcpl_wf, 'X') + wf.connect('inputspec', 'bold', predict_pain_sensitivity_rcpl_wf, 'in_file') + + collect_pain_predictions_wf = collect_pain_predictions('collect_pain_predictions_wf') + wf.connect(predict_pain_sensitivity_rpn_wf, 'out_file', collect_pain_predictions_wf, 'rpn_out_file') + wf.connect(predict_pain_sensitivity_rcpl_wf, 'out_file', collect_pain_predictions_wf, 'rcpl_out_file') + + wf.write_graph('RCPL-pipeline.png') + create_dataset_description(wf, pipeline_description_name='RCPL-pipeline') + + +rcpl_app = BidsApp( + pipeline=rcpl, + name='rcpl', + bids_dir='../data_in/pumi-unittest' # if you pass a cli argument this will be written over! +) +rcpl_app.parser.add_argument( + '--bbr', + default='yes', + type=lambda x: (str(x).lower() in ['true', '1', 'yes']), + help="Use BBR registration: yes/no (default: yes)" +) + +rcpl_app.run() diff --git a/pipelines/rcpl.py b/pipelines/rcpl.py old mode 100755 new mode 100644 index c96cfcb..c22d697 --- a/pipelines/rcpl.py +++ b/pipelines/rcpl.py @@ -8,6 +8,8 @@ from PUMI.pipelines.func.compcor import anat_noise_roi, compcor from PUMI.pipelines.anat.func_to_anat import func2anat from nipype.interfaces import utility + +from PUMI.pipelines.func.deconfound import fieldmap_correction_topup from PUMI.pipelines.func.func_proc import func_proc_despike_afni from PUMI.pipelines.func.timeseries_extractor import pick_atlas, extract_timeseries_nativespace from PUMI.utils import mist_modules, mist_labels, get_reference @@ -366,13 +368,30 @@ def merge_predictions(rpn_out_file, rcpl_out_file): @BidsPipeline(output_query={ 'T1w': dict( datatype='anat', - suffix="T1w", + suffix='T1w', extension=['nii', 'nii.gz'] ), 'bold': dict( datatype='func', - suffix="bold", + suffix='bold', + extension=['nii', 'nii.gz'] + ), + 'bold_json': dict( + datatype='func', + suffix='bold', + extension='.json' + ), + 'fmap': dict( + datatype='fmap', + acquisition='bold', + suffix='epi', extension=['nii', 'nii.gz'] + ), + 'fmap_json': dict( + datatype='fmap', + acquisition='bold', + suffix='epi', + extension='.json' ) }) def rcpl(wf, bbr=True, **kwargs): @@ -385,11 +404,20 @@ def rcpl(wf, bbr=True, **kwargs): reorient_func_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_func_wf") wf.connect('inputspec', 'bold', reorient_func_wf, 'in_file') + reorient_fmap_wf = Node(Reorient2Std(output_type='NIFTI_GZ'), name="reorient_fmap_wf") + wf.connect('inputspec', 'fmap', reorient_fmap_wf, 'in_file') + + fieldmap_corr = fieldmap_correction_topup('fieldmap_corr') + wf.connect(reorient_func_wf, 'out_file', fieldmap_corr, 'main') + wf.connect('inputspec', 'bold_json', fieldmap_corr, 'main_json') + wf.connect(reorient_fmap_wf, 'out_file', fieldmap_corr, 'fmap') + wf.connect('inputspec', 'fmap_json', fieldmap_corr, 'fmap_json') + anatomical_preprocessing_wf = anat_proc(name='anatomical_preprocessing_wf', bet_tool='deepbet') wf.connect(reorient_struct_wf, 'out_file', anatomical_preprocessing_wf, 'in_file') func2anat_wf = func2anat(name='func2anat_wf', bbr=bbr) - wf.connect(reorient_func_wf, 'out_file', func2anat_wf, 'func') + wf.connect(fieldmap_corr, 'out_file', func2anat_wf, 'func') wf.connect(anatomical_preprocessing_wf, 'brain', func2anat_wf, 'head') wf.connect(anatomical_preprocessing_wf, 'probmap_wm', func2anat_wf, 'anat_wm_segmentation') wf.connect(anatomical_preprocessing_wf, 'probmap_csf', func2anat_wf, 'anat_csf_segmentation') @@ -401,7 +429,7 @@ def rcpl(wf, bbr=True, **kwargs): wf.connect(func2anat_wf, 'ventricle_mask_in_funcspace', compcor_roi_wf, 'ventricle_mask') func_proc_wf = func_proc_despike_afni('func_proc_wf', bet_tool='deepbet', deepbet_n_dilate=2) - wf.connect(reorient_func_wf, 'out_file', func_proc_wf, 'func') + wf.connect(fieldmap_corr, 'out_file', func_proc_wf, 'func') wf.connect(compcor_roi_wf, 'out_file', func_proc_wf, 'cc_noise_roi') pick_atlas_wf = mist_atlas('pick_atlas_wf') @@ -435,11 +463,11 @@ def rcpl(wf, bbr=True, **kwargs): predict_pain_sensitivity_rpn_wf = predict_pain_sensitivity_rpn('predict_pain_sensitivity_rpn_wf') wf.connect(calculate_connectivity_wf, 'features', predict_pain_sensitivity_rpn_wf, 'X') - wf.connect('inputspec', 'bold', predict_pain_sensitivity_rpn_wf, 'in_file') + wf.connect(fieldmap_corr, 'out_file', predict_pain_sensitivity_rpn_wf, 'in_file') predict_pain_sensitivity_rcpl_wf = predict_pain_sensitivity_rcpl('predict_pain_sensitivity_rcpl_wf') wf.connect(calculate_connectivity_wf, 'features', predict_pain_sensitivity_rcpl_wf, 'X') - wf.connect('inputspec', 'bold', predict_pain_sensitivity_rcpl_wf, 'in_file') + wf.connect(fieldmap_corr, 'out_file', predict_pain_sensitivity_rcpl_wf, 'in_file') collect_pain_predictions_wf = collect_pain_predictions('collect_pain_predictions_wf') wf.connect(predict_pain_sensitivity_rpn_wf, 'out_file', collect_pain_predictions_wf, 'rpn_out_file')