diff --git a/gems/Executables/kvlAtlasMeshBuilder.cxx b/gems/Executables/kvlAtlasMeshBuilder.cxx index 8a04eab901e..ff113ee07c7 100755 --- a/gems/Executables/kvlAtlasMeshBuilder.cxx +++ b/gems/Executables/kvlAtlasMeshBuilder.cxx @@ -88,14 +88,15 @@ AtlasMeshBuilder ::SetUp( const std::vector< LabelImageType::ConstPointer >& labelImages, const CompressionLookupTable* compressionLookupTable, const itk::Size< 3>& initialSize, - const std::vector< double >& initialStiffnesses ) + const std::vector< double >& initialStiffnesses, + const unsigned int maximumNumberOfIterations) { m_LabelImages = labelImages; m_CompressionLookupTable = compressionLookupTable; m_InitialSize = initialSize; m_InitialStiffnesses = initialStiffnesses; m_Mesher->SetUp( m_LabelImages, m_CompressionLookupTable, m_InitialSize, m_InitialStiffnesses ); - + m_MaximumNumberOfIterations = maximumNumberOfIterations; } diff --git a/gems/Executables/kvlAtlasMeshBuilder.h b/gems/Executables/kvlAtlasMeshBuilder.h index c681767d996..d044dc2d1ab 100755 --- a/gems/Executables/kvlAtlasMeshBuilder.h +++ b/gems/Executables/kvlAtlasMeshBuilder.h @@ -171,7 +171,8 @@ public : void SetUp( const std::vector< LabelImageType::ConstPointer >& labelImages, const CompressionLookupTable* compressionLookupTable, const itk::Size< 3>& initialSize, - const std::vector< double >& initialStiffnesses ); + const std::vector< double >& initialStiffnesses, + const unsigned int maximumNumberOfIterations ); // Get label images const std::vector< LabelImageType::ConstPointer >& GetLabelImages() const diff --git a/gems/Executables/kvlBuildAtlasMesh.cxx b/gems/Executables/kvlBuildAtlasMesh.cxx index 266280a8c54..b0d8ec0ac8c 100755 --- a/gems/Executables/kvlBuildAtlasMesh.cxx +++ b/gems/Executables/kvlBuildAtlasMesh.cxx @@ -115,7 +115,7 @@ int main( int argc, char** argv ) // Sanity check on input if ( argc < 8 ) { - std::cerr << "Usage: " << argv[ 0 ] << " numberOfUpsamplingSteps meshSizeX meshSizeY meshSizeZ stiffness logDirectory fileName1 [ fileName2 ... ]" << std::endl; + std::cerr << "Usage: " << argv[ 0 ] << " numberOfUpsamplingSteps meshSizeX meshSizeY meshSizeZ stiffness numberOfIterations edgeCollapseFactor logDirectory fileName1 [ fileName2 ... ]" << std::endl; return -1; } @@ -126,7 +126,7 @@ int main( int argc, char** argv ) // Retrieve the input parameters std::ostringstream inputParserStream; - for ( int argumentNumber = 1; argumentNumber < 7; argumentNumber++ ) + for ( int argumentNumber = 1; argumentNumber < 9; argumentNumber++ ) { inputParserStream << argv[ argumentNumber ] << " "; } @@ -136,16 +136,33 @@ int main( int argc, char** argv ) unsigned int meshSizeY; unsigned int meshSizeZ; double stiffness; + unsigned int numberOfIterations; + double edgeCollapseEncouragmentFactor; std::string logDirectory; - inputStream >> numberOfUpsamplingSteps >> meshSizeX >> meshSizeY >> meshSizeZ >> stiffness >> logDirectory; - + inputStream >> \ + numberOfUpsamplingSteps >> \ + meshSizeX >> meshSizeY >> meshSizeZ >> \ + stiffness >> \ + numberOfIterations >> \ + edgeCollapseEncouragmentFactor >> \ + logDirectory; + + std::cout << "kvlBuildAtlasMesh Command line params:" << std::endl; + std::cout << " numberOfUpsamplingSteps: " << numberOfUpsamplingSteps << std::endl; + std::cout << " meshSizeX: " << meshSizeX << std::endl; + std::cout << " meshSizeY: " << meshSizeY << std::endl; + std::cout << " meshSizeZ: " << meshSizeZ << std::endl; + std::cout << " stiffness: " << stiffness << std::endl; + std::cout << " numberOfIterations: " << numberOfIterations << std::endl; + std::cout << " edgeCollapseEncouragmentFactor: " << edgeCollapseEncouragmentFactor << std::endl; + std::cout << " logDirectory: " << logDirectory << std::endl; // Read the input images typedef kvl::CompressionLookupTable::ImageType LabelImageType; std::vector< LabelImageType::ConstPointer > labelImages; - for ( int argumentNumber = 7; argumentNumber < argc; argumentNumber++ ) + for ( int argumentNumber = 9; argumentNumber < argc; argumentNumber++ ) { - std::cout << "Reading input image: " << argv[ argumentNumber ] << std::endl + std::cout << "Reading input image: " << argv[ argumentNumber ] << std::endl; // Read the input image typedef itk::ImageFileReader< LabelImageType > ReaderType; ReaderType::Pointer reader = ReaderType::New(); @@ -174,7 +191,7 @@ int main( int argc, char** argv ) kvl::AtlasMeshBuilder::Pointer builder = kvl::AtlasMeshBuilder::New(); const itk::Size< 3 > initialSize = { meshSizeX, meshSizeY, meshSizeZ }; std::vector< double > initialStiffnesses( numberOfUpsamplingSteps+1, stiffness ); - builder->SetUp( labelImages, lookupTable, initialSize, initialStiffnesses ); + builder->SetUp( labelImages, lookupTable, initialSize, initialStiffnesses, numberOfIterations); builder->SetVerbose( false ); // Add some observers/callbacks @@ -252,11 +269,15 @@ int main( int argc, char** argv ) { std::cerr << "Couldn't read mesh from file " << explicitStartCollectionFileName << std::endl; return -1; + } + else + { + std::cout << "explicitStartCollection found; reading from: " << explicitStartCollectionFileName << std::endl; } } // If edgeCollapseEncouragmentFactor.txt exists in the current directory, read it's content - double edgeCollapseEncouragmentFactor = 1.0; + //double edgeCollapseEncouragmentFactor = 1.0; const std::string edgeCollapseEncouragmentFactorFileName = "edgeCollapseEncouragmentFactor.txt"; //if ( itksys::SystemTools::FileExists( edgeCollapseEncouragmentFactorFileName.c_str(), true ) ) // { diff --git a/infant/CMakeLists.txt b/infant/CMakeLists.txt index 18a0f63a791..48d67eb3dac 100644 --- a/infant/CMakeLists.txt +++ b/infant/CMakeLists.txt @@ -32,6 +32,7 @@ add_subdirectory(labelfusion) # Entrypoint for containers install(PROGRAMS docker/infant-container-entrypoint.bash DESTINATION bin) +install(PROGRAMS docker/infant-container-entrypoint-aws.bash DESTINATION bin) # install external niftyreg binaries if(MARTINOS_BUILD) diff --git a/infant/docker/infant-container-entrypoint-aws.bash b/infant/docker/infant-container-entrypoint-aws.bash new file mode 100755 index 00000000000..be4f57f0e8b --- /dev/null +++ b/infant/docker/infant-container-entrypoint-aws.bash @@ -0,0 +1,55 @@ +#!/bin/bash + +# TODO: remove `FS_INFANT_MODEL` (no longer needed?) +echo "===============================================================" +echo "ENVIRONMENT VARIABLES" +echo "" +echo "AWS_BATCH_JOB_ID: $AWS_BATCH_JOB_ID" +echo "AWS_BATCH_JQ_NAME: $AWS_BATCH_JQ_NAME" +echo "AWS_BATCH_CE_NAME: $AWS_BATCH_CE_NAME" +echo "---------------------------------------------------------------" +echo "FREESURFER_HOME: $FREESURFER_HOME" +echo "FS_INFANT_MODEL: $FS_INFANT_MODEL" +echo "SUBJECTS_DIR: $SUBJECTS_DIR" +echo "FS_SUB_NAME: $FS_SUB_NAME" +echo "SSCNN_MODEL_DIR: $SSCNN_MODEL_DIR" +echo "---------------------------------------------------------------" +echo "FS_NIFTI_INPUT_S3_FILEPATH: $FS_NIFTI_INPUT_S3_FILEPATH" +echo "FS_NIFTI_INPUT_LOCAL_FILEPATH: $FS_NIFTI_INPUT_LOCAL_FILEPATH" +echo "FS_OUTPUT_S3_FILEPATH: $FS_OUTPUT_S3_FILEPATH" +echo "===============================================================" + +# infant pipeline input must be ${SUBJECTS_DIR}:${FS_SUB_NAME}/mprage.nii.gz +if [ -n "${SUBJECTS_DIR}" ] && [ -n "${FS_SUB_NAME}" ]; then + echo "---------------------------------------------------------------" + echo "SUBJECTS_DIR and FS_SUB_NAME detected. Attempting to make dir" + echo "mkdir -p ${SUBJECTS_DIR}/${FS_SUB_NAME}" + mkdir -p ${SUBJECTS_DIR}/${FS_SUB_NAME} + echo "---------------------------------------------------------------" +fi + +if [ -n "${FS_NIFTI_INPUT_S3_FILEPATH}" ] && [ -n "${FS_NIFTI_INPUT_LOCAL_FILEPATH}" ]; then + echo "---------------------------------------------------------------" + echo "FS_NIFTI_INPUT_S3_FILEPATH and FS_NIFTI_INPUT_LOCAL_FILEPATH detected. Attempting to copy file locally" + echo "aws s3 cp ${FS_NIFTI_INPUT_S3_FILEPATH} ${FS_NIFTI_INPUT_LOCAL_FILEPATH}" + aws s3 cp $FS_NIFTI_INPUT_S3_FILEPATH $FS_NIFTI_INPUT_LOCAL_FILEPATH + echo "---------------------------------------------------------------" +fi + +# Symlink the volume_mounted model files to where FreeSurfer expects them +# PW 2021/11/18 No longer needed, since `SSCNN_MODEL_DIR` can now be used in `sscnn_skullstrip` +# ----------------------------------------------------------------------- +#mkdir -p $FREESURFER_HOME/average/sscnn_skullstripping +#ln -s $FS_INFANT_MODEL/sscnn_skullstrip/cor_sscnn.h5 $FREESURFER_HOME/average/sscnn_skullstripping/cor_sscnn.h5 +#ln -s $FS_INFANT_MODEL/sscnn_skullstrip/ax_sscnn.h5 $FREESURFER_HOME/average/sscnn_skullstripping/ax_sscnn.h5 +#ln -s $FS_INFANT_MODEL/sscnn_skullstrip/sag_sscnn.h5 $FREESURFER_HOME/average/sscnn_skullstripping/sag_sscnn.h5 + +eval "$@" + +if [ -n "${FS_OUTPUT_S3_FILEPATH}" ]; then + echo "---------------------------------------------------------------" + echo "FS_OUTPUT_S3_FILEPATH detected. Attempting to copy the subjects_dir to s3:" + echo "aws s3 cp --recursive ${SUBJECTS_DIR}/${FS_SUB_NAME} ${FS_OUTPUT_S3_FILEPATH}" + aws s3 cp --recursive ${SUBJECTS_DIR}/${FS_SUB_NAME} ${FS_OUTPUT_S3_FILEPATH} + echo "---------------------------------------------------------------" +fi diff --git a/infant/docker/infant-container-entrypoint.bash b/infant/docker/infant-container-entrypoint.bash index d906c539be4..e46026dcbb0 100755 --- a/infant/docker/infant-container-entrypoint.bash +++ b/infant/docker/infant-container-entrypoint.bash @@ -1,9 +1,22 @@ #!/bin/bash +# TODO: remove `FS_INFANT_MODEL` (no longer needed?) +echo "===============================================================" +echo "ENVIRONMENT VARIABLES" +echo "" +echo "FREESURFER_HOME: $FREESURFER_HOME" +echo "FS_INFANT_MODEL: $FS_INFANT_MODEL" +echo "SUBJECTS_DIR: $SUBJECTS_DIR" +echo "FS_SUB_NAME: $FS_SUB_NAME" +echo "SSCNN_MODEL_DIR: $SSCNN_MODEL_DIR" +echo "===============================================================" + # Symlink the volume_mounted model files to where FreeSurfer expects them -mkdir -p $FREESURFER_HOME/average/sscnn_skullstripping -ln -s $FS_INFANT_MODEL/sscnn_skullstrip/cor_sscnn.h5 $FREESURFER_HOME/average/sscnn_skullstripping/cor_sscnn.h5 -ln -s $FS_INFANT_MODEL/sscnn_skullstrip/ax_sscnn.h5 $FREESURFER_HOME/average/sscnn_skullstripping/ax_sscnn.h5 -ln -s $FS_INFANT_MODEL/sscnn_skullstrip/sag_sscnn.h5 $FREESURFER_HOME/average/sscnn_skullstripping/sag_sscnn.h5 +# PW 2021/11/18 No longer needed, since `SSCNN_MODEL_DIR` can now be used in `sscnn_skullstrip` +# ----------------------------------------------------------------------- +#mkdir -p $FREESURFER_HOME/average/sscnn_skullstripping +#ln -s $FS_INFANT_MODEL/sscnn_skullstrip/cor_sscnn.h5 $FREESURFER_HOME/average/sscnn_skullstripping/cor_sscnn.h5 +#ln -s $FS_INFANT_MODEL/sscnn_skullstrip/ax_sscnn.h5 $FREESURFER_HOME/average/sscnn_skullstripping/ax_sscnn.h5 +#ln -s $FS_INFANT_MODEL/sscnn_skullstrip/sag_sscnn.h5 $FREESURFER_HOME/average/sscnn_skullstripping/sag_sscnn.h5 eval "$@" diff --git a/python/freesurfer/samseg/prepareAtlasDirectory.py b/python/freesurfer/samseg/prepareAtlasDirectory.py old mode 100644 new mode 100755 index 71f66de2e85..8530f52b595 --- a/python/freesurfer/samseg/prepareAtlasDirectory.py +++ b/python/freesurfer/samseg/prepareAtlasDirectory.py @@ -1,12 +1,16 @@ +#!/usr/bin/env python3 + from freesurfer.samseg import initVisualizer, requireNumpyArray, gems from freesurfer.samseg.io import GMMparameter, kvlReadCompressionLookupTable, \ - kvlWriteCompressionLookupTable, kvlWriteSharedGMMParameters + kvlWriteCompressionLookupTable, kvlWriteSharedGMMParameters, \ + kvlReadSharedGMMParameters from freesurfer.samseg.merge_alphas import kvlGetMergingFractionsTable, kvlMergeAlphas import numpy as np from scipy import ndimage import matplotlib.pyplot as plt import os, shutil - +import sys +import argparse def readAndSimplifyCompressionLookupTable( compressionLookupTableFileName, uninterestingStructureSearchStrings=None ): @@ -254,5 +258,39 @@ def prepareAtlasDirectory( directoryName, return - - +def parse_args(args): + parser = argparse.ArgumentParser() + parser.add_argument("-a", "--atlasdir", required=True, + help="The atlas directory to create") + parser.add_argument("-m", "--mesh", required=True, + help="The filename of the mesh collection to use (output of kvlBuildAtlasMesh)") + parser.add_argument("-c", "--compression_lut", required=True, + help="The compression lookup table to use (output of kvlBuildAtlasMesh)") + parser.add_argument("-g", "--shared_gmm_params", required=True, + help="The filename of the shared GMM parameters to use") + parser.add_argument("-t", "--template", required=True, + help="The filename of the template nifti to use") + parser.add_argument("--show_figs", default=False) + return parser.parse_args() + +def main(argv): + args = parse_args(argv) + if not os.path.exists(args.mesh): + print("ERROR: Can't find the mesh file " + args.mesh) + if not os.path.exists(args.compression_lut): + print("ERROR: Can't find the compression LUT file " + args.compression_lut) + if not os.path.exists(args.shared_gmm_params): + print("ERROR: Can't find the shared GMM params file " + args.shared_gmm_params) + if not os.path.exists(args.template): + print("ERROR: Can't find the template file " + args.template) + + shared_gmm_params = kvlReadSharedGMMParameters(args.shared_gmm_params) + prepareAtlasDirectory(args.atlasdir, + args.mesh, + args.compression_lut, + shared_gmm_params, + args.template, + showFigures=args.show_figs) + +if __name__ == "__main__": + sys.exit(main(sys.argv)) diff --git a/python/requirements-extra.txt b/python/requirements-extra.txt index aa9d76e3858..159dccc576d 100644 --- a/python/requirements-extra.txt +++ b/python/requirements-extra.txt @@ -16,3 +16,7 @@ pandas matplotlib transforms3d scikit-image==0.16.2 + +## Samseg vis +pyqtgraph +PyQt5 diff --git a/samseg/recompute_atlas_probs b/samseg/recompute_atlas_probs deleted file mode 100755 index a18c59127a7..00000000000 --- a/samseg/recompute_atlas_probs +++ /dev/null @@ -1,241 +0,0 @@ -#!/usr/bin/env python3 - -import os -import sys -import numpy as np -import nibabel as nib -import argparse -import freesurfer as fs -import tempfile - -from freesurfer import samseg - -eps = np.finfo(float).eps - -parser = fs.utils.ArgumentParser() -parser.add_argument( - '--subjects-dir', - metavar='DIR', - help='Directory with saved SAMSEG runs with --history and --save-posteriors flags (defaults to $SUBJECTS_DIR).', - required=False) -# --atlas should be one of: -# - 20Subjects_smoothing2_down2_smoothingForAffine2 -# - 20Subjects_smoothing2_down2_smoothingForAffine2_lesion -# - 20Subjects_smoothing2_down2_smoothingForAffine2_lesion_wm_prior -# Or whatever is returned when running -# `find $FREESURFER_HOME/average/samseg/ -type d -maxdepth 1|sed 's/.*samseg\///'` -# --atlas also needs to match the atlas used when running the subjects in --subjects-dir; -parser.add_argument( - '--atlas', - help='The atlas used when samseg was run (default: 20Subjects_smoothing2_down2_smoothingForAffine2)', - default='20Subjects_smoothing2_down2_smoothingForAffine2') -parser.add_argument( - '--out-dir', - metavar='DIR', - help='Output directory (will create a temp dir if not specified)', - default='.') -# After running a dataset through samseg with `--save-posteriors`, find a list of -# valid lables using: -# `ls -1 $SUBJECTS_DIR/$SUB/mri/samseg/posteriors |sed 's/.mgz//' -parser.add_argument( - '--label-set', - nargs='+', - help='The label set to consider when recomputing (defaults to the full label set).' -) -parser.add_argument('--showfigs', action='store_true', default=False, help='Show figures during run.') -args = parser.parse_args() - -# Sanity Check Args -if args.subjects_dir is not None: - subjects_dir = args.subjects_dir -else: - subjects_dir = os.environ['SUBJECTS_DIR'] -if subjects_dir is None: - sys.exit("The flag --subjects-dir was not specified and the env var $SUBJECTS_DIR is not set") -if not os.path.isdir(subjects_dir): - sys.exit("Invalid subjects_dir: "+subjects_dir) - -fshome_dir = os.environ['FREESURFER_HOME'] -if fshome_dir is None: - sys.exit("The env var $FREESURFER_HOME is not set") -if not os.path.isdir(fshome_dir): - sys.exit("Invalid $FREESURFER_HOME env var: "+fshome_dir) - -atlas_dir = os.path.join(fshome_dir, 'average/samseg', args.atlas) -if not os.path.isdir(atlas_dir): - sys.exit("Can't find samseg atlas dir: "+atlas_dir) - -mesh_collections = [os.path.join(atlas_dir, 'atlas_level1.txt.gz'), - os.path.join(atlas_dir, 'atlas_level2.txt.gz')] -for mesh_file in mesh_collections: - if not os.path.isfile(mesh_file): - sys.exit("Can't find samseg mesh file: "+mesh_file) - -out_dir = args.out_dir -if not os.path.exists(out_dir): - print("Creating Output Directory: "+ out_dir) - os.makedirs(out_dir) - -if args.label_set is not None: - label_set = args.label_set -else: - # TODO: this should match the label set and order of the altas if re-writing alphas - label_set = [ - '3rd-Ventricle', - '4th-Ventricle', - '5th-Ventricle', - 'Brain-Stem', - 'CSF', - 'Fluid_Inside_Eyes', - 'Left-Accumbens-area', - 'Left-Amygdala', - 'Left-Caudate', - 'Left-Cerebellum-Cortex', - 'Left-Cerebellum-White-Matter', - 'Left-Cerebral-Cortex', - 'Left-Cerebral-White-Matter', - 'Left-choroid-plexus', - 'Left-Hippocampus', - 'Left-Inf-Lat-Vent', - 'Left-Lateral-Ventricle', - 'Left-Pallidum', - 'Left-Putamen', - 'Left-Thalamus', - 'Left-VentralDC', - 'Left-vessel', - 'non-WM-hypointensities', - 'Optic-Chiasm', - 'Right-Accumbens-area', - 'Right-Amygdala', - 'Right-Caudate', - 'Right-Cerebellum-Cortex', - 'Right-Cerebellum-White-Matter', - 'Right-Cerebral-Cortex', - 'Right-Cerebral-White-Matter', - 'Right-choroid-plexus', - 'Right-Hippocampus', - 'Right-Inf-Lat-Vent', - 'Right-Lateral-Ventricle', - 'Right-Pallidum', - 'Right-Putamen', - 'Right-Thalamus', - 'Right-VentralDC', - 'Right-vessel', - 'Skull', - 'Soft_Nonbrain_Tissue', - 'Unknown', - 'WM-hypointensities'] - -showfigs = args.showfigs - - -# From Doug: For the cortex project, we'll need at least cortex (lh and rh) and wm (lh and rh), and maybe lateral ventricles (lh and rh). -#label_set = ['Left-Cerebral-Cortex', 'Left-Cerebral-White-Matter', 'Left-Lateral-Ventricle', 'Right-Cerebral-Cortex', 'Right-Cerebral-White-Matter', 'Right-Lateral-Ventricle'] - -subjectList = [ pathname for pathname in os.listdir(subjects_dir) \ - if os.path.isdir(os.path.join(subjects_dir, pathname)) ] -numberOfSubjects = len(subjectList) -numberOfLabels = len(label_set) -print('Number of subjects: ' + str(numberOfSubjects)) -print('Number of labels: ' + str(numberOfLabels)) - -# Sweep through all subjects and labels and ensure all the files we need can -# be found -for subject in subjectList: - for label in label_set: - filename = os.path.join(subjects_dir, subject, 'mri/samseg/posteriors', label+'.mgz') - if not os.path.isfile(filename): - sys.exit("Can't find the file: "+filename) - -if showfigs: - visualizer = samseg.initVisualizer(True, True) -else: - visualizer = samseg.initVisualizer(False, False) - -# We need an init of the probabilistic segmentation class -# to call instance methods -atlas = samseg.ProbabilisticAtlas() - -for level, meshCollectionFile in enumerate(mesh_collections): - print("Working on mesh collection at level " + str(level + 1)) - - # Read mesh collection - print("Loading mesh collection at: " + str(meshCollectionFile)) - meshCollection = samseg.gems.KvlMeshCollection() - meshCollection.read(meshCollectionFile) - - # We are interested only on the reference mesh - mesh = meshCollection.reference_mesh - numberOfNodes = mesh.point_count - - # Define what we are interested in, i.e., the label statistics of lesion - labelStatisticsInMeshNodes = np.zeros([numberOfNodes, numberOfLabels, numberOfSubjects]) - - for subjectNumber, subjectDir in enumerate(subjectList): - print('Working on Subject:', subjectDir) - # Load the history file and model params - history_filename = os.path.join(subjects_dir, subjectDir, 'mri/samseg/history.p') - history = np.load(history_filename, allow_pickle=True) - modelSpecifications = history['input']['modelSpecifications'] - transform_matrix = history['transform'] - transform = samseg.gems.KvlTransform(samseg.requireNumpyArray(transform_matrix)) - deformations = history['historyWithinEachMultiResolutionLevel'][level]['deformation'] - deformationAtlasFileName = history['historyWithinEachMultiResolutionLevel'][level]['deformationAtlasFileName'] - cropping = history['cropping'] - - nodePositions = atlas.getMesh( - meshCollectionFile, - transform, - K=modelSpecifications.K, - initialDeformation=deformations, - initialDeformationMeshCollectionFileName=meshCollectionFile - ).points - # The image is cropped as well so the voxel coordinates - # do not exactly match with the original image, - # i.e., there's a shift. Let's undo that. - nodePositions += [slc.start for slc in cropping] - - # Loop through the mri/samseg/posteriors and load each volume in label_set - # into the numpy array `subject_posteriorMap_float` - for label_num, label_name in enumerate(label_set): - print('Loading labal:', label_name) - posterior_filename = os.path.join( - subjects_dir, subjectDir, - 'mri/samseg/posteriors', - label_name+'.mgz') - probImage = nib.load(posterior_filename).get_fdata() - # init np array - if label_num == 0: - subject_posteriorMap_float = np.zeros( - [probImage.shape[0], probImage.shape[1], probImage.shape[2], numberOfLabels], - np.float) - subject_posteriorMap_float[:,:,:,label_num] = probImage - # subject_posteriorMap_float is now fully loaded, noramlize so it sums to 1 across the label dimension - # If it's not already (TODO) - - # Estimate alphas representing the new posterior map, initialized with a flat prior - subject_posteriorMap_uint16 = np.uint16(subject_posteriorMap_float * 65535) - #alphas = np.zeros([numberOfNodes, numberOfLabels]) + 0.5 - mesh = meshCollection.reference_mesh - mesh.points = nodePositions - mesh.alphas = mesh.fit_alphas(subject_posteriorMap_uint16) - - # Show rasterized prior with updated alphas - #if showfigs: - # rasterizedPrior = mesh.rasterize(segmentationImage.shape, 1) / 65535 - # visualizer.show(images=rasterizedPrior) - - # Show progress to anyone who's watching - print('====================================================================') - print('') - print('subjectNumber: ' + str(subjectNumber + 1)) - print('') - print('====================================================================') - - # Save label statistics of subject - labelStatisticsInMeshNodes[:, :, subjectNumber] = mesh.alphas.copy() - - # Save label statistics in a npy file - np.save(os.path.join(out_dir, 'labelStatistics_atlas_%d' % level), labelStatisticsInMeshNodes) - - new_altlas_priors = np.linalg.norm(labelStatisticsInMeshNodes, axis=2) diff --git a/sscnn_skullstripping/sscnn_skullstrip b/sscnn_skullstripping/sscnn_skullstrip index c94fde9122c..be68241cfcb 100755 --- a/sscnn_skullstripping/sscnn_skullstrip +++ b/sscnn_skullstripping/sscnn_skullstrip @@ -43,7 +43,9 @@ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # get model files -model_dir = os.path.join(fs.fshome(), 'average', 'sscnn_skullstripping') +model_dir = os.environ.get('SSCNN_MODEL_DIR') +if model_dir is None: + model_dir = os.path.join(fs.fshome(), 'average', 'sscnn_skullstripping') if not os.path.exists(model_dir): model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'model_files') if not os.path.exists(model_dir):