Skip to content

Commit

Permalink
samseg refactor, removed duplicate code (freesurfer#854)
Browse files Browse the repository at this point in the history
  • Loading branch information
ste93ste authored May 3, 2021
1 parent 12d7fe6 commit 96dacda
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 87 deletions.
74 changes: 0 additions & 74 deletions python/bindings/gems/pyKvlMesh.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -443,80 +443,6 @@ py::array KvlMesh::RasterizeValues(std::vector<size_t> size, py::array_t<double,
return createNumpyArrayFStyle(size, buffer);
}

py::array_t<double> KvlMesh::CollectLabelStatisticsInMeshNodes(const py::array_t<uint16_t, py::array::f_style | py::array::forcecast> multiAlphaImageBuffer) {

// Determine the size of the image to be created
typedef kvl::AtlasMeshProbabilityImageStatisticsCollector::ProbabilityImageType ProbabilityImageType;
typedef ProbabilityImageType::SizeType SizeType;
SizeType imageSize;
for ( int i = 0; i < 3; i++ )
{
imageSize[ i ] = multiAlphaImageBuffer.shape( i );
std::cout << "imageSize[ i ]: " << imageSize[ i ] << std::endl;
}

// Allocate an image of that size
const unsigned int numberOfClasses = multiAlphaImageBuffer.shape( 3 );
std::cout << "numberOfClasses: " << numberOfClasses << std::endl;
ProbabilityImageType::Pointer probabilityImage = ProbabilityImageType::New();
probabilityImage->SetRegions( imageSize );
probabilityImage->Allocate();
ProbabilityImageType::PixelType emptyEntry( numberOfClasses );
emptyEntry.Fill( 0.0f );
probabilityImage->FillBuffer( emptyEntry );

// Fill in -- relying on the fact that we've guaranteed a F-style Numpy array input
auto *data = multiAlphaImageBuffer.data();
for ( int classNumber = 0; classNumber < numberOfClasses; classNumber++ )
{
// Loop over all voxels
itk::ImageRegionIterator< ProbabilityImageType > it( probabilityImage,
probabilityImage->GetBufferedRegion() );
for ( ;!it.IsAtEnd(); ++it, ++data )
{
it.Value()[ classNumber ] = static_cast< float >( *data ) / 65535.0;
}

}

std::cout << "Created and filled probabilityImage" << std::endl;


// Retrieve input mesh
kvl::AtlasMesh::ConstPointer constMesh = static_cast< const kvl::AtlasMesh* >( mesh );
std::cout << "Got mesh" << std::endl;

// Collect statistics
kvl::AtlasMeshProbabilityImageStatisticsCollector::Pointer statisticsCollector =
kvl::AtlasMeshProbabilityImageStatisticsCollector::New();
statisticsCollector->SetProbabilityImage( probabilityImage );
statisticsCollector->Rasterize( constMesh );


// Copy the computed statistics in the mesh nodes into a numpy array
const unsigned int numberOfNodes = constMesh->GetPoints()->Size();
auto *outData = new double[numberOfNodes * numberOfClasses];
auto dataIterator = outData;

for ( kvl::AtlasMeshProbabilityImageStatisticsCollector::StatisticsContainerType::ConstIterator
statIt = statisticsCollector->GetLabelStatistics()->Begin();
statIt != statisticsCollector->GetLabelStatistics()->End(); ++statIt)
{

for ( int classNumber = 0; classNumber < numberOfClasses; classNumber++ )
{
*dataIterator++ = statIt.Value()[ classNumber ];
} // End loop over classes

} // End loop over all points


// Also return the minLogLikelihood
const double minLogLikelihood = statisticsCollector->GetMinLogLikelihood();
std::cout << minLogLikelihood << std::endl; // TODO: return minLogLikelihood

return createNumpyArrayCStyle({numberOfNodes, numberOfClasses}, outData);
}

py::array_t<double> KvlMesh::FitAlphas( const py::array_t< uint16_t,
py::array::f_style | py::array::forcecast >&
Expand Down
3 changes: 1 addition & 2 deletions python/bindings/gems/pyKvlMesh.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ class KvlMesh {
void Scale(const SCALE_3D &scaling);
py::array_t<uint16_t> RasterizeMesh(std::vector<size_t> size, int classNumber=-1);
py::array RasterizeValues(std::vector<size_t> size, py::array_t<double, py::array::c_style | py::array::forcecast> values);
py::array_t<double> CollectLabelStatisticsInMeshNodes(const py::array_t<uint16_t, py::array::f_style | py::array::forcecast > multiAlphaImageBuffer);
py::array_t<double> FitAlphas( const py::array_t< uint16_t,
py::array_t<double> FitAlphas( const py::array_t< uint16_t,
py::array::f_style | py::array::forcecast >&
probabilityImageBuffer ) const;

Expand Down
1 change: 0 additions & 1 deletion python/bindings/gems/python.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ PYBIND11_MODULE(gemsbindings, m) {
.def_property_readonly("point_count", &KvlMesh::PointCount, py::return_value_policy::take_ownership)
.def_property("points", &KvlMesh::GetPointSet, &KvlMesh::SetPointSet, py::return_value_policy::take_ownership)
.def_property("alphas", &KvlMesh::GetAlphas, &KvlMesh::SetAlphas, py::return_value_policy::take_ownership)
.def("collect_label_statistics_nodes", &KvlMesh::CollectLabelStatisticsInMeshNodes, py::return_value_policy::take_ownership)
.def( "fit_alphas", &KvlMesh::FitAlphas, py::return_value_policy::take_ownership )
.def("scale", &KvlMesh::Scale, py::return_value_policy::take_ownership)
.def("rasterize_values", &KvlMesh::RasterizeValues, py::arg("shape"), py::arg("values"), py::return_value_policy::take_ownership)
Expand Down
12 changes: 2 additions & 10 deletions samseg/gems_compute_binary_atlas_probs
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,8 @@ for level, meshCollectionFile in enumerate(args.mesh_collections):
segmentationMap[:, :, :, 1] = segmentationImage * 65535
alphas = np.zeros([numberOfNodes, 2]) + 0.5
mesh = meshCollection.reference_mesh
mesh.alphas = alphas
mesh.points = nodePositions

for EMIterationNumber in range(20):
# E-step
# TODO: return also minLogLikelihood, now we are only printing it in the gems function
labelStatistics = mesh.collect_label_statistics_nodes(segmentationMap)
# M-step
alphas = labelStatistics / (np.expand_dims(np.sum(labelStatistics, axis=1) + eps, 1))
mesh.alphas = alphas
mesh.alphas = mesh.fit_alphas(segmentationMap)

# Show rasterized prior with updated alphas
if args.showfigs:
Expand All @@ -111,7 +103,7 @@ for level, meshCollectionFile in enumerate(args.mesh_collections):
print('====================================================================')

# Save label statistics of subject
labelStatisticsInMeshNodes[:, :, subjectNumber] = labelStatistics
labelStatisticsInMeshNodes[:, :, subjectNumber] = mesh.alphas.copy()

# Save label statistics in a npy file
np.save(os.path.join(args.out_dir, 'labelStatistics_atlas_%d' % level), labelStatisticsInMeshNodes)

0 comments on commit 96dacda

Please sign in to comment.