From 1d4c70c319050cc277e7b7e94b7bee4db1614035 Mon Sep 17 00:00:00 2001 From: BBillot <31892068+BBillot@users.noreply.github.com> Date: Fri, 5 Nov 2021 15:21:35 +0000 Subject: [PATCH] code maintenance for SynthSeg (#895) * fixed posteriors * updated hypothalamic_subunits * use None shape for UNet input * cosmetic changes to hypo subunits * added SynthSeg * aligned SynthSR scripts to SynthSeg format * changed order between alignment and resampling in SynthSeg * changed mri_SynthSR to SynthSR * very small changes to hypo subunits * added resampling to hypo_seg and aligned prediction pipelines * harmonized inputs between hypo_seg, SynthSeg, and SynthSR * used consistent synthsr synthseg notations and used fs.fatal * used fs.fatal for hypo_seg * updated CMakeLists.txt * fixed cpu call * important debugging * enabled synthseg crop to be a list * renamed synthseg mri_synthseg * code maintenance for n_neutral_labels * updated lab2im --- .../mri_segment_hypothalamic_subunits | 132 +++++++++++++----- mri_synthseg/mri_synthseg | 122 ++++++++-------- mri_synthsr/mri_synthsr | 17 ++- mri_synthsr/mri_synthsr_hyperfine | 50 ++++--- 4 files changed, 199 insertions(+), 122 deletions(-) diff --git a/mri_segment_hypothalamic_subunits/mri_segment_hypothalamic_subunits b/mri_segment_hypothalamic_subunits/mri_segment_hypothalamic_subunits index d9e81f1d1ad..6926362964a 100755 --- a/mri_segment_hypothalamic_subunits/mri_segment_hypothalamic_subunits +++ b/mri_segment_hypothalamic_subunits/mri_segment_hypothalamic_subunits @@ -134,7 +134,7 @@ def predict(name_subjects=None, path_volumes) # get label and classes lists - label_list = np.concatenate([np.zeros(1, dtype='int32'), np.arange(801, 811)]) + segmentation_labels = np.concatenate([np.zeros(1, dtype='int32'), np.arange(801, 811)]) # prepare volume file if needed if path_main_volumes is not None: @@ -150,7 +150,7 @@ def predict(name_subjects=None, # build network _, _, n_dims, n_channels, _, _ = get_volume_info(path_images[0]) model_input_shape = [None] * n_dims + [n_channels] - net = build_model(path_model, model_input_shape, len(label_list)) + net = build_model(path_model, model_input_shape, len(segmentation_labels)) # perform segmentation loop_info = LoopInfo(len(path_images), 10, 'predicting', True) @@ -178,7 +178,7 @@ def predict(name_subjects=None, # postprocessing try: - seg, posteriors = postprocess(prediction_patch, shape, crop_idx, label_list, aff) + seg, posteriors = postprocess(prediction_patch, shape, crop_idx, segmentation_labels, aff) except Exception as e: print('\nthe following problem occured when postprocessing segmentation %s :' % path_segmentation) print(e) @@ -1055,7 +1055,7 @@ def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3): nib.save(nifty, path) -def get_volume_info(path_volume, return_volume=False, aff_ref=None): +def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels=10): """ Gather information about a volume: shape, affine matrix, number of dimensions and channels, header, and resolution. :param path_volume: path of the volume to get information form. @@ -1070,7 +1070,7 @@ def get_volume_info(path_volume, return_volume=False, aff_ref=None): # understand if image is multichannel im_shape = list(im.shape) - n_dims, n_channels = get_dims(im_shape, max_channels=10) + n_dims, n_channels = get_dims(im_shape, max_channels=max_channels) im_shape = im_shape[:n_dims] # get labels res @@ -1132,7 +1132,10 @@ def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None): elif isinstance(var, tuple): var = list(var) elif isinstance(var, np.ndarray): - var = np.squeeze(var).tolist() + if var.shape == (1,): + var = [var[0]] + else: + var = np.squeeze(var).tolist() elif isinstance(var, str): var = [var] elif isinstance(var, bool): @@ -1252,7 +1255,7 @@ def get_dims(shape, max_channels=10): return n_dims, n_channels -def add_axis(x, axis): +def add_axis(x, axis=0): """Add axis to a numpy array. :param axis: index of the new axis to add. Can also be a list of indices to add several axes at the same time.""" axis = reformat_to_list(axis) @@ -1323,23 +1326,23 @@ class LoopInfo: print(self.text + ' {}'.format(iteration)) -def find_closest_number_divisible_by_m(n, m, smaller_ans=True): - """Return the closest integer to n that is divisible by m. - If smaller_ans is True, only values lower than n are considered.""" - # quotient - q = int(n / m) - # 1st possible closest number - n1 = m * q - # 2nd possible closest number - if (n * m) > 0: - n2 = (m * (q + 1)) - else: - n2 = (m * (q - 1)) - # find closest solution - if (abs(n - n1) < abs(n - n2)) | smaller_ans: - return n1 +def find_closest_number_divisible_by_m(n, m, answer_type='lower'): + """Return the closest integer to n that is divisible by m. answer_type can either be 'closer', 'lower' (only returns + values lower than n), or 'higher (only returns values higher than m).""" + if n % m == 0: + return n else: - return n2 + q = int(n / m) + lower = q * m + higher = (q + 1) * m + if answer_type == 'lower': + return lower + elif answer_type == 'higher': + return higher + elif answer_type == 'closer': + return lower if (n - lower) < (higher - n) else higher + else: + raise Exception('answer_type should be lower, higher, or closer, had : %s' % answer_type) def build_binary_structure(connectivity, n_dims, shape=None): @@ -1451,30 +1454,55 @@ def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, ret return output[0] if len(output) == 1 else tuple(output) -def crop_volume_around_region(volume, mask=None, threshold=0.1, masking_labels=None, margin=0, aff=None): - """Crop a volume around a specific region. This region is defined by a mask obtained by either +def crop_volume_around_region(volume, + mask=None, + masking_labels=None, + threshold=0.1, + margin=0, + cropping_shape=None, + cropping_shape_div_by=None, + aff=None): + """Crop a volume around a specific region. + This region is defined by a mask obtained by either: 1) directly specifying it as input - 2) thresholding the input volume - 3) keeping a set of label values if the volume is a label map. + 2) keeping a set of label values (defined by masking_labels) if the volume is a label map. + 3) thresholding the input volume + The cropped region is defined by either: + 1) cropping around the non-zero values in the above-defined mask (possibly with a margin) + 2) cropping to a specified shape, centered around the middle of the above-defined mask + 3) cropping to a shape divisible by the given number, centered around the middle of the above-defined mask :param volume: a 2d or 3d numpy array :param mask: (optional) mask of region to crop around. Must be same size as volume. Can either be boolean or 0/1. - it defaults to masking around all values above threshold. - :param threshold: (optional) if mask is None, lower bound to determine values to crop around + If no mask is given, it will be computed by either thresholding the input volume or using masking_labels. :param masking_labels: (optional) if mask is None, and if the volume is a label map, it can be cropped around a set of labels specified in masking_labels, which can either be a single int, a sequence or a 1d numpy array. + :param threshold: (optional) if mask amd masking_labels are None, lower bound to determine values to crop around. :param margin: (optional) add margin around mask + :param cropping_shape: (optional) shape to which the input volumes must be cropped. Volumes are padded around the + centre of the above-defined mask is they are too small for the given shape. Can be an integer or sequence. + Cannot be given at the same time as margin or cropping_shape_div_by. + :param cropping_shape_div_by: (optional) makes sure the shape of the cropped region is divisible by the provided + number. If it is not, then we enlarge the cropping area. If the enlarged area is too big fort he input volume, we + pad it with 0. Must be a integer. Cannot be given at the same time as margin or cropping_shape. :param aff: (optional) if specified, this function returns an updated affine matrix of the volume after cropping. :return: the cropped volume, the cropping indices (in the order [lower_bound_dim_1, ..., upper_bound_dim_1, ...]), and the updated affine matrix if aff is not None. """ + assert not ((margin > 0) & (cropping_shape is not None)), "margin and cropping_shape can't be given together." + assert not ((margin > 0) & (cropping_shape_div_by is not None)), \ + "margin and cropping_shape_div_by can't be given together." + assert not ((cropping_shape_div_by is not None) & (cropping_shape is not None)), \ + "cropping_shape_div_by and cropping_shape can't be given together." + new_vol = volume.copy() - n_dims, _ = get_dims(new_vol.shape) + n_dims, n_channels = get_dims(new_vol.shape) + vol_shape = np.array(new_vol.shape[:n_dims]) # mask ROIs for cropping if mask is None: if masking_labels is not None: - masked_volume, mask = mask_label_map(new_vol, masking_values=masking_labels, return_mask=True) + _, mask = mask_label_map(new_vol, masking_values=masking_labels, return_mask=True) else: mask = new_vol > threshold @@ -1482,16 +1510,48 @@ def crop_volume_around_region(volume, mask=None, threshold=0.1, masking_labels=N if np.any(mask): indices = np.nonzero(mask) min_idx = np.maximum(np.array([np.min(idx) for idx in indices]) - margin, 0) - max_idx = np.minimum(np.array([np.max(idx) for idx in indices]) + 1 + margin, np.array(new_vol.shape[:n_dims])) + max_idx = np.minimum(np.array([np.max(idx) for idx in indices]) + 1 + margin, vol_shape) cropping = np.concatenate([min_idx, max_idx]) + # modify the cropping indices if we want the output to have a given shape + if (cropping_shape is not None) | (cropping_shape_div_by is not None): + + # expand/retract (depending on the desired shape) the cropping region around the centre + intermediate_vol_shape = max_idx - min_idx + if cropping_shape is not None: + cropping_shape = np.array(reformat_to_list(cropping_shape, length=n_dims)) + else: + cropping_shape = [find_closest_number_divisible_by_m(s, cropping_shape_div_by, answer_type='higher') + for s in intermediate_vol_shape] + min_idx = min_idx - np.int32(np.ceil((cropping_shape - intermediate_vol_shape)/2)) + max_idx = max_idx + np.int32(np.floor((cropping_shape - intermediate_vol_shape)/2)) + + # check if we need to pad the output to the desired shape + min_padding = np.abs(np.minimum(min_idx, 0)) + max_padding = np.maximum(max_idx - vol_shape, 0) + if np.any(min_padding > 0) | np.any(max_padding > 0): + pad_margins = tuple([(min_padding[i], max_padding[i]) for i in range(n_dims)]) + else: + pad_margins = None + cropping = np.concatenate([np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)]) + + else: + pad_margins = None + # crop volume if n_dims == 3: - new_vol = new_vol[min_idx[0]:max_idx[0], min_idx[1]:max_idx[1], min_idx[2]:max_idx[2], ...] + new_vol = new_vol[cropping[0]:cropping[3], cropping[1]:cropping[4], cropping[2]:cropping[5], ...] elif n_dims == 2: - new_vol = new_vol[min_idx[0]:max_idx[0], min_idx[1]:max_idx[1], ...] + new_vol = new_vol[cropping[0]:cropping[2], cropping[1]:cropping[3], ...] else: raise ValueError('cannot crop volumes with more than 3 dimensions') + + # pad volume if necessary + if pad_margins is not None: + pad_margins = tuple(list(pad_margins) + [(0, 0)]) if n_channels > 1 else pad_margins + new_vol = np.pad(new_vol, pad_margins, mode='constant', constant_values=0) + + # if there's nothing to crop around, we return the input as is else: min_idx = np.zeros((3, 1)) cropping = None @@ -1536,7 +1596,7 @@ def crop_volume_with_idx(volume, crop_idx, aff=None, n_dims=None): return new_volume -def resample_volume(volume, aff, new_vox_size): +def resample_volume(volume, aff, new_vox_size, interpolation='linear'): """This function resizes the voxels of a volume to a new provided size, while adjusting the header to keep the RAS :param volume: a numpy array :param aff: affine matrix of the volume @@ -1557,7 +1617,7 @@ def resample_volume(volume, aff, new_vox_size): y = np.arange(0, volume_filt.shape[1]) z = np.arange(0, volume_filt.shape[2]) - my_interpolating_function = RegularGridInterpolator((x, y, z), volume_filt) + my_interpolating_function = RegularGridInterpolator((x, y, z), volume_filt, method=interpolation) start = - (factor - 1) / (2 * factor) step = 1.0 / factor diff --git a/mri_synthseg/mri_synthseg b/mri_synthseg/mri_synthseg index e5c3c1b1341..df2983e7e4f 100644 --- a/mri_synthseg/mri_synthseg +++ b/mri_synthseg/mri_synthseg @@ -64,12 +64,12 @@ def main(): # locate model weights if not fs.fshome(): - path_model = segmentation_label_list = segmentation_names_list = topology_classes = None + path_model = segmentation_labels = segmentation_names_list = topology_classes = None fs.fatal('FREESURFER_HOME is not set. Please source freesurfer.') else: path_model = os.path.join(fs.fshome(), 'models', 'synthseg_1.0.h5') - segmentation_label_list = os.path.join(fs.fshome(), 'models', 'synthseg_segmentation_labels.npy') - segmentation_names_list = os.path.join(fs.fshome(), 'models', 'synthseg_segmentation_names.npy') + segmentation_labels = os.path.join(fs.fshome(), 'models', 'synthseg_segmentation_labels.npy') + segmentation_label_names = os.path.join(fs.fshome(), 'models', 'synthseg_segmentation_names.npy') topology_classes = os.path.join(fs.fshome(), 'models', 'synthseg_topological_classes.npy') # run prediction @@ -77,8 +77,8 @@ def main(): path_images=args.i, path_segmentations=args.o, path_model=path_model, - segmentation_label_list=segmentation_label_list, - segmentation_names_list=segmentation_names_list, + segmentation_labels=segmentation_labels, + segmentation_label_names=segmentation_label_names, topology_classes=topology_classes, path_posteriors=args.post, path_resampled=args.resample, @@ -96,8 +96,8 @@ def main(): def predict(path_images, path_segmentations, path_model, - segmentation_label_list, - segmentation_names_list, + segmentation_labels, + segmentation_label_names, topology_classes, path_posteriors=None, path_resampled=None, @@ -122,41 +122,47 @@ def predict(path_images, tf.config.threading.set_inter_op_parallelism_threads(threads) tf.config.threading.set_intra_op_parallelism_threads(threads) - # build correspondance table between contralateral structures - label_list, n_neutral = get_list_labels(label_list=segmentation_label_list, FS_sort=True) - n_labels = len(label_list) - n_side_labels = int((n_labels - n_neutral) / 2) - lr_corresp = np.stack([label_list[n_neutral:n_neutral + n_side_labels], label_list[n_neutral + n_side_labels:]]) + # get label list + segmentation_labels, _ = get_list_labels(label_list=segmentation_labels) + n_labels = len(segmentation_labels) + n_neutral_labels = 18 - # get final version of label list - label_list, _ = get_list_labels(label_list=segmentation_label_list, FS_sort=False) - label_list, indices = np.unique(label_list, return_index=True) + # build correspondance table between contralateral structures + n_sided_labels = int((n_labels - n_neutral_labels) / 2) + lr_corresp = np.stack([segmentation_labels[n_neutral_labels:n_neutral_labels + n_sided_labels], + segmentation_labels[n_neutral_labels + n_sided_labels:]]) + + # get unique label values + segmentation_labels, indices = np.unique(segmentation_labels, return_index=True) + + # get indices of corresponding contralateral structures in new label order + lr_corresp_unique, lr_corresp_indices = np.unique(lr_corresp[0, :], return_index=True) + lr_corresp_unique = np.stack([lr_corresp_unique, lr_corresp[1, lr_corresp_indices]]) + lr_corresp_unique = lr_corresp_unique[:, 1:] if not np.all(lr_corresp_unique[:, 0]) else lr_corresp_unique + lr_indices = np.zeros_like(lr_corresp_unique) + for i in range(lr_corresp_unique.shape[0]): + for j, lab in enumerate(lr_corresp_unique[i]): + lr_indices[i, j] = np.where(segmentation_labels == lab)[0] # prepare topology classes if topology_classes is not None: topology_classes = load_array_if_path(topology_classes, load_as_numpy=True)[indices] - # get correspondance for labels with different right/left values - lr_indices = np.zeros_like(lr_corresp) - for i in range(lr_corresp.shape[0]): - for j, lab in enumerate(lr_corresp[i]): - lr_indices[i, j] = np.where(label_list == lab)[0] - # prepare volume file if needed if path_volumes is not None: - if segmentation_names_list is not None: - segmentation_names_list = load_array_if_path(segmentation_names_list)[indices] - csv_header = [[''] + segmentation_names_list[1:].tolist()] - csv_header += [[''] + [str(lab) for lab in label_list[1:]]] + if segmentation_label_names is not None: + segmentation_label_names = load_array_if_path(segmentation_label_names)[indices] + csv_header = [[''] + segmentation_label_names[1:].tolist()] + csv_header += [[''] + [str(lab) for lab in segmentation_labels[1:]]] else: - csv_header = [['subjects'] + [str(lab) for lab in label_list[1:]]] + csv_header = [['subjects'] + [str(lab) for lab in segmentation_labels[1:]]] with open(path_volumes, 'w') as csvFile: writer = csv.writer(csvFile) writer.writerows(csv_header) csvFile.close() # build network - net = build_model(path_model, len(label_list)) + net = build_model(path_model, len(segmentation_labels)) # perform segmentation loop_info = LoopInfo(len(path_images), 10, 'predicting', True) @@ -186,7 +192,7 @@ def predict(path_images, # postprocessing try: - seg, posteriors = postprocess(prediction_patch, pad_shape, shape, crop_idx, label_list, lr_indices, + seg, posteriors = postprocess(prediction_patch, pad_shape, shape, crop_idx, segmentation_labels, lr_indices, aff, topology_classes, prediction_patch_flip) except Exception as e: print('\nthe following problem occured when postprocessing segmentation %s :' % path_segmentation) @@ -419,7 +425,7 @@ def build_model(model_file, n_lab): return net -def postprocess(post_patch, pad_shape, im_shape, crop, labels, left_right_indices, aff, +def postprocess(post_patch, pad_shape, im_shape, crop, segmentation_labels, left_right_indices, aff, topology_classes, post_patch_flip): # get posteriors @@ -451,11 +457,11 @@ def postprocess(post_patch, pad_shape, im_shape, crop, labels, left_right_indice # paste patches back to matrix of original image size seg = np.zeros(shape=pad_shape, dtype='int32') - posteriors = np.zeros(shape=[*pad_shape, labels.shape[0]]) + posteriors = np.zeros(shape=[*pad_shape, segmentation_labels.shape[0]]) posteriors[..., 0] = np.ones(pad_shape) # place background around patch seg[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5]] = seg_patch posteriors[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5], :] = post_patch - seg = labels[seg.astype('int')].astype('int') + seg = segmentation_labels[seg.astype('int')].astype('int') if im_shape != pad_shape: bounds = [int((p-i)/2) for (p, i) in zip(pad_shape, im_shape)] @@ -901,7 +907,7 @@ def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3): nib.save(nifty, path) -def get_volume_info(path_volume, return_volume=False, aff_ref=None): +def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels=10): """ Gather information about a volume: shape, affine matrix, number of dimensions and channels, header, and resolution. :param path_volume: path of the volume to get information form. @@ -916,7 +922,7 @@ def get_volume_info(path_volume, return_volume=False, aff_ref=None): # understand if image is multichannel im_shape = list(im.shape) - n_dims, n_channels = get_dims(im_shape, max_channels=10) + n_dims, n_channels = get_dims(im_shape, max_channels=max_channels) im_shape = im_shape[:n_dims] # get labels res @@ -984,8 +990,9 @@ def get_list_labels(label_list=None, labels_dir=None, save_label_list=None, FS_s # sort labels in neutral/left/right according to FS labels n_neutral_labels = 0 if FS_sort: - neutral_FS_labels = [0, 14, 15, 16, 21, 22, 23, 24, 72, 77, 80, 85, 101, 102, 103, 104, 105, 165, 251, 252, 253, - 254, 255, 258, 259, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, + neutral_FS_labels = [0, 14, 15, 16, 21, 22, 23, 24, 72, 77, 80, 85, 100, 101, 102, 103, 104, 105, 106, 107, 108, + 109, 165, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 251, 252, 253, 254, 255, 258, 259, 260, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 502, 506, 507, 508, 509, 511, 512, 514, 515, 516, 517, 530, 531, 532, 533, 534, 535, 536, 537] neutral = list() @@ -1053,7 +1060,10 @@ def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None): elif isinstance(var, tuple): var = list(var) elif isinstance(var, np.ndarray): - var = np.squeeze(var).tolist() + if var.shape == (1,): + var = [var[0]] + else: + var = np.squeeze(var).tolist() elif isinstance(var, str): var = [var] elif isinstance(var, bool): @@ -1139,7 +1149,7 @@ def get_dims(shape, max_channels=10): return n_dims, n_channels -def add_axis(x, axis): +def add_axis(x, axis=0): """Add axis to a numpy array. :param axis: index of the new axis to add. Can also be a list of indices to add several axes at the same time.""" axis = reformat_to_list(axis) @@ -1224,23 +1234,23 @@ class LoopInfo: print(self.text + ' {}'.format(iteration)) -def find_closest_number_divisible_by_m(n, m, smaller_ans=True): - """Return the closest integer to n that is divisible by m. - If smaller_ans is True, only values lower than n are considered.""" - # quotient - q = int(n / m) - # 1st possible closest number - n1 = m * q - # 2nd possible closest number - if (n * m) > 0: - n2 = (m * (q + 1)) - else: - n2 = (m * (q - 1)) - # find closest solution - if (abs(n - n1) < abs(n - n2)) | smaller_ans: - return n1 +def find_closest_number_divisible_by_m(n, m, answer_type='lower'): + """Return the closest integer to n that is divisible by m. answer_type can either be 'closer', 'lower' (only returns + values lower than n), or 'higher (only returns values higher than m).""" + if n % m == 0: + return n else: - return n2 + q = int(n / m) + lower = q * m + higher = (q + 1) * m + if answer_type == 'lower': + return lower + elif answer_type == 'higher': + return higher + elif answer_type == 'closer': + return lower if (n - lower) < (higher - n) else higher + else: + raise Exception('answer_type should be lower, higher, or closer, had : %s' % answer_type) def build_binary_structure(connectivity, n_dims, shape=None): @@ -1459,7 +1469,7 @@ def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx= pad_idx = np.concatenate([min_margins, min_margins + np.array(vol_shape[:n_dims])]) pad_margins = tuple([(min_margins[i], max_margins[i]) for i in range(n_dims)]) if n_channels > 1: - pad_margins = tuple(list(pad_margins) + [[0, 0]]) + pad_margins = tuple(list(pad_margins) + [(0, 0)]) # pad volume new_volume = np.pad(new_volume, pad_margins, mode='constant', constant_values=padding_value) @@ -1512,7 +1522,7 @@ def flip_volume(volume, axis=None, direction=None, aff=None): return np.flip(new_volume, axis=axis) -def resample_volume(volume, aff, new_vox_size): +def resample_volume(volume, aff, new_vox_size, interpolation='linear'): """This function resizes the voxels of a volume to a new provided size, while adjusting the header to keep the RAS :param volume: a numpy array :param aff: affine matrix of the volume @@ -1533,7 +1543,7 @@ def resample_volume(volume, aff, new_vox_size): y = np.arange(0, volume_filt.shape[1]) z = np.arange(0, volume_filt.shape[2]) - my_interpolating_function = RegularGridInterpolator((x, y, z), volume_filt) + my_interpolating_function = RegularGridInterpolator((x, y, z), volume_filt, method=interpolation) start = - (factor - 1) / (2 * factor) step = 1.0 / factor diff --git a/mri_synthsr/mri_synthsr b/mri_synthsr/mri_synthsr index 8d59059e062..3af91e0779c 100644 --- a/mri_synthsr/mri_synthsr +++ b/mri_synthsr/mri_synthsr @@ -702,7 +702,7 @@ def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3): nib.save(nifty, path) -def get_volume_info(path_volume, return_volume=False, aff_ref=None): +def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels=10): """ Gather information about a volume: shape, affine matrix, number of dimensions and channels, header, and resolution. :param path_volume: path of the volume to get information form. @@ -717,7 +717,7 @@ def get_volume_info(path_volume, return_volume=False, aff_ref=None): # understand if image is multichannel im_shape = list(im.shape) - n_dims, n_channels = get_dims(im_shape, max_channels=10) + n_dims, n_channels = get_dims(im_shape, max_channels=max_channels) im_shape = im_shape[:n_dims] # get labels res @@ -779,7 +779,10 @@ def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None): elif isinstance(var, tuple): var = list(var) elif isinstance(var, np.ndarray): - var = np.squeeze(var).tolist() + if var.shape == (1,): + var = [var[0]] + else: + var = np.squeeze(var).tolist() elif isinstance(var, str): var = [var] elif isinstance(var, bool): @@ -865,7 +868,7 @@ def get_dims(shape, max_channels=10): return n_dims, n_channels -def add_axis(x, axis): +def add_axis(x, axis=0): """Add axis to a numpy array. :param axis: index of the new axis to add. Can also be a list of indices to add several axes at the same time.""" axis = reformat_to_list(axis) @@ -994,7 +997,7 @@ def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx= pad_idx = np.concatenate([min_margins, min_margins + np.array(vol_shape[:n_dims])]) pad_margins = tuple([(min_margins[i], max_margins[i]) for i in range(n_dims)]) if n_channels > 1: - pad_margins = tuple(list(pad_margins) + [[0, 0]]) + pad_margins = tuple(list(pad_margins) + [(0, 0)]) # pad volume new_volume = np.pad(new_volume, pad_margins, mode='constant', constant_values=padding_value) @@ -1016,7 +1019,7 @@ def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx= return output[0] if len(output) == 1 else tuple(output) -def resample_volume(volume, aff, new_vox_size): +def resample_volume(volume, aff, new_vox_size, interpolation='linear'): """This function resizes the voxels of a volume to a new provided size, while adjusting the header to keep the RAS :param volume: a numpy array :param aff: affine matrix of the volume @@ -1037,7 +1040,7 @@ def resample_volume(volume, aff, new_vox_size): y = np.arange(0, volume_filt.shape[1]) z = np.arange(0, volume_filt.shape[2]) - my_interpolating_function = RegularGridInterpolator((x, y, z), volume_filt) + my_interpolating_function = RegularGridInterpolator((x, y, z), volume_filt, method=interpolation) start = - (factor - 1) / (2 * factor) step = 1.0 / factor diff --git a/mri_synthsr/mri_synthsr_hyperfine b/mri_synthsr/mri_synthsr_hyperfine index ed17c0e1ed1..ad6083f363c 100644 --- a/mri_synthsr/mri_synthsr_hyperfine +++ b/mri_synthsr/mri_synthsr_hyperfine @@ -257,7 +257,7 @@ def preprocess_image(path_t1_image, path_t2_image, n_levels=5): # resample and align image im_t1, aff_t1 = resample_volume(im_t1, aff_t1, [1.0, 1.0, 1.0]) im_t1, aff_t1_mod = align_volume_to_ref(im_t1, aff_t1, aff_ref=np.eye(4), return_aff=True, n_dims=3) - im_t2 = resample_like(im_t1, aff_t1_mod, im_t2, aff_t2) + im_t2 = resample_volume_like(im_t1, aff_t1_mod, im_t2, aff_t2) # normalise images minimum = np.min(im_t1) @@ -748,7 +748,7 @@ def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3): nib.save(nifty, path) -def get_volume_info(path_volume, return_volume=False, aff_ref=None): +def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels=10): """ Gather information about a volume: shape, affine matrix, number of dimensions and channels, header, and resolution. :param path_volume: path of the volume to get information form. @@ -763,7 +763,7 @@ def get_volume_info(path_volume, return_volume=False, aff_ref=None): # understand if image is multichannel im_shape = list(im.shape) - n_dims, n_channels = get_dims(im_shape, max_channels=10) + n_dims, n_channels = get_dims(im_shape, max_channels=max_channels) im_shape = im_shape[:n_dims] # get labels res @@ -825,7 +825,10 @@ def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None): elif isinstance(var, tuple): var = list(var) elif isinstance(var, np.ndarray): - var = np.squeeze(var).tolist() + if var.shape == (1,): + var = [var[0]] + else: + var = np.squeeze(var).tolist() elif isinstance(var, str): var = [var] elif isinstance(var, bool): @@ -911,7 +914,7 @@ def get_dims(shape, max_channels=10): return n_dims, n_channels -def add_axis(x, axis): +def add_axis(x, axis=0): """Add axis to a numpy array. :param axis: index of the new axis to add. Can also be a list of indices to add several axes at the same time.""" axis = reformat_to_list(axis) @@ -1040,7 +1043,7 @@ def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx= pad_idx = np.concatenate([min_margins, min_margins + np.array(vol_shape[:n_dims])]) pad_margins = tuple([(min_margins[i], max_margins[i]) for i in range(n_dims)]) if n_channels > 1: - pad_margins = tuple(list(pad_margins) + [[0, 0]]) + pad_margins = tuple(list(pad_margins) + [(0, 0)]) # pad volume new_volume = np.pad(new_volume, pad_margins, mode='constant', constant_values=padding_value) @@ -1062,7 +1065,7 @@ def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx= return output[0] if len(output) == 1 else tuple(output) -def resample_volume(volume, aff, new_vox_size): +def resample_volume(volume, aff, new_vox_size, interpolation='linear'): """This function resizes the voxels of a volume to a new provided size, while adjusting the header to keep the RAS :param volume: a numpy array :param aff: affine matrix of the volume @@ -1083,7 +1086,7 @@ def resample_volume(volume, aff, new_vox_size): y = np.arange(0, volume_filt.shape[1]) z = np.arange(0, volume_filt.shape[2]) - my_interpolating_function = RegularGridInterpolator((x, y, z), volume_filt) + my_interpolating_function = RegularGridInterpolator((x, y, z), volume_filt, method=interpolation) start = - (factor - 1) / (2 * factor) step = 1.0 / factor @@ -1110,21 +1113,9 @@ def resample_volume(volume, aff, new_vox_size): return volume2, aff2 -def get_ras_axes(aff, n_dims=3): - """This function finds the RAS axes corresponding to each dimension of a volume, based on its affine matrix. - :param aff: affine matrix Can be a 2d numpy array of size n_dims*n_dims, n_dims+1*n_dims+1, or n_dims*n_dims+1. - :param n_dims: number of dimensions (excluding channels) of the volume corresponding to the provided affine matrix. - :return: two numpy 1d arrays of lengtn n_dims, one with the axes corresponding to RAS orientations, - and one with their corresponding direction. - """ - aff_inverted = np.linalg.inv(aff) - img_ras_axes = np.argmax(np.absolute(aff_inverted[0:n_dims, 0:n_dims]), axis=0) - return img_ras_axes - - -def resample_like(vol_ref, aff_ref, vol_flo, aff_flo): +def resample_volume_like(vol_ref, aff_ref, vol_flo, aff_flo, interpolation='linear'): """This function reslices a floating image to the space of a reference image - :param vol_res: a numpy array with the reference volume + :param vol_ref: a numpy array with the reference volume :param aff_ref: affine matrix of the reference volume :param vol_flo: a numpy array with the floating volume :param aff_flo: affine matrix of the floating volume @@ -1137,7 +1128,8 @@ def resample_like(vol_ref, aff_ref, vol_flo, aff_flo): yf = np.arange(0, vol_flo.shape[1]) zf = np.arange(0, vol_flo.shape[2]) - my_interpolating_function = RegularGridInterpolator((xf, yf, zf), vol_flo, bounds_error=False, fill_value=0.0) + my_interpolating_function = RegularGridInterpolator((xf, yf, zf), vol_flo, bounds_error=False, fill_value=0.0, + method=interpolation) xr = np.arange(0, vol_ref.shape[0]) yr = np.arange(0, vol_ref.shape[1]) @@ -1156,6 +1148,18 @@ def resample_like(vol_ref, aff_ref, vol_flo, aff_flo): return result.reshape(vol_ref.shape) +def get_ras_axes(aff, n_dims=3): + """This function finds the RAS axes corresponding to each dimension of a volume, based on its affine matrix. + :param aff: affine matrix Can be a 2d numpy array of size n_dims*n_dims, n_dims+1*n_dims+1, or n_dims*n_dims+1. + :param n_dims: number of dimensions (excluding channels) of the volume corresponding to the provided affine matrix. + :return: two numpy 1d arrays of lengtn n_dims, one with the axes corresponding to RAS orientations, + and one with their corresponding direction. + """ + aff_inverted = np.linalg.inv(aff) + img_ras_axes = np.argmax(np.absolute(aff_inverted[0:n_dims, 0:n_dims]), axis=0) + return img_ras_axes + + def align_volume_to_ref(volume, aff, aff_ref=None, return_aff=False, n_dims=None): """This function aligns a volume to a reference orientation (axis and direction) specified by an affine matrix. :param volume: a numpy array