selxItkImageRegistrationMethodv4Component.hxx 14.9 KB
Newer Older
Floris Berendsen's avatar
Floris Berendsen committed
1
2
/*=========================================================================
 *
3
 *  Copyright Leiden University Medical Center, Erasmus University Medical
Floris Berendsen's avatar
Floris Berendsen committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
 *  Center and contributors
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0.txt
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *=========================================================================*/

20
21
#include "selxItkImageRegistrationMethodv4Component.h"

FBerendsen's avatar
FBerendsen committed
22
//TODO: get rid of these
Floris Berendsen's avatar
Floris Berendsen committed
23
24
25
26
27
#include "itkMeanSquaresImageToImageMetricv4.h"
#include "itkANTSNeighborhoodCorrelationImageToImageMetricv4.h"
#include "itkGradientDescentOptimizerv4.h"
#include "itkImageFileWriter.h"

28
29
namespace selx
{
30
31
32
33
template< typename TFilter >
class CommandIterationUpdate : public itk::Command
{
public:
Floris Berendsen's avatar
Floris Berendsen committed
34

35
36
37
38
  typedef CommandIterationUpdate    Self;
  typedef itk::Command              Superclass;
  typedef itk::SmartPointer< Self > Pointer;
  itkNewMacro( Self );
Floris Berendsen's avatar
Floris Berendsen committed
39

40
41
  typedef itk::GradientDescentOptimizerv4 OptimizerType;
  typedef   const OptimizerType *         OptimizerPointer;
Floris Berendsen's avatar
Floris Berendsen committed
42

43
protected:
Floris Berendsen's avatar
Floris Berendsen committed
44

45
  CommandIterationUpdate() {}
Floris Berendsen's avatar
Floris Berendsen committed
46

47
public:
Floris Berendsen's avatar
Floris Berendsen committed
48

49
50
51
52
  virtual void Execute( itk::Object * caller, const itk::EventObject & event ) ITK_OVERRIDE
  {
    Execute( (const itk::Object *)caller, event );
  }
Floris Berendsen's avatar
Floris Berendsen committed
53
54


55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  virtual void Execute( const itk::Object * object, const itk::EventObject & event ) ITK_OVERRIDE
  {
    const TFilter * filter = static_cast< const TFilter * >( object );
    if( typeid( event ) == typeid( itk::MultiResolutionIterationEvent ) )
    {
      unsigned int currentLevel = filter->GetCurrentLevel();
      typename TFilter::ShrinkFactorsPerDimensionContainerType shrinkFactors = filter->GetShrinkFactorsPerDimension( currentLevel );
      typename TFilter::SmoothingSigmasArrayType smoothingSigmas             = filter->GetSmoothingSigmasPerLevel();
      typename TFilter::TransformParametersAdaptorsContainerType adaptors    = filter->GetTransformParametersAdaptorsPerLevel();

      const itk::ObjectToObjectOptimizerBase * optimizerBase = filter->GetOptimizer();
      typedef itk::GradientDescentOptimizerv4 GradientDescentOptimizerv4Type;
      typename GradientDescentOptimizerv4Type::ConstPointer optimizer = dynamic_cast< const GradientDescentOptimizerv4Type * >( optimizerBase );
      if( !optimizer )
      {
        itkGenericExceptionMacro( "Error dynamic_cast failed" );
Floris Berendsen's avatar
Floris Berendsen committed
71
      }
72
73
74
75
76
77
78
79
80
81
82
83
84
      typename GradientDescentOptimizerv4Type::DerivativeType gradient = optimizer->GetGradient();

      /* orig
      std::cout << "  Current level = " << currentLevel << std::endl;
      std::cout << "    shrink factor = " << shrinkFactors[currentLevel] << std::endl;
      std::cout << "    smoothing sigma = " << smoothingSigmas[currentLevel] << std::endl;
      std::cout << "    required fixed parameters = " << adaptors[currentLevel]->GetRequiredFixedParameters() << std::endl;
      */

      //debug:
      std::cout << "  CL Current level:           " << currentLevel << std::endl;
      std::cout << "   SF Shrink factor:          " << shrinkFactors << std::endl;
      std::cout << "   SS Smoothing sigma:        " << smoothingSigmas[ currentLevel ] << std::endl;
85
      //std::cout << "   RFP Required fixed params: " << adaptors[ currentLevel ]->GetRequiredFixedParameters() << std::endl;
86
87
88
89
90
      std::cout << "   LR Final learning rate:    " << optimizer->GetLearningRate() << std::endl;
      std::cout << "   FM Final metric value:     " << optimizer->GetCurrentMetricValue() << std::endl;
      std::cout << "   SC Optimizer scales:       " << optimizer->GetScales() << std::endl;
      std::cout << "   FG Final metric gradient (sample of values): ";
      if( gradient.GetSize() < 10 )
Floris Berendsen's avatar
Floris Berendsen committed
91
      {
92
        std::cout << gradient;
Floris Berendsen's avatar
Floris Berendsen committed
93
94
95
      }
      else
      {
96
97
98
99
        for( itk::SizeValueType i = 0; i < gradient.GetSize(); i += ( gradient.GetSize() / 16 ) )
        {
          std::cout << gradient[ i ] << " ";
        }
Floris Berendsen's avatar
Floris Berendsen committed
100
      }
101
      std::cout << std::endl;
Floris Berendsen's avatar
Floris Berendsen committed
102
    }
103
104
105
106
107
108
109
110
111
112
113
114
115
    else if( !( itk::IterationEvent().CheckEvent( &event ) ) )
    {
      return;
    }
    else
    {
      OptimizerPointer optimizer = static_cast< OptimizerPointer >( object );
      std::cout << optimizer->GetCurrentIteration() << ": ";
      std::cout << optimizer->GetCurrentMetricValue() << std::endl;
      //std::cout << optimizer->GetInfinityNormOfProjectedGradient() << std::endl;
    }
  }
};
Floris Berendsen's avatar
Floris Berendsen committed
116

117
template< int Dimensionality, class TPixel >
118
ItkImageRegistrationMethodv4Component< Dimensionality, TPixel >::ItkImageRegistrationMethodv4Component() : m_TransformAdaptorsContainerInterface(nullptr)
119
120
{
  m_theItkFilter = TheItkFilterType::New();
121
  m_theItkFilter->InPlaceOn();
122

123
124
  //TODO: instantiating the filter in the constructor might be heavy for the use in component selector factory, since all components of the database are created during the selection process.
  // we could choose to keep the component light weighted (for checking criteria such as names and connections) until the settings are passed to the filter, but this requires an additional initialization step.
125
}
126

127
128
129

template< int Dimensionality, class TPixel >
ItkImageRegistrationMethodv4Component< Dimensionality, TPixel >::~ItkImageRegistrationMethodv4Component()
130
131
132
{
}

133
134
135
136
137

template< int Dimensionality, class TPixel >
int
ItkImageRegistrationMethodv4Component< Dimensionality, TPixel >
::Set( itkImageFixedInterface< Dimensionality, TPixel > * component )
138
{
Floris Berendsen's avatar
Floris Berendsen committed
139
  auto fixedImage = component->GetItkImageFixed();
140
  // connect the itk pipeline
141
  this->m_theItkFilter->SetFixedImage( fixedImage );
Floris Berendsen's avatar
Floris Berendsen committed
142

143
  return 0;
144
145
}

146
147
148
149
150

template< int Dimensionality, class TPixel >
int
ItkImageRegistrationMethodv4Component< Dimensionality, TPixel >
::Set( itkImageMovingInterface< Dimensionality, TPixel > * component )
151
{
Floris Berendsen's avatar
Floris Berendsen committed
152
  auto movingImage = component->GetItkImageMoving();
153
  // connect the itk pipeline
154
  this->m_theItkFilter->SetMovingImage( movingImage );
155
  return 0;
156
}
157

158

159
160
161
162
template< int Dimensionality, class TPixel >
int
ItkImageRegistrationMethodv4Component< Dimensionality, TPixel >::Set( itkTransformInterface< TransformInternalComputationValueType,
  Dimensionality > * component )
163
{
164
  this->m_theItkFilter->SetInitialTransform( component->GetItkTransform() );
165
166
167
168

  return 0;
}

169
170
template< int Dimensionality, class TPixel >
int
Floris Berendsen's avatar
Floris Berendsen committed
171
ItkImageRegistrationMethodv4Component< Dimensionality, TPixel >::Set(TransformParametersAdaptorsContainerInterfaceType * component)
172
{
173
  // store the interface to the ParametersAdaptorsContainer since during the setup of the connections the TransformParametersAdaptorComponent might not be fully connected and thus does not have the adaptors ready.
Floris Berendsen's avatar
Floris Berendsen committed
174
  this->m_TransformAdaptorsContainerInterface = component;
175
176
177
  return 0;
}

178

179
180
181
template< int Dimensionality, class TPixel >
int
ItkImageRegistrationMethodv4Component< Dimensionality, TPixel >::Set( itkMetricv4Interface< Dimensionality, TPixel > * component )
182
{
183
  //TODO: The optimizer must be set explicitly, since this is a work-around for a bug in itkRegistrationMethodv4.
184
  //TODO: report bug to itk: when setting a metric, the optimizer must be set explicitly as well, since default optimizer setup breaks.
185
  this->m_theItkFilter->SetMetric( component->GetItkMetricv4() );
186

187
  return 0;
188
}
189

190

191
192
193
194
195
template< int Dimensionality, class TPixel >
int
ItkImageRegistrationMethodv4Component< Dimensionality, TPixel >::Set( itkOptimizerv4Interface< OptimizerInternalComputationValueType > * component )
{
  //TODO: The optimizer must be set explicitly, since this is a work-around for a bug in itkRegistrationMethodv4.
196
  //TODO: report bug to itk: when setting a metric, the optimizer must be set explicitly as well, since default optimizer setup breaks.
197
  this->m_theItkFilter->SetOptimizer( component->GetItkOptimizerv4() );
198

199
  return 0;
200
201
}

202

203
204
205
206
207
template< int Dimensionality, class TPixel >
void
ItkImageRegistrationMethodv4Component< Dimensionality, TPixel >::RunRegistration( void )
{
  typename FixedImageType::ConstPointer fixedImage   = this->m_theItkFilter->GetFixedImage();
FBerendsen's avatar
FBerendsen committed
208
  typename MovingImageType::ConstPointer movingImage = this->m_theItkFilter->GetMovingImage();
209

210
  // Scale estimator is not used in current implementation yet
211
212
  typename ScalesEstimatorType::Pointer scalesEstimator = ScalesEstimatorType::New();

213
214
215
  ImageMetricType * theMetric = dynamic_cast< ImageMetricType * >( this->m_theItkFilter->GetModifiableMetric() );

  auto optimizer = dynamic_cast< itk::GradientDescentOptimizerv4 * >( this->m_theItkFilter->GetModifiableOptimizer() );
216
  //auto optimizer = dynamic_cast<itk::ObjectToObjectOptimizerBaseTemplate< InternalComputationValueType > *>(this->m_theItkFilter->GetModifiableOptimizer());
217

218
  auto transform = this->m_theItkFilter->GetModifiableTransform();
219

220
  if( theMetric )
221
  {
222
    scalesEstimator->SetMetric( theMetric );
223
  }
Floris Berendsen's avatar
Floris Berendsen committed
224
225
  else
  {
226
    itkExceptionMacro( "Error casting to ImageMetricv4Type failed" );
Floris Berendsen's avatar
Floris Berendsen committed
227
228
229
  }

  //std::cout << "estimated step scale: " << scalesEstimator->EstimateStepScale(1.0);
230
231
232
233
  scalesEstimator->SetTransformForward( true );
  scalesEstimator->SetSmallParameterVariation( 1.0 );

  optimizer->SetScalesEstimator( ITK_NULLPTR );
Floris Berendsen's avatar
Floris Berendsen committed
234
  //optimizer->SetScalesEstimator(scalesEstimator);
235
236
  optimizer->SetDoEstimateLearningRateOnce( false ); //true by default
  optimizer->SetDoEstimateLearningRateAtEachIteration( false );
Floris Berendsen's avatar
Floris Berendsen committed
237

238
  this->m_theItkFilter->SetOptimizer( optimizer );
Floris Berendsen's avatar
Floris Berendsen committed
239

240
241
242
243
  if (this->m_TransformAdaptorsContainerInterface != nullptr)
  {
    this->m_theItkFilter->SetTransformParametersAdaptorsPerLevel(this->m_TransformAdaptorsContainerInterface->GetItkTransformParametersAdaptorsContainer());
  }
244

245
  typedef CommandIterationUpdate< TheItkFilterType > RegistrationCommandType;
246
  typename RegistrationCommandType::Pointer registrationObserver = RegistrationCommandType::New();
247
248
  this->m_theItkFilter->AddObserver( itk::IterationEvent(), registrationObserver );

FBerendsen's avatar
FBerendsen committed
249
  // perform the actual registration
250
  this->m_theItkFilter->Update();
251
}
252

253
254
255
256

template< int Dimensionality, class TPixel >
typename ItkImageRegistrationMethodv4Component< Dimensionality, TPixel >::TransformPointer
ItkImageRegistrationMethodv4Component< Dimensionality, TPixel >
257
::GetItkTransform()
258
{
259
  return this->m_theItkFilter->GetModifiableTransform();
260
261
262
}


263
template< int Dimensionality, class TPixel >
264
bool
265
266
ItkImageRegistrationMethodv4Component< Dimensionality, TPixel >
::MeetsCriterion( const ComponentBase::CriterionType & criterion )
267
{
268
269
270
  bool hasUndefinedCriteria( false );
  bool meetsCriteria( false );
  if( criterion.first == "ComponentProperty" )
271
272
  {
    meetsCriteria = true;
273
    for( auto const & criterionValue : criterion.second ) // auto&& preferred?
274
    {
275
      if( criterionValue != "SomeProperty" )  // e.g. "GradientDescent", "SupportsSparseSamples
276
277
278
279
280
      {
        meetsCriteria = false;
      }
    }
  }
281
  else if( criterion.first == "Dimensionality" ) //Supports this?
282
283
  {
    meetsCriteria = true;
284
    for( auto const & criterionValue : criterion.second ) // auto&& preferred?
285
    {
286
      if( std::stoi( criterionValue ) != Dimensionality )
287
288
289
290
291
      {
        meetsCriteria = false;
      }
    }
  }
292
  else if( criterion.first == "PixelType" ) //Supports this?
293
294
  {
    meetsCriteria = true;
295
    for( auto const & criterionValue : criterion.second ) // auto&& preferred?
296
    {
297
      if( criterionValue != Self::GetPixelTypeNameString() )
298
299
300
301
302
      {
        meetsCriteria = false;
      }
    }
  }
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
  else if (criterion.first == "NumberOfLevels") //Supports this?
  {
    meetsCriteria = true;
    if (criterion.second.size() == 1)
    {
      if (this->m_NumberOfLevelsLastSetBy == "") // check if some other settings set the NumberOfLevels
      {
        // try catch?
        this->m_theItkFilter->SetNumberOfLevels(std::stoi(criterion.second[0]));
        this->m_NumberOfLevelsLastSetBy = criterion.first;
      }
      else
      {
        if (this->m_theItkFilter->GetNumberOfLevels() != std::stoi(criterion.second[0]))
        {
          // TODO log error?
          std::cout << "A conflicting NumberOfLevels was set by " << this->m_NumberOfLevelsLastSetBy << std::endl;
          meetsCriteria = false;
          return meetsCriteria;
        }
      }
    }
    else
    {
      // TODO log error?
      std::cout << "NumberOfLevels accepts one number only" << std::endl;
      meetsCriteria = false;
      return meetsCriteria;
    }
    
  }
334
335
336
337
  else if (criterion.first == "ShrinkFactorsPerLevel") //Supports this?
  {
    meetsCriteria = true;

338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    const int impliedNumberOfResolutions = criterion.second.size();

    if (this->m_NumberOfLevelsLastSetBy == "") // check if some other settings set the NumberOfLevels
    {
      // try catch?
      this->m_theItkFilter->SetNumberOfLevels(impliedNumberOfResolutions);
      this->m_NumberOfLevelsLastSetBy = criterion.first;
    }
    else
    {
      if (this->m_theItkFilter->GetNumberOfLevels() != impliedNumberOfResolutions)
      {
        // TODO log error?
        std::cout << "A conflicting NumberOfLevels was set by " << this->m_NumberOfLevelsLastSetBy << std::endl;
        meetsCriteria = false;
        return meetsCriteria;;
      }
    }

357
    itk::Array<itk::SizeValueType>  shrinkFactorsPerLevel;
358
    shrinkFactorsPerLevel.SetSize(impliedNumberOfResolutions);
359
360
361
362
363
364
365
366
367
368

    unsigned int resolutionIndex = 0;
    for (auto const & criterionValue : criterion.second) // auto&& preferred?
    {
      shrinkFactorsPerLevel[resolutionIndex] = std::stoi(criterionValue);
      ++resolutionIndex;
    }
    // try catch?
    this->m_theItkFilter->SetShrinkFactorsPerLevel(shrinkFactorsPerLevel);
  }
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
  else if (criterion.first == "SmoothingSigmasPerLevel") //Supports this?
  {
    meetsCriteria = true;

    const int impliedNumberOfResolutions = criterion.second.size();

    if (this->m_NumberOfLevelsLastSetBy == "") // check if some other settings set the NumberOfLevels
    {
      // try catch?
      this->m_theItkFilter->SetNumberOfLevels(impliedNumberOfResolutions);
      this->m_NumberOfLevelsLastSetBy = criterion.first;
    }
    else
    {
      if (this->m_theItkFilter->GetNumberOfLevels() != impliedNumberOfResolutions)
      {
        // TODO log error?
        std::cout << "A conflicting NumberOfLevels was set by " << this->m_NumberOfLevelsLastSetBy << std::endl;
        meetsCriteria = false;
        return meetsCriteria;;
      }
    }

    itk::Array<TransformInternalComputationValueType>  smoothingSigmasPerLevel;

    smoothingSigmasPerLevel.SetSize(impliedNumberOfResolutions);

    unsigned int resolutionIndex = 0;
    for (auto const & criterionValue : criterion.second) // auto&& preferred?
    {
      smoothingSigmasPerLevel[resolutionIndex] = std::stoi(criterionValue);
      ++resolutionIndex;
    }
    // try catch?
    // Smooth by specified gaussian sigmas for each level.  These values are specified in
    // physical units.
    this->m_theItkFilter->SetSmoothingSigmasPerLevel(smoothingSigmasPerLevel);
  }

408
409
410
  return meetsCriteria;
}
} //end namespace selx