Source code for ants.contrib.sklearn_interface.sklearn_registration



__all__ = ['RigidRegistration']

from ...registration import interface, apply_transforms


class Registration(object):
    """
    How would it work:

    # Co-registration within-visit
    reg = Registration('Rigid', fixed_image=t1_template, 
                       save_dir=save_dir, save_suffix='_coreg')
    for img in other_imgs:
        reg.fit(img)
    
    for img in [flair_img, t2_img]:
        reg = Registration('Rigid', fixed_image=t1_template,
                           save_dir=save_dir, save_suffix='_coreg')
        reg.fit(img)
    
    # Co-registration across-visit
    reg = Registration('Rigid', fixed_image=t1)
    reg.fit(moving=t1_followup)
    
    # now align all followups with first visit
    for img in [flair_follwup, t2_followup]:
        img_reg = reg.transform(img)
    
    # conversly, align all first visits with followups
    for img in [flair, t2]:
        img_reg = reg.inverse_transform(img)
    """

    def __init__(self, type_of_transform, fixed_image):
        """
        Properties:
            type_of_transform
            fixed_image (template)
            save_dir (where to save outputs)
            save_suffix (what to append to saved outputs)
            save_prefix (what to preppend to saved outputs)
        """
        self.type_of_transform = type_of_transform
        self.fixed_image = fixed_image

    def fit(self, X, y=None):
        """
        X : ANTsImage | string | list of ANTsImage types | list of strings
            images to register to fixed image

        y : string | list of strings
            labels for images
        """
        moving_images = X if isinstance(X, (list,tuple)) else [X]
        moving_labels = y if y is not None else [i for i in range(len(moving_images))]
        fixed_image = self.fixed_image

        self.fwdtransforms_ = {}
        self.invtransforms_ = {}
        self.warpedmovout_ = {}
        self.warpedfixout_ = {}
        
        for moving_image, moving_label in zip(moving_images, moving_labels):
            fit_result = interface.registration(fixed_image,
                                                moving_image,
                                                type_of_transform=self.type_of_transform,
                                                initial_transform=None,
                                                outprefix='',
                                                mask=None,
                                                grad_step=0.2,
                                                flow_sigma=3,
                                                total_sigma=0,
                                                aff_metric='mattes',
                                                aff_sampling=32,
                                                syn_metric='mattes',
                                                syn_sampling=32,
                                                reg_iterations=(40,20,0),
                                                verbose=False)

            self.fwdtransforms_[moving_label] = fit_result['fwdtransforms']
            self.invtransforms_[moving_label] = fit_result['invtransforms']
            self.warpedmovout_[moving_label]  = fit_result['warpedmovout']
            self.warpedfixout_[moving_label]  = fit_result['warpedfixout']

        return self

    def transform(self, X, y=None):
        pass


[docs]class RigidRegistration(object): """ Rigid Registration as a Scikit-Learn compatible transform class Example ------- >>> import ants >>> import ants.extra as extrants >>> fi = ants.image_read(ants.get_data('r16')) >>> mi = ants.image_read(ants.get_data('r64')) >>> regtx = extrants.RigidRegistration() >>> regtx.fit(fi, mi) >>> mi_r = regtx.transform(mi) >>> ants.plot(fi, mi_r.iMath_Canny(1, 2, 4).iMath('MD',1)) """ def __init__(self, fixed_image=None): self.type_of_transform = 'Rigid' self.fixed_image = fixed_image
[docs] def fit(self, moving_image, fixed_image=None): if fixed_image is None: if self.fixed_image is None: raise ValueError('must give fixed_image in fit() or set it in __init__') fixed_image = self.fixed_image fit_result = interface.registration(fixed_image, moving_image, type_of_transform=self.type_of_transform, initial_transform=None, outprefix='', mask=None, grad_step=0.2, flow_sigma=3, total_sigma=0, aff_metric='mattes', aff_sampling=32, syn_metric='mattes', syn_sampling=32, reg_iterations=(40,20,0), verbose=False) self._fit_result = fit_result self.fwdtransforms_ = fit_result['fwdtransforms'] self.invtransforms_ = fit_result['invtransforms'] self.warpedmovout_ = fit_result['warpedmovout'] self.warpedfiout_ = fit_result['warpedfixout']
[docs] def transform(self, moving_image, fixed_image=None): result = apply_transforms(fixed=fixed_image, moving=moving_image, transformlist=self.fwdtransforms) return result
[docs] def inverse_transform(self, moving_image, fixed_image=None): result = apply_transforms(fixed=fixed_image, moving=moving_image, transformlist=self.invtransforms) return result