Skip to content

Commit

Permalink
Merge branch 'dev' of github.com:freesurfer/freesurfer into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
Douglas Greve committed Jan 14, 2022
2 parents b9aed15 + 9cfa6f3 commit 16e731d
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 23 deletions.
19 changes: 13 additions & 6 deletions python/freesurfer/freeview.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def copy(self):
copied.tempdir = self.tempdir
return copied

def vol(self, volume, swap_batch_dim=False, **kwargs):
def vol(self, volume, swap_batch_dim=False, lut=None, **kwargs):
'''
Loads a volume in the sessions. If the volume provided is not a filepath,
then the input will be saved as a volume in a temporary directory. Any
Expand All @@ -83,7 +83,7 @@ def vol(self, volume, swap_batch_dim=False, **kwargs):
'''

# convert the input to a proper file (if it's not one already)
filename = self._vol_to_file(volume, swap_batch_dim=swap_batch_dim)
filename = self._vol_to_file(volume, swap_batch_dim=swap_batch_dim, lut=lut)
if filename is None:
return

Expand Down Expand Up @@ -226,7 +226,7 @@ def _kwargs_to_tags(self, kwargs):

return tags + extra_tags

def _vol_to_file(self, volume, name=None, force=None, ext='mgz', swap_batch_dim=False):
def _vol_to_file(self, volume, name=None, force=None, ext='mgz', lut=None, swap_batch_dim=False):
'''
Converts an unknown volume type (whether it's a filename, array, or
other object) into a valid file.
Expand Down Expand Up @@ -285,6 +285,11 @@ def _vol_to_file(self, volume, name=None, force=None, ext='mgz', swap_batch_dim=

# check if fs array container
if isinstance(volume, (Overlay, Image, Volume)):

# set lookup table
if (lut is not None) and (volume.lut is None) and np.issubdtype(volume.data.dtype, np.integer):
volume.lut = lut

volume.write(filename)
return filename

Expand Down Expand Up @@ -403,13 +408,15 @@ def fv(*args, **kwargs):
Args:
opts: Additional string of flags to add to the command.
background: Run freeview as a background process. Defaults to True.
lut: Provides LookupTable to segmentations without one.
swap_batch_dim: Move the first axis to the last if input is a numpy image array. Default is False.
kwargs: kwargs are forwarded to the Freeview.show() call.
'''
background = kwargs.pop('background', True)
opts = kwargs.pop('opts', '')
swap_batch_dim = kwargs.pop('swap_batch_dim', False)
geom = kwargs.pop('geom', None)
lut = kwargs.pop('lut', None)

# expand any nested lists/tuples within args
def flatten(deep):
Expand All @@ -425,17 +432,17 @@ def flatten(deep):
if isinstance(arg, str):
# try to guess filetype if string
if arg.endswith(('.mgz', '.mgh', '.nii.gz', '.nii')):
fv.vol(arg)
fv.vol(arg, lut=lut)
elif arg.startswith(('lh.', 'rh.')) or arg.endswith('.stl'):
fv.surf(arg)
else:
fv.vol(arg)
fv.vol(arg, lut=lut)
elif isinstance(arg, Surface):
# surface
fv.surf(arg)
else:
# assume anything else is a volume
fv.vol(arg, swap_batch_dim=swap_batch_dim)
fv.vol(arg, swap_batch_dim=swap_batch_dim, lut=lut)

fv.show(background=background, opts=opts, **kwargs)

Expand Down
9 changes: 5 additions & 4 deletions python/freesurfer/samseg/GMM.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def tiedGaussiansFit(self, data, gaussianPosteriors):
self.hyperMeans[self.gaussNumber2Tied] = self.means[self.gaussNumber1Tied]
self.hyperVariances[self.gaussNumber2Tied] = self.rho * self.variances[self.gaussNumber1Tied]

def sampleMeansAndVariancesConditioned(self, data, posterior, gaussianNumber,constraints=None):
def sampleMeansAndVariancesConditioned(self, data, posterior, gaussianNumber, rngNumpy=np.random.default_rng(), constraints=None):
tmpGmm = GMM([1], self.numberOfContrasts, self.useDiagonalCovarianceMatrices,
initialHyperMeans=np.array([self.hyperMeans[gaussianNumber]]),
initialHyperMeansNumberOfMeasurements=np.array([self.hyperMeansNumberOfMeasurements[gaussianNumber]]),
Expand All @@ -416,7 +416,8 @@ def sampleMeansAndVariancesConditioned(self, data, posterior, gaussianNumber,con

# Murphy, page 134 with v0 = hyperVarianceNumberOfMeasurements - numberOfContrasts - 2
variance = invwishart.rvs(N + tmpGmm.hyperVariancesNumberOfMeasurements[0] - self.numberOfContrasts - 2,
tmpGmm.variances[0] * (tmpGmm.hyperVariancesNumberOfMeasurements[0] + N))
tmpGmm.variances[0] * (tmpGmm.hyperVariancesNumberOfMeasurements[0] + N),
random_state=rngNumpy)

# If numberOfContrast is 1 force variance to be a (1,1) array
if self.numberOfContrasts == 1:
Expand All @@ -425,8 +426,8 @@ def sampleMeansAndVariancesConditioned(self, data, posterior, gaussianNumber,con
if self.useDiagonalCovarianceMatrices:
variance = np.diag(np.diag(variance))

mean = np.random.multivariate_normal(tmpGmm.means[0],
variance / (tmpGmm.hyperMeansNumberOfMeasurements[0] + N)).reshape(-1, 1)
mean = rngNumpy.multivariate_normal(tmpGmm.means[0],
variance / (tmpGmm.hyperMeansNumberOfMeasurements[0] + N)).reshape(-1, 1)
if constraints is not None:
def truncsample(mean, var, lower, upper):
from scipy.stats import truncnorm
Expand Down
12 changes: 8 additions & 4 deletions python/freesurfer/samseg/SamsegLesion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, imageFileNames, atlasDir, savePath, userModelSpecifications={
numberOfSamplingSteps=50, numberOfBurnInSteps=50,
numberOfPseudoSamplesMean=500, numberOfPseudoSamplesVariance=500, rho=50,
intensityMaskingPattern=None, intensityMaskingSearchString='Cortex', gmmFileName=None, sampler=True,
ignoreUnknownPriors=False,
ignoreUnknownPriors=False, randomSeed=12345,
):
Samseg.__init__(self, imageFileNames, atlasDir, savePath, userModelSpecifications, userOptimizationOptions,
imageToImageTransformMatrix, visualizer, saveHistory, savePosteriors,
Expand All @@ -37,6 +37,10 @@ def __init__(self, imageFileNames, atlasDir, savePath, userModelSpecifications={
self.intensityMaskingClassNumber = self.getClassNumber(intensityMaskingSearchString)
self.sampler = sampler

# Set random seed
self.seed = randomSeed
self.rngNumpy = np.random.default_rng(self.seed)

if intensityMaskingPattern is None:
raise ValueError('Intensity mask pattern must be set')
if len(intensityMaskingPattern) != len(imageFileNames):
Expand Down Expand Up @@ -156,7 +160,7 @@ def computeFinalSegmentation(self):

# Initialize the VAE tensorflow model and its various settings.
# Restore from checkpoint the VAE
vae = VAE(self.atlasDir, self.transform, imageSize)
vae = VAE(self.atlasDir, self.transform, imageSize, self.seed)

# Do the actual sampling of lesion, latent variables of the VAE model, and mean/variance of the lesion intensity model.
averagePosteriors = np.zeros_like(likelihoods)
Expand All @@ -177,7 +181,7 @@ def computeFinalSegmentation(self):

# Sample from the mean and variance, conditioned on the data and the lesion segmentation
mean, variance = self.gmm.sampleMeansAndVariancesConditioned(data, lesion[self.mask].reshape(-1, 1),
self.lesionGaussianNumber)
self.lesionGaussianNumber, self.rngNumpy)

# Sample from the lesion segmentation, conditioned on the data and the VAE latent variables
# (Implementation-wise the latter is encoded in the VAE prior). At the same time we also
Expand All @@ -202,7 +206,7 @@ def computeFinalSegmentation(self):
likelihoods[:, self.lesionStructureNumber] = self.gmm.getGaussianLikelihoods(data, mean, variance)
posteriors = effectivePriors * likelihoods
posteriors /= np.expand_dims(np.sum(posteriors, axis=1) + eps, 1)
sample = np.random.rand(numberOfVoxels) <= posteriors[:, self.lesionStructureNumber]
sample = self.rngNumpy.random(numberOfVoxels) <= posteriors[:, self.lesionStructureNumber]
lesion = np.zeros(imageSize)
lesion[self.mask] = sample

Expand Down
10 changes: 10 additions & 0 deletions python/freesurfer/samseg/SamsegLongitudinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self,
userOptimizationOptions={},
visualizer=None,
saveHistory=False,
saveMesh=None,
targetIntensity=None,
targetSearchStrings=None,
numberOfIterations=5,
Expand Down Expand Up @@ -126,6 +127,7 @@ def __init__(self,
self.visualizer = visualizer

self.saveHistory = saveHistory
self.saveMesh = saveMesh
self.saveSSTResults = saveSSTResults
self.updateLatentMeans = updateLatentMeans
self.updateLatentVariances = updateLatentVariances
Expand Down Expand Up @@ -679,6 +681,14 @@ def postProcess(self, saveWarp=False):
if saveWarp:
timepointModel.saveWarpField(os.path.join(timepointDir, 'template.m3z'))

# Save the final mesh collection
if self.saveMesh:
print('Saving the final mesh in template space')
deformedAtlasFileName = os.path.join(timepointModel.savePath, 'mesh.txt')
timepointModel.probabilisticAtlas.saveDeformedAtlas(timepointModel.modelSpecifications.atlasFileName,
deformedAtlasFileName, nodePositions)


self.timepointVolumesInCubicMm.append(volumesInCubicMm)

#
Expand Down
4 changes: 3 additions & 1 deletion python/freesurfer/samseg/SamsegLongitudinalLesion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self,
userOptimizationOptions={},
visualizer=None,
saveHistory=False,
saveMesh=None,
targetIntensity=None,
targetSearchStrings=None,
numberOfIterations=5,
Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(self,
userOptimizationOptions=userOptimizationOptions,
visualizer=visualizer,
saveHistory=saveHistory,
saveMesh=saveMesh,
targetIntensity=targetIntensity,
targetSearchStrings=targetSearchStrings,
numberOfIterations=numberOfIterations,
Expand Down Expand Up @@ -153,4 +155,4 @@ def setLesionLatentVariables(self):
self.latentMeans[self.sstModel.lesionGaussianNumber] = self.sstModel.gmm.means[self.sstModel.wmGaussianNumber]
self.latentVariances[self.sstModel.lesionGaussianNumber] = self.rho * self.sstModel.gmm.variances[self.sstModel.wmGaussianNumber]
self.latentMeansNumberOfMeasurements[self.sstModel.lesionGaussianNumber] = self.numberOfPseudoSamplesMean
self.latentVariancesNumberOfMeasurements[self.sstModel.lesionGaussianNumber] = self.numberOfPseudoSamplesVariance
self.latentVariancesNumberOfMeasurements[self.sstModel.lesionGaussianNumber] = self.numberOfPseudoSamplesVariance
9 changes: 6 additions & 3 deletions python/freesurfer/samseg/VAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class VAE:
def __init__(self, atlasDir, transform, imageSize):
def __init__(self, atlasDir, transform, imageSize, seed):

self.sess = tf.Session()
self.imageSize = imageSize
Expand All @@ -29,6 +29,9 @@ def __init__(self, atlasDir, transform, imageSize):
self.trainToSubjectMat = transform.as_numpy_array @ trainToTemplateMat
self.subjectToTrainMat = np.linalg.inv(self.trainToSubjectMat)

# Set tf seed
self.seed = seed

# Create tf placeholder
self.lesionPlaceholder = tf.placeholder(tf.float32, [1, self.net_shape[0], self.net_shape[1], self.net_shape[2], 1])
self.net = self.run_net(self.lesionPlaceholder, self.net_shape)
Expand Down Expand Up @@ -227,7 +230,7 @@ def pad_up_to(self, t, max_in_dims, constant_values):
# here is the net function. The input goes through the encoder, we sample from it and then it goes through the decoder
def run_net(self, lesion, imageSize):
mu, sigma = self.get_encoder(lesion)
sample_latent = tf.random.normal(mu.shape, 0, 1) * sigma + mu
sample_latent = tf.random.normal(mu.shape, 0, 1, seed=self.seed) * sigma + mu
return self.get_decoder(sample_latent, imageSize)

def sample(self, lesion):
Expand All @@ -247,4 +250,4 @@ def sample(self, lesion):
# we pass to the function the inverse of trainToSubjectMat, so subjectToTrainMat
lesionPriorVAE = affine_transform(lesionVAETrainSpace, self.subjectToTrainMat, output_shape=self.imageSize, order=1)

return lesionPriorVAE
return lesionPriorVAE
12 changes: 7 additions & 5 deletions samseg/run_samseg
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ parser.add_argument('--lesion-pseudo-samples', nargs=2, type=float, default=[500
parser.add_argument('--lesion-rho', type=float, default=50, help='Lesion ratio.')
parser.add_argument('--lesion-mask-structure', default='Cortex', help='Intensity mask brain structure. Lesion segmentation must be enabled.')
parser.add_argument('--lesion-mask-pattern', type=int, nargs='+', help='Lesion mask list (set value for each input volume): -1 below lesion mask structure mean, +1 above, 0 no mask. Lesion segmentation must be enabled.')
parser.add_argument('--random-seed', type=int, default=12345, help='Random seed.')
# optional options for segmenting 3D reconstructions of photo volumes
parser.add_argument('--dissection-photo', default=None, help='Use this flag for 3D reconstructed photos, and specify hemispheres that are present in the volumes: left, right, or both')

Expand Down Expand Up @@ -132,9 +133,7 @@ samseg_kwargs = dict(
pallidumAsWM=(not args.pallidum_separate),
saveModelProbabilities=args.save_probabilities,
gmmFileName=args.gmm,
ignoreUnknownPriors=ignoreUnknownPriors,
dissectionPhoto=args.dissection_photo,
nthreads=args.threads,
ignoreUnknownPriors=ignoreUnknownPriors
)

if args.lesion:
Expand All @@ -155,11 +154,14 @@ if args.lesion:
intensityMaskingPattern=lesion_mask_pattern,
numberOfBurnInSteps=args.burnin,
numberOfSamplingSteps=args.samples,
threshold=args.threshold
threshold=args.threshold,
randomSeed=args.random_seed
)

else:
samseg = samseg.Samseg(**samseg_kwargs)
samseg = samseg.Samseg(**samseg_kwargs,
dissectionPhoto=args.dissection_photo,
nthreads=args.threads)

_, _, _, optimizationSummary = samseg.segment(
costfile=costfile,
Expand Down
2 changes: 2 additions & 0 deletions samseg/run_samseg_long
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ parser.add_argument('--lesion-mask-pattern', type=int, nargs='+', help='Lesion m
parser.add_argument('-m', '--mode', nargs='+', help='Output basenames for the input image mode.')
parser.add_argument('-a', '--atlas', metavar='DIR', help='Point to an alternative atlas directory.')
parser.add_argument('--save-warp', action='store_true', help='Save the image->template warp fields.')
parser.add_argument('--save-mesh', action='store_true', help='Save the final mesh of each timepoint in template space.')
parser.add_argument('--save-posteriors', nargs='*', help='Save posterior volumes to the "posteriors" subdirectory.')
parser.add_argument('--pallidum-separate', action='store_true', default=False, help='Move pallidum outside of global white matter class. Use this flag when T2/flair is used.')
parser.add_argument('--threads', type=int, default=default_threads, help='Number of threads to use. Defaults to current OMP_NUM_THREADS or 1.')
Expand Down Expand Up @@ -85,6 +86,7 @@ samseg_kwargs = dict(
modeNames=args.mode,
pallidumAsWM=(not args.pallidum_separate),
savePosteriors=savePosteriors,
saveMesh=args.save_mesh,
visualizer=visualizer
)

Expand Down

0 comments on commit 16e731d

Please sign in to comment.