Skip to content

Commit

Permalink
update synthstrip with new model
Browse files Browse the repository at this point in the history
  • Loading branch information
ahoopes committed Dec 25, 2021
1 parent b58dadd commit 96cdeb4
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 25 deletions.
2 changes: 1 addition & 1 deletion mri_synth_strip/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
project(mri_synth_strip)

install_pyscript(mri_synth_strip)
install(FILES synthstrip.dtrans.gen2.00.mse.h5 DESTINATION models)
install(FILES synthstrip.dtrans.h5 DESTINATION models)
104 changes: 80 additions & 24 deletions mri_synth_strip/mri_synth_strip
Original file line number Diff line number Diff line change
Expand Up @@ -20,59 +20,115 @@ parser.add_argument('-m', '--mask', help='Output mask filename.')
parser.add_argument('-b', '--border', default=1, type=int, help='Mask border threshold. Default is 1.')
parser.add_argument('--model', help='Alternative model file.')
parser.add_argument('--uthresh', type=float, help='Intensity threshold to erase from input image.')
parser.add_argument('--norm', help='Save the conformed, normalized input data for debugging.')
parser.add_argument('-v', '--verbose', action='store_true', help='Verbose output for debugging.')
parser.add_argument('-g', '--gpu', help='GPU ID. CPU is used by default.')
parser.add_argument('--remove_neck', default=None, type=int,
help='remove bottom part of 2nd axis from listed locations to end')
parser.add_argument('--max-norm', action='store_true', help='Normalize by max intensity instead of 97th percentile.')
parser.add_argument('--remove-neck', type=int,
help='Remove bottom part of 2nd axis from listed locations to end.')
args = parser.parse_args()

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' if args.verbose else '3'

# defer slow tensorflow imports for faster argument parsing
import neurite as ne
import voxelmorph as vxm
import tensorflow as tf
import tensorflow.keras.layers as KL

# set verbosity
if not args.verbose:
warnings.filterwarnings('ignore')
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

# configure model file
if args.model is not None:
model_file = args.model
else:
fshome = fs.fshome()
if fshome is None:
fs.fatal('FREESURFER_HOME must be configured to find model files.')
model_file = os.path.join(fshome, 'models', 'synthstrip.dtrans.h5')

# configure Strip model
# TODO: needs to be moved to neurite
class Strip(ne.tf.modelio.LoadableModel):
"""
SynthStrip model for learning subject-to-subject registration from images
with arbitrary contrasts synthesized from label maps.
"""
@ne.tf.modelio.store_config_args
def __init__(self,
inshape,
nb_unet_features=None,
nb_unet_levels=None,
unet_feat_mult=1,
nb_unet_conv_per_level=1,
src_feats=1,
segout=False):
ndims = len(inshape)
assert ndims in [1, 2, 3], 'ndims should be one of 1, 2, or 3. found: %d' % ndims
unet = vxm.networks.Unet(
inshape=(*inshape, src_feats),
nb_features=nb_unet_features,
nb_levels=nb_unet_levels,
feat_mult=unet_feat_mult,
nb_conv_per_level=nb_unet_conv_per_level,
name='strip_unet')
conv = getattr(KL, 'Conv%dD' % ndims)
if segout:
prob = conv(2, kernel_size=3, padding='same', name='seg_prob')(unet.output)
prob = tf.keras.layers.Softmax()(prob)
else:
prob = conv(1, kernel_size=3, padding='same', name='strip_prob')(unet.output)
super().__init__(inputs=unet.input, outputs=prob)

class SynthStripTrainer(ne.tf.modelio.LoadableModel):
"""
SynthStrip model for learning subject-to-subject registration from images
with arbitrary contrasts synthesized from label maps.
"""
@ne.tf.modelio.store_config_args
def __init__(self,
inshape,
labels_in,
labels_out,
old_synth=False,
synth_params={},
**kwargs):
inshape = tuple(inshape)
image = KL.Input(inshape)
strip_model = Strip(inshape, **kwargs)
prob = strip_model(image)
super().__init__(inputs=image, outputs=prob)

# load model
device, ngpus = ne.tf.utils.setup_device(args.gpu)
with tf.device(device):
model = SynthStripTrainer.load(model_file)

# load input volume
in_img = fs.Volume.read(args.input)

# threshold input
if args.uthresh is not None:
in_img.data[in_img.data > args.uthresh] = 0

# remove y-axis voxels
if args.remove_neck is not None:
print(f'removing neck from {args.remove_neck} to bottom of image')
in_img.data[:, args.remove_neck:, :] = 0 # remove some of neck
print(f'Removing neck from {args.remove_neck} to bottom of image.')
in_img.data[:, args.remove_neck:, :] = 0

# conform image and normalize
conf_img = in_img.reslice(1.0).fit_to_shape((256, 256, 256))
in_data = conf_img.data.astype('float32')

conf_img = in_img.conform(shape=(256, 256, 256), voxsize=1.0, dtype='float32', interp_method='nearest')
in_data = conf_img.data
in_data -= in_data.min()
in_data = np.clip(in_data / np.percentile(in_data, 97), 0, 1)

# save normalized input data
if args.norm:
conf_img.copy(in_data).write(args.norm)

# configure model file
if args.model is not None:
model_file = args.model
if args.max_norm:
in_data /= in_data.max()
else:
fshome = fs.fshome()
if fshome is None:
fs.fatal('FREESURFER_HOME must be configured to find model files.')
model_file = os.path.join(fshome, 'models', 'synthstrip.dtrans.gen2.00.mse.h5')
in_data = np.clip(in_data / np.percentile(in_data, 99), 0, 1)

# load model and predict
device, ngpus = ne.tf.utils.setup_device(args.gpu)
with tf.device(device):
model = ne.models.SynthStrip.load(model_file).get_strip_model()
pred = model.predict(in_data[np.newaxis, ..., np.newaxis]).squeeze()

# unconform the predicted mask
Expand All @@ -84,5 +140,5 @@ if args.mask:

# mask the input image
masked_img = in_img.copy()
masked_img.data[mask.data >= args.border] = 0
masked_img.data[mask.data.squeeze() >= args.border] = 0
masked_img.write(args.output)
Binary file removed mri_synth_strip/synthstrip.dtrans.gen2.00.mse.h5
Binary file not shown.
Binary file added mri_synth_strip/synthstrip.dtrans.h5
Binary file not shown.

0 comments on commit 96cdeb4

Please sign in to comment.