Skip to content

Commit

Permalink
nf: samseg photo features
Browse files Browse the repository at this point in the history
  • Loading branch information
ahoopes committed Dec 23, 2021
1 parent 5403612 commit b58dadd
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 35 deletions.
2 changes: 1 addition & 1 deletion mri_synthseg/mri_synthseg
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import tensorflow as tf
from tensorflow import keras

# set tensorflow logging
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)


# ================================================================================================
Expand Down
6 changes: 3 additions & 3 deletions python/freesurfer/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def recode(seg, mapping):

# this is such an ugly hack - we really shouldn't include
# this kind of support
if seg.__class__.__name__ in ('Tensor', 'EagerTensor'):
import neurite as ne
return ne.utils.seg.recode(seg, mapping)
#if seg.__class__.__name__ in ('Tensor', 'EagerTensor'):
# import neurite as ne
# return ne.utils.seg.recode(seg, mapping)

seg_data = seg.data if isinstance(seg, ArrayContainerTemplate) else seg
recoded = np.zeros_like(seg_data, dtype=np.int32)
Expand Down
59 changes: 45 additions & 14 deletions python/freesurfer/samseg/BiasField.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@


class BiasField:
def __init__(self, imageSize, smoothingKernelSize):
self.fullBasisFunctions = self.getBiasFieldBasisFunctions(imageSize, smoothingKernelSize)
def __init__(self, imageSize, smoothingKernelSize, photo_mode=False):
self.fullBasisFunctions = self.getBiasFieldBasisFunctions(imageSize, smoothingKernelSize, photo_mode)
self.basisFunctions = self.fullBasisFunctions.copy()
self.coefficients = None

Expand Down Expand Up @@ -74,21 +74,49 @@ def computePrecisionOfKroneckerProductBasisFunctions(self, kroneckerProductBasis
precisionMatrix = result.reshape( ( np.prod( Ms ), np.prod( Ms ) ) )
return precisionMatrix

def getBiasFieldBasisFunctions(self, imageSize, smoothingKernelSize):
def getBiasFieldBasisFunctions(self, imageSize, smoothingKernelSize, photo_mode=False):
# Our bias model is a linear combination of a set of basis functions. We are using so-called
# "DCT-II" basis functions, i.e., the lowest few frequency components of the Discrete Cosine
# Transform.
biasFieldBasisFunctions = []
for dimensionNumber in range(3):
N = imageSize[dimensionNumber]
delta = smoothingKernelSize[dimensionNumber]
M = math.ceil(N / delta) + 1
Nvirtual = (M - 1) * delta
js = [(index + 0.5) * math.pi / Nvirtual for index in range(N)]
scaling = [math.sqrt(2 / Nvirtual)] * M
scaling[0] /= math.sqrt(2)
A = np.array([[math.cos(freq * m) * scaling[m] for m in range(M)] for freq in js])
biasFieldBasisFunctions.append(A)

# when we are working with reconstructed photos, we have a 2D basis per slice
if photo_mode:

if False:
for dimensionNumber in range(2):
N = imageSize[dimensionNumber]
A = np.ones((N, 1), dtype=np.float64)
biasFieldBasisFunctions.append(A)

N = imageSize[2]
A = np.identity(N, dtype=np.float64)
biasFieldBasisFunctions.append(A)

else:
for dimensionNumber in range(2):
N = imageSize[dimensionNumber]
A = np.empty((N, 3), dtype=np.float64)
A[:, 0] = np.ones((N), dtype=np.float64)
A[:, 1] = np.linspace(0, 1, N, dtype=np.float64)
A[:, 2] = np.flipud(A[:, 1])
biasFieldBasisFunctions.append(A)

N = imageSize[2]
A = np.identity(N, dtype=np.float64)
biasFieldBasisFunctions.append(A)

else:
for dimensionNumber in range(3):
N = imageSize[dimensionNumber]
delta = smoothingKernelSize[dimensionNumber]
M = math.ceil(N / delta) + 1
Nvirtual = (M - 1) * delta
js = [(index + 0.5) * math.pi / Nvirtual for index in range(N)]
scaling = [math.sqrt(2 / Nvirtual)] * M
scaling[0] /= math.sqrt(2)
A = np.array([[math.cos(freq * m) * scaling[m] for m in range(M)] for freq in js])
biasFieldBasisFunctions.append(A)

return biasFieldBasisFunctions

Expand All @@ -106,7 +134,7 @@ def getBiasFields(self, mask=None):

return biasFields

def fitBiasFieldParameters(self, imageBuffers, gaussianPosteriors, means, variances, mask):
def fitBiasFieldParameters(self, imageBuffers, gaussianPosteriors, means, variances, mask, photo_mode=False):

# Bias field correction: implements Eq. 8 in the paper
# Van Leemput, "Automated Model-based Bias Field Correction of MR Images of the Brain", IEEE TMI 1999
Expand Down Expand Up @@ -157,6 +185,9 @@ def fitBiasFieldParameters(self, imageBuffers, gaussianPosteriors, means, varian
tmpImageBuffer).reshape(-1, 1)

# Solve the linear system x = lhs \ rhs
# When working with photos, we like to regularize
if photo_mode:
lhs = lhs + 1e-5 * np.identity(lhs.shape[1], lhs.dtype)
solution = np.linalg.solve(lhs, rhs)

#
Expand Down
115 changes: 101 additions & 14 deletions python/freesurfer/samseg/Samseg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import pickle
import scipy.io
from scipy.ndimage.morphology import binary_dilation as dilation
import freesurfer as fs
import sys

Expand Down Expand Up @@ -38,10 +39,14 @@ def __init__(self,
saveModelProbabilities=False,
gmmFileName=None,
ignoreUnknownPriors=False,
dissectionPhoto=None,
nthreads=1,
):

# Store input parameters as class variables
self.imageFileNames = imageFileNames
self.originalImageFileNames = imageFileNames # Keep a copy since photo version modifies self.imageFileNames
self.photo_mask = None # Useful when working with photos
self.savePath = savePath
self.atlasDir = atlasDir
self.threshold = threshold
Expand All @@ -63,7 +68,17 @@ def __init__(self,
self.probabilisticAtlas = ProbabilisticAtlas()

# Get full model specifications and optimization options (using default unless overridden by user)
# Note that, when processing photos, we point to a different GMM file by default!
self.optimizationOptions = getOptimizationOptions(atlasDir, userOptimizationOptions)
if dissectionPhoto and (gmmFileName is None):
if dissectionPhoto == 'left':
gmmFileName = self.atlasDir + '/photo.lh.sharedGMMParameters.txt'
elif dissectionPhoto == 'right':
gmmFileName = self.atlasDir + '/photo.rh.sharedGMMParameters.txt'
elif dissectionPhoto == 'both':
gmmFileName = self.atlasDir + '/photo.both.sharedGMMParameters.txt'
else:
fs.fatal('dissection photo mode must be left, right, or both')
self.modelSpecifications = getModelSpecifications(
atlasDir,
userModelSpecifications,
Expand Down Expand Up @@ -105,6 +120,7 @@ def __init__(self,
self.saveWarp = saveWarp
self.saveMesh = saveMesh
self.ignoreUnknownPriors = ignoreUnknownPriors
self.dissectionPhoto = dissectionPhoto

# Make sure we can write in the target/results directory
os.makedirs(savePath, exist_ok=True)
Expand All @@ -122,6 +138,7 @@ def __init__(self,
self.optimizationHistory = None
self.deformation = None
self.deformationAtlasFileName = None
self.nthreads = nthreads

def validateTransform(self, trf):
# =======================================================================================
Expand Down Expand Up @@ -174,8 +191,43 @@ def segment(self, costfile=None, timer=None, reg_only=False, transformFile=None,
trf = self.validateTransform(fs.LinearTransform.read(transformFile))
worldToWorldTransformMatrix = convertRASTransformToLPS(trf.as_ras().matrix)

# Register to template
if self.dissectionPhoto is not None:
# Dissection photos are converted to grayscale
input_vol = fs.Volume.read(self.imageFileNames[0])
while len(input_vol.data.shape) > 3:
input_vol.data = np.mean(input_vol.data, axis=-1)
# We also a small band of noise around the mask; otherwise the background/skull/etc may fit the cortex
self.photo_mask = input_vol.data > 0
mask_dilated = dilation(self.photo_mask, iterations=5)
ring = (mask_dilated==True) & (self.photo_mask == False)
max_noise = np.max(input_vol.data) / 50.0
rng = np.random.default_rng(2021)
input_vol.data[ring] = max_noise * rng.random(input_vol.data[ring].shape[0])
self.imageFileNames = []
self.imageFileNames.append(self.savePath + '/grayscale.mgz')
input_vol.write(self.imageFileNames[0])

# Register to template, either with SAMSEG code, or externally with FreeSurfer tools (for photos)
if self.imageToImageTransformMatrix is None:

if self.dissectionPhoto is not None:
reference = self.imageFileNames[0]
if self.dissectionPhoto=='left':
moving = self.atlasDir + '/exvivo.template.lh.suptent.nii'
elif self.dissectionPhoto=='right':
moving = self.atlasDir + '/exvivo.template.rh.suptent.nii'
elif self.dissectionPhoto=='both':
moving = self.atlasDir + '/exvivo.template.suptent.nii'
else:
fs.fatal('dissection photo mode must be left, right, or both')
transformFile = self.savePath + '/atlas2image.lta'
cmd = 'mri_coreg --mov ' + moving + ' --ref ' + reference + ' --reg ' + transformFile + \
' --dof 12 --threads ' + str(self.nthreads)
os.system(cmd)
trf = fs.LinearTransform.read(transformFile)
trf_val = self.validateTransform(trf)
worldToWorldTransformMatrix = convertRASTransformToLPS(trf_val.as_ras().matrix)

self.register(
costfile=costfile,
timer=timer,
Expand Down Expand Up @@ -389,17 +441,29 @@ def writeResults(self, biasFields, posteriors):
# Write out various images - segmentation first
self.writeImage(segmentation, os.path.join(self.savePath, 'seg.mgz'), saveLabels=True)

for contrastNumber, imageFileName in enumerate(self.imageFileNames):
# Contrast-specific filename prefix
contastPrefix = os.path.join(self.savePath, self.modeNames[contrastNumber])

# Write bias field and bias-corrected image
self.writeImage(expBiasFields[..., contrastNumber], contastPrefix + '_bias_field.mgz')
self.writeImage(expImageBuffers[..., contrastNumber], contastPrefix + '_bias_corrected.mgz')

# Save a note indicating the scaling factor
with open(contastPrefix + '_scaling.txt', 'w') as fid:
print(scalingFactors[contrastNumber], file=fid)
# Bias corrected images: depends on whether we're dealing with MRIs or 3D photo reconstructions
if self.dissectionPhoto is None: # MRI
for contrastNumber, imageFileName in enumerate(self.imageFileNames):
# Contrast-specific filename prefix
contastPrefix = os.path.join(self.savePath, self.modeNames[contrastNumber])

# Write bias field and bias-corrected image
self.writeImage(expBiasFields[..., contrastNumber], contastPrefix + '_bias_field.mgz')
self.writeImage(expImageBuffers[..., contrastNumber], contastPrefix + '_bias_corrected.mgz')

# Save a note indicating the scaling factor
with open(contastPrefix + '_scaling.txt', 'w') as fid:
print(scalingFactors[contrastNumber], file=fid)

else: # photos
self.writeImage(expBiasFields[..., 0], self.savePath + '/illlumination_field.mgz')
original_vol = fs.Volume.read(self.originalImageFileNames[0])
bias_native = fs.Volume.read(self.savePath + '/illlumination_field.mgz')
if len(original_vol.data.shape) == 3:
original_vol.data = original_vol.data[..., np.newaxis]
original_vol.data = original_vol.data / (1e-6 + bias_native.data[..., np.newaxis])
original_vol.data = np.squeeze(original_vol.data)
original_vol.write(self.savePath + '/illlumination_corrected.mgz')

if self.savePosteriors:
posteriorPath = os.path.join(self.savePath, 'posteriors')
Expand Down Expand Up @@ -560,7 +624,7 @@ def initializeBiasField(self):
# Our bias model is a linear combination of a set of basis functions. We are using so-called "DCT-II" basis functions,
# i.e., the lowest few frequency components of the Discrete Cosine Transform.
self.biasField = BiasField(self.imageBuffers.shape[0:3],
self.modelSpecifications.biasFieldSmoothingKernelSize / self.voxelSpacing)
self.modelSpecifications.biasFieldSmoothingKernelSize / self.voxelSpacing, photo_mode=(self.dissectionPhoto is not None))

# Visualize some stuff
if hasattr(self.visualizer, 'show_flag'):
Expand Down Expand Up @@ -636,6 +700,10 @@ def estimateModelParameters(self, initialBiasFieldCoefficients=None, initialDefo
[
multiResolutionLevel].targetDownsampledVoxelSpacing / self.voxelSpacing))
downSamplingFactors[downSamplingFactors < 1] = 1
# When working with 3D reconstructed photos, we don't downsample in z
if self.dissectionPhoto is not None:
downSamplingFactors[2] = 1

downSampledImageBuffers, downSampledMask, downSampledMesh, downSampledInitialDeformationApplied, \
downSampledTransform = self.getDownSampledModel(
optimizationOptions.multiResolutionSpecification[multiResolutionLevel].atlasFileName,
Expand Down Expand Up @@ -728,7 +796,8 @@ def estimateModelParameters(self, initialBiasFieldCoefficients=None, initialDefo
if (estimateBiasField and not ((iterationNumber == 0)
and skipBiasFieldParameterEstimationInFirstIteration)):
self.biasField.fitBiasFieldParameters(downSampledImageBuffers, downSampledGaussianPosteriors,
self.gmm.means, self.gmm.variances, downSampledMask)
self.gmm.means, self.gmm.variances, downSampledMask,
photo_mode=(self.dissectionPhoto is not None))
# End test if bias field update

# End loop over EM iterations
Expand Down Expand Up @@ -849,6 +918,24 @@ def computeFinalSegmentation(self):
priors[:, unknown_label] += priors[:, label]
priors[:, label] = 0

# In dissection photos, we merge the choroid with the lateral ventricle
if self.dissectionPhoto is not None:
for n in range(len(self.modelSpecifications.names)):
if self.modelSpecifications.names[n]=='Left-Lateral-Ventricle':
llv = n
elif self.modelSpecifications.names[n]=='Left-choroid-plexus':
lcp = n
elif self.modelSpecifications.names[n]=='Right-Lateral-Ventricle':
rlv = n
elif self.modelSpecifications.names[n]=='Right-choroid-plexus':
rcp = n
if self.dissectionPhoto=='left' or self.dissectionPhoto=='both':
priors[:, llv] += priors[:, lcp]
priors[:, lcp] = 0
if self.dissectionPhoto=='right' or self.dissectionPhoto=='both':
priors[:, rlv] += priors[:, rcp]
priors[:, rcp] = 0

# Get bias field corrected data
# Make sure that the bias field basis function are not downsampled
# (this might happens if the parameters estimation is made only with one downsampled resolution)
Expand Down
17 changes: 15 additions & 2 deletions samseg/run_samseg
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ parser.add_argument('--lesion-pseudo-samples', nargs=2, type=float, default=[500
parser.add_argument('--lesion-rho', type=float, default=50, help='Lesion ratio.')
parser.add_argument('--lesion-mask-structure', default='Cortex', help='Intensity mask brain structure. Lesion segmentation must be enabled.')
parser.add_argument('--lesion-mask-pattern', type=int, nargs='+', help='Lesion mask list (set value for each input volume): -1 below lesion mask structure mean, +1 above, 0 no mask. Lesion segmentation must be enabled.')
# optional options for segmenting 3D reconstructions of photo volumes
parser.add_argument('--dissection-photo', default=None, help='Use this flag for 3D reconstructed photos, and specify hemispheres that are present in the volumes: left, right, or both')

# optional debugging options
parser.add_argument('--history', action='store_true', default=False, help='Save history.')
parser.add_argument('--save-posteriors', nargs='*', help='Save posterior volumes to the "posteriors" subdirectory.')
Expand Down Expand Up @@ -104,13 +107,21 @@ else:

# ------ Run Samseg ------

# If we are dealing with photos, we skip rescaling of intensities and also force ignoreUnknownPriors=True
if args.dissection_photo is None:
intensityWM = 110
ignoreUnknownPriors = args.ignore_unknown
else:
intensityWM = None
ignoreUnknownPriors = True

samseg_kwargs = dict(
imageFileNames=args.inputFileNames,
atlasDir=atlasDir,
savePath=args.outputDirectory,
userModelSpecifications=userModelSpecifications,
userOptimizationOptions=userOptimizationOptions,
targetIntensity=110,
targetIntensity=intensityWM,
targetSearchStrings=[ 'Cerebral-White-Matter' ],
visualizer=visualizer,
saveHistory=args.history,
Expand All @@ -121,7 +132,9 @@ samseg_kwargs = dict(
pallidumAsWM=(not args.pallidum_separate),
saveModelProbabilities=args.save_probabilities,
gmmFileName=args.gmm,
ignoreUnknownPriors=args.ignore_unknown,
ignoreUnknownPriors=ignoreUnknownPriors,
dissectionPhoto=args.dissection_photo,
nthreads=args.threads,
)

if args.lesion:
Expand Down

0 comments on commit b58dadd

Please sign in to comment.