Commit d7229b42 authored by Floris Berendsen's avatar Floris Berendsen
Browse files

ENH: several: Optimizer SetLearningRate, added interface ImageDomainFixed,

Updated WBIRDemo to new RegistrationMethodv4 Component
parent b44ab6b2
......@@ -33,7 +33,9 @@ namespace selx
class ItkImageSourceFixedComponent :
public Implements<
Accepting<>,
Providing< SourceInterface, itkImageFixedInterface<Dimensionality, TPixel > >
Providing< SourceInterface,
itkImageFixedInterface<Dimensionality, TPixel >,
itkImageDomainFixedInterface<Dimensionality>>
>
{
public:
......@@ -45,14 +47,19 @@ namespace selx
virtual ~ItkImageSourceFixedComponent();
typedef typename itkImageFixedInterface<Dimensionality, TPixel >::ItkImageType ItkImageType;
typedef typename itkImageDomainFixedInterface<Dimensionality>::ItkImageDomainType ItkImageDomainType;
typedef typename itk::ImageFileReader<ItkImageType> ItkImageReaderType;
typedef FileReaderDecorator<ItkImageReaderType> DecoratedReaderType;
// providing interfaces
virtual typename ItkImageType::Pointer GetItkImageFixed() override;
virtual void SetMiniPipelineInput(itk::DataObject::Pointer) override;
virtual AnyFileReader::Pointer GetInputFileReader(void) override;
virtual typename ItkImageDomainType::Pointer GetItkImageDomainFixed() override;
virtual bool MeetsCriterion(const ComponentBase::CriterionType &criterion) override;
static const char * GetDescription() { return "ItkImageSourceFixed Component"; };
private:
......
......@@ -65,6 +65,18 @@ namespace selx
return DecoratedReaderType::New().GetPointer();
}
template<int Dimensionality, class TPixel>
typename ItkImageSourceFixedComponent< Dimensionality, TPixel>::ItkImageDomainType::Pointer
ItkImageSourceFixedComponent< Dimensionality, TPixel>
::GetItkImageDomainFixed()
{
if (this->m_Image == nullptr)
{
itkExceptionMacro("SourceComponent needs to be initialized by SetMiniPipelineInput()");
}
return this->m_Image.GetPointer();
}
template<int Dimensionality, class TPixel>
bool
ItkImageSourceFixedComponent< Dimensionality, TPixel>
......
......@@ -30,11 +30,10 @@ namespace selx
template <class InternalComputationValueType, int Dimensionality>
class ItkGaussianExponentialDiffeomorphicTransformComponent :
public Implements<
Accepting< itkImageFixedInterface<Dimensionality, double> >,
Accepting< itkImageDomainFixedInterface<Dimensionality> >,
Providing< itkTransformInterface<InternalComputationValueType,Dimensionality>,
RunRegistrationInterface
>>
//Should be fixed domain only
RunRegistrationInterface>
>
{
public:
selxNewMacro(ItkGaussianExponentialDiffeomorphicTransformComponent, ComponentBase);
......@@ -43,17 +42,16 @@ namespace selx
ItkGaussianExponentialDiffeomorphicTransformComponent();
virtual ~ItkGaussianExponentialDiffeomorphicTransformComponent();
//typedef double InternalComputationValueType;
/** Type of the optimizer. */
/** Get types from interfaces */
using TransformType = typename itkTransformInterface<InternalComputationValueType,Dimensionality>::TransformType;
using TransformPointer = typename itkTransformInterface<InternalComputationValueType,Dimensionality>::TransformPointer;
typedef typename itk::GaussianExponentialDiffeomorphicTransform<InternalComputationValueType, Dimensionality> GaussianExponentialDiffeomorphicTransformType;
using ItkImageDomainType = typename itkImageDomainFixedInterface<Dimensionality>::ItkImageDomainType;
using GaussianExponentialDiffeomorphicTransformType = typename itk::GaussianExponentialDiffeomorphicTransform<InternalComputationValueType, Dimensionality>;
//Accepting Interfaces:
virtual int Set(itkImageFixedInterface<Dimensionality, double>*) override;
virtual int Set(itkImageDomainFixedInterface<Dimensionality>*) override;
//Providing Interfaces:
virtual TransformPointer GetItkTransform() override;
......@@ -65,7 +63,7 @@ namespace selx
static const char * GetDescription() { return "ItkGaussianExponentialDiffeomorphicTransform Component"; };
private:
typename GaussianExponentialDiffeomorphicTransformType::Pointer m_Transform;
typename itk::Image<double, Dimensionality>::Pointer m_FixedImage;
typename ItkImageDomainType::Pointer m_FixedImageDomain;
protected:
/* The following struct returns the string name of computation type */
/* default implementation */
......
......@@ -40,16 +40,16 @@ ItkGaussianExponentialDiffeomorphicTransformComponent< InternalComputationValueT
template<class InternalComputationValueType, int Dimensionality>
int ItkGaussianExponentialDiffeomorphicTransformComponent< InternalComputationValueType, Dimensionality>
::Set(itkImageFixedInterface<Dimensionality, double>* component)
::Set(itkImageDomainFixedInterface<Dimensionality>* component)
{
this->m_FixedImage = component->GetItkImageFixed();
this->m_FixedImageDomain = component->GetItkImageDomainFixed();
auto displacementField = GaussianExponentialDiffeomorphicTransformType::DisplacementFieldType::New();
//auto zeroVector = itk::NumericTraits<GaussianExponentialDiffeomorphicTransformType::DisplacementFieldType::PixelType>::Zero();
auto zeroVector = typename GaussianExponentialDiffeomorphicTransformType::DisplacementFieldType::PixelType(0.0);
displacementField->CopyInformation(this->m_FixedImage);
displacementField->SetRegions(this->m_FixedImage->GetBufferedRegion());
displacementField->CopyInformation(this->m_FixedImageDomain);
displacementField->SetRegions(this->m_FixedImageDomain->GetBufferedRegion());
displacementField->Allocate();
displacementField->FillBuffer(zeroVector);
......
......@@ -18,6 +18,7 @@
*=========================================================================*/
#include "selxItkGradientDescentOptimizerv4.h"
#include <boost/lexical_cast.hpp>
namespace selx
{
......@@ -83,6 +84,29 @@ ItkGradientDescentOptimizerv4Component< InternalComputationValueType>
}
}
}
else if (criterion.first == "LearningRate") //Supports this?
{
if (criterion.second.size() != 1)
{
meetsCriteria = false;
//itkExceptionMacro("The criterion Sigma may have only 1 value");
}
else
{
auto const & criterionValue = *criterion.second.begin();
try
{
this->m_Optimizer->SetNumberOfIterations(boost::lexical_cast<InternalComputationValueType>(criterionValue));
//this->m_Optimizer->SetLearningRate(std::stod(criterionValue));
meetsCriteria = true;
}
catch (itk::ExceptionObject & err) // TODO: should catch(const bad_lexical_cast &) too
{
//TODO log the error message?
meetsCriteria = false;
}
}
}
return meetsCriteria;
}
......
......@@ -29,6 +29,11 @@
#include "selxItkImageRegistrationMethodv4Component.h"
#include "selxItkANTSNeighborhoodCorrelationImageToImageMetricv4.h"
#include "selxItkMeanSquaresImageToImageMetricv4.h"
#include "selxItkGradientDescentOptimizerv4.h"
#include "selxItkGaussianExponentialDiffeomorphicTransform.h"
#include "selxItkTransformDisplacementFilter.h"
#include "selxItkResampleFilter.h"
#include "selxItkImageSourceFixed.h"
#include "selxItkImageSourceMoving.h"
......@@ -61,7 +66,24 @@ class WBIRDemoTest : public ::testing::Test {
public:
typedef Overlord::Pointer OverlordPointerType;
typedef SuperElastixFilter<TypeList<>> SuperElastixFilterType;
/** Fill SUPERelastix' component data base by registering various components */
typedef TypeList <
DisplacementFieldItkImageFilterSinkComponent<2, float>,
ItkImageSinkComponent<2, float>,
ItkImageSourceFixedComponent<2, float>,
ItkImageSourceMovingComponent<2, float>,
ElastixComponent<2, float>,
ItkImageRegistrationMethodv4Component<2, float>,
ItkANTSNeighborhoodCorrelationImageToImageMetricv4Component<2, float>,
ItkMeanSquaresImageToImageMetricv4Component < 2, float >,
ItkGradientDescentOptimizerv4Component<double>,
ItkGaussianExponentialDiffeomorphicTransformComponent<double, 2>,
ItkTransformDisplacementFilterComponent<2, float, double >,
ItkResampleFilterComponent<2, float, double > > RegisterComponents;
typedef SuperElastixFilter<RegisterComponents> SuperElastixFilterType;
typedef Blueprint::Pointer BlueprintPointerType;
typedef Blueprint::ConstPointer BlueprintConstPointerType;
typedef Blueprint::ParameterMapType ParameterMapType;
......@@ -75,20 +97,9 @@ public:
typedef itk::Image<itk::Vector<float, 2>, 2> VectorImage2DType;
typedef itk::ImageFileWriter<VectorImage2DType> VectorImageWriter2DType;
/** Fill SUPERelastix' component data base by registering various components */
virtual void SetUp() {
ComponentFactory<DisplacementFieldItkImageFilterSinkComponent<2, float>>::RegisterOneFactory();
ComponentFactory<ItkImageSourceFixedComponent<2, float>>::RegisterOneFactory();
ComponentFactory<ItkImageSourceMovingComponent<2, float>>::RegisterOneFactory();
ComponentFactory<ItkSmoothingRecursiveGaussianImageFilterComponent<2, float>>::RegisterOneFactory();
ComponentFactory<ItkImageRegistrationMethodv4Component<2, float>>::RegisterOneFactory();
ComponentFactory<ItkANTSNeighborhoodCorrelationImageToImageMetricv4Component<2, float>>::RegisterOneFactory();
ComponentFactory<ItkMeanSquaresImageToImageMetricv4Component<2, float>>::RegisterOneFactory();
ComponentFactory<ElastixComponent<2, float>>::RegisterOneFactory();
ComponentFactory<ItkImageSinkComponent<2, float>>::RegisterOneFactory();
}
virtual void TearDown() {
......@@ -106,64 +117,58 @@ TEST_F(WBIRDemoTest, itkv4_SVF_ANTSCC)
/** make example blueprint configuration */
blueprint = Blueprint::New();
ParameterMapType component0Parameters;
component0Parameters["NameOfClass"] = { "ItkImageRegistrationMethodv4Component" };
blueprint->AddComponent("RegistrationMethod", component0Parameters);
ParameterMapType component1Parameters;
component1Parameters["NameOfClass"] = { "ItkImageSourceFixedComponent" };
blueprint->AddComponent("FixedImageSource", component1Parameters);
ParameterMapType component2Parameters;
component2Parameters["NameOfClass"] = { "ItkImageSourceMovingComponent" };
blueprint->AddComponent("MovingImageSource", component2Parameters);
ParameterMapType component3Parameters;
component3Parameters["NameOfClass"] = { "ItkImageSinkComponent" };
blueprint->AddComponent("ResultImageSink", component3Parameters);
blueprint->AddComponent("RegistrationMethod", { { "NameOfClass", { "ItkImageRegistrationMethodv4Component" } } });
blueprint->AddComponent("Metric", { { "NameOfClass", { "ItkANTSNeighborhoodCorrelationImageToImageMetricv4Component" } } });
blueprint->AddComponent("Optimizer", { { "NameOfClass", { "ItkGradientDescentOptimizerv4Component" } },
{ "NumberOfIterations", { "100" } },
{ "LearningRate", { "0.001" } } });
blueprint->AddComponent("Transform", { { "NameOfClass", { "ItkGaussianExponentialDiffeomorphicTransformComponent" } } });
ParameterMapType component4Parameters;
component4Parameters["NameOfClass"] = { "DisplacementFieldItkImageFilterSinkComponent" };
blueprint->AddComponent("ResultDisplacementFieldSink", component4Parameters);
blueprint->AddComponent("ResampleFilter", { { "NameOfClass", { "ItkResampleFilterComponent" } } });
blueprint->AddComponent("TransformDisplacementFilter", { { "NameOfClass", { "ItkTransformDisplacementFilterComponent" } } });
ParameterMapType component5Parameters;
component5Parameters["NameOfClass"] = { "ItkANTSNeighborhoodCorrelationImageToImageMetricv4Component" };
blueprint->AddComponent("Metric", component5Parameters);
blueprint->AddComponent("FixedImageSource", { { "NameOfClass", { "ItkImageSourceFixedComponent" } } });
blueprint->AddComponent("MovingImageSource", { { "NameOfClass", { "ItkImageSourceMovingComponent" } } });
blueprint->AddComponent("ResultImageSink", { { "NameOfClass", { "ItkImageSinkComponent" } } });
blueprint->AddComponent("ResultDisplacementFieldSink", { { "NameOfClass", { "DisplacementFieldItkImageFilterSinkComponent" } } });
ParameterMapType connection1Parameters;
//optionally, tie properties to connection to avoid ambiguities
//connection1Parameters["NameOfInterface"] = { "itkImageFixedInterface" };
blueprint->AddConnection("FixedImageSource", "RegistrationMethod", connection1Parameters);
//blueprint->AddConnection("FixedImageSource", "RegistrationMethod", { { "NameOfInterface", { "itkImageFixedInterface" } } });
blueprint->AddConnection("FixedImageSource", "RegistrationMethod", { {} });
ParameterMapType connection2Parameters;
//optionally, tie properties to connection to avoid ambiguities
//connection2Parameters["NameOfInterface"] = { "itkImageMovingInterface" };
blueprint->AddConnection("MovingImageSource", "RegistrationMethod", connection2Parameters);
//blueprint->AddConnection("MovingImageSource", "RegistrationMethod", { { "NameOfInterface", { "itkImageMovingInterface" } } });
blueprint->AddConnection("MovingImageSource", "RegistrationMethod", { {} });
ParameterMapType connection3Parameters;
//optionally, tie properties to connection to avoid ambiguities
//connection3Parameters["NameOfInterface"] = { "itkImageSourceInterface" };
blueprint->AddConnection("RegistrationMethod", "ResultImageSink", connection3Parameters);
//blueprint->AddConnection("RegistrationMethod", "ResultImageSink", { { "NameOfInterface", { "itkImageSourceInterface" } } });
blueprint->AddConnection("ResampleFilter", "ResultImageSink", { {} });
ParameterMapType connection4Parameters;
//optionally, tie properties to connection to avoid ambiguities
//connection4Parameters["NameOfInterface"] = { "DisplacementFieldItkImageSourceInterface" };
blueprint->AddConnection("RegistrationMethod", "ResultDisplacementFieldSink", connection4Parameters);
//blueprint->AddConnection("RegistrationMethod", "ResultDisplacementFieldSink", { { "NameOfInterface", { "DisplacementFieldItkImageSourceInterface" } } });
blueprint->AddConnection("TransformDisplacementFilter", "ResultDisplacementFieldSink", { {} });
ParameterMapType connection5Parameters;
//optionally, tie properties to connection to avoid ambiguities
//connection5Parameters["NameOfInterface"] = { "itkMetricv4Interface" };
blueprint->AddConnection("Metric", "RegistrationMethod", connection5Parameters);
//blueprint->AddConnection("Metric", "RegistrationMethod", { { "NameOfInterface", { "itkMetricv4Interface" } } });
blueprint->AddConnection("Metric", "RegistrationMethod", { {} });
blueprint->AddConnection("FixedImageSource", "Transform", { {} });
blueprint->AddConnection("Transform", "RegistrationMethod", { {} });
blueprint->AddConnection("Optimizer", "RegistrationMethod", { {} });
blueprint->AddConnection("RegistrationMethod", "TransformDisplacementFilter", { {} });
blueprint->AddConnection("FixedImageSource", "TransformDisplacementFilter", { {} });
blueprint->AddConnection("RegistrationMethod", "ResampleFilter", { {} });
blueprint->AddConnection("FixedImageSource", "ResampleFilter", { {} });
blueprint->AddConnection("MovingImageSource", "ResampleFilter", { {} });
// Data manager provides the paths to the input and output data for unit tests
DataManagerType::Pointer dataManager = DataManagerType::New();
blueprint->WriteBlueprint("itkv4_SVF_ANTSCC.dot");
blueprint->WriteBlueprint(dataManager->GetOutputFile("itkv4_SVF_ANTSCC.dot"));
// Instantiate SuperElastix
EXPECT_NO_THROW(superElastixFilter = SuperElastixFilterType::New());
// Data manager provides the paths to the input and output data for unit tests
DataManagerType::Pointer dataManager = DataManagerType::New();
// Set up the readers and writers
ImageReader2DType::Pointer fixedImageReader = ImageReader2DType::New();
fixedImageReader->SetFileName(dataManager->GetInputFile("coneA2d64.mhd"));
......@@ -201,64 +206,59 @@ TEST_F(WBIRDemoTest, itkv4_SVF_MSD)
/** make example blueprint configuration */
blueprint = Blueprint::New();
ParameterMapType component0Parameters;
component0Parameters["NameOfClass"] = { "ItkImageRegistrationMethodv4Component" };
blueprint->AddComponent("RegistrationMethod", component0Parameters);
ParameterMapType component1Parameters;
component1Parameters["NameOfClass"] = { "ItkImageSourceFixedComponent" };
blueprint->AddComponent("FixedImageSource", component1Parameters);
ParameterMapType component2Parameters;
component2Parameters["NameOfClass"] = { "ItkImageSourceMovingComponent" };
blueprint->AddComponent("MovingImageSource", component2Parameters);
ParameterMapType component3Parameters;
component3Parameters["NameOfClass"] = { "ItkImageSinkComponent" };
blueprint->AddComponent("ResultImageSink", component3Parameters);
blueprint->AddComponent("RegistrationMethod", { { "NameOfClass", { "ItkImageRegistrationMethodv4Component" } } });
blueprint->AddComponent("Metric", { { "NameOfClass", { "ItkMeanSquaresImageToImageMetricv4Component" } } });
blueprint->AddComponent("Optimizer", { { "NameOfClass", { "ItkGradientDescentOptimizerv4Component" } },
{ "NumberOfIterations", { "100" } },
{ "LearningRate", { "0.001" } } });
blueprint->AddComponent("Transform", { { "NameOfClass", { "ItkGaussianExponentialDiffeomorphicTransformComponent" } } });
ParameterMapType component4Parameters;
component4Parameters["NameOfClass"] = { "DisplacementFieldItkImageFilterSinkComponent" };
blueprint->AddComponent("ResultDisplacementFieldSink", component4Parameters);
blueprint->AddComponent("ResampleFilter", { { "NameOfClass", { "ItkResampleFilterComponent" } } });
blueprint->AddComponent("TransformDisplacementFilter", { { "NameOfClass", { "ItkTransformDisplacementFilterComponent" } } });
ParameterMapType component5Parameters;
component5Parameters["NameOfClass"] = { "ItkMeanSquaresImageToImageMetricv4Component" };
blueprint->AddComponent("Metric", component5Parameters);
blueprint->AddComponent("FixedImageSource", { { "NameOfClass", { "ItkImageSourceFixedComponent" } } });
blueprint->AddComponent("MovingImageSource", { { "NameOfClass", { "ItkImageSourceMovingComponent" } } });
blueprint->AddComponent("ResultImageSink", { { "NameOfClass", { "ItkImageSinkComponent" } } });
blueprint->AddComponent("ResultDisplacementFieldSink", { { "NameOfClass", { "DisplacementFieldItkImageFilterSinkComponent" } } });
ParameterMapType connection1Parameters;
//optionally, tie properties to connection to avoid ambiguities
//connection1Parameters["NameOfInterface"] = { "itkImageFixedInterface" };
blueprint->AddConnection("FixedImageSource", "RegistrationMethod", connection1Parameters);
//blueprint->AddConnection("FixedImageSource", "RegistrationMethod", { { "NameOfInterface", { "itkImageFixedInterface" } } });
blueprint->AddConnection("FixedImageSource", "RegistrationMethod", { {} });
ParameterMapType connection2Parameters;
//optionally, tie properties to connection to avoid ambiguities
//connection2Parameters["NameOfInterface"] = { "itkImageMovingInterface" };
blueprint->AddConnection("MovingImageSource", "RegistrationMethod", connection2Parameters);
//blueprint->AddConnection("MovingImageSource", "RegistrationMethod", { { "NameOfInterface", { "itkImageMovingInterface" } } });
blueprint->AddConnection("MovingImageSource", "RegistrationMethod", { {} });
ParameterMapType connection3Parameters;
//optionally, tie properties to connection to avoid ambiguities
//connection3Parameters["NameOfInterface"] = { "itkImageSourceInterface" };
blueprint->AddConnection("RegistrationMethod", "ResultImageSink", connection3Parameters);
//blueprint->AddConnection("RegistrationMethod", "ResultImageSink", { { "NameOfInterface", { "itkImageSourceInterface" } } });
blueprint->AddConnection("ResampleFilter", "ResultImageSink", { {} });
ParameterMapType connection4Parameters;
//optionally, tie properties to connection to avoid ambiguities
//connection4Parameters["NameOfInterface"] = { "DisplacementFieldItkImageSourceInterface" };
blueprint->AddConnection("RegistrationMethod", "ResultDisplacementFieldSink", connection4Parameters);
//blueprint->AddConnection("RegistrationMethod", "ResultDisplacementFieldSink", { { "NameOfInterface", { "DisplacementFieldItkImageSourceInterface" } } });
blueprint->AddConnection("TransformDisplacementFilter", "ResultDisplacementFieldSink", { {} });
ParameterMapType connection5Parameters;
//optionally, tie properties to connection to avoid ambiguities
//connection5Parameters["NameOfInterface"] = { "itkMetricv4Interface" };
blueprint->AddConnection("Metric", "RegistrationMethod", connection5Parameters);
//blueprint->AddConnection("Metric", "RegistrationMethod", { { "NameOfInterface", { "itkMetricv4Interface" } } });
blueprint->AddConnection("Metric", "RegistrationMethod", { {} });
blueprint->AddConnection("FixedImageSource", "Transform", { {} });
blueprint->AddConnection("Transform", "RegistrationMethod", { {} });
blueprint->AddConnection("Optimizer", "RegistrationMethod", { {} });
blueprint->AddConnection("RegistrationMethod", "TransformDisplacementFilter", { {} });
blueprint->AddConnection("FixedImageSource", "TransformDisplacementFilter", { {} });
blueprint->AddConnection("RegistrationMethod", "ResampleFilter", { {} });
blueprint->AddConnection("FixedImageSource", "ResampleFilter", { {} });
blueprint->AddConnection("MovingImageSource", "ResampleFilter", { {} });
// Data manager provides the paths to the input and output data for unit tests
DataManagerType::Pointer dataManager = DataManagerType::New();
blueprint->WriteBlueprint("itkv4_SVF_MSD.dot");
blueprint->WriteBlueprint(dataManager->GetOutputFile("itkv4_SVF_MSD.dot"));
// Instantiate SuperElastix
EXPECT_NO_THROW(superElastixFilter = SuperElastixFilterType::New());
// Data manager provides the paths to the input and output data for unit tests
DataManagerType::Pointer dataManager = DataManagerType::New();
// Set up the readers and writers
ImageReader2DType::Pointer fixedImageReader = ImageReader2DType::New();
fixedImageReader->SetFileName(dataManager->GetInputFile("coneA2d64.mhd"));
......
......@@ -114,6 +114,15 @@ struct InterfaceName < itkImageFixedInterface <D, TPixel> >
}
};
template <int D>
struct InterfaceName < itkImageDomainFixedInterface <D> >
{
static const char* Get()
{
return "itkImageDomainFixedInterface";
}
};
template <int D, class TPixel>
struct InterfaceName < itkImageMovingInterface <D, TPixel> >
{
......
......@@ -86,6 +86,14 @@ namespace selx
virtual typename ItkImageType::Pointer GetItkImageFixed() = 0;
};
template<int Dimensionality>
class itkImageDomainFixedInterface {
// An interface that passes the pointer of an output image
public:
typedef typename itk::ImageBase<Dimensionality> ItkImageDomainType;
virtual typename ItkImageDomainType::Pointer GetItkImageDomainFixed() = 0;
};
template<int Dimensionality, class TPixel>
class itkImageMovingInterface {
// An interface that passes the pointer of an output image
......
......@@ -170,7 +170,7 @@ namespace selx
if (numberOfConnections == 0)
{
isAllSuccess = false;
std::cout << "Warning: a connection was specified, but no compatible interfaces were found.";
std::cout << "Warning: a connection from " << name << " to " << outgoingName << " was specified, but no compatible interfaces were found." << std::endl;
}
}
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment