Skip to content

Commit

Permalink
ENH: Add group-wise registration class skeleton with test
Browse files Browse the repository at this point in the history
  • Loading branch information
dzenanz committed Feb 8, 2024
1 parent f6ff16b commit 9fbfe00
Show file tree
Hide file tree
Showing 5 changed files with 391 additions and 0 deletions.
189 changes: 189 additions & 0 deletions include/itkANTsGroupwiseRegistration.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*=========================================================================
*
* Copyright NumFOCUS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*=========================================================================*/
#ifndef itkANTsGroupwiseRegistration_h
#define itkANTsGroupwiseRegistration_h

#include "itkImageToImageFilter.h"
#include "itkANTSRegistration.h"

namespace itk
{

/** \class ANTsGroupwiseRegistration
*
* \brief Group-wise image registration method parameterized according to ANTsPy.
*
* Inputs are images to be registered, and optionally initial template.
* Outputs are the computed template image, and forward and inverse transforms
* for each of the input images when registered to the template.
*
* This is similar to ANTsPy build_template function:
* https://github.com/ANTsX/ANTsPy/blob/master/ants/registration/build_template.py
*
* \ingroup ANTsWasm
* \ingroup Registration
*
*/
template <typename TImage,
typename TTemplateImage = ::itk::Image<float, TImage::ImageDimension>,
typename TParametersValueType = double>
class ANTsGroupwiseRegistration : public ImageToImageFilter<TTemplateImage, TTemplateImage>
{
public:
ITK_DISALLOW_COPY_AND_MOVE(ANTsGroupwiseRegistration);

static constexpr unsigned int ImageDimension = TImage::ImageDimension;

using ImageType = TImage;
using PixelType = typename ImageType::PixelType;
using TemplateImageType = TTemplateImage;

using ParametersValueType = TParametersValueType;
using TransformType = Transform<TParametersValueType, ImageDimension, ImageDimension>;
using CompositeTransformType = CompositeTransform<ParametersValueType, ImageDimension>;
using OutputTransformType = CompositeTransformType;
using DecoratedOutputTransformType = DataObjectDecorator<OutputTransformType>;

/** Standard class aliases. */
using Self = ANTsGroupwiseRegistration<ImageType, TTemplateImage, ParametersValueType>;
using Superclass = ImageToImageFilter<TTemplateImage, TTemplateImage>;
using Pointer = SmartPointer<Self>;
using ConstPointer = SmartPointer<const Self>;

/** Run-time type information. */
itkTypeMacro(ANTsGroupwiseRegistration, ImageToImageFilter);

/** Standard New macro. */
itkNewMacro(Self);


/** Set/Get the initial template image. */
void
SetInitialTemplateImage(const TemplateImageType * initialTemplate)
{
this->SetNthInput(0, const_cast<TemplateImageType *>(initialTemplate)); // the primary input
}
const TemplateImageType *
GetInitialTemplateImage()
{
return this->GetNthInput(0); // the primary input
}

/** Get the optimal template image. */
TemplateImageType *
GetTemplateImage()
{
return GetOutput(0); // this is just the primary output
}

/** Returns the transform resulting from the registration process */
const OutputTransformType *
GetTransform(unsigned imageIndex) const
{
return nullptr;
// return this->GetOutput(imageIndex)->Get();
}

/** Set/Get the gradient step size for transform optimizers that use it. */
itkSetMacro(GradientStep, ParametersValueType);
itkGetMacro(GradientStep, ParametersValueType);

/** Set/Get smoothing for update field.
* This only affects transform which use a deformation field. */
itkSetMacro(BlendingWeight, ParametersValueType);
itkGetMacro(BlendingWeight, ParametersValueType);

/** Set/Get whether template update step uses the rigid component. */
itkSetMacro(UseNoRigid, bool);
itkGetMacro(UseNoRigid, bool);

/** Set/Get number of iterations for each pyramid level for SyN transforms.
* Shrink factors and smoothing sigmas for SyN are determined based on iterations. */
itkSetMacro(Iterations, unsigned int);
itkGetMacro(Iterations, unsigned int);

/** Set/Get the weight for each image. */
itkSetMacro(Weights, std::vector<ParametersValueType>);
itkGetConstReferenceMacro(Weights, std::vector<ParametersValueType>);

/** Set/Get the images to register. */
itkSetMacro(ImageList, std::vector<ImageType *>);
itkGetConstReferenceMacro(ImageList, std::vector<ImageType *>);

using ProcessObject::AddInput;
using ProcessObject::RemoveInput;
using ProcessObject::GetInput;

protected:
ANTsGroupwiseRegistration();
~ANTsGroupwiseRegistration() override = default;

/** Make a DataObject of the correct type to be used as the specified output. */
using DataObjectPointerArraySizeType = ProcessObject::DataObjectPointerArraySizeType;
using ProcessObject::DataObjectPointer;
DataObjectPointer MakeOutput(DataObjectPointerArraySizeType) override;

using PairwiseType = ANTSRegistration<ImageType, ImageType, ParametersValueType>;

void
PrintSelf(std::ostream & os, Indent indent) const override;

void
GenerateData() override;

void
VerifyInputInformation() const override
{}

// helper function to create the right kind of concrete transform
template <typename TTransform>
static void
MakeOutputTransform(SmartPointer<TTransform> & ptr)
{
ptr = TTransform::New();
}

/** Sets the output to the provided forward transform. */
void
SetTransform(unsigned index, const OutputTransformType * transform)
{
return this->GetOutput(index)->Set(transform);
}

ParametersValueType m_GradientStep{ 0.2 };
ParametersValueType m_BlendingWeight{ 0.75 };
bool m_UseNoRigid{ true };
unsigned int m_Iterations{ 3 };

std::vector<ParametersValueType> m_Weights;
std::vector<ImageType *> m_ImageList;
typename PairwiseType::Pointer m_PairwiseRegistration{ PairwiseType::New() };

#ifdef ITK_USE_CONCEPT_CHECKING
static_assert(TImage::ImageDimension == TTemplateImage::ImageDimension,
"Template imagemust have the same dimension as the input images.");
static_assert(ImageDimension >= 2, "Images must be at least two-dimensional.");
#endif
};
} // namespace itk

#ifndef ITK_MANUAL_INSTANTIATION
# include "itkANTsGroupwiseRegistration.hxx"
#endif

#endif // itkANTsGroupwiseRegistration
106 changes: 106 additions & 0 deletions include/itkANTsGroupwiseRegistration.hxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*=========================================================================
*
* Copyright NumFOCUS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*=========================================================================*/
#ifndef itkANTsGroupwiseRegistration_hxx
#define itkANTsGroupwiseRegistration_hxx

#include <sstream>

#include "itkPrintHelper.h"

namespace itk
{

template <typename TImage, typename TTemplateImage, typename TParametersValueType>
ANTsGroupwiseRegistration<TImage, TTemplateImage, TParametersValueType>::ANTsGroupwiseRegistration()
{
SetPrimaryInputName("InitialTemplate");
this->SetPrimaryOutputName("OptimizedImage");

this->SetInput(0, TemplateImageType::New()); // empty initial template

this->GetMultiThreader()->SetMaximumNumberOfThreads(1); // registrations are already multi-threaded
}


template <typename TImage, typename TTemplateImage, typename TParametersValueType>
void
ANTsGroupwiseRegistration<TImage, TTemplateImage, TParametersValueType>::PrintSelf(std::ostream & os, Indent indent) const
{
using namespace print_helper;
Superclass::PrintSelf(os, indent);

os << indent << "GradientStep: " << this->m_GradientStep << std::endl;
os << indent << "BlendingWeight: " << this->m_BlendingWeight << std::endl;
os << indent << "UseNoRigid: " << (this->m_UseNoRigid ? "On" : "Off") << std::endl;
os << indent << "Iterations: " << this->m_Iterations << std::endl;
os << indent << "Weights: " << this->m_Weights << std::endl;

os << indent << "ImageList: " << std::endl;
unsigned i = 0;
for (const auto & image : this->m_ImageList)
{
os << indent.GetNextIndent() << "Image" << i++ << ": ";
image->Print(os, indent.GetNextIndent());
}

this->m_PairwiseRegistration->Print(os, indent);
}


template <typename TImage, typename TTemplateImage, typename TParametersValueType>
auto
ANTsGroupwiseRegistration<TImage, TTemplateImage, TParametersValueType>::MakeOutput(DataObjectPointerArraySizeType) -> DataObjectPointer
{
typename OutputTransformType::Pointer ptr;
Self::MakeOutputTransform(ptr);
typename DecoratedOutputTransformType::Pointer decoratedOutputTransform = DecoratedOutputTransformType::New();
decoratedOutputTransform->Set(ptr);
return decoratedOutputTransform;
}


template <typename TImage, typename TTemplateImage, typename TParametersValueType>
void
ANTsGroupwiseRegistration<TImage, TTemplateImage, TParametersValueType>::GenerateData()
{
this->AllocateOutputs();

this->UpdateProgress(0.01);

// TODO: reimplement stuff from:
// https://github.com/ANTsX/ANTsPy/blob/master/ants/registration/build_template.py

this->UpdateProgress(0.95);


// typename OutputTransformType::Pointer inverseTransform = OutputTransformType::New();
// if (forwardTransform->GetInverse(inverseTransform))
// {
// this->SetInverseTransform(inverseTransform);
// }
// else
// {
// this->SetInverseTransform(nullptr);
// }

this->UpdateProgress(1.0);
}

} // end namespace itk

#endif // itkANTsGroupwiseRegistration_hxx
7 changes: 7 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ itk_module_test()
set(ANTsWasmTests
itkANTSRegistrationTest.cxx
itkANTSRegistrationBasicTests.cxx
itkANTsGroupwiseRegistrationTest.cxx
)

CreateTestDriver(ANTsWasm "${ANTsWasm-Test_LIBRARIES}" "${ANTsWasmTests}")
Expand All @@ -13,6 +14,12 @@ itk_add_test(NAME itkANTSRegistrationBasicTests
itkANTSRegistrationBasicTests ${ITK_TEST_OUTPUT_DIR}
)

itk_add_test(NAME itkANTsGroupwiseRegistrationTest
COMMAND ANTsWasmTestDriver
itkANTsGroupwiseRegistrationTest ${CMAKE_CURRENT_LIST_DIR}/Input ${ITK_TEST_OUTPUT_DIR}
)


itk_add_test(NAME antsRegistrationTest_AffineScaleMasks
COMMAND ANTsWasmTestDriver
--compare
Expand Down
Loading

0 comments on commit 9fbfe00

Please sign in to comment.