From b58dadd999fe0ade9b1079e995c4881733e27f55 Mon Sep 17 00:00:00 2001 From: Andrew Hoopes Date: Wed, 22 Dec 2021 22:58:36 -0500 Subject: [PATCH] nf: samseg photo features --- ...moothing2_down2_smoothingForAffine2.tar.gz | 2 +- mri_synthseg/mri_synthseg | 2 +- python/freesurfer/label.py | 6 +- python/freesurfer/samseg/BiasField.py | 59 ++++++--- python/freesurfer/samseg/Samseg.py | 115 +++++++++++++++--- samseg/run_samseg | 17 ++- 6 files changed, 166 insertions(+), 35 deletions(-) diff --git a/distribution/average/samseg/20Subjects_smoothing2_down2_smoothingForAffine2.tar.gz b/distribution/average/samseg/20Subjects_smoothing2_down2_smoothingForAffine2.tar.gz index 2ce64be88ab..e09f949ed13 120000 --- a/distribution/average/samseg/20Subjects_smoothing2_down2_smoothingForAffine2.tar.gz +++ b/distribution/average/samseg/20Subjects_smoothing2_down2_smoothingForAffine2.tar.gz @@ -1 +1 @@ -../../../.git/annex/objects/zG/F3/SHA256E-s51498227--a22429822fc36d02aa438459886f425f8bad989646c6c14479b8c6b769ee9cc5.tar.gz/SHA256E-s51498227--a22429822fc36d02aa438459886f425f8bad989646c6c14479b8c6b769ee9cc5.tar.gz \ No newline at end of file +../../../.git/annex/objects/1f/x2/SHA256E-s55538651--0dfd6c4d78b8fe9f703e5d338e852acf75bb2d7357dd2245d0cc2dbe07d95849.tar.gz/SHA256E-s55538651--0dfd6c4d78b8fe9f703e5d338e852acf75bb2d7357dd2245d0cc2dbe07d95849.tar.gz \ No newline at end of file diff --git a/mri_synthseg/mri_synthseg b/mri_synthseg/mri_synthseg index df2983e7e4f..b35825cae4a 100644 --- a/mri_synthseg/mri_synthseg +++ b/mri_synthseg/mri_synthseg @@ -22,7 +22,7 @@ import tensorflow as tf from tensorflow import keras # set tensorflow logging -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) # ================================================================================================ diff --git a/python/freesurfer/label.py b/python/freesurfer/label.py index 60032888f19..a89979c65ca 100644 --- a/python/freesurfer/label.py +++ b/python/freesurfer/label.py @@ -19,9 +19,9 @@ def recode(seg, mapping): # this is such an ugly hack - we really shouldn't include # this kind of support - if seg.__class__.__name__ in ('Tensor', 'EagerTensor'): - import neurite as ne - return ne.utils.seg.recode(seg, mapping) + #if seg.__class__.__name__ in ('Tensor', 'EagerTensor'): + # import neurite as ne + # return ne.utils.seg.recode(seg, mapping) seg_data = seg.data if isinstance(seg, ArrayContainerTemplate) else seg recoded = np.zeros_like(seg_data, dtype=np.int32) diff --git a/python/freesurfer/samseg/BiasField.py b/python/freesurfer/samseg/BiasField.py index 9c5aee0d73c..0f52e740a45 100644 --- a/python/freesurfer/samseg/BiasField.py +++ b/python/freesurfer/samseg/BiasField.py @@ -5,8 +5,8 @@ class BiasField: - def __init__(self, imageSize, smoothingKernelSize): - self.fullBasisFunctions = self.getBiasFieldBasisFunctions(imageSize, smoothingKernelSize) + def __init__(self, imageSize, smoothingKernelSize, photo_mode=False): + self.fullBasisFunctions = self.getBiasFieldBasisFunctions(imageSize, smoothingKernelSize, photo_mode) self.basisFunctions = self.fullBasisFunctions.copy() self.coefficients = None @@ -74,21 +74,49 @@ def computePrecisionOfKroneckerProductBasisFunctions(self, kroneckerProductBasis precisionMatrix = result.reshape( ( np.prod( Ms ), np.prod( Ms ) ) ) return precisionMatrix - def getBiasFieldBasisFunctions(self, imageSize, smoothingKernelSize): + def getBiasFieldBasisFunctions(self, imageSize, smoothingKernelSize, photo_mode=False): # Our bias model is a linear combination of a set of basis functions. We are using so-called # "DCT-II" basis functions, i.e., the lowest few frequency components of the Discrete Cosine # Transform. biasFieldBasisFunctions = [] - for dimensionNumber in range(3): - N = imageSize[dimensionNumber] - delta = smoothingKernelSize[dimensionNumber] - M = math.ceil(N / delta) + 1 - Nvirtual = (M - 1) * delta - js = [(index + 0.5) * math.pi / Nvirtual for index in range(N)] - scaling = [math.sqrt(2 / Nvirtual)] * M - scaling[0] /= math.sqrt(2) - A = np.array([[math.cos(freq * m) * scaling[m] for m in range(M)] for freq in js]) - biasFieldBasisFunctions.append(A) + + # when we are working with reconstructed photos, we have a 2D basis per slice + if photo_mode: + + if False: + for dimensionNumber in range(2): + N = imageSize[dimensionNumber] + A = np.ones((N, 1), dtype=np.float64) + biasFieldBasisFunctions.append(A) + + N = imageSize[2] + A = np.identity(N, dtype=np.float64) + biasFieldBasisFunctions.append(A) + + else: + for dimensionNumber in range(2): + N = imageSize[dimensionNumber] + A = np.empty((N, 3), dtype=np.float64) + A[:, 0] = np.ones((N), dtype=np.float64) + A[:, 1] = np.linspace(0, 1, N, dtype=np.float64) + A[:, 2] = np.flipud(A[:, 1]) + biasFieldBasisFunctions.append(A) + + N = imageSize[2] + A = np.identity(N, dtype=np.float64) + biasFieldBasisFunctions.append(A) + + else: + for dimensionNumber in range(3): + N = imageSize[dimensionNumber] + delta = smoothingKernelSize[dimensionNumber] + M = math.ceil(N / delta) + 1 + Nvirtual = (M - 1) * delta + js = [(index + 0.5) * math.pi / Nvirtual for index in range(N)] + scaling = [math.sqrt(2 / Nvirtual)] * M + scaling[0] /= math.sqrt(2) + A = np.array([[math.cos(freq * m) * scaling[m] for m in range(M)] for freq in js]) + biasFieldBasisFunctions.append(A) return biasFieldBasisFunctions @@ -106,7 +134,7 @@ def getBiasFields(self, mask=None): return biasFields - def fitBiasFieldParameters(self, imageBuffers, gaussianPosteriors, means, variances, mask): + def fitBiasFieldParameters(self, imageBuffers, gaussianPosteriors, means, variances, mask, photo_mode=False): # Bias field correction: implements Eq. 8 in the paper # Van Leemput, "Automated Model-based Bias Field Correction of MR Images of the Brain", IEEE TMI 1999 @@ -157,6 +185,9 @@ def fitBiasFieldParameters(self, imageBuffers, gaussianPosteriors, means, varian tmpImageBuffer).reshape(-1, 1) # Solve the linear system x = lhs \ rhs + # When working with photos, we like to regularize + if photo_mode: + lhs = lhs + 1e-5 * np.identity(lhs.shape[1], lhs.dtype) solution = np.linalg.solve(lhs, rhs) # diff --git a/python/freesurfer/samseg/Samseg.py b/python/freesurfer/samseg/Samseg.py index dfb7a6b692b..27c997338b4 100644 --- a/python/freesurfer/samseg/Samseg.py +++ b/python/freesurfer/samseg/Samseg.py @@ -1,6 +1,7 @@ import logging import pickle import scipy.io +from scipy.ndimage.morphology import binary_dilation as dilation import freesurfer as fs import sys @@ -38,10 +39,14 @@ def __init__(self, saveModelProbabilities=False, gmmFileName=None, ignoreUnknownPriors=False, + dissectionPhoto=None, + nthreads=1, ): # Store input parameters as class variables self.imageFileNames = imageFileNames + self.originalImageFileNames = imageFileNames # Keep a copy since photo version modifies self.imageFileNames + self.photo_mask = None # Useful when working with photos self.savePath = savePath self.atlasDir = atlasDir self.threshold = threshold @@ -63,7 +68,17 @@ def __init__(self, self.probabilisticAtlas = ProbabilisticAtlas() # Get full model specifications and optimization options (using default unless overridden by user) + # Note that, when processing photos, we point to a different GMM file by default! self.optimizationOptions = getOptimizationOptions(atlasDir, userOptimizationOptions) + if dissectionPhoto and (gmmFileName is None): + if dissectionPhoto == 'left': + gmmFileName = self.atlasDir + '/photo.lh.sharedGMMParameters.txt' + elif dissectionPhoto == 'right': + gmmFileName = self.atlasDir + '/photo.rh.sharedGMMParameters.txt' + elif dissectionPhoto == 'both': + gmmFileName = self.atlasDir + '/photo.both.sharedGMMParameters.txt' + else: + fs.fatal('dissection photo mode must be left, right, or both') self.modelSpecifications = getModelSpecifications( atlasDir, userModelSpecifications, @@ -105,6 +120,7 @@ def __init__(self, self.saveWarp = saveWarp self.saveMesh = saveMesh self.ignoreUnknownPriors = ignoreUnknownPriors + self.dissectionPhoto = dissectionPhoto # Make sure we can write in the target/results directory os.makedirs(savePath, exist_ok=True) @@ -122,6 +138,7 @@ def __init__(self, self.optimizationHistory = None self.deformation = None self.deformationAtlasFileName = None + self.nthreads = nthreads def validateTransform(self, trf): # ======================================================================================= @@ -174,8 +191,43 @@ def segment(self, costfile=None, timer=None, reg_only=False, transformFile=None, trf = self.validateTransform(fs.LinearTransform.read(transformFile)) worldToWorldTransformMatrix = convertRASTransformToLPS(trf.as_ras().matrix) - # Register to template + if self.dissectionPhoto is not None: + # Dissection photos are converted to grayscale + input_vol = fs.Volume.read(self.imageFileNames[0]) + while len(input_vol.data.shape) > 3: + input_vol.data = np.mean(input_vol.data, axis=-1) + # We also a small band of noise around the mask; otherwise the background/skull/etc may fit the cortex + self.photo_mask = input_vol.data > 0 + mask_dilated = dilation(self.photo_mask, iterations=5) + ring = (mask_dilated==True) & (self.photo_mask == False) + max_noise = np.max(input_vol.data) / 50.0 + rng = np.random.default_rng(2021) + input_vol.data[ring] = max_noise * rng.random(input_vol.data[ring].shape[0]) + self.imageFileNames = [] + self.imageFileNames.append(self.savePath + '/grayscale.mgz') + input_vol.write(self.imageFileNames[0]) + + # Register to template, either with SAMSEG code, or externally with FreeSurfer tools (for photos) if self.imageToImageTransformMatrix is None: + + if self.dissectionPhoto is not None: + reference = self.imageFileNames[0] + if self.dissectionPhoto=='left': + moving = self.atlasDir + '/exvivo.template.lh.suptent.nii' + elif self.dissectionPhoto=='right': + moving = self.atlasDir + '/exvivo.template.rh.suptent.nii' + elif self.dissectionPhoto=='both': + moving = self.atlasDir + '/exvivo.template.suptent.nii' + else: + fs.fatal('dissection photo mode must be left, right, or both') + transformFile = self.savePath + '/atlas2image.lta' + cmd = 'mri_coreg --mov ' + moving + ' --ref ' + reference + ' --reg ' + transformFile + \ + ' --dof 12 --threads ' + str(self.nthreads) + os.system(cmd) + trf = fs.LinearTransform.read(transformFile) + trf_val = self.validateTransform(trf) + worldToWorldTransformMatrix = convertRASTransformToLPS(trf_val.as_ras().matrix) + self.register( costfile=costfile, timer=timer, @@ -389,17 +441,29 @@ def writeResults(self, biasFields, posteriors): # Write out various images - segmentation first self.writeImage(segmentation, os.path.join(self.savePath, 'seg.mgz'), saveLabels=True) - for contrastNumber, imageFileName in enumerate(self.imageFileNames): - # Contrast-specific filename prefix - contastPrefix = os.path.join(self.savePath, self.modeNames[contrastNumber]) - - # Write bias field and bias-corrected image - self.writeImage(expBiasFields[..., contrastNumber], contastPrefix + '_bias_field.mgz') - self.writeImage(expImageBuffers[..., contrastNumber], contastPrefix + '_bias_corrected.mgz') - - # Save a note indicating the scaling factor - with open(contastPrefix + '_scaling.txt', 'w') as fid: - print(scalingFactors[contrastNumber], file=fid) + # Bias corrected images: depends on whether we're dealing with MRIs or 3D photo reconstructions + if self.dissectionPhoto is None: # MRI + for contrastNumber, imageFileName in enumerate(self.imageFileNames): + # Contrast-specific filename prefix + contastPrefix = os.path.join(self.savePath, self.modeNames[contrastNumber]) + + # Write bias field and bias-corrected image + self.writeImage(expBiasFields[..., contrastNumber], contastPrefix + '_bias_field.mgz') + self.writeImage(expImageBuffers[..., contrastNumber], contastPrefix + '_bias_corrected.mgz') + + # Save a note indicating the scaling factor + with open(contastPrefix + '_scaling.txt', 'w') as fid: + print(scalingFactors[contrastNumber], file=fid) + + else: # photos + self.writeImage(expBiasFields[..., 0], self.savePath + '/illlumination_field.mgz') + original_vol = fs.Volume.read(self.originalImageFileNames[0]) + bias_native = fs.Volume.read(self.savePath + '/illlumination_field.mgz') + if len(original_vol.data.shape) == 3: + original_vol.data = original_vol.data[..., np.newaxis] + original_vol.data = original_vol.data / (1e-6 + bias_native.data[..., np.newaxis]) + original_vol.data = np.squeeze(original_vol.data) + original_vol.write(self.savePath + '/illlumination_corrected.mgz') if self.savePosteriors: posteriorPath = os.path.join(self.savePath, 'posteriors') @@ -560,7 +624,7 @@ def initializeBiasField(self): # Our bias model is a linear combination of a set of basis functions. We are using so-called "DCT-II" basis functions, # i.e., the lowest few frequency components of the Discrete Cosine Transform. self.biasField = BiasField(self.imageBuffers.shape[0:3], - self.modelSpecifications.biasFieldSmoothingKernelSize / self.voxelSpacing) + self.modelSpecifications.biasFieldSmoothingKernelSize / self.voxelSpacing, photo_mode=(self.dissectionPhoto is not None)) # Visualize some stuff if hasattr(self.visualizer, 'show_flag'): @@ -636,6 +700,10 @@ def estimateModelParameters(self, initialBiasFieldCoefficients=None, initialDefo [ multiResolutionLevel].targetDownsampledVoxelSpacing / self.voxelSpacing)) downSamplingFactors[downSamplingFactors < 1] = 1 + # When working with 3D reconstructed photos, we don't downsample in z + if self.dissectionPhoto is not None: + downSamplingFactors[2] = 1 + downSampledImageBuffers, downSampledMask, downSampledMesh, downSampledInitialDeformationApplied, \ downSampledTransform = self.getDownSampledModel( optimizationOptions.multiResolutionSpecification[multiResolutionLevel].atlasFileName, @@ -728,7 +796,8 @@ def estimateModelParameters(self, initialBiasFieldCoefficients=None, initialDefo if (estimateBiasField and not ((iterationNumber == 0) and skipBiasFieldParameterEstimationInFirstIteration)): self.biasField.fitBiasFieldParameters(downSampledImageBuffers, downSampledGaussianPosteriors, - self.gmm.means, self.gmm.variances, downSampledMask) + self.gmm.means, self.gmm.variances, downSampledMask, + photo_mode=(self.dissectionPhoto is not None)) # End test if bias field update # End loop over EM iterations @@ -849,6 +918,24 @@ def computeFinalSegmentation(self): priors[:, unknown_label] += priors[:, label] priors[:, label] = 0 + # In dissection photos, we merge the choroid with the lateral ventricle + if self.dissectionPhoto is not None: + for n in range(len(self.modelSpecifications.names)): + if self.modelSpecifications.names[n]=='Left-Lateral-Ventricle': + llv = n + elif self.modelSpecifications.names[n]=='Left-choroid-plexus': + lcp = n + elif self.modelSpecifications.names[n]=='Right-Lateral-Ventricle': + rlv = n + elif self.modelSpecifications.names[n]=='Right-choroid-plexus': + rcp = n + if self.dissectionPhoto=='left' or self.dissectionPhoto=='both': + priors[:, llv] += priors[:, lcp] + priors[:, lcp] = 0 + if self.dissectionPhoto=='right' or self.dissectionPhoto=='both': + priors[:, rlv] += priors[:, rcp] + priors[:, rcp] = 0 + # Get bias field corrected data # Make sure that the bias field basis function are not downsampled # (this might happens if the parameters estimation is made only with one downsampled resolution) diff --git a/samseg/run_samseg b/samseg/run_samseg index db3959fab96..224201154f8 100755 --- a/samseg/run_samseg +++ b/samseg/run_samseg @@ -37,6 +37,9 @@ parser.add_argument('--lesion-pseudo-samples', nargs=2, type=float, default=[500 parser.add_argument('--lesion-rho', type=float, default=50, help='Lesion ratio.') parser.add_argument('--lesion-mask-structure', default='Cortex', help='Intensity mask brain structure. Lesion segmentation must be enabled.') parser.add_argument('--lesion-mask-pattern', type=int, nargs='+', help='Lesion mask list (set value for each input volume): -1 below lesion mask structure mean, +1 above, 0 no mask. Lesion segmentation must be enabled.') +# optional options for segmenting 3D reconstructions of photo volumes +parser.add_argument('--dissection-photo', default=None, help='Use this flag for 3D reconstructed photos, and specify hemispheres that are present in the volumes: left, right, or both') + # optional debugging options parser.add_argument('--history', action='store_true', default=False, help='Save history.') parser.add_argument('--save-posteriors', nargs='*', help='Save posterior volumes to the "posteriors" subdirectory.') @@ -104,13 +107,21 @@ else: # ------ Run Samseg ------ +# If we are dealing with photos, we skip rescaling of intensities and also force ignoreUnknownPriors=True +if args.dissection_photo is None: + intensityWM = 110 + ignoreUnknownPriors = args.ignore_unknown +else: + intensityWM = None + ignoreUnknownPriors = True + samseg_kwargs = dict( imageFileNames=args.inputFileNames, atlasDir=atlasDir, savePath=args.outputDirectory, userModelSpecifications=userModelSpecifications, userOptimizationOptions=userOptimizationOptions, - targetIntensity=110, + targetIntensity=intensityWM, targetSearchStrings=[ 'Cerebral-White-Matter' ], visualizer=visualizer, saveHistory=args.history, @@ -121,7 +132,9 @@ samseg_kwargs = dict( pallidumAsWM=(not args.pallidum_separate), saveModelProbabilities=args.save_probabilities, gmmFileName=args.gmm, - ignoreUnknownPriors=args.ignore_unknown, + ignoreUnknownPriors=ignoreUnknownPriors, + dissectionPhoto=args.dissection_photo, + nthreads=args.threads, ) if args.lesion: