Skip to content

Commit

Permalink
add samseg WM/cortex atlas smoothing
Browse files Browse the repository at this point in the history
  • Loading branch information
ahoopes committed Feb 1, 2022
1 parent 58b116b commit 4493c7d
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 28 deletions.
43 changes: 37 additions & 6 deletions python/freesurfer/samseg/ProbabilisticAtlas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import numpy as np
from . import gems
import scipy.ndimage
import freesurfer as fs

from freesurfer.samseg import gems
from freesurfer.samseg.warp_mesh import kvlWarpMesh
from freesurfer.samseg.utilities import requireNumpyArray
import freesurfer as fs
import os


class ProbabilisticAtlas:
def __init__(self):
Expand All @@ -18,17 +21,45 @@ def __init__(self):
self.previousDeformationMesh = None


def getMesh(self, meshCollectionFileName,
def getMesh(self,
meshCollectionFileName,
transform=None,
K=None,
initialDeformation=None, initialDeformationMeshCollectionFileName=None,
returnInitialDeformationApplied=False):
initialDeformation=None,
initialDeformationMeshCollectionFileName=None,
returnInitialDeformationApplied=False,
competingStructures=None,
smoothingSigma=0):

# Get the mesh
mesh_collection = gems.KvlMeshCollection()
mesh_collection.read(meshCollectionFileName)
if K is not None:
mesh_collection.k = K

# Do competing structure smoothing if enabled
if smoothingSigma > 0 and competingStructures:

print(f'Smoothing competing atlas priors with sigma {smoothingSigma:.2f}')

# Get initial priors
size = np.array(mesh_collection.reference_position.max(axis=0) + 1.5, dtype=int)
priors = mesh_collection.reference_mesh.rasterize(size, -1)

# Smooth the cortex and WM alphas
for competingStructureNumbers in competingStructures:
miniPriors = priors[..., competingStructureNumbers]
weightsToReassign = np.sum(miniPriors, axis=-1, keepdims=True)
miniPriors = scipy.ndimage.gaussian_filter(miniPriors.astype(float), sigma=(smoothingSigma, smoothingSigma, smoothingSigma, 0))
miniPriors /= (np.sum(miniPriors, -1, keepdims=True) + 1e-12) # Normalize to sum to 1
miniPriors *= weightsToReassign
priors[..., competingStructureNumbers] = miniPriors

# Set the alphas
alphas = mesh_collection.reference_mesh.fit_alphas(priors)
mesh_collection.reference_mesh.alphas = alphas

# Transform
if transform:
mesh_collection.transform(transform)
else:
Expand Down
50 changes: 28 additions & 22 deletions python/freesurfer/samseg/Samseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,18 @@ def register(self, costfile=None, timer=None, reg_only=False, worldToWorldTransf
print('registration-only requested, so quiting now')
sys.exit()

def getMesh(self, *args, **kwargs):
"""
Load the atlas mesh and perform optional smoothing of the cortex and WM priors.
"""
if self.modelSpecifications.whiteMatterAndCortexSmoothingSigma > 0:
competingNames = [['Left-Cerebral-White-Matter', 'Left-Cerebral-Cortex'],
['Right-Cerebral-White-Matter', 'Right-Cerebral-Cortex']]
competingStructures = [[self.modelSpecifications.names.index(n) for n in names] for names in competingNames]
kwargs['competingStructures'] = competingStructures
kwargs['smoothingSigma'] = self.modelSpecifications.whiteMatterAndCortexSmoothingSigma
return self.probabilisticAtlas.getMesh(*args, **kwargs)

def preProcess(self):
# =======================================================================================
#
Expand Down Expand Up @@ -321,20 +333,15 @@ def preProcess(self):
self.voxelSpacing
)


# Let's prepare for the bias field correction that is part of the imaging model. It assumes
# an additive effect, whereas the MR physics indicate it's a multiplicative one - so we log
# transform the data first.
self.imageBuffers = logTransform(self.imageBuffers, self.mask)

mesh = self.probabilisticAtlas.getMesh(self.modelSpecifications.atlasFileName, self.transform)
priors = mesh.rasterize(self.imageBuffers.shape[:3], 4).astype(float)
self.writeImage(priors, os.path.join(self.savePath, 'priors-testing.mgz'))

# Visualize some stuff
if hasattr(self.visualizer, 'show_flag'):
self.visualizer.show(
mesh=self.probabilisticAtlas.getMesh(self.modelSpecifications.atlasFileName, self.transform),
mesh=self.getMesh(self.modelSpecifications.atlasFileName, self.transform),
shape=self.imageBuffers.shape,
window_id='samsegment mesh', title='Mesh',
names=self.modelSpecifications.names, legend_width=350)
Expand Down Expand Up @@ -512,7 +519,7 @@ def writeResults(self, biasFields, posteriors):

def saveWarpField(self, filename):
# extract node positions in image space
nodePositions = self.probabilisticAtlas.getMesh(
nodePositions = self.getMesh(
self.modelSpecifications.atlasFileName,
self.transform,
initialDeformation=self.deformation,
Expand All @@ -536,7 +543,7 @@ def saveWarpField(self, filename):
matrix = scipy.io.loadmat(matricesFileName)['imageToImageTransformMatrix']

# rasterize the final node coordinates (in image space) using the initial template mesh
mesh = self.probabilisticAtlas.getMesh(self.modelSpecifications.atlasFileName)
mesh = self.getMesh(self.modelSpecifications.atlasFileName)
coordmap = mesh.rasterize_values(templateGeom.shape, nodePositions)

# the rasterization is a bit buggy and some voxels are not filled - mark these as invalid
Expand All @@ -555,9 +562,9 @@ def saveGaussianProbabilities( self, probabilitiesPath ):
os.makedirs(probabilitiesPath, exist_ok=True)

# Get the class priors as dictated by the current mesh position
mesh = self.probabilisticAtlas.getMesh(self.modelSpecifications.atlasFileName, self.transform,
initialDeformation=self.deformation,
initialDeformationMeshCollectionFileName=self.deformationAtlasFileName)
mesh = self.getMesh(self.modelSpecifications.atlasFileName, self.transform,
initialDeformation=self.deformation,
initialDeformationMeshCollectionFileName=self.deformationAtlasFileName)
mergedAlphas = kvlMergeAlphas( mesh.alphas, self.classFractions )
mesh.alphas = mergedAlphas
classPriors = mesh.rasterize(self.imageBuffers.shape[0:3], -1)
Expand Down Expand Up @@ -625,15 +632,14 @@ def getDownSampledModel(self, atlasFileName, downSamplingFactors):
downSampledTransform = gems.KvlTransform(requireNumpyArray(downSamplingTransformMatrix @ self.transform.as_numpy_array))

# Get the mesh
downSampledMesh, downSampledInitialDeformationApplied = self.probabilisticAtlas.getMesh(atlasFileName,
downSampledTransform,
self.modelSpecifications.K,
self.deformation,
self.deformationAtlasFileName,
returnInitialDeformationApplied=True)
downSampledMesh, downSampledInitialDeformationApplied = self.getMesh(atlasFileName,
downSampledTransform,
self.modelSpecifications.K,
self.deformation,
self.deformationAtlasFileName,
returnInitialDeformationApplied=True)

return downSampledImageBuffers, downSampledMask, downSampledMesh, downSampledInitialDeformationApplied, \
downSampledTransform,
return downSampledImageBuffers, downSampledMask, downSampledMesh, downSampledInitialDeformationApplied, downSampledTransform

def initializeBiasField(self):

Expand Down Expand Up @@ -918,9 +924,9 @@ def estimateModelParameters(self, initialBiasFieldCoefficients=None, initialDefo

def computeFinalSegmentation(self):
# Get the final mesh
mesh = self.probabilisticAtlas.getMesh(self.modelSpecifications.atlasFileName, self.transform,
initialDeformation=self.deformation,
initialDeformationMeshCollectionFileName=self.deformationAtlasFileName)
mesh = self.getMesh(self.modelSpecifications.atlasFileName, self.transform,
initialDeformation=self.deformation,
initialDeformationMeshCollectionFileName=self.deformationAtlasFileName)

# Get the priors as dictated by the current mesh position
priors = mesh.rasterize(self.imageBuffers.shape[0:3], -1)
Expand Down
1 change: 1 addition & 0 deletions python/freesurfer/samseg/SamsegUtility.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def getModelSpecifications(atlasDir, userModelSpecifications={}, pallidumAsWM=Tr
'maskingDistance': 10.0, # distance in mm of how far into background the mask goes out
'K': 0.1, # stiffness of the mesh
'biasFieldSmoothingKernelSize': 50, # distance in mm of sinc function center to first zero crossing
'whiteMatterAndCortexSmoothingSigma': 0, # Sigma value to smooth the WM and cortex atlas priors
}

modelSpecifications.update(userModelSpecifications)
Expand Down
8 changes: 8 additions & 0 deletions samseg/run_samseg
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ parser.add_argument('--gmm', metavar='FILE', help='Point to an alternative GMM f
parser.add_argument('--ignore-unknown', action='store_true', help='Ignore final priors corresponding to unknown class.')
parser.add_argument('--options', metavar='FILE', help='Override advanced options via a json file.')
parser.add_argument('--pallidum-separate', action='store_true', default=False, help='Move pallidum outside of global white matter class. Use this flag when T2/flair is used.')
parser.add_argument('--smooth-wm-cortex-priors', type=float, help='Sigma value to smooth the WM and cortex atlas priors.')
parser.add_argument('--bias-field-smoothing-kernel', type=float, help='Distance in mm of sinc function center to first zero crossing.')
# optional lesion options
parser.add_argument('--lesion', action='store_true', default=False, help='Enable lesion segmentation (requires tensorflow).')
parser.add_argument('--threshold', type=float, default=0.3, help='Lesion threshold for final segmentation. Lesion segmentation must be enabled.')
Expand Down Expand Up @@ -106,6 +108,12 @@ if args.save_posteriors is not None and len(args.save_posteriors) == 0:
else:
savePosteriors = args.save_posteriors

if args.smooth_wm_cortex_priors is not None:
userModelSpecifications['whiteMatterAndCortexSmoothingSigma'] = args.smooth_wm_cortex_priors

if args.bias_field_smoothing_kernel is not None:
userModelSpecifications['biasFieldSmoothingKernelSize'] = args.bias_field_smoothing_kernel

# ------ Run Samseg ------

# If we are dealing with photos, we skip rescaling of intensities and also force ignoreUnknownPriors=True
Expand Down

0 comments on commit 4493c7d

Please sign in to comment.