diff --git a/Code/Registration/include/sitkImageRegistrationMethod.h b/Code/Registration/include/sitkImageRegistrationMethod.h index 34a8405d4..d2ea13a2b 100644 --- a/Code/Registration/include/sitkImageRegistrationMethod.h +++ b/Code/Registration/include/sitkImageRegistrationMethod.h @@ -177,6 +177,18 @@ namespace simple { return this->m_FixedInitialTransform; } /**@}*/ + + /** \brief Set the virtual domain used for sampling + * + * @{ + */ + SITK_RETURN_SELF_TYPE_HEADER SetVirtualDomain( const std::vector &virtualSize, + const std::vector &virtualOrigin, + const std::vector &virtualSpacing, + const std::vector &virtualDirection ); + SITK_RETURN_SELF_TYPE_HEADER SetVirtualDomainFromImage( const Image &virtualImage ); + /**@}*/ + /** \brief Use normalized cross correlation using a small * neighborhood for each voxel between two images, with speed * optimizations for dense registration. @@ -601,6 +613,11 @@ namespace simple Transform m_MovingInitialTransform; Transform m_FixedInitialTransform; + std::vector m_VirtualDomainSize; + std::vector m_VirtualDomainOrigin; + std::vector m_VirtualDomainSpacing; + std::vector m_VirtualDomainDirection; + // optimizer enum OptimizerType { ConjugateGradientLineSearch, RegularStepGradientDescent, diff --git a/Code/Registration/src/sitkImageRegistrationMethod.cxx b/Code/Registration/src/sitkImageRegistrationMethod.cxx index 2f2a77f10..d84c7a6fa 100644 --- a/Code/Registration/src/sitkImageRegistrationMethod.cxx +++ b/Code/Registration/src/sitkImageRegistrationMethod.cxx @@ -131,6 +131,47 @@ ImageRegistrationMethod::SetInitialTransform ( Transform &transform, bool inPlac return *this; } +ImageRegistrationMethod::Self& +ImageRegistrationMethod::SetVirtualDomain( const std::vector &virtualSize, + const std::vector &virtualOrigin, + const std::vector &virtualSpacing, + const std::vector &virtualDirection ) +{ + const size_t dim = virtualSize.size(); + + if ( virtualOrigin.size() != dim ) + { + sitkExceptionMacro("Expected virtualOrigin to be of length " << dim << "!" ); + } + + if ( virtualSpacing.size() != dim ) + { + sitkExceptionMacro("Expected virtualSpacing to be of length " << dim << "!" ); + } + + if ( virtualDirection.size() != dim*dim ) + { + sitkExceptionMacro("Expected virtualDirection to be of length " << dim*dim << "!" ); + } + + this->m_VirtualDomainSize = virtualSize; + this->m_VirtualDomainOrigin = virtualOrigin; + this->m_VirtualDomainSpacing = virtualSpacing; + this->m_VirtualDomainDirection = virtualDirection; + return *this; +} + +ImageRegistrationMethod::Self& +ImageRegistrationMethod::SetVirtualDomainFromImage( const Image &virtualImage ) +{ + this->m_VirtualDomainSize = virtualImage.GetSize(); + this->m_VirtualDomainOrigin = virtualImage.GetOrigin(); + this->m_VirtualDomainSpacing = virtualImage.GetSpacing(); + this->m_VirtualDomainDirection = virtualImage.GetDirection(); + + return *this; +} + ImageRegistrationMethod::Self& ImageRegistrationMethod::SetMetricAsANTSNeighborhoodCorrelation( unsigned int radius ) { @@ -967,6 +1008,17 @@ ImageRegistrationMethod::SetupMetric( metric->SetUseFixedImageGradientFilter( m_MetricUseFixedImageGradientFilter ); metric->SetUseMovingImageGradientFilter( m_MetricUseMovingImageGradientFilter ); + if ( this->m_VirtualDomainSize.size() != 0 ) + { + typename FixedImageType::SpacingType itkSpacing = sitkSTLVectorToITK(this->m_VirtualDomainSpacing); + typename FixedImageType::PointType itkOrigin = sitkSTLVectorToITK(this->m_VirtualDomainOrigin); + typename FixedImageType::DirectionType itkDirection = sitkSTLToITKDirection(this->m_VirtualDomainDirection); + + typename FixedImageType::RegionType itkRegion; + itkRegion.SetSize( sitkSTLVectorToITK( this->m_VirtualDomainSize ) ); + + metric->SetVirtualDomain( itkSpacing, itkOrigin, itkDirection, itkRegion ); + } typedef itk::InterpolateImageFunction< FixedImageType, double > FixedInterpolatorType; typename FixedInterpolatorType::Pointer fixedInterpolator = CreateInterpolator(fixed, m_Interpolator); diff --git a/Testing/Unit/sitkImageRegistrationMethodTests.cxx b/Testing/Unit/sitkImageRegistrationMethodTests.cxx index de534cf4e..6f63405a4 100644 --- a/Testing/Unit/sitkImageRegistrationMethodTests.cxx +++ b/Testing/Unit/sitkImageRegistrationMethodTests.cxx @@ -118,7 +118,7 @@ class sitkRegistrationMethodTest const std::vector &pt1, const std::vector &size) { - sitk::GaussianImageSource source1; + sitk::GaussianImageSource source1; source1.SetMean(pt0); source1.SetScale(1.0); @@ -162,7 +162,6 @@ TEST_F(sitkRegistrationMethodTest, Metric_Evaluate) sitk::Image fixed = fixedBlobs; sitk::Image moving = fixedBlobs; - sitk::ImageRegistrationMethod R; R.SetInitialTransform(sitk::Transform(fixed.GetDimension(),sitk::sitkIdentity)); @@ -441,6 +440,186 @@ TEST_F(sitkRegistrationMethodTest, Mask_Test2) } +TEST_F(sitkRegistrationMethodTest, VirtualDomain_Test) +{ + // Test usage of setting virtual domain + + sitk::ImageRegistrationMethod R; + R.SetInterpolator(sitk::sitkLinear); + //R.DebugOn(); + sitk::Image virtualImage = MakeGaussianBlob( v2(32,32), std::vector(2,64) ); + + R.SetVirtualDomainFromImage(virtualImage); + + // transform to optimize + sitk::TranslationTransform tx(virtualImage.GetDimension()); + tx.SetOffset(v2(3.2,-1.2)); + R.SetInitialTransform(tx, false); + + sitk::Image fixedImage = virtualImage; + fixedImage.SetOrigin(v2(100, 0)); + + // virtual image to fixed image + sitk::TranslationTransform fixedTransform(fixedImage.GetDimension()); + fixedTransform.SetOffset(v2(100, 0)); + R.SetFixedInitialTransform(fixedTransform); + + sitk::Image movingImage = virtualImage; + movingImage.SetOrigin(v2(0, 200)); + + // transform from virtual domain to moving image with "optimizing" transform + sitk::TranslationTransform movingTransform(movingImage.GetDimension()); + movingTransform.SetOffset(v2(0, 200)); + R.SetMovingInitialTransform(movingTransform); + + R.SetMetricAsCorrelation(); + + double minStep=1e-5; + unsigned int numberOfIterations=100; + double relaxationFactor=0.75; + double gradientMagnitudeTolerance = 1e-20; + R.SetOptimizerAsRegularStepGradientDescent(.1, + minStep, + numberOfIterations, + relaxationFactor, + gradientMagnitudeTolerance); + + + IterationUpdate cmd(R); + R.AddCommand(sitk::sitkIterationEvent, cmd); + + sitk::Transform outTx = R.Execute(fixedImage, movingImage); + + + std::cout << "-------" << std::endl; + std::cout << outTx.ToString() << std::endl; + std::cout << "Optimizer stop condition: " << R.GetOptimizerStopConditionDescription() << std::endl; + std::cout << " Iteration: " << R.GetOptimizerIteration() << std::endl; + std::cout << " Metric value: " << R.GetMetricValue() << std::endl; + + EXPECT_VECTOR_DOUBLE_NEAR(v2(0.0,0.0), outTx.GetParameters(), 1e-3); + EXPECT_GT( R.GetOptimizerIteration(), 1u ); + +} + +TEST_F(sitkRegistrationMethodTest, VirtualDomain_MultiRes_Test) +{ + // Test usage of setting virtual domain + + sitk::ImageRegistrationMethod R; + R.SetInterpolator(sitk::sitkLinear); + //R.DebugOn(); + sitk::Image virtualImage = MakeGaussianBlob( v3(32,32,32), std::vector(3,64) ); + + R.SetVirtualDomainFromImage(virtualImage); + + // transform to optimize + sitk::TranslationTransform tx(virtualImage.GetDimension()); + tx.SetOffset(v3(5.21231, 3.2,-1.2)); + R.SetInitialTransform(tx, false); + + sitk::Image fixedImage = virtualImage; + fixedImage.SetOrigin(v3(1000, 100, 0)); + + // virtual image to fixed image + sitk::TranslationTransform fixedTransform(fixedImage.GetDimension()); + fixedTransform.SetOffset(v3(1000, 100, 0)); + R.SetFixedInitialTransform(fixedTransform); + + sitk::Image movingImage = virtualImage; + movingImage.SetOrigin(v3(0, 200, 512)); + + // transform from virtual domain to moving image with "optimizing" transform + sitk::TranslationTransform movingTransform(movingImage.GetDimension()); + movingTransform.SetOffset(v3(0, 200, 512)); + R.SetMovingInitialTransform(movingTransform); + + R.SetMetricAsMeanSquares(); + + double minStep=1e-3; + unsigned int numberOfIterations=10; + double relaxationFactor=0.6; + double gradientMagnitudeTolerance = 1e-10; + sitk::ImageRegistrationMethod::EstimateLearningRateType estimateLearningRate = sitk::ImageRegistrationMethod::Never; + R.SetOptimizerAsRegularStepGradientDescent(2, + minStep, + numberOfIterations, + relaxationFactor, + gradientMagnitudeTolerance, + estimateLearningRate); + + std::vector shrinkFactors(2); + shrinkFactors[0] = 8; + shrinkFactors[1] = 1; + R.SetShrinkFactorsPerLevel( shrinkFactors ); + R.SetSmoothingSigmasPerLevel( v2(0.0, 0.0) ); + + R.SetOptimizerScalesFromJacobian(); + + IterationUpdate cmd(R); + R.AddCommand(sitk::sitkIterationEvent, cmd); + + sitk::Transform outTx = R.Execute(fixedImage, movingImage); + + + std::cout << "-------" << std::endl; + std::cout << outTx.ToString() << std::endl; + std::cout << "Optimizer stop condition: " << R.GetOptimizerStopConditionDescription() << std::endl; + std::cout << " Iteration: " << R.GetOptimizerIteration() << std::endl; + std::cout << " Metric value: " << R.GetMetricValue() << std::endl; + + EXPECT_VECTOR_DOUBLE_NEAR(v3(0.0,0.0,0.0), outTx.GetParameters(), 1e-1); + EXPECT_GT( R.GetOptimizerIteration(), 1u ); + +} + + + +TEST_F(sitkRegistrationMethodTest, VirtualDomain_Parameters) +{ + + sitk::ImageRegistrationMethod R; + R.SetInterpolator(sitk::sitkLinear); + + sitk::Image virtualImage = MakeGaussianBlob( v3(32,32,32), std::vector(3,64) ); + + std::vector size(3, 64); + + EXPECT_NO_THROW( + R.SetVirtualDomain(size, + v3(0.0,0.0,0.0), + v3(1.0,1.0,1.0), + v9(1.0,0.0,0.0, + 0.0,1.0,0.0, + 0.0,0.0,1.0)) ); + + EXPECT_THROW( + R.SetVirtualDomain(size, + v2(0.0,0.0), + v3(1.0,1.0,1.0), + v9(1.0,0.0,0.0, + 0.0,1.0,0.0, + 0.0,0.0,1.0)), + sitk::GenericException); + + + EXPECT_THROW( + R.SetVirtualDomain(size, + v3(0.0,0.0,0.0), + v2(1.0,1.0), + v9(1.0,0.0,0.0, + 0.0,1.0,0.0, + 0.0,0.0,1.0)), + sitk::GenericException ); + + EXPECT_THROW( + R.SetVirtualDomain(size, + v3(0.0,0.0,0.0), + v3(1.0,1.0,1.0), + v3(1.0,0.0,0.0)), + sitk::GenericException ); +} + TEST_F(sitkRegistrationMethodTest, OptimizerWeights_Test) { // Test the usage of optimizer weights