Saturday, February 8, 2014

Rigid Image Registration (a quick & dirty intensity-based approach)




Fig. 1. Showing the original image, the manual transformed, and the automatically corrected image via registration algorithm.

Image registration is useful when you want to compare different images from the same object but taken from different angles. In order to make the images comparable, one has to correct the transformation of the target images, so that the target images align with the reference image. This is a quite useful technique with a broad range of applications. One of them is in Medical Imaging where images of the same anatomical structures are taken with different modalities and at different time points.

In this example I am loading good old Lena into the application to demonstrate the function of the code. The important part of this implementation it the finding of initial values (guess) for the fmin (optimization-) function. Without that step, the function is very likely not able to find the correct transformation parameters. Thus, to provide an initial guess, I let the x,y (image shift) and the r (image rotation) vary from -10 to 10 and calculate the intensity-based image errors. Subsequently, the minima for these variations are determined and used as initial guess. (Maybe not the best method but quite simple and robust). Then, I pretty much let the fmin function do the magic to search for the parameters that yield the smallest value returned form the intensity-based image error function.    

The Python code:



       
'''
@author: Christian Rossmann, PhD
@license:  Public Domain
@blog: http://scientificcomputingco.blogspot.com/
'''

import numpy as np
import Image
from scipy import ndimage
from scipy import optimize
from scipy import misc

from pylab import *

def MeasureErr(img1,img2):
    diff = (img1-img2)
    return sum(diff**2)

def RigidRegistration(img,ximg):
    
    # Perform initial guess rotation & translation 
    v_range =  np.array(xrange(-10,10))
    
    err = np.array([MeasureErr(img,ndimage.shift(ximg,(v,0))) for v in v_range])
    x = v_range[where(err==err.min())[0]]
    
    err = np.array([MeasureErr(img,ndimage.shift(ximg,(0,v))) for v in v_range])
    y = v_range[where(err==err.min())[0]]

    err = np.array([MeasureErr(img,ndimage.rotate(ximg,v,reshape=0)) for v in v_range])
    r = v_range[where(err==err.min())[0]]

    # List contains displacement in x and y and rotation
    
    param = [x,y,r]
    
    def ErrFunc(param,img=img,ximg=ximg):
        
        # Perform rotational and translational transformation
        
        _img = ximg.copy()
        _img = ndimage.rotate(_img,param[2],reshape=0)
        _img = ndimage.shift(_img,param[:2])
        
        return MeasureErr(img,_img)

    
    param = optimize.fmin(ErrFunc,param)
    
    #Final transformation
    _img = ximg.copy()
    _img = ndimage.rotate(_img,param[2],reshape=0)
    _img = ndimage.shift(_img,param[:2])
    
    return (_img,param)

img = misc.lena().astype('float32')

# Normalize image (0-1)
img -= img.min() 
img /= img.max()

# Generate transformed image
ximg = img.copy()
ximg = ndimage.shift(ximg,(5,-1))
ximg = ndimage.rotate(ximg,-4,reshape=0)

(rimg,param) =  RigidRegistration(img,ximg)

 
figure(1)
clf()
subplot(1,3,1)
title('Original')
imshow(img)
subplot(1,3,2)
title('Transformed')
imshow(ximg)
subplot(1,3,3)
title('Registered')
imshow(rimg)
show()

3 comments:

  1. I get an error:
    "RuntimeError: sequence argument must have length equal to input rank"

    Only thing I've changes is lena() to face() since lena is not supported anymore :(

    ReplyDelete
    Replies
    1. This comment has been removed by the author.

      Delete
    2. In case you're still facing the error:
      This error comes up when the number of elements in the shift tuple is less than the number of dimensions of the input image. Specifically in a RGB image with 3 channels, the dimensionality is 3. If the shift tuple has only 2 elements, the function would not know how to handle the third dimension.
      Fix: ndimage.shift(rgb_image,(0,h_shift,w_shift))

      Delete