diff --git a/Modules/Components/itkImageRegistrationMethodv4/include/selxItkImageRegistrationMethodv4Component.h b/Modules/Components/itkImageRegistrationMethodv4/include/selxItkImageRegistrationMethodv4Component.h index 7a284d6bfcee7fb7e22c058a9bab11b0961172fd..a781da1d9912ac56a831ec9f9a4ce661514755e0 100644 --- a/Modules/Components/itkImageRegistrationMethodv4/include/selxItkImageRegistrationMethodv4Component.h +++ b/Modules/Components/itkImageRegistrationMethodv4/include/selxItkImageRegistrationMethodv4Component.h @@ -9,6 +9,13 @@ #include <itkTransformToDisplacementFieldFilter.h> #include <string.h> #include "elxMacro.h" + + +#include "itkComposeDisplacementFieldsImageFilter.h" +#include "itkGaussianExponentialDiffeomorphicTransform.h" +#include "itkGaussianExponentialDiffeomorphicTransformParametersAdaptor.h" + + namespace selx { template <int Dimensionality, class TPixel> @@ -33,7 +40,7 @@ namespace selx virtual ~ItkImageRegistrationMethodv4Component(); typedef TPixel PixelType; - + // the in and output image type of the component are chosen to be the same typedef itk::Image<PixelType, Dimensionality> ConnectionImageType; @@ -48,7 +55,11 @@ namespace selx typedef itk::ImageSource<DisplacementFieldImageType>DisplacementFieldItkImageSourceType; typedef typename DisplacementFieldItkImageSourceType::Pointer DisplacementFieldItkImageSourcePointer; - typedef itk::ImageRegistrationMethodv4<FixedImageType, MovingImageType> TheItkFilterType; + // TODO for now we hard code the transform to be a stationary velocity field. See Set(*MetricInterface) for implementation + typedef double RealType; + typedef itk::GaussianExponentialDiffeomorphicTransform<RealType, Dimensionality> ConstantVelocityFieldTransformType; + + typedef itk::ImageRegistrationMethodv4<FixedImageType, MovingImageType, ConstantVelocityFieldTransformType> TheItkFilterType; typedef itk::ResampleImageFilter<MovingImageType, ConnectionImageType> ResampleFilterType; typedef itk::TransformToDisplacementFieldFilter<DisplacementFieldImageType> DisplacementFieldFilterType; diff --git a/Modules/Components/itkImageRegistrationMethodv4/include/selxItkImageRegistrationMethodv4Component.hxx b/Modules/Components/itkImageRegistrationMethodv4/include/selxItkImageRegistrationMethodv4Component.hxx index 8b0380be4b9f78bf95fe45510a735a2906e12a5e..a9406570006c3def6731b6f8d44c5522b0a8b242 100644 --- a/Modules/Components/itkImageRegistrationMethodv4/include/selxItkImageRegistrationMethodv4Component.hxx +++ b/Modules/Components/itkImageRegistrationMethodv4/include/selxItkImageRegistrationMethodv4Component.hxx @@ -39,22 +39,119 @@ int ItkImageRegistrationMethodv4Component< Dimensionality, TPixel>::Set(itkMetri { this->m_theItkFilter->SetMetric(component->GetItkMetricv4()); - //TODO: this is a bug in itkv4 - typedef itk::GradientDescentOptimizerv4 OptimizerType; - OptimizerType::Pointer optimizer = OptimizerType::New(); - this->m_theItkFilter->SetOptimizer(optimizer); - return 1; } template<int Dimensionality, class TPixel> void ItkImageRegistrationMethodv4Component< Dimensionality, TPixel>::RunRegistration(void) { - this->m_theItkFilter->Update(); + FixedImageType::ConstPointer fixedImage = this->m_theItkFilter->GetFixedImage(); MovingImageType::ConstPointer movingImage = this->m_theItkFilter->GetMovingImage(); - this->m_resampler->SetTransform(this->m_theItkFilter->GetTransform()); + // Below some hard coded options. Eventually, these should be part of new components. + + //TODO: Setting the optimizer explicitly is a work around for a bug in itkv4. + //TODO: report bug to itk. + typedef itk::GradientDescentOptimizerv4 OptimizerType; + OptimizerType::Pointer optimizer = OptimizerType::New(); + this->m_theItkFilter->SetOptimizer(optimizer); + + // TODO: for now we hard code the transform to be a stationary velocity field. See template declaration. + + typedef itk::CompositeTransform<RealType, Dimensionality> CompositeTransformType; + typename CompositeTransformType::Pointer compositeTransform = CompositeTransformType::New(); + + typedef itk::IdentityTransform < RealType, Dimensionality> IdentityTransformType; + typename IdentityTransformType::Pointer idTransform = IdentityTransformType::New(); + compositeTransform->AddTransform(idTransform); + + + typedef itk::Vector<RealType, Dimensionality> VectorType; + VectorType zeroVector(0.0); + typedef itk::Image<VectorType, Dimensionality> DisplacementFieldType; + typedef itk::Image<VectorType, Dimensionality> ConstantVelocityFieldType; + typename ConstantVelocityFieldType::Pointer displacementField = ConstantVelocityFieldType::New(); + displacementField->CopyInformation(fixedImage); + displacementField->SetRegions(fixedImage->GetBufferedRegion()); + displacementField->Allocate(); + displacementField->FillBuffer(zeroVector); + + typename ConstantVelocityFieldTransformType::Pointer fieldTransform = ConstantVelocityFieldTransformType::New(); + fieldTransform->SetGaussianSmoothingVarianceForTheUpdateField(0.75); + fieldTransform->SetGaussianSmoothingVarianceForTheConstantVelocityField(1.5); + fieldTransform->SetConstantVelocityField(displacementField); + fieldTransform->SetCalculateNumberOfIntegrationStepsAutomatically(true); + fieldTransform->IntegrateVelocityField(); + + //this->m_theItkFilter->SetMovingInitialTransform(compositeTransform); + //this->m_theItkFilter->SetMovingInitialTransform(idTransform); + + this->m_theItkFilter->SetNumberOfLevels(1); + + // Shrink the virtual domain by specified factors for each level. See documentation + // for the itkShrinkImageFilter for more detailed behavior. + typename TheItkFilterType::ShrinkFactorsArrayType shrinkFactorsPerLevel; + shrinkFactorsPerLevel.SetSize(1); + //shrinkFactorsPerLevel[0] = 3; + //shrinkFactorsPerLevel[1] = 2; + shrinkFactorsPerLevel[0] = 1; + this->m_theItkFilter->SetShrinkFactorsPerLevel(shrinkFactorsPerLevel); + + // Smooth by specified gaussian sigmas for each level. These values are specified in + // physical units. + typename TheItkFilterType::SmoothingSigmasArrayType smoothingSigmasPerLevel; + smoothingSigmasPerLevel.SetSize(1); + //smoothingSigmasPerLevel[0] = 2; + //smoothingSigmasPerLevel[1] = 1; + smoothingSigmasPerLevel[0] = 1; + this->m_theItkFilter->SetSmoothingSigmasPerLevel(smoothingSigmasPerLevel); + + typedef itk::GaussianExponentialDiffeomorphicTransformParametersAdaptor<ConstantVelocityFieldTransformType> VelocityFieldTransformAdaptorType; + + typename TheItkFilterType::TransformParametersAdaptorsContainerType adaptors; + /* + for (unsigned int level = 0; level < shrinkFactorsPerLevel.Size(); level++) + { + // We use the shrink image filter to calculate the fixed parameters of the virtual + // domain at each level. To speed up calculation and avoid unnecessary memory + // usage, we could calculate these fixed parameters directly. + + typedef itk::ShrinkImageFilter<ConstantVelocityFieldType, ConstantVelocityFieldType> ShrinkFilterType; + typename ShrinkFilterType::Pointer shrinkFilter = ShrinkFilterType::New(); + shrinkFilter->SetShrinkFactors(shrinkFactorsPerLevel[level]); + shrinkFilter->SetInput(fieldTransform->GetConstantVelocityField()); + + typename VelocityFieldTransformAdaptorType::Pointer fieldTransformAdaptor = VelocityFieldTransformAdaptorType::New(); + fieldTransformAdaptor->SetRequiredSpacing(shrinkFilter->GetOutput()->GetSpacing()); + fieldTransformAdaptor->SetRequiredSize(shrinkFilter->GetOutput()->GetBufferedRegion().GetSize()); + fieldTransformAdaptor->SetRequiredDirection(shrinkFilter->GetOutput()->GetDirection()); + fieldTransformAdaptor->SetRequiredOrigin(shrinkFilter->GetOutput()->GetOrigin()); + + adaptors.push_back(fieldTransformAdaptor.GetPointer()); + } + */ + + typename VelocityFieldTransformAdaptorType::Pointer fieldTransformAdaptor = VelocityFieldTransformAdaptorType::New(); + fieldTransformAdaptor->SetRequiredSpacing(fixedImage->GetSpacing()); + fieldTransformAdaptor->SetRequiredSize(fixedImage->GetBufferedRegion().GetSize()); + fieldTransformAdaptor->SetRequiredDirection(fixedImage->GetDirection()); + fieldTransformAdaptor->SetRequiredOrigin(fixedImage->GetOrigin()); + + adaptors.push_back(fieldTransformAdaptor.GetPointer()); + + //this->m_theItkFilter->SetTransformParametersAdaptorsPerLevel(adaptors); + + this->m_theItkFilter->SetInitialTransform(fieldTransform); + this->m_theItkFilter->InPlaceOn(); + + + // + this->m_theItkFilter->Update(); + + //this->m_resampler->SetTransform(this->m_theItkFilter->GetTransform()); + this->m_resampler->SetTransform(this->m_theItkFilter->GetTransformOutput()->Get()); + //this->m_resampler->SetTransform(this->m_theItkFilter->GetOutput()); this->m_resampler->SetInput(movingImage); this->m_resampler->SetSize(fixedImage->GetBufferedRegion().GetSize()); //should be virtual image... this->m_resampler->SetOutputOrigin(fixedImage->GetOrigin()); @@ -62,7 +159,9 @@ void ItkImageRegistrationMethodv4Component< Dimensionality, TPixel>::RunRegistra this->m_resampler->SetOutputDirection(fixedImage->GetDirection()); this->m_resampler->SetDefaultPixelValue(0); - this->m_DisplacementFieldFilter->SetTransformInput(this->m_theItkFilter->GetTransformOutput()); + //this->m_DisplacementFieldFilter->SetTransformInput(this->m_theItkFilter->GetTransformOutput()); + //this->m_DisplacementFieldFilter->SetTransformInput(this->m_theItkFilter->GetTransformOutput()->Get()); + this->m_DisplacementFieldFilter->SetTransform(this->m_theItkFilter->GetTransformOutput()->Get()); this->m_DisplacementFieldFilter->SetSize(fixedImage->GetBufferedRegion().GetSize()); //should be virtual image... this->m_DisplacementFieldFilter->SetOutputOrigin(fixedImage->GetOrigin()); this->m_DisplacementFieldFilter->SetOutputSpacing(fixedImage->GetSpacing()); diff --git a/Testing/Unit/elxRegistrationItkv4Test.cxx b/Testing/Unit/elxRegistrationItkv4Test.cxx index a6ba63988f1485b1f0dfde642d45e16a5ffbbf3c..930c71250749ea9fb45f15dbc873e10f871abb3e 100644 --- a/Testing/Unit/elxRegistrationItkv4Test.cxx +++ b/Testing/Unit/elxRegistrationItkv4Test.cxx @@ -315,7 +315,8 @@ TEST_F(RegistrationItkv4Test, DisplacementField2D) bool allUniqueComponents; EXPECT_NO_THROW(allUniqueComponents = overlord->Configure()); EXPECT_TRUE(allUniqueComponents); - EXPECT_NO_THROW(overlord->Execute()); - + //EXPECT_NO_THROW(overlord->Execute()); + overlord->Execute(); } } // namespace elx +