Skip to content

Commit

Permalink
debugging of synthsr/synthseg/hypo_seg (freesurfer#885)
Browse files Browse the repository at this point in the history
* fixed posteriors

* updated hypothalamic_subunits

* use None shape for UNet input

* cosmetic changes to hypo subunits

* added SynthSeg

* aligned SynthSR scripts to SynthSeg format

* changed order between alignment and resampling in SynthSeg

* changed mri_SynthSR to SynthSR

* very small changes to hypo subunits

* added resampling to hypo_seg and aligned prediction pipelines

* harmonized inputs between hypo_seg, SynthSeg, and SynthSR

* used consistent synthsr synthseg notations and used fs.fatal

* used fs.fatal for hypo_seg

* updated CMakeLists.txt

* fixed cpu call

* important debugging
  • Loading branch information
BBillot authored Oct 13, 2021
1 parent 938c756 commit 932ba60
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 60 deletions.
44 changes: 27 additions & 17 deletions mri_segment_hypothalamic_subunits/mri_segment_hypothalamic_subunits
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def predict(name_subjects=None,

# preprocessing
try:
image, aff, h, im_res, shape, crop_idx = preprocess_image(path_image, crop)
image, aff, h, im_res, shape, crop_idx = preprocess_image(path_image, crop, path_resample=path_resample)
except Exception as e:
print('\nthe following problem occured when preprocessing image %s :' % path_image)
print(e)
Expand Down Expand Up @@ -188,9 +188,9 @@ def predict(name_subjects=None,
# write results to disk
try:
if path_segmentation is not None:
save_volume(seg.astype('int'), aff, h, path_segmentation)
save_volume(seg, aff, h, path_segmentation, dtype='int32')
if path_posterior is not None:
save_volume(posteriors.astype('float'), aff, h, path_posterior)
save_volume(posteriors, aff, h, path_posterior, dtype='float32')
except Exception as e:
print('\nthe following problem occured when saving the results for image %s :' % path_image)
print(e)
Expand Down Expand Up @@ -289,14 +289,21 @@ def prepare_output_files(name_subjects, subjects_dir, write_posteriors_FS, path_
# path_images is a folder
if ('.nii.gz' not in basename) & ('.nii' not in basename) & ('.mgz' not in basename) & ('.npz' not in basename):
if os.path.isfile(path_images):
fs.fatal('Extension not supported for %s, only use: nii.gz, .nii, .mgz, or .npz' % path_images)
fs.fatal('Extension not supported for %s, only use: .nii.gz, .nii, .mgz, or .npz' % path_images)
path_images = list_images_in_folder(path_images)
if (out_seg[-7:] == '.nii.gz') | (out_seg[-4:] == '.nii') | \
(out_seg[-4:] == '.mgz') | (out_seg[-4:] == '.npz'):
fs.fatal('Output folders cannot have extensions: .nii.gz, .nii, .mgz, or .npz, had %s' % out_seg)
mkdir(out_seg)
out_seg = [os.path.join(out_seg, os.path.basename(image)).replace('.nii', '_hypo_seg.nii') for image in
path_images]
out_seg = [seg_path.replace('.mgz', '_hypo_seg.mgz') for seg_path in out_seg]
out_seg = [seg_path.replace('.npz', '_hypo_seg.npz') for seg_path in out_seg]
if out_posteriors is not None:
if (out_posteriors[-7:] == '.nii.gz') | (out_posteriors[-4:] == '.nii') | \
(out_posteriors[-4:] == '.mgz') | (out_posteriors[-4:] == '.npz'):
fs.fatal('Output folders cannot have extensions: '
'.nii.gz, .nii, .mgz, or .npz, had %s' % out_posteriors)
mkdir(out_posteriors)
out_posteriors = [os.path.join(out_posteriors, os.path.basename(image)).replace('.nii',
'_posteriors.nii') for image in path_images]
Expand All @@ -307,6 +314,10 @@ def prepare_output_files(name_subjects, subjects_dir, write_posteriors_FS, path_
else:
out_posteriors = [out_posteriors] * len(path_images)
if out_resampled is not None:
if (out_resampled[-7:] == '.nii.gz') | (out_resampled[-4:] == '.nii') | \
(out_resampled[-4:] == '.mgz') | (out_resampled[-4:] == '.npz'):
fs.fatal('Output folders cannot have extensions: '
'.nii.gz, .nii, .mgz, or .npz, had %s' % out_resampled)
mkdir(out_resampled)
out_resampled = [os.path.join(out_resampled, os.path.basename(image)).replace('.nii',
'_resampled.nii') for image in path_images]
Expand Down Expand Up @@ -516,24 +527,23 @@ def preprocess_image(im_path, crop=184, n_levels=3, path_resample=None):
if n_channels > 1:
print('WARNING: detected more than 1 channel, only keeping the first channel.')
im = im[..., 0]
shape = shape[:n_dims]
if n_dims != 3:
fs.fatal('Input images should be 3D, found %s dimensions.' % n_dims)

# resample image if necessary
if np.any((im_res > np.array([1.15]*3)) | (im_res < np.array([0.95]*3))):
im_res = np.array([1.]*3)
im, aff = resample_volume(im, aff, im_res)
shape = list(im.shape)
if path_resample is not None:
save_volume(im, aff, header, path_resample)

# align image
im = align_volume_to_ref(im, aff, aff_ref=np.eye(4), n_dims=n_dims)
shape = list(im.shape)

# check that shape is divisible by 2**n_levels
crop = reformat_to_list(crop, length=n_dims, dtype='int')
if not all([shape[i] >= crop[i] for i in range(len(shape))]):
if not all([shape[i] >= crop[i] for i in range(n_dims)]):
crop = [min(shape[i], crop[i]) for i in range(n_dims)]
if not all([size % (2**n_levels) == 0 for size in crop]):
crop = [find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in crop]
Expand Down Expand Up @@ -1024,8 +1034,6 @@ def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3):
"""

mkdir(os.path.dirname(path))
if dtype is not None:
volume = volume.astype(dtype=dtype)
if '.npz' in path:
np.savez_compressed(path, vol_data=volume)
else:
Expand All @@ -1037,6 +1045,8 @@ def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3):
elif aff is None:
aff = np.eye(4)
nifty = nib.Nifti1Image(volume, aff, header)
if dtype is not None:
nifty.set_data_dtype(dtype)
if res is not None:
if n_dims is None:
n_dims, _ = get_dims(volume.shape)
Expand Down Expand Up @@ -1408,21 +1418,21 @@ def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, ret
# find cropping indices
if cropping_margin is not None:
cropping_margin = reformat_to_list(cropping_margin, length=n_dims)
min_crop_idx = cropping_margin
max_crop_idx = [vol_shape[i] - cropping_margin[i] for i in range(n_dims)]
assert (np.array(max_crop_idx) >= np.array(min_crop_idx)).all(), 'cropping_margin is larger than volume shape'
do_cropping = np.array(vol_shape[:n_dims]) > 2 * np.array(cropping_margin)
min_crop_idx = [cropping_margin[i] if do_cropping[i] else 0 for i in range(n_dims)]
max_crop_idx = [vol_shape[i] - cropping_margin[i] if do_cropping[i] else vol_shape[i] for i in range(n_dims)]
else:
cropping_shape = reformat_to_list(cropping_shape, length=n_dims)
if mode == 'center':
min_crop_idx = [int((vol_shape[i] - cropping_shape[i]) / 2) for i in range(n_dims)]
max_crop_idx = [min_crop_idx[i] + cropping_shape[i] for i in range(n_dims)]
min_crop_idx = np.maximum([int((vol_shape[i] - cropping_shape[i]) / 2) for i in range(n_dims)], 0)
max_crop_idx = np.minimum([min_crop_idx[i] + cropping_shape[i] for i in range(n_dims)],
np.array(vol_shape)[:n_dims])
elif mode == 'random':
crop_max_val = np.array([vol_shape[i] - cropping_shape[i] for i in range(n_dims)])
crop_max_val = np.maximum(np.array([vol_shape[i] - cropping_shape[i] for i in range(n_dims)]), 0)
min_crop_idx = np.random.randint(0, high=crop_max_val + 1)
max_crop_idx = min_crop_idx + np.array(cropping_shape)
max_crop_idx = np.minimum(min_crop_idx + np.array(cropping_shape), np.array(vol_shape)[:n_dims])
else:
raise ValueError('mode should be either "center" or "random", had %s' % mode)
assert (np.array(min_crop_idx) >= 0).all(), 'cropping_shape is larger than volume shape'
crop_idx = np.concatenate([np.array(min_crop_idx), np.array(max_crop_idx)])

# crop volume
Expand Down
49 changes: 28 additions & 21 deletions synthseg/synthseg
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def main():
parser.add_argument("--vol", help="(optional) Output CSV file with volumes for all structures and subjects.")

# parameters
parser.add_argument("--crop", type=int, default=None, dest="crop", help="(optional) Size of 3D patches to analyse."
parser.add_argument("--crop", type=int, default=192, dest="crop", help="(optional) Size of 3D patches to analyse."
"Default is 192.")
parser.add_argument("--threads", type=int, default=1, help="(optional) Number of cores to be used. Default is 1.")
parser.add_argument("--cpu", action="store_true", help="(optional) Enforce running with CPU rather than GPU.")
Expand Down Expand Up @@ -197,9 +197,9 @@ def predict(path_images,
# write results to disk
try:
if path_segmentation is not None:
save_volume(seg.astype('int'), aff, h, path_segmentation)
save_volume(seg, aff, h, path_segmentation, dtype='int32')
if path_posterior is not None:
save_volume(posteriors.astype('float'), aff, h, path_posterior)
save_volume(posteriors, aff, h, path_posterior, dtype='float32')
except Exception as e:
print('\nthe following problem occured when saving the results for image %s :' % path_image)
print(e)
Expand Down Expand Up @@ -268,12 +268,17 @@ def prepare_output_files(path_images, out_seg, out_posteriors, out_resampled, ou
if os.path.isfile(path_images):
fs.fatal('Extension not supported for %s, only use: .nii.gz, .nii, .mgz, or .npz' % path_images)
path_images = list_images_in_folder(path_images)
if (out_seg[-7:] == '.nii.gz') | (out_seg[-4:] == '.nii') | (out_seg[-4:] == '.mgz') | (out_seg[-4:] == '.npz'):
fs.fatal('Output folders cannot have extensions: .nii.gz, .nii, .mgz, or .npz, had %s' % out_seg)
mkdir(out_seg)
out_seg = [os.path.join(out_seg, os.path.basename(image)).replace('.nii', '_synthseg.nii') for image in
path_images]
out_seg = [seg_path.replace('.mgz', '_synthseg.mgz') for seg_path in out_seg]
out_seg = [seg_path.replace('.npz', '_synthseg.npz') for seg_path in out_seg]
if out_posteriors is not None:
if (out_posteriors[-7:] == '.nii.gz') | (out_posteriors[-4:] == '.nii') | \
(out_posteriors[-4:] == '.mgz') | (out_posteriors[-4:] == '.npz'):
fs.fatal('Output folders cannot have extensions: .nii.gz, .nii, .mgz, or .npz, had %s' % out_posteriors)
mkdir(out_posteriors)
out_posteriors = [os.path.join(out_posteriors, os.path.basename(image)).replace('.nii',
'_posteriors.nii') for image in path_images]
Expand All @@ -282,6 +287,9 @@ def prepare_output_files(path_images, out_seg, out_posteriors, out_resampled, ou
else:
out_posteriors = [out_posteriors] * len(path_images)
if out_resampled is not None:
if (out_resampled[-7:] == '.nii.gz') | (out_resampled[-4:] == '.nii') | \
(out_resampled[-4:] == '.mgz') | (out_resampled[-4:] == '.npz'):
fs.fatal('Output folders cannot have extensions: .nii.gz, .nii, .mgz, or .npz, had %s' % out_resampled)
mkdir(out_resampled)
out_resampled = [os.path.join(out_resampled, os.path.basename(image)).replace('.nii',
'_resampled.nii') for image in path_images]
Expand Down Expand Up @@ -343,28 +351,27 @@ def preprocess_image(im_path, crop=192, n_levels=5, path_resample=None):
if n_channels > 1:
print('WARNING: detected more than 1 channel, only keeping the first channel.')
im = im[..., 0]
shape = shape[:n_dims]
if n_dims != 3:
fs.fatal('Input images should be 3D, found %s dimensions.' % n_dims)

# resample image if necessary
if np.any((np.array(im_res) > 1.05) | (np.array(im_res) < 0.95)):
im_res = np.array([1.]*3)
im, aff = resample_volume(im, aff, im_res)
shape = list(im.shape)
if path_resample is not None:
save_volume(im, aff, header, path_resample)

# align image
im = align_volume_to_ref(im, aff, aff_ref=np.eye(4), n_dims=n_dims)
shape = list(im.shape)

# pad image
crop = reformat_to_list(crop, length=n_dims, dtype='int')
im = pad_volume(im, padding_shape=crop)
pad_shape = im.shape[:n_dims]

# check that shape is divisible by 2**n_levels
if not all([pad_shape[i] >= crop[i] for i in range(len(pad_shape))]):
if not all([pad_shape[i] >= crop[i] for i in range(n_dims)]):
crop = [min(pad_shape[i], crop[i]) for i in range(n_dims)]
if not all([size % (2**n_levels) == 0 for size in crop]):
crop = [find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in crop]
Expand Down Expand Up @@ -873,8 +880,6 @@ def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3):
"""

mkdir(os.path.dirname(path))
if dtype is not None:
volume = volume.astype(dtype=dtype)
if '.npz' in path:
np.savez_compressed(path, vol_data=volume)
else:
Expand All @@ -886,6 +891,8 @@ def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3):
elif aff is None:
aff = np.eye(4)
nifty = nib.Nifti1Image(volume, aff, header)
if dtype is not None:
nifty.set_data_dtype(dtype)
if res is not None:
if n_dims is None:
n_dims, _ = get_dims(volume.shape)
Expand Down Expand Up @@ -1364,21 +1371,21 @@ def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, ret
# find cropping indices
if cropping_margin is not None:
cropping_margin = reformat_to_list(cropping_margin, length=n_dims)
min_crop_idx = cropping_margin
max_crop_idx = [vol_shape[i] - cropping_margin[i] for i in range(n_dims)]
assert (np.array(max_crop_idx) >= np.array(min_crop_idx)).all(), 'cropping_margin is larger than volume shape'
do_cropping = np.array(vol_shape[:n_dims]) > 2 * np.array(cropping_margin)
min_crop_idx = [cropping_margin[i] if do_cropping[i] else 0 for i in range(n_dims)]
max_crop_idx = [vol_shape[i] - cropping_margin[i] if do_cropping[i] else vol_shape[i] for i in range(n_dims)]
else:
cropping_shape = reformat_to_list(cropping_shape, length=n_dims)
if mode == 'center':
min_crop_idx = [int((vol_shape[i] - cropping_shape[i]) / 2) for i in range(n_dims)]
max_crop_idx = [min_crop_idx[i] + cropping_shape[i] for i in range(n_dims)]
min_crop_idx = np.maximum([int((vol_shape[i] - cropping_shape[i]) / 2) for i in range(n_dims)], 0)
max_crop_idx = np.minimum([min_crop_idx[i] + cropping_shape[i] for i in range(n_dims)],
np.array(vol_shape)[:n_dims])
elif mode == 'random':
crop_max_val = np.array([vol_shape[i] - cropping_shape[i] for i in range(n_dims)])
crop_max_val = np.maximum(np.array([vol_shape[i] - cropping_shape[i] for i in range(n_dims)]), 0)
min_crop_idx = np.random.randint(0, high=crop_max_val + 1)
max_crop_idx = min_crop_idx + np.array(cropping_shape)
max_crop_idx = np.minimum(min_crop_idx + np.array(cropping_shape), np.array(vol_shape)[:n_dims])
else:
raise ValueError('mode should be either "center" or "random", had %s' % mode)
assert (np.array(min_crop_idx) >= 0).all(), 'cropping_shape is larger than volume shape'
crop_idx = np.concatenate([np.array(min_crop_idx), np.array(max_crop_idx)])

# crop volume
Expand Down Expand Up @@ -1444,13 +1451,12 @@ def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx=
padding_shape = reformat_to_list(padding_shape, length=n_dims, dtype='int')

# check if need to pad
if not np.array_equal(np.array(padding_shape, dtype='int32'), np.array(vol_shape[:n_dims], dtype='int32')):
if np.any(np.array(padding_shape, dtype='int32') > np.array(vol_shape[:n_dims], dtype='int32')):

# get padding margins
min_margins = np.maximum(np.int32(np.floor((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0)
max_margins = np.minimum(np.int32(np.ceil((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)),
np.array(vol_shape)[:n_dims])
pad_idx = np.concatenate([min_margins, min_margins + np.array(vol_shape)])
max_margins = np.maximum(np.int32(np.ceil((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0)
pad_idx = np.concatenate([min_margins, min_margins + np.array(vol_shape[:n_dims])])
pad_margins = tuple([(min_margins[i], max_margins[i]) for i in range(n_dims)])
if n_channels > 1:
pad_margins = tuple(list(pad_margins) + [[0, 0]])
Expand All @@ -1464,7 +1470,7 @@ def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx=
aff[:-1, -1] = aff[:-1, -1] - aff[:-1, :-1] @ min_margins

else:
pad_idx = np.concatenate([np.array([0] * n_dims), np.array(vol_shape)])
pad_idx = np.concatenate([np.array([0] * n_dims), np.array(vol_shape[:n_dims])])

# sort outputs
output = [new_volume]
Expand All @@ -1474,6 +1480,7 @@ def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx=
output.append(pad_idx)
return output[0] if len(output) == 1 else tuple(output)


def flip_volume(volume, axis=None, direction=None, aff=None):
"""Flip volume along a specified axis.
If unknown, this axis can be inferred from an affine matrix with a specified anatomical direction.
Expand Down
Loading

0 comments on commit 932ba60

Please sign in to comment.