Commit 8a94bd42 authored by Floris Berendsen's avatar Floris Berendsen
Browse files

ENH: added itkSyNImageRegistrationMethodComponent with tests

parent d093ce02
......@@ -92,6 +92,7 @@ selxmodule_enable( ModuleExamples )
selxmodule_enable( ModuleSinksAndSources )
selxmodule_enable( ModuleItkSmoothingRecursiveGaussianImageFilter )
selxmodule_enable( ModuleItkImageRegistrationMethodv4 )
selxmodule_enable( ModuleItkSyNImageRegistrationMethod )
selxmodule_enable( ModuleElastix )
selxmodule_enable( ModuleController )
#selxmodule_enable( ModuleCommandLine )
......
......@@ -40,7 +40,7 @@ set( ${MODULE}_TESTS
# Module source files
set( ${MODULE}_SOURCE_FILES
${${MODULE}_SOURCE_DIR}/src/selxItkImageRegistrationMethodv4.cxx
${${MODULE}_SOURCE_DIR}/src/selxItkImageRegistrationMethodv4Component.cxx
)
# Compile library
......
......@@ -39,6 +39,7 @@ set( ${MODULE}_TESTS
# Module source files
set( ${MODULE}_SOURCE_FILES
${${MODULE}_SOURCE_DIR}/src/selxItkSyNImageRegistrationMethodComponent.cxx
)
# Compile library
......
......@@ -42,7 +42,7 @@ namespace selx
public SuperElastixComponent<
Accepting< itkImageFixedInterface<Dimensionality, TPixel>,
itkImageMovingInterface<Dimensionality, TPixel>,
itkMetricv4Interface<Dimensionality, TPixel>,
itkMetricv4Interface<Dimensionality, TPixel>
>,
Providing< itkTransformInterface<double, Dimensionality>,
RunRegistrationInterface
......
......@@ -18,102 +18,17 @@
*=========================================================================*/
#include "selxItkSyNImageRegistrationMethodComponent.h"
#include "selxItkImageRegistrationMethodv4Component.h"
#include "itkDisplacementFieldTransformParametersAdaptor.h"
//TODO: get rid of these
#include "itkGradientDescentOptimizerv4.h"
namespace selx
{
template<typename TFilter>
class CommandIterationUpdate : public itk::Command
{
public:
typedef CommandIterationUpdate Self;
typedef itk::Command Superclass;
typedef itk::SmartPointer<Self> Pointer;
itkNewMacro(Self);
typedef itk::GradientDescentOptimizerv4 OptimizerType;
typedef const OptimizerType * OptimizerPointer;
protected:
CommandIterationUpdate() {};
public:
virtual void Execute(itk::Object *caller, const itk::EventObject & event) ITK_OVERRIDE
{
Execute((const itk::Object *) caller, event);
}
virtual void Execute(const itk::Object * object, const itk::EventObject & event) ITK_OVERRIDE
{
const TFilter * filter = static_cast< const TFilter * >(object);
if (typeid(event) == typeid(itk::MultiResolutionIterationEvent))
{
unsigned int currentLevel = filter->GetCurrentLevel();
typename TFilter::ShrinkFactorsPerDimensionContainerType shrinkFactors = filter->GetShrinkFactorsPerDimension(currentLevel);
typename TFilter::SmoothingSigmasArrayType smoothingSigmas = filter->GetSmoothingSigmasPerLevel();
typename TFilter::TransformParametersAdaptorsContainerType adaptors = filter->GetTransformParametersAdaptorsPerLevel();
const itk::ObjectToObjectOptimizerBase * optimizerBase = filter->GetOptimizer();
typedef itk::GradientDescentOptimizerv4 GradientDescentOptimizerv4Type;
typename GradientDescentOptimizerv4Type::ConstPointer optimizer = dynamic_cast<const GradientDescentOptimizerv4Type *>(optimizerBase);
if (!optimizer)
{
itkGenericExceptionMacro("Error dynamic_cast failed");
}
typename GradientDescentOptimizerv4Type::DerivativeType gradient = optimizer->GetGradient();
/* orig
std::cout << " Current level = " << currentLevel << std::endl;
std::cout << " shrink factor = " << shrinkFactors[currentLevel] << std::endl;
std::cout << " smoothing sigma = " << smoothingSigmas[currentLevel] << std::endl;
std::cout << " required fixed parameters = " << adaptors[currentLevel]->GetRequiredFixedParameters() << std::endl;
*/
//debug:
std::cout << " CL Current level: " << currentLevel << std::endl;
std::cout << " SF Shrink factor: " << shrinkFactors << std::endl;
std::cout << " SS Smoothing sigma: " << smoothingSigmas[currentLevel] << std::endl;
std::cout << " RFP Required fixed params: " << adaptors[currentLevel]->GetRequiredFixedParameters() << std::endl;
std::cout << " LR Final learning rate: " << optimizer->GetLearningRate() << std::endl;
std::cout << " FM Final metric value: " << optimizer->GetCurrentMetricValue() << std::endl;
std::cout << " SC Optimizer scales: " << optimizer->GetScales() << std::endl;
std::cout << " FG Final metric gradient (sample of values): ";
if (gradient.GetSize() < 10)
{
std::cout << gradient;
}
else
{
for (itk::SizeValueType i = 0; i < gradient.GetSize(); i += (gradient.GetSize() / 16))
{
std::cout << gradient[i] << " ";
}
}
std::cout << std::endl;
}
else if (!(itk::IterationEvent().CheckEvent(&event)))
{
return;
}
else
{
OptimizerPointer optimizer = static_cast<OptimizerPointer>(object);
std::cout << optimizer->GetCurrentIteration() << ": " ;
std::cout << optimizer->GetCurrentMetricValue() << std::endl;
//std::cout << optimizer->GetInfinityNormOfProjectedGradient() << std::endl;
}
}
};
template<int Dimensionality, class TPixel>
ItkSyNImageRegistrationMethodComponent< Dimensionality, TPixel>::ItkSyNImageRegistrationMethodComponent()
{
......@@ -175,6 +90,30 @@ void ItkSyNImageRegistrationMethodComponent< Dimensionality, TPixel>::RunRegistr
auto optimizer = dynamic_cast<itk::GradientDescentOptimizerv4 *>(this->m_theItkFilter->GetModifiableOptimizer());
//auto optimizer = dynamic_cast<itk::ObjectToObjectOptimizerBaseTemplate< InternalComputationValueType > *>(this->m_theItkFilter->GetModifiableOptimizer());
typedef itk::Vector<TransformInternalComputationValueType, Dimensionality> VectorType;
VectorType zeroVector(0.0);
typedef itk::Image<VectorType, Dimensionality> DisplacementFieldType;
typename DisplacementFieldType::Pointer displacementField = DisplacementFieldType::New();
displacementField->CopyInformation(fixedImage);
displacementField->SetRegions(fixedImage->GetBufferedRegion());
displacementField->Allocate();
displacementField->FillBuffer(zeroVector);
typename DisplacementFieldType::Pointer inverseDisplacementField = DisplacementFieldType::New();
inverseDisplacementField->CopyInformation(fixedImage);
inverseDisplacementField->SetRegions(fixedImage->GetBufferedRegion());
inverseDisplacementField->Allocate();
inverseDisplacementField->FillBuffer(zeroVector);
typedef typename TheItkFilterType::OutputTransformType OutputTransformType;
typename OutputTransformType::Pointer outputTransform = OutputTransformType::New();
outputTransform->SetDisplacementField(displacementField);
outputTransform->SetInverseDisplacementField(inverseDisplacementField);
this->m_theItkFilter->SetInitialTransform(outputTransform);
this->m_theItkFilter->InPlaceOn();
auto transform = this->m_theItkFilter->GetModifiableTransform();
if (theMetric)
......@@ -197,8 +136,10 @@ void ItkSyNImageRegistrationMethodComponent< Dimensionality, TPixel>::RunRegistr
optimizer->SetDoEstimateLearningRateAtEachIteration(false);
this->m_theItkFilter->SetOptimizer(optimizer);
//this->m_theItkFilter->SetOptimizer(optimizer);
// Below some hard coded options. Eventually, these should be part of new components.
this->m_theItkFilter->SetNumberOfLevels(3);
......@@ -221,11 +162,10 @@ void ItkSyNImageRegistrationMethodComponent< Dimensionality, TPixel>::RunRegistr
this->m_theItkFilter->SetSmoothingSigmasPerLevel(smoothingSigmasPerLevel);
// TODO for now we hard code the TransformAdaptors for stationary velocity fields.
typedef double RealType;
typedef itk::GaussianExponentialDiffeomorphicTransform<RealType, Dimensionality> ConstantVelocityFieldTransformType;
typedef typename ConstantVelocityFieldTransformType::ConstantVelocityFieldType ConstantVelocityFieldType;
typedef itk::GaussianExponentialDiffeomorphicTransformParametersAdaptor<ConstantVelocityFieldTransformType> VelocityFieldTransformAdaptorType;
// TODO for now we hard code the TransformAdaptors for DisplacementFieldTransform.
typedef itk::DisplacementFieldTransformParametersAdaptor<OutputTransformType> DisplacementFieldTransformAdaptorType;
typename TheItkFilterType::TransformParametersAdaptorsContainerType adaptors;
......@@ -241,7 +181,7 @@ void ItkSyNImageRegistrationMethodComponent< Dimensionality, TPixel>::RunRegistr
shrinkFilter->SetInput(fixedImage);
shrinkFilter->Update();
typename VelocityFieldTransformAdaptorType::Pointer fieldTransformAdaptor = VelocityFieldTransformAdaptorType::New();
typename DisplacementFieldTransformAdaptorType::Pointer fieldTransformAdaptor = DisplacementFieldTransformAdaptorType::New();
fieldTransformAdaptor->SetRequiredSpacing(shrinkFilter->GetOutput()->GetSpacing());
fieldTransformAdaptor->SetRequiredSize(shrinkFilter->GetOutput()->GetBufferedRegion().GetSize());
fieldTransformAdaptor->SetRequiredDirection(shrinkFilter->GetOutput()->GetDirection());
......
/*=========================================================================
*
* Copyright Leiden University Medical Center, Erasmus University Medical
* Center and contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*=========================================================================*/
#include "selxItkSyNImageRegistrationMethodComponent.h"
\ No newline at end of file
......@@ -108,6 +108,9 @@ public:
typedef itk::ImageFileReader<Image3DType> ImageReader3DType;
typedef itk::ImageFileWriter<Image3DType> ImageWriter3DType;
typedef itk::Image<itk::Vector<float, 2>, 2> DisplacementImage2DType;
typedef itk::ImageFileWriter<DisplacementImage2DType> DisplacementImageWriter2DType;
typedef itk::Image<itk::Vector<double,3>, 3> DisplacementImage3DType;
typedef itk::ImageFileWriter<DisplacementImage3DType> DisplacementImageWriter3DType;
......@@ -131,7 +134,7 @@ TEST_F(SyNRegistrationItkv4Test, FullyConfigured3d)
blueprint = Blueprint::New();
ParameterMapType component0Parameters;
component0Parameters["NameOfClass"] = { "ItkImageRegistrationMethodv4Component" };
component0Parameters["NameOfClass"] = { "ItkSyNImageRegistrationMethodComponent" };
component0Parameters["Dimensionality"] = { "3" }; // should be derived from the inputs
blueprint->AddComponent("RegistrationMethod", component0Parameters);
......@@ -160,6 +163,10 @@ TEST_F(SyNRegistrationItkv4Test, FullyConfigured3d)
component5Parameters["Dimensionality"] = { "3" }; // should be derived from the inputs
blueprint->AddComponent("Metric", component5Parameters);
ParameterMapType component6Parameters;
component6Parameters["NameOfClass"] = { "ItkTransformDisplacementFilterComponent" };
component6Parameters["Dimensionality"] = { "3" }; // should be derived from the inputs
blueprint->AddComponent("TransformDisplacementFilter", component6Parameters);
blueprint->AddComponent("ResampleFilter", { { "NameOfClass", { "ItkResampleFilterComponent" } },
{ "Dimensionality", { "3" } } });
......@@ -187,7 +194,6 @@ TEST_F(SyNRegistrationItkv4Test, FullyConfigured3d)
connection5Parameters["NameOfInterface"] = { "itkMetricv4Interface" };
blueprint->AddConnection("Metric", "RegistrationMethod", connection5Parameters);
blueprint->AddConnection("FixedImageSource", "Transform", { {} });
blueprint->AddConnection("RegistrationMethod", "TransformDisplacementFilter", { {} });
blueprint->AddConnection("FixedImageSource", "TransformDisplacementFilter", { {} });
blueprint->AddConnection("RegistrationMethod", "ResampleFilter", { {} });
......@@ -235,11 +241,123 @@ TEST_F(SyNRegistrationItkv4Test, FullyConfigured3d)
// Update call on the writers triggers SuperElastix to configure and execute
EXPECT_NO_THROW(resultImageWriter->Update());
EXPECT_NO_THROW(resultDisplacementWriter->Update());
blueprint->WriteBlueprint(dataManager->GetOutputFile("SyNRegistrationItkv4Test_DisplacementField_network.dot"));
}
TEST_F(SyNRegistrationItkv4Test, WBIRDemo)
{
/** make example blueprint configuration */
blueprint = Blueprint::New();
ParameterMapType component0Parameters;
component0Parameters["NameOfClass"] = { "ItkSyNImageRegistrationMethodComponent" };
component0Parameters["Dimensionality"] = { "2" }; // should be derived from the inputs
blueprint->AddComponent("RegistrationMethod", component0Parameters);
ParameterMapType component1Parameters;
component1Parameters["NameOfClass"] = { "ItkImageSourceFixedComponent" };
component1Parameters["Dimensionality"] = { "2" }; // should be derived from the inputs
blueprint->AddComponent("FixedImageSource", component1Parameters);
ParameterMapType component2Parameters;
component2Parameters["NameOfClass"] = { "ItkImageSourceMovingComponent" };
component2Parameters["Dimensionality"] = { "2" }; // should be derived from the inputs
blueprint->AddComponent("MovingImageSource", component2Parameters);
ParameterMapType component3Parameters;
component3Parameters["NameOfClass"] = { "ItkImageSinkComponent" };
component3Parameters["Dimensionality"] = { "2" }; // should be derived from the outputs
blueprint->AddComponent("ResultImageSink", component3Parameters);
ParameterMapType component4Parameters;
component4Parameters["NameOfClass"] = { "DisplacementFieldItkImageFilterSinkComponent" };
component4Parameters["Dimensionality"] = { "2" }; // should be derived from the outputs
blueprint->AddComponent("ResultDisplacementFieldSink", component4Parameters);
ParameterMapType component5Parameters;
component5Parameters["NameOfClass"] = { "ItkANTSNeighborhoodCorrelationImageToImageMetricv4Component" };
component5Parameters["Dimensionality"] = { "2" }; // should be derived from the inputs
blueprint->AddComponent("Metric", component5Parameters);
ParameterMapType component6Parameters;
component6Parameters["NameOfClass"] = { "ItkTransformDisplacementFilterComponent" };
component6Parameters["Dimensionality"] = { "2" }; // should be derived from the inputs
blueprint->AddComponent("TransformDisplacementFilter", component6Parameters);
blueprint->AddComponent("ResampleFilter", { { "NameOfClass", { "ItkResampleFilterComponent" } },
{ "Dimensionality", { "2" } } });
blueprint->AddComponent("Controller", { { "NameOfClass", { "RegistrationControllerComponent" } } });
ParameterMapType connection1Parameters;
connection1Parameters["NameOfInterface"] = { "itkImageFixedInterface" };
blueprint->AddConnection("FixedImageSource", "RegistrationMethod", connection1Parameters);
ParameterMapType connection2Parameters;
connection2Parameters["NameOfInterface"] = { "itkImageMovingInterface" };
blueprint->AddConnection("MovingImageSource", "RegistrationMethod", connection2Parameters);
ParameterMapType connection3Parameters;
connection3Parameters["NameOfInterface"] = { "itkImageInterface" };
blueprint->AddConnection("ResampleFilter", "ResultImageSink", connection3Parameters);
ParameterMapType connection4Parameters;
connection4Parameters["NameOfInterface"] = { "DisplacementFieldItkImageSourceInterface" };
blueprint->AddConnection("TransformDisplacementFilter", "ResultDisplacementFieldSink", connection4Parameters);
ParameterMapType connection5Parameters;
connection5Parameters["NameOfInterface"] = { "itkMetricv4Interface" };
blueprint->AddConnection("Metric", "RegistrationMethod", connection5Parameters);
blueprint->AddConnection("RegistrationMethod", "TransformDisplacementFilter", { {} });
blueprint->AddConnection("FixedImageSource", "TransformDisplacementFilter", { {} });
blueprint->AddConnection("RegistrationMethod", "ResampleFilter", { {} });
blueprint->AddConnection("FixedImageSource", "ResampleFilter", { {} });
blueprint->AddConnection("MovingImageSource", "ResampleFilter", { {} });
blueprint->AddConnection("RegistrationMethod", "Controller", { {} }); //RunRegistrationInterface
blueprint->AddConnection("ResampleFilter", "Controller", { {} }); //ReconnectTransformInterface
blueprint->AddConnection("TransformDisplacementFilter", "Controller", { {} }); //ReconnectTransformInterface
blueprint->AddConnection("ResultImageSink", "Controller", { {} }); //AfterRegistrationInterface
blueprint->AddConnection("ResultDisplacementFieldSink", "Controller", { {} }); //AfterRegistrationInterface
// Data manager provides the paths to the input and output data for unit tests
DataManagerType::Pointer dataManager = DataManagerType::New();
blueprint->WriteBlueprint(dataManager->GetOutputFile("SyN_ANTSCC.dot"));
// Instantiate SuperElastix
SuperElastixFilterType::Pointer superElastixFilter;
EXPECT_NO_THROW(superElastixFilter = SuperElastixFilterType::New());
// Set up the readers and writers
ImageReader2DType::Pointer fixedImageReader = ImageReader2DType::New();
fixedImageReader->SetFileName(dataManager->GetInputFile("coneA2d64.mhd"));
ImageReader2DType::Pointer movingImageReader = ImageReader2DType::New();
movingImageReader->SetFileName(dataManager->GetInputFile("coneB2d64.mhd"));
ImageWriter2DType::Pointer resultImageWriter = ImageWriter2DType::New();
resultImageWriter->SetFileName(dataManager->GetOutputFile("SyN_ANTSCC_Image.mhd"));
DisplacementImageWriter2DType::Pointer resultDisplacementWriter = DisplacementImageWriter2DType::New();
resultDisplacementWriter->SetFileName(dataManager->GetOutputFile("SyN_ANTSCC_Displacement.mhd"));
// Connect SuperElastix in an itk pipeline
superElastixFilter->SetInput("FixedImageSource", fixedImageReader->GetOutput());
superElastixFilter->SetInput("MovingImageSource", movingImageReader->GetOutput());
resultImageWriter->SetInput(superElastixFilter->GetOutput<Image2DType>("ResultImageSink"));
resultDisplacementWriter->SetInput(superElastixFilter->GetOutput<DisplacementImage2DType>("ResultDisplacementFieldSink"));
EXPECT_NO_THROW(superElastixFilter->SetBlueprint(blueprint));
//Optional Update call
//superElastixFilter->Update();
// Update call on the writers triggers SuperElastix to configure and execute
EXPECT_NO_THROW(resultImageWriter->Update());
EXPECT_NO_THROW(resultDisplacementWriter->Update());
}
} // namespace selx
......@@ -143,6 +143,13 @@ namespace selx
std::cout << " Blueprint Node: " << name << std::endl << " HasProvidingInterface " << connectionProperties[keys::NameOfInterface][0] << std::endl;
std::cout << " Blueprint Node: " << outgoingName << std::endl << " HasAcceptingInterface " << connectionProperties[keys::NameOfInterface][0] << std::endl;
}
if ((this->m_ComponentSelectorContainer[outgoingName]->HasMultipleComponents() == false) && (this->m_ComponentSelectorContainer[outgoingName]->GetComponent().IsNull()))
{
std::stringstream msg;
msg << "Too many criteria for Component " << outgoingName << std::endl;
throw std::runtime_error(msg.str());
}
}
if ((this->m_ComponentSelectorContainer[name]->HasMultipleComponents() == false) && (this->m_ComponentSelectorContainer[name]->GetComponent().IsNull()))
{
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment