Skip to content

Commit

Permalink
ENH: Switched to gradient descent for exponential transform.
Browse files Browse the repository at this point in the history
  • Loading branch information
ntustison authored and hjmjohnson committed Jan 18, 2013
1 parent 470a2cb commit 5bedd97
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 16 deletions.
15 changes: 10 additions & 5 deletions Examples/itkantsRegistrationHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@
#ifndef __antsRegistrationHelper_h
#define __antsRegistrationHelper_h

#include <string>
#include <iostream>
#include <deque>

#include "antsCommandLineParser.h"
#include "itkANTSAffine3DTransform.h"
#include "itkANTSCenteredAffine2DTransform.h"
#include "itkANTSNeighborhoodCorrelationImageToImageMetricv4.h"
Expand All @@ -36,6 +31,7 @@
#include "itkBSplineTransformParametersAdaptor.h"
#include "itkCommand.h"
#include "itkCompositeTransform.h"
#include "itkConjugateGradientLineSearchOptimizerv4.h"
#include "itkCorrelationImageToImageMetricv4.h"
#include "itkDemonsImageToImageMetricv4.h"
#include "itkDisplacementFieldTransform.h"
Expand All @@ -45,6 +41,7 @@
#include "itkGaussianExponentialDiffeomorphicTransformParametersAdaptor.h"
#include "itkGaussianSmoothingOnUpdateDisplacementFieldTransform.h"
#include "itkGaussianSmoothingOnUpdateDisplacementFieldTransformParametersAdaptor.h"
#include "itkGradientDescentOptimizerv4.h"
#include "itkHistogramMatchingImageFilter.h"
#include "itkIdentityTransform.h"
#include "itkImage.h"
Expand All @@ -60,6 +57,7 @@
#include "itkMeanSquaresImageToImageMetricv4.h"
#include "itkObject.h"
#include "itkQuaternionRigidTransform.h"
#include "itkRegistrationParameterScalesFromPhysicalShift.h"
#include "itkSimilarity2DTransform.h"
#include "itkSimilarity3DTransform.h"
#include "itkSyNImageRegistrationMethod.h"
Expand All @@ -77,6 +75,13 @@

#include "itkantsReadWriteTransform.h"

#include "antsAllocImage.h"
#include "antsCommandLineParser.h"

#include <string>
#include <iostream>
#include <deque>

namespace ants
{
typedef itk::ants::CommandLineParser ParserType;
Expand Down
33 changes: 22 additions & 11 deletions Examples/itkantsRegistrationHelper.hxx
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
#ifndef __itkantsRegistrationHelper_hxx
#define __itkantsRegistrationHelper_hxx
#include "itkRegistrationParameterScalesFromPhysicalShift.h"
#include "itkDemonsImageToImageMetricv4.h"
#include "itkConjugateGradientLineSearchOptimizerv4.h"
#include "antsAllocImage.h"

namespace ants
{
Expand Down Expand Up @@ -52,7 +48,7 @@ public:
this->Logger() << " required fixed parameters = " << adaptors[currentLevel]->GetRequiredFixedParameters()
<< std::endl;

typedef itk::ConjugateGradientLineSearchOptimizerv4 GradientDescentOptimizerType;
typedef itk::GradientDescentOptimizerv4 GradientDescentOptimizerType;
GradientDescentOptimizerType * optimizer = reinterpret_cast<GradientDescentOptimizerType *>(
const_cast<typename TFilter::OptimizerType *>( const_cast<TFilter *>( filter )->GetOptimizer() ) );

Expand Down Expand Up @@ -1115,8 +1111,8 @@ RegistrationHelper<VImageDimension>
scalesEstimator->SetMetric( metric );
scalesEstimator->SetTransformForward( true );

typedef itk::ConjugateGradientLineSearchOptimizerv4 GradientDescentOptimizerType;
typename GradientDescentOptimizerType::Pointer optimizer = GradientDescentOptimizerType::New();
typedef itk::ConjugateGradientLineSearchOptimizerv4 ConjugateGradientDescentOptimizerType;
typename ConjugateGradientDescentOptimizerType::Pointer optimizer = ConjugateGradientDescentOptimizerType::New();
optimizer->SetLowerLimit( 0 );
optimizer->SetUpperLimit( 2 );
optimizer->SetEpsilon( 0.2 );
Expand All @@ -1129,11 +1125,26 @@ RegistrationHelper<VImageDimension>
optimizer->SetConvergenceWindowSize( convergenceWindowSize );
optimizer->SetDoEstimateLearningRateAtEachIteration( this->m_DoEstimateLearningRateAtEachIteration );
optimizer->SetDoEstimateLearningRateOnce( !this->m_DoEstimateLearningRateAtEachIteration );
typedef antsRegistrationOptimizerCommandIterationUpdate<GradientDescentOptimizerType> OptimizerCommandType;
typedef antsRegistrationOptimizerCommandIterationUpdate<ConjugateGradientDescentOptimizerType> OptimizerCommandType;
typename OptimizerCommandType::Pointer optimizerObserver = OptimizerCommandType::New();
optimizerObserver->SetLogStream( *this->m_LogStream );
optimizerObserver->SetOptimizer( optimizer );

typedef itk::GradientDescentOptimizerv4 GradientDescentOptimizerType;
typename GradientDescentOptimizerType::Pointer optimizer2 = GradientDescentOptimizerType::New();
optimizer2->SetLearningRate( learningRate );
optimizer2->SetMaximumStepSizeInPhysicalUnits( learningRate );
optimizer2->SetNumberOfIterations( currentStageIterations[0] );
optimizer2->SetScalesEstimator( scalesEstimator );
optimizer2->SetMinimumConvergenceValue( convergenceThreshold );
optimizer2->SetConvergenceWindowSize( convergenceWindowSize );
optimizer2->SetDoEstimateLearningRateAtEachIteration( this->m_DoEstimateLearningRateAtEachIteration );
optimizer2->SetDoEstimateLearningRateOnce( !this->m_DoEstimateLearningRateAtEachIteration );
typedef antsRegistrationOptimizerCommandIterationUpdate<GradientDescentOptimizerType> OptimizerCommandType2;
typename OptimizerCommandType2::Pointer optimizerObserver2 = OptimizerCommandType2::New();
optimizerObserver2->SetLogStream( *this->m_LogStream );
optimizerObserver2->SetOptimizer( optimizer2 );

// Set up the image registration methods along with the transforms
XfrmMethod whichTransform = this->m_TransformMethods[currentStage].m_XfrmMethod;

Expand Down Expand Up @@ -2529,7 +2540,7 @@ RegistrationHelper<VImageDimension>
displacementFieldRegistration->SetMetricSamplingStrategy(
static_cast<typename DisplacementFieldRegistrationType::MetricSamplingStrategyType>( metricSamplingStrategy ) );
displacementFieldRegistration->SetMetricSamplingPercentage( samplingPercentage );
displacementFieldRegistration->SetOptimizer( optimizer );
displacementFieldRegistration->SetOptimizer( optimizer2 );
displacementFieldRegistration->SetTransformParametersAdaptorsPerLevel( adaptors );
if( this->m_CompositeTransform->GetNumberOfTransforms() > 0 )
{
Expand All @@ -2543,7 +2554,7 @@ RegistrationHelper<VImageDimension>
typedef antsRegistrationCommandIterationUpdate<DisplacementFieldRegistrationType> DisplacementFieldCommandType;
typename DisplacementFieldCommandType::Pointer displacementFieldRegistrationObserver =
DisplacementFieldCommandType::New();
displacementFieldRegistrationObserver->SetLogStream(*this->m_LogStream);
displacementFieldRegistrationObserver->SetLogStream(*this->m_LogStream );
displacementFieldRegistrationObserver->SetNumberOfIterations( currentStageIterations );

displacementFieldRegistration->AddObserver( itk::IterationEvent(), displacementFieldRegistrationObserver );
Expand Down Expand Up @@ -2688,7 +2699,7 @@ RegistrationHelper<VImageDimension>
displacementFieldRegistration->SetMetricSamplingStrategy(
static_cast<typename DisplacementFieldRegistrationType::MetricSamplingStrategyType>( metricSamplingStrategy ) );
displacementFieldRegistration->SetMetricSamplingPercentage( samplingPercentage );
displacementFieldRegistration->SetOptimizer( optimizer );
displacementFieldRegistration->SetOptimizer( optimizer2 );
displacementFieldRegistration->SetTransformParametersAdaptorsPerLevel( adaptors );

typedef antsRegistrationCommandIterationUpdate<DisplacementFieldRegistrationType> DisplacementFieldCommandType;
Expand Down

0 comments on commit 5bedd97

Please sign in to comment.