Commit d69f08a8 authored by Floris Berendsen's avatar Floris Berendsen
Browse files

ENH: WIP started itkSyN Component

parent 17510368
......@@ -388,14 +388,14 @@ TEST_F(WBIRDemoTest, elastix_BS_NCC)
blueprint->AddConnection("ResultImageSink", "Controller", { {} }); // { { "NameOfInterface", { "AfterRegistrationInterface" } } } ;
blueprint->WriteBlueprint("elastix_BS_NCC.dot");
// Data manager provides the paths to the input and output data for unit tests
DataManagerType::Pointer dataManager = DataManagerType::New();
blueprint->WriteBlueprint(dataManager->GetOutputFile("elastix_BS_NCC.dot"));
// Instantiate SuperElastix
EXPECT_NO_THROW(superElastixFilter = SuperElastixFilterType::New());
// Data manager provides the paths to the input and output data for unit tests
DataManagerType::Pointer dataManager = DataManagerType::New();
// Set up the readers and writers
ImageReader2DType::Pointer fixedImageReader = ImageReader2DType::New();
fixedImageReader->SetFileName(dataManager->GetInputFile("coneA2d64.mhd"));
......@@ -484,14 +484,14 @@ TEST_F(WBIRDemoTest, elastix_BS_MSD)
blueprint->AddConnection("ResultImageSink", "Controller", { {} }); // { { "NameOfInterface", { "AfterRegistrationInterface" } } } ;
blueprint->WriteBlueprint("elastix_BS_MSD.dot");
// Data manager provides the paths to the input and output data for unit tests
DataManagerType::Pointer dataManager = DataManagerType::New();
blueprint->WriteBlueprint(dataManager->GetOutputFile("elastix_BS_MSD.dot"));
// Instantiate SuperElastix
EXPECT_NO_THROW(superElastixFilter = SuperElastixFilterType::New());
// Data manager provides the paths to the input and output data for unit tests
DataManagerType::Pointer dataManager = DataManagerType::New();
// Set up the readers and writers
ImageReader2DType::Pointer fixedImageReader = ImageReader2DType::New();
fixedImageReader->SetFileName(dataManager->GetInputFile("coneA2d64.mhd"));
......
#=========================================================================
#
# Copyright Leiden University Medical Center, Erasmus University Medical
# Center and contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#=========================================================================
set( MODULE ModuleItkSyNImageRegistrationMethod )
# Export include files
set( ${MODULE}_INCLUDE_DIRS
${${MODULE}_SOURCE_DIR}/include
)
# Collect header files for Visual Studio Project
file(GLOB ${MODULE}_HEADER_FILES "${${MODULE}_SOURCE_DIR}/include/*.*")
# Export libraries
set( ${MODULE}_LIBRARIES
${MODULE}
)
# Export tests
set( ${MODULE}_TESTS
${${MODULE}_SOURCE_DIR}/test/selxSyNRegistrationItkv4Test.cxx
)
# Module source files
set( ${MODULE}_SOURCE_FILES
)
# Compile library
add_library( ${MODULE} STATIC ${${MODULE}_SOURCE_FILES} ${${MODULE}_HEADER_FILES})
target_link_libraries( ${MODULE} ${SUPERELASTIX_LIBRARIES} )
/*=========================================================================
*
* Copyright Leiden University Medical Center, Erasmus University Medical
* Center and contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*=========================================================================*/
#ifndef selxItkSyNImageRegistrationMethodComponent_h
#define selxItkSyNImageRegistrationMethodComponent_h
#include "selxSuperElastixComponent.h"
#include "selxInterfaces.h"
#include "itkSyNImageRegistrationMethod.h"
#include "itkGradientDescentOptimizerv4.h"
#include "itkImageSource.h"
#include <itkTransformToDisplacementFieldFilter.h>
#include <string.h>
#include "selxMacro.h"
#include "itkComposeDisplacementFieldsImageFilter.h"
#include "itkGaussianExponentialDiffeomorphicTransform.h"
#include "itkGaussianExponentialDiffeomorphicTransformParametersAdaptor.h"
namespace selx
{
template <int Dimensionality, class TPixel>
class ItkSyNImageRegistrationMethodComponent :
public SuperElastixComponent<
Accepting< itkImageFixedInterface<Dimensionality, TPixel>,
itkImageMovingInterface<Dimensionality, TPixel>,
itkMetricv4Interface<Dimensionality, TPixel>,
>,
Providing< itkTransformInterface<double, Dimensionality>,
RunRegistrationInterface
>
>
{
public:
selxNewMacro(ItkSyNImageRegistrationMethodComponent, ComponentBase);
//itkStaticConstMacro(Dimensionality, unsigned int, Dimensionality);
ItkSyNImageRegistrationMethodComponent();
virtual ~ItkSyNImageRegistrationMethodComponent();
typedef TPixel PixelType;
// Get the type definitions from the interfaces
typedef typename itkImageFixedInterface<Dimensionality, TPixel>::ItkImageType FixedImageType;
typedef typename itkImageMovingInterface<Dimensionality, TPixel>::ItkImageType MovingImageType;
typedef typename itkTransformInterface<double, Dimensionality>::TransformType TransformType;
typedef typename itkTransformInterface<double, Dimensionality>::TransformPointer TransformPointer;
typedef typename itkTransformInterface<double, Dimensionality>::InternalComputationValueType TransformInternalComputationValueType; //should be from class template
typedef itk::SyNImageRegistrationMethod<FixedImageType, MovingImageType> TheItkFilterType;
typedef typename TheItkFilterType::ImageMetricType ImageMetricType;
typedef itk::RegistrationParameterScalesFromPhysicalShift<ImageMetricType> ScalesEstimatorType;
//Accepting Interfaces:
virtual int Set(itkImageFixedInterface<Dimensionality, TPixel>*) override;
virtual int Set(itkImageMovingInterface<Dimensionality, TPixel>*) override;
virtual int Set(itkMetricv4Interface<Dimensionality, TPixel>*) override;
//Providing Interfaces:
virtual TransformPointer GetItkTransform() override;
virtual void RunRegistration() override;
//BaseClass methods
virtual bool MeetsCriterion(const ComponentBase::CriterionType &criterion) override;
//static const char * GetName() { return "ItkSyNImageRegistrationMethod"; } ;
static const char * GetDescription() { return "ItkSyNImageRegistrationMethod Component"; };
private:
typename TheItkFilterType::Pointer m_theItkFilter;
protected:
/* The following struct returns the string name of computation type */
/* default implementation */
static inline const std::string GetTypeNameString()
{
itkGenericExceptionMacro(<< "Unknown ScalarType" << typeid(TPixel).name());
// TODO: provide the user instructions how to enable the compilation of the component with the required template types (if desired)
// We might define an exception object that can communicate various error messages: for simple user, for developer user, etc
}
static inline const std::string GetPixelTypeNameString()
{
itkGenericExceptionMacro(<< "Unknown PixelType" << typeid(TPixel).name());
// TODO: provide the user instructions how to enable the compilation of the component with the required template types (if desired)
// We might define an exception object that can communicate various error messages: for simple user, for developer user, etc
}
};
template <>
inline const std::string
ItkSyNImageRegistrationMethodComponent<2, float>
::GetPixelTypeNameString()
{
return std::string("float");
}
template <>
inline const std::string
ItkSyNImageRegistrationMethodComponent<2, double>
::GetPixelTypeNameString()
{
return std::string("double");
}
template <>
inline const std::string
ItkSyNImageRegistrationMethodComponent<3, float>
::GetPixelTypeNameString()
{
return std::string("float");
}
template <>
inline const std::string
ItkSyNImageRegistrationMethodComponent<3, double>
::GetPixelTypeNameString()
{
return std::string("double");
}
template <>
inline const std::string
ItkSyNImageRegistrationMethodComponent<2, float>
::GetTypeNameString()
{
return std::string("2_float");
}
template <>
inline const std::string
ItkSyNImageRegistrationMethodComponent<2, double>
::GetTypeNameString()
{
return std::string("2_double");
}
template <>
inline const std::string
ItkSyNImageRegistrationMethodComponent<3,float>
::GetTypeNameString()
{
return std::string("3_float");
}
template <>
inline const std::string
ItkSyNImageRegistrationMethodComponent<3,double>
::GetTypeNameString()
{
return std::string("3_double");
}
} //end namespace selx
#ifndef ITK_MANUAL_INSTANTIATION
#include "selxItkSyNImageRegistrationMethodComponent.hxx"
#endif
#endif // #define selxItkSyNImageRegistrationMethodComponent_h
\ No newline at end of file
/*=========================================================================
*
* Copyright Leiden University Medical Center, Erasmus University Medical
* Center and contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*=========================================================================*/
#include "selxItkSyNImageRegistrationMethodComponent.h"
//TODO: get rid of these
#include "itkGradientDescentOptimizerv4.h"
namespace selx
{
template<typename TFilter>
class CommandIterationUpdate : public itk::Command
{
public:
typedef CommandIterationUpdate Self;
typedef itk::Command Superclass;
typedef itk::SmartPointer<Self> Pointer;
itkNewMacro(Self);
typedef itk::GradientDescentOptimizerv4 OptimizerType;
typedef const OptimizerType * OptimizerPointer;
protected:
CommandIterationUpdate() {};
public:
virtual void Execute(itk::Object *caller, const itk::EventObject & event) ITK_OVERRIDE
{
Execute((const itk::Object *) caller, event);
}
virtual void Execute(const itk::Object * object, const itk::EventObject & event) ITK_OVERRIDE
{
const TFilter * filter = static_cast< const TFilter * >(object);
if (typeid(event) == typeid(itk::MultiResolutionIterationEvent))
{
unsigned int currentLevel = filter->GetCurrentLevel();
typename TFilter::ShrinkFactorsPerDimensionContainerType shrinkFactors = filter->GetShrinkFactorsPerDimension(currentLevel);
typename TFilter::SmoothingSigmasArrayType smoothingSigmas = filter->GetSmoothingSigmasPerLevel();
typename TFilter::TransformParametersAdaptorsContainerType adaptors = filter->GetTransformParametersAdaptorsPerLevel();
const itk::ObjectToObjectOptimizerBase * optimizerBase = filter->GetOptimizer();
typedef itk::GradientDescentOptimizerv4 GradientDescentOptimizerv4Type;
typename GradientDescentOptimizerv4Type::ConstPointer optimizer = dynamic_cast<const GradientDescentOptimizerv4Type *>(optimizerBase);
if (!optimizer)
{
itkGenericExceptionMacro("Error dynamic_cast failed");
}
typename GradientDescentOptimizerv4Type::DerivativeType gradient = optimizer->GetGradient();
/* orig
std::cout << " Current level = " << currentLevel << std::endl;
std::cout << " shrink factor = " << shrinkFactors[currentLevel] << std::endl;
std::cout << " smoothing sigma = " << smoothingSigmas[currentLevel] << std::endl;
std::cout << " required fixed parameters = " << adaptors[currentLevel]->GetRequiredFixedParameters() << std::endl;
*/
//debug:
std::cout << " CL Current level: " << currentLevel << std::endl;
std::cout << " SF Shrink factor: " << shrinkFactors << std::endl;
std::cout << " SS Smoothing sigma: " << smoothingSigmas[currentLevel] << std::endl;
std::cout << " RFP Required fixed params: " << adaptors[currentLevel]->GetRequiredFixedParameters() << std::endl;
std::cout << " LR Final learning rate: " << optimizer->GetLearningRate() << std::endl;
std::cout << " FM Final metric value: " << optimizer->GetCurrentMetricValue() << std::endl;
std::cout << " SC Optimizer scales: " << optimizer->GetScales() << std::endl;
std::cout << " FG Final metric gradient (sample of values): ";
if (gradient.GetSize() < 10)
{
std::cout << gradient;
}
else
{
for (itk::SizeValueType i = 0; i < gradient.GetSize(); i += (gradient.GetSize() / 16))
{
std::cout << gradient[i] << " ";
}
}
std::cout << std::endl;
}
else if (!(itk::IterationEvent().CheckEvent(&event)))
{
return;
}
else
{
OptimizerPointer optimizer = static_cast<OptimizerPointer>(object);
std::cout << optimizer->GetCurrentIteration() << ": " ;
std::cout << optimizer->GetCurrentMetricValue() << std::endl;
//std::cout << optimizer->GetInfinityNormOfProjectedGradient() << std::endl;
}
}
};
template<int Dimensionality, class TPixel>
ItkSyNImageRegistrationMethodComponent< Dimensionality, TPixel>::ItkSyNImageRegistrationMethodComponent()
{
m_theItkFilter = TheItkFilterType::New();
m_theItkFilter->InPlaceOn();
//TODO: instantiating the filter in the constructor might be heavy for the use in component selector factory, since all components of the database are created during the selection process.
// we could choose to keep the component light weighted (for checking criteria such as names and connections) until the settings are passed to the filter, but this requires an additional initialization step.
}
template<int Dimensionality, class TPixel>
ItkSyNImageRegistrationMethodComponent< Dimensionality, TPixel>::~ItkSyNImageRegistrationMethodComponent()
{
}
template<int Dimensionality, class TPixel>
int ItkSyNImageRegistrationMethodComponent< Dimensionality, TPixel>
::Set(itkImageFixedInterface<Dimensionality, TPixel>* component)
{
auto fixedImage = component->GetItkImageFixed();
// connect the itk pipeline
this->m_theItkFilter->SetFixedImage(fixedImage);
return 0;
}
template<int Dimensionality, class TPixel>
int ItkSyNImageRegistrationMethodComponent< Dimensionality, TPixel>
::Set(itkImageMovingInterface<Dimensionality, TPixel>* component)
{
auto movingImage = component->GetItkImageMoving();
// connect the itk pipeline
this->m_theItkFilter->SetMovingImage(movingImage);
return 0;
}
template<int Dimensionality, class TPixel>
int ItkSyNImageRegistrationMethodComponent< Dimensionality, TPixel>::Set(itkMetricv4Interface<Dimensionality, TPixel>* component)
{
this->m_theItkFilter->SetMetric(component->GetItkMetricv4());
return 0;
}
template<int Dimensionality, class TPixel>
void ItkSyNImageRegistrationMethodComponent< Dimensionality, TPixel>::RunRegistration(void)
{
typename FixedImageType::ConstPointer fixedImage = this->m_theItkFilter->GetFixedImage();
typename MovingImageType::ConstPointer movingImage = this->m_theItkFilter->GetMovingImage();
// Scale estimator is not used in current implementation yet
typename ScalesEstimatorType::Pointer scalesEstimator = ScalesEstimatorType::New();
ImageMetricType* theMetric = dynamic_cast<ImageMetricType*>(this->m_theItkFilter->GetModifiableMetric());;
auto optimizer = dynamic_cast<itk::GradientDescentOptimizerv4 *>(this->m_theItkFilter->GetModifiableOptimizer());
//auto optimizer = dynamic_cast<itk::ObjectToObjectOptimizerBaseTemplate< InternalComputationValueType > *>(this->m_theItkFilter->GetModifiableOptimizer());
auto transform = this->m_theItkFilter->GetModifiableTransform();
if (theMetric)
{
scalesEstimator->SetMetric(theMetric);
}
else
{
itkExceptionMacro("Error casting to ImageMetricv4Type failed");
}
//std::cout << "estimated step scale: " << scalesEstimator->EstimateStepScale(1.0);
scalesEstimator->SetTransformForward(true);
scalesEstimator->SetSmallParameterVariation(1.0);
optimizer->SetScalesEstimator(ITK_NULLPTR);
//optimizer->SetScalesEstimator(scalesEstimator);
optimizer->SetDoEstimateLearningRateOnce(false); //true by default
optimizer->SetDoEstimateLearningRateAtEachIteration(false);
this->m_theItkFilter->SetOptimizer(optimizer);
// Below some hard coded options. Eventually, these should be part of new components.
this->m_theItkFilter->SetNumberOfLevels(3);
// Shrink the virtual domain by specified factors for each level. See documentation
// for the itkShrinkImageFilter for more detailed behavior.
typename TheItkFilterType::ShrinkFactorsArrayType shrinkFactorsPerLevel;
shrinkFactorsPerLevel.SetSize(3);
shrinkFactorsPerLevel[0] = 4;
shrinkFactorsPerLevel[1] = 2;
shrinkFactorsPerLevel[2] = 1;
this->m_theItkFilter->SetShrinkFactorsPerLevel(shrinkFactorsPerLevel);
// Smooth by specified gaussian sigmas for each level. These values are specified in
// physical units.
typename TheItkFilterType::SmoothingSigmasArrayType smoothingSigmasPerLevel;
smoothingSigmasPerLevel.SetSize(3);
smoothingSigmasPerLevel[0] = 4;
smoothingSigmasPerLevel[1] = 2;
smoothingSigmasPerLevel[2] = 1;
this->m_theItkFilter->SetSmoothingSigmasPerLevel(smoothingSigmasPerLevel);
// TODO for now we hard code the TransformAdaptors for stationary velocity fields.
typedef double RealType;
typedef itk::GaussianExponentialDiffeomorphicTransform<RealType, Dimensionality> ConstantVelocityFieldTransformType;
typedef typename ConstantVelocityFieldTransformType::ConstantVelocityFieldType ConstantVelocityFieldType;
typedef itk::GaussianExponentialDiffeomorphicTransformParametersAdaptor<ConstantVelocityFieldTransformType> VelocityFieldTransformAdaptorType;
typename TheItkFilterType::TransformParametersAdaptorsContainerType adaptors;
for (unsigned int level = 0; level < shrinkFactorsPerLevel.Size(); level++)
{
// We use the shrink image filter to calculate the fixed parameters of the virtual
// domain at each level. To speed up calculation and avoid unnecessary memory
// usage, we could calculate these fixed parameters directly.
typedef itk::ShrinkImageFilter<FixedImageType, FixedImageType> ShrinkFilterType;
typename ShrinkFilterType::Pointer shrinkFilter = ShrinkFilterType::New();
shrinkFilter->SetShrinkFactors(shrinkFactorsPerLevel[level]);
shrinkFilter->SetInput(fixedImage);
shrinkFilter->Update();
typename VelocityFieldTransformAdaptorType::Pointer fieldTransformAdaptor = VelocityFieldTransformAdaptorType::New();
fieldTransformAdaptor->SetRequiredSpacing(shrinkFilter->GetOutput()->GetSpacing());
fieldTransformAdaptor->SetRequiredSize(shrinkFilter->GetOutput()->GetBufferedRegion().GetSize());
fieldTransformAdaptor->SetRequiredDirection(shrinkFilter->GetOutput()->GetDirection());
fieldTransformAdaptor->SetRequiredOrigin(shrinkFilter->GetOutput()->GetOrigin());
adaptors.push_back(fieldTransformAdaptor.GetPointer());
}
/*
typename VelocityFieldTransformAdaptorType::Pointer fieldTransformAdaptor = VelocityFieldTransformAdaptorType::New();
fieldTransformAdaptor->SetRequiredSpacing(fixedImage->GetSpacing());
fieldTransformAdaptor->SetRequiredSize(fixedImage->GetBufferedRegion().GetSize());
fieldTransformAdaptor->SetRequiredDirection(fixedImage->GetDirection());
fieldTransformAdaptor->SetRequiredOrigin(fixedImage->GetOrigin());
adaptors.push_back(fieldTransformAdaptor.GetPointer());
*/
this->m_theItkFilter->SetTransformParametersAdaptorsPerLevel(adaptors);
typedef CommandIterationUpdate<TheItkFilterType> RegistrationCommandType;
typename RegistrationCommandType::Pointer registrationObserver = RegistrationCommandType::New();
this->m_theItkFilter->AddObserver(itk::IterationEvent(), registrationObserver);
// perform the actual registration
this->m_theItkFilter->Update();
}
template<int Dimensionality, class TPixel>
typename ItkSyNImageRegistrationMethodComponent< Dimensionality, TPixel>::TransformPointer
ItkSyNImageRegistrationMethodComponent< Dimensionality, TPixel>
::GetItkTransform()
{
return this->m_theItkFilter->GetModifiableTransform();
}
template<int Dimensionality, class TPixel>
bool
ItkSyNImageRegistrationMethodComponent< Dimensionality, TPixel>
::MeetsCriterion(const ComponentBase::CriterionType &criterion)
{
bool hasUndefinedCriteria(false);
bool meetsCriteria(false);
if (criterion.first == "ComponentProperty")
{
meetsCriteria = true;
for (auto const & criterionValue : criterion.second) // auto&& preferred?
{
if (criterionValue != "SomeProperty") // e.g. "GradientDescent", "SupportsSparseSamples
{
meetsCriteria = false;
}
}
}