Commit 57a3f660 authored by Floris Berendsen's avatar Floris Berendsen
Browse files

ENH: split off ResampleFilter from RegistrationMethodv4

parent fc3a4525
......@@ -45,8 +45,7 @@ namespace selx
itkMetricv4Interface<Dimensionality, TPixel>,
itkOptimizerv4Interface<double>
>,
Providing< itkImageInterface<Dimensionality, TPixel>,
itkTransformInterface<double, Dimensionality>,
Providing< itkTransformInterface<double, Dimensionality>,
RunRegistrationInterface
>
>
......@@ -66,7 +65,6 @@ namespace selx
typedef typename itkImageMovingInterface<Dimensionality, TPixel>::ItkImageType MovingImageType;
typedef typename itkTransformInterface<double, Dimensionality>::TransformPointer TransformPointer;
typedef typename itkOptimizerv4Interface<double>::InternalComputationValueType InternalComputationValueType;
typedef typename itkImageInterface<Dimensionality, TPixel>::ItkImageType ResultItkImageType;
// TODO for now we hard code the transform to be a stationary velocity field. See Set(*MetricInterface) for implementation
......@@ -76,8 +74,6 @@ namespace selx
typedef itk::ImageRegistrationMethodv4<FixedImageType, MovingImageType> TheItkFilterType;
typedef typename TheItkFilterType::ImageMetricType ImageMetricType;
typedef itk::RegistrationParameterScalesFromPhysicalShift<ImageMetricType> ScalesEstimatorType;
typedef itk::ResampleImageFilter<MovingImageType, ResultItkImageType> ResampleFilterType;
//typedef itk::TransformToDisplacementFieldFilter<DisplacementFieldImageType> DisplacementFieldFilterType;
//Accepting Interfaces:
virtual int Set(itkImageFixedInterface<Dimensionality, TPixel>*) override;
......@@ -86,10 +82,7 @@ namespace selx
virtual int Set(itkOptimizerv4Interface<InternalComputationValueType>*) override;
//Providing Interfaces:
virtual typename ResultItkImageType::Pointer GetItkImage() override;
virtual TransformPointer GetItkTransform() override;
//virtual typename DisplacementFieldImageType::Pointer GetDisplacementFieldItkImage() override;
virtual void RunRegistration() override;
//BaseClass methods
......@@ -98,7 +91,6 @@ namespace selx
static const char * GetDescription() { return "ItkImageRegistrationMethodv4 Component"; };
private:
typename TheItkFilterType::Pointer m_theItkFilter;
typename ResampleFilterType::Pointer m_resampler;
protected:
/* The following struct returns the string name of computation type */
/* default implementation */
......
......@@ -121,30 +121,12 @@ namespace selx
ItkImageRegistrationMethodv4Component< Dimensionality, TPixel>::ItkImageRegistrationMethodv4Component()
{
m_theItkFilter = TheItkFilterType::New();
m_resampler = ResampleFilterType::New();
//m_DisplacementFieldFilter = DisplacementFieldFilterType::New();
//m_DisplacementFieldFilter->GetTransformInput()->Graft<ConstantVelocityFieldTransformType>(&(const_cast<ConstantVelocityFieldTransformType>( m_theItkFilter->GetOutput())));
//m_DisplacementFieldFilter->GetTransformInput()->Graft(m_theItkFilter->GetOutput());
typename ConstantVelocityFieldTransformType::Pointer transform = ConstantVelocityFieldTransformType::New();
m_theItkFilter->SetInitialTransform(transform);
m_theItkFilter->InPlaceOff();
typename itk::DataObjectDecorator<typename ConstantVelocityFieldTransformType::Superclass::Superclass::Superclass>::Pointer decoratedDummyTransform = itk::DataObjectDecorator<typename ConstantVelocityFieldTransformType::Superclass::Superclass::Superclass>::New();
typename ConstantVelocityFieldTransformType::Pointer dummyTranform = ConstantVelocityFieldTransformType::New();
decoratedDummyTransform->Set(dummyTranform);
//decoratedTransform->Set(m_theItkFilter->GetOutput()->Get());
//m_DisplacementFieldFilter->SetTransformInput(const_cast< itk::DataObjectDecorator<ConstantVelocityFieldTransformType::Superclass::Superclass::Superclass>*>(decoratedTransform));
//m_DisplacementFieldFilter->SetTransformInput(decoratedDummyTransform);
//m_theItkFilter->GetOutput()->Graft(m_DisplacementFieldFilter->GetTransformInput());
//m_DisplacementFieldFilter->GetTransformInput()->Graft(decoratedTransform);
//m_DisplacementFieldFilter->SetTransformInput(const_cast< itk::DataObjectDecorator<ConstantVelocityFieldTransformType::Superclass::Superclass::Superclass>*>(m_theItkFilter->GetOutput()));
//m_DisplacementFieldFilter->GetTransformInput()->Graft(const_cast< itk::DataObjectDecorator<ConstantVelocityFieldTransformType>*>(m_theItkFilter->GetOutput()));
//m_DisplacementFieldFilter->GetOutput()->SetLargestPossibleRegion()
//TODO: instantiating the filter in the constructor might be heavy for the use in component selector factory, since all components of the database are created during the selection process.
// we could choose to keep the component light weighted (for checking criteria such as names and connections) until the settings are passed to the filter, but this requires an additional initialization step.
}
......@@ -162,13 +144,6 @@ int ItkImageRegistrationMethodv4Component< Dimensionality, TPixel>
// connect the itk pipeline
this->m_theItkFilter->SetFixedImage(fixedImage);
//this->m_resampler->SetSize(fixedImage->GetBufferedRegion().GetSize()); //should be virtual image...
this->m_resampler->SetSize(fixedImage->GetLargestPossibleRegion().GetSize()); //should be virtual image...
this->m_resampler->SetOutputOrigin(fixedImage->GetOrigin());
this->m_resampler->SetOutputSpacing(fixedImage->GetSpacing());
this->m_resampler->SetOutputDirection(fixedImage->GetDirection());
this->m_resampler->SetDefaultPixelValue(0);
return 0;
}
......@@ -179,9 +154,6 @@ int ItkImageRegistrationMethodv4Component< Dimensionality, TPixel>
auto movingImage = component->GetItkImageMoving();
// connect the itk pipeline
this->m_theItkFilter->SetMovingImage(movingImage);
this->m_resampler->SetInput(movingImage);
this->m_resampler->UpdateOutputInformation();
return 0;
}
......@@ -190,7 +162,7 @@ int ItkImageRegistrationMethodv4Component< Dimensionality, TPixel>::Set(itkMetri
{
this->m_theItkFilter->SetMetric(component->GetItkMetricv4());
return 1;
return 0;
}
template<int Dimensionality, class TPixel>
......@@ -198,7 +170,7 @@ int ItkImageRegistrationMethodv4Component< Dimensionality, TPixel>::Set(itkOptim
{
this->m_theItkFilter->SetOptimizer(component->GetItkOptimizerv4());
return 1;
return 0;
}
template<int Dimensionality, class TPixel>
......@@ -367,67 +339,8 @@ void ItkImageRegistrationMethodv4Component< Dimensionality, TPixel>::RunRegistra
// perform the actual registration
this->m_theItkFilter->Update();
// TODO get access to the inverse transform
//ConstantVelocityFieldTransformType::ConstPointer tranform = this->m_theItkFilter->GetTransformOutput()->Get();
//ConstantVelocityFieldTransformType::ConstPointer inversetranform = tranform->GetInverseTransform();
//ConstantVelocityFieldTransformType::ConstPointer tranform = this->m_theItkFilter->GetTransformOutput()->Get();
//ConstantVelocityFieldTransformType::Pointer inversetranform = ConstantVelocityFieldTransformType::New();
//typename ConstantVelocityFieldTransformType::Superclass::Pointer inversetranform = fieldTransform->GetInverseTransform();
//Temporary solution: create new displacement field transforms
typedef itk::DisplacementFieldTransform<RealType, Dimensionality> DisplacementFieldTransformType;
typename DisplacementFieldTransformType::Pointer forwardDisplacement = DisplacementFieldTransformType::New();
forwardDisplacement->SetDisplacementField(fieldTransform->GetDisplacementField());
//forwardDisplacement->SetSize(fixedImage->GetBufferedRegion().GetSize()); //should be virtual image...
//forwardDisplacement->SetOutputOrigin(fixedImage->GetOrigin());
//forwardDisplacement->SetOutputSpacing(fixedImage->GetSpacing());
//forwardDisplacement->SetOutputDirection(fixedImage->GetDirection());
typename DisplacementFieldTransformType::Pointer backwardDisplacement = DisplacementFieldTransformType::New();
backwardDisplacement->SetDisplacementField(fieldTransform->GetDisplacementField());
auto inversetranform = fieldTransform->GetInverseTransform();
//auto inversetranform = tranform->GetInverseTransform();
//fieldTransform->GetInverse(inversetranform);
//inversetranform->IntegrateVelocityField();
//inversetranform->IntegrateVelocityField();
//this->m_resampler->SetTransform(this->m_theItkFilter->GetTransform());
//this->m_resampler->SetTransform(this->m_theItkFilter->GetTransformOutput()->Get()->GetInverseTransform());
//this->m_resampler->SetTransform(this->m_theItkFilter->GetTransformOutput()->Get());
//BIG TODO: the resampler is insensitive for any of these options:
//this->m_resampler->SetTransform(inversetranform);
//this->m_resampler->SetTransform(fieldTransform);
this->m_resampler->SetTransform(forwardDisplacement);
//this->m_resampler->SetTransform(this->m_theItkFilter->GetOutput());
// TODO: is this needed?
//this->m_resampler->Update();
//this->m_DisplacementFieldFilter->SetTransformInput(this->m_theItkFilter->GetTransformOutput());
//this->m_DisplacementFieldFilter->SetTransformInput(this->m_theItkFilter->GetTransformOutput()->Get());
//BIG TODO: the DisplacementFieldFilter is insensitive for any of these options:
//this->m_DisplacementFieldFilter->SetTransform(inversetranform);
//this->m_DisplacementFieldFilter->SetTransform(fieldTransform);
}
template<int Dimensionality, class TPixel>
typename ItkImageRegistrationMethodv4Component< Dimensionality, TPixel>::ResultItkImageType::Pointer
ItkImageRegistrationMethodv4Component< Dimensionality, TPixel>
::GetItkImage()
{
return this->m_resampler->GetOutput();
}
template<int Dimensionality, class TPixel>
typename ItkImageRegistrationMethodv4Component< Dimensionality, TPixel>::TransformPointer
ItkImageRegistrationMethodv4Component< Dimensionality, TPixel>
......
......@@ -37,6 +37,7 @@
#include "selxItkGradientDescentOptimizerv4.h"
#include "selxItkAffineTransform.h"
#include "selxItkTransformDisplacementFilter.h"
#include "selxItkResampleFilter.h"
#include "selxItkImageSourceFixed.h"
#include "selxItkImageSourceMoving.h"
......@@ -89,7 +90,9 @@ public:
ItkGradientDescentOptimizerv4Component<double>,
ItkAffineTransformComponent<double,3>,
ItkTransformDisplacementFilterComponent<2, float, double >,
ItkTransformDisplacementFilterComponent<3, double, double >> RegisterComponents;
ItkTransformDisplacementFilterComponent<3, double, double >,
ItkResampleFilterComponent<2, float, double >,
ItkResampleFilterComponent<3, double, double >> RegisterComponents;
typedef SuperElastixFilter<RegisterComponents> SuperElastixFilterType;
......@@ -406,6 +409,12 @@ TEST_F(RegistrationItkv4Test, DisplacementField)
component7Parameters["NumberOfIterations"] = { "1" };
blueprint->AddComponent("Optimizer", component7Parameters);
ParameterMapType component8Parameters;
component8Parameters["NameOfClass"] = { "ItkResampleFilterComponent" };
component8Parameters["Dimensionality"] = { "3" }; // should be derived from the outputs
blueprint->AddComponent("ResampleFilter", component8Parameters);
ParameterMapType connection1Parameters;
connection1Parameters["NameOfInterface"] = { "itkImageFixedInterface" };
blueprint->AddConnection("FixedImageSource", "RegistrationMethod", connection1Parameters);
......@@ -416,7 +425,7 @@ TEST_F(RegistrationItkv4Test, DisplacementField)
ParameterMapType connection3Parameters;
connection3Parameters["NameOfInterface"] = { "itkImageInterface" };
blueprint->AddConnection("RegistrationMethod", "ResultImageSink", connection3Parameters);
blueprint->AddConnection("ResampleFilter", "ResultImageSink", connection3Parameters);
ParameterMapType connection4Parameters;
connection4Parameters["NameOfInterface"] = { "DisplacementFieldItkImageSourceInterface" };
......@@ -426,9 +435,12 @@ TEST_F(RegistrationItkv4Test, DisplacementField)
connection5Parameters["NameOfInterface"] = { "itkMetricv4Interface" };
blueprint->AddConnection("Metric", "RegistrationMethod", connection5Parameters);
blueprint->AddConnection("Optimizer", "RegistrationMethod", { {} });
blueprint->AddConnection("RegistrationMethod", "TransformDisplacementFilter", { {} });
blueprint->AddConnection("FixedImageSource", "TransformDisplacementFilter", { {} });
blueprint->AddConnection("Optimizer", "RegistrationMethod", { {} });
blueprint->AddConnection("RegistrationMethod", "ResampleFilter", { {} });
blueprint->AddConnection("FixedImageSource", "ResampleFilter", { {} });
blueprint->AddConnection("MovingImageSource", "ResampleFilter", { {} });
// Instantiate SuperElastix
SuperElastixFilterType::Pointer superElastixFilter;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment