"""
ANTsPy Registration
"""
__all__ = ["registration", "motion_correction"]
import os
import numpy as np
from tempfile import mktemp
import glob
import re
import pandas as pd
import itertools
from . import apply_transforms
from . import apply_transforms_to_points
from .. import utils
from ..core import ants_image as iio
from .. import core
[docs]def registration(
fixed,
moving,
type_of_transform="SyN",
initial_transform=None,
outprefix="",
mask=None,
moving_mask=None,
mask_all_stages=False,
grad_step=0.2,
flow_sigma=3,
total_sigma=0,
aff_metric="mattes",
aff_sampling=32,
aff_random_sampling_rate=0.2,
syn_metric="mattes",
syn_sampling=32,
reg_iterations=(40, 20, 0),
aff_iterations=(2100, 1200, 1200, 10),
aff_shrink_factors=(6, 4, 2, 1),
aff_smoothing_sigmas=(3, 2, 1, 0),
write_composite_transform=False,
random_seed=None,
verbose=False,
multivariate_extras=None,
restrict_transformation=None,
smoothing_in_mm=False,
**kwargs
):
"""
Register a pair of images either through the full or simplified
interface to the ANTs registration method.
ANTsR function: `antsRegistration`
Arguments
---------
fixed : ANTsImage
fixed image to which we register the moving image.
moving : ANTsImage
moving image to be mapped to fixed space.
type_of_transform : string
A linear or non-linear registration type. Mutual information metric by default.
See Notes below for more.
initial_transform : list of strings (optional)
transforms to prepend
outprefix : string
output will be named with this prefix.
mask : ANTsImage (optional)
Registration metric mask in the fixed image space.
moving_mask : ANTsImage (optional)
Registration metric mask in the moving image space.
mask_all_stages : boolean
If true, apply metric mask(s) to all registration stages, instead of just the final stage.
grad_step : scalar
gradient step size (not for all tx)
flow_sigma : scalar
smoothing for update field
At each iteration, the similarity metric and gradient is calculated.
That gradient field is also called the update field and is smoothed
before composing with the total field (i.e., the estimate of the total
transform at that iteration). This total field can also be smoothed
after each iteration.
total_sigma : scalar
smoothing for total field
aff_metric : string
the metric for the affine part (GC, mattes, meansquares)
aff_sampling : scalar
number of bins for the mutual information metric
aff_random_sampling_rate : scalar
the fraction of points used to estimate the metric. this can impact
speed but also reproducibility and/or accuracy.
syn_metric : string
the metric for the syn part (CC, mattes, meansquares, demons)
syn_sampling : scalar
the nbins or radius parameter for the syn metric
reg_iterations : list/tuple of integers
vector of iterations for syn. we will set the smoothing and multi-resolution parameters based on the length of this vector.
aff_iterations : list/tuple of integers
vector of iterations for low-dimensional (translation, rigid, affine) registration.
aff_shrink_factors : list/tuple of integers
vector of multi-resolution shrink factors for low-dimensional (translation, rigid, affine) registration.
aff_smoothing_sigmas : list/tuple of integers
vector of multi-resolution smoothing factors for low-dimensional (translation, rigid, affine) registration.
random_seed : integer
random seed to improve reproducibility. note that the number of ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS should be 1 if you want perfect reproducibility.
write_composite_transform : boolean
Boolean specifying whether or not the composite transform (and its inverse, if it exists) should be written to an hdf5 composite file. This is false by default so that only the transform for each stage is written to file.
verbose : boolean
request verbose output (useful for debugging)
multivariate_extras : additional metrics for multi-metric registration
list of additional images and metrics which will
trigger the use of multiple metrics in the registration
process in the deformable stage. Each multivariate metric needs 5
entries: name of metric, fixed, moving, weight,
samplingParam. the list of lists should be of the form ( (
"nameOfMetric2", img, img, weight, metricParam ) ). Another
example would be ( ( "MeanSquares", f2, m2, 0.5, 0
), ( "CC", f2, m2, 0.5, 2 ) ) . This is only compatible
with the SyNOnly or antsRegistrationSyN* transformations.
restrict_transformation : This option allows the user to restrict the
optimization of the displacement field, translation, rigid or
affine transform on a per-component basis. For example, if
one wants to limit the deformation or rotation of 3-D volume
to the first two dimensions, this is possible by specifying a
weight vector of ‘(1,1,0)’ for a 3D deformation field or
‘(1,1,0,1,1,0)’ for a rigid transformation. Restriction
currently only works if there are no preceding
transformations.
smoothing_in_mm : boolean ; currently only impacts low dimensional registration
kwargs : keyword args
extra arguments
Returns
-------
dict containing follow key/value pairs:
`warpedmovout`: Moving image warped to space of fixed image.
`warpedfixout`: Fixed image warped to space of moving image.
`fwdtransforms`: Transforms to move from moving to fixed image.
`invtransforms`: Transforms to move from fixed to moving image.
Notes
-----
type_of_transform can be one of:
- "Translation": Translation transformation.
- "Rigid": Rigid transformation: Only rotation and translation.
- "Similarity": Similarity transformation: scaling, rotation and translation.
- "QuickRigid": Rigid transformation: Only rotation and translation.
May be useful for quick visualization fixes.'
- "DenseRigid": Rigid transformation: Only rotation and translation.
Employs dense sampling during metric estimation.'
- "BOLDRigid": Rigid transformation: Parameters typical for BOLD to
BOLD intrasubject registration'.'
- "Affine": Affine transformation: Rigid + scaling.
- "AffineFast": Fast version of Affine.
- "BOLDAffine": Affine transformation: Parameters typical for BOLD to
BOLD intrasubject registration'.'
- "TRSAA": translation, rigid, similarity, affine (twice). please set
regIterations if using this option. this would be used in
cases where you want a really high quality affine mapping
(perhaps with mask).
- "Elastic": Elastic deformation: Affine + deformable.
- "ElasticSyN": Symmetric normalization: Affine + deformable
transformation, with mutual information as optimization
metric and elastic regularization.
- "SyN": Symmetric normalization: Affine + deformable transformation,
with mutual information as optimization metric.
- "SyNRA": Symmetric normalization: Rigid + Affine + deformable
transformation, with mutual information as optimization metric.
- "SyNOnly": Symmetric normalization: no initial transformation,
with mutual information as optimization metric. Assumes
images are aligned by an inital transformation. Can be
useful if you want to run an unmasked affine followed by
masked deformable registration.
- "SyNCC": SyN, but with cross-correlation as the metric.
- "SyNabp": SyN optimized for abpBrainExtraction.
- "SyNBold": SyN, but optimized for registrations between BOLD and T1 images.
- "SyNBoldAff": SyN, but optimized for registrations between BOLD
and T1 images, with additional affine step.
- "SyNAggro": SyN, but with more aggressive registration
(fine-scale matching and more deformation).
Takes more time than SyN.
- "TV[n]": time-varying diffeomorphism with where 'n' indicates number of
time points in velocity field discretization. The initial transform
should be computed, if needed, in a separate call to ants.registration.
- "TVMSQ": time-varying diffeomorphism with mean square metric
- "TVMSQC": time-varying diffeomorphism with mean square metric for very large deformation
- "antsRegistrationSyN[x]": recreation of the antsRegistrationSyN.sh script in ANTs
where 'x' is one of the transforms available (e.g., 't', 'b', 's')
- "antsRegistrationSyNQuick[x]": recreation of the antsRegistrationSyNQuick.sh script in ANTs
where 'x' is one of the transforms available (e.g., 't', 'b', 's')
- "antsRegistrationSyNRepro[x]": reproducible registration. x options as above.
- "antsRegistrationSyNQuickRepro[x]": quick reproducible registration. x options as above.
Example
-------
>>> import ants
>>> fi = ants.image_read(ants.get_ants_data('r16'))
>>> mi = ants.image_read(ants.get_ants_data('r64'))
>>> fi = ants.resample_image(fi, (60,60), 1, 0)
>>> mi = ants.resample_image(mi, (60,60), 1, 0)
>>> mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = 'SyN' )
"""
if isinstance(fixed, list) and (moving is None):
processed_args = utils._int_antsProcessArguments(fixed)
libfn = utils.get_lib_fn("antsRegistration")
reg_exit = libfn(processed_args)
if (reg_exit != 0):
raise RuntimeError(f"Registration failed with error code {reg_exit}")
else:
return 0
if not (isinstance(fixed, iio.ANTsImage) and isinstance(moving, iio.ANTsImage)):
raise ValueError("Fixed and moving images must be ANTsImage objects")
if type_of_transform == "":
type_of_transform = "SyN"
if isinstance(type_of_transform, (tuple, list)) and (len(type_of_transform) == 1):
type_of_transform = type_of_transform[0]
if (outprefix == "") or len(outprefix) == 0:
outprefix = mktemp()
if np.sum(np.isnan(fixed.numpy())) > 0:
raise ValueError("fixed image has NaNs - replace these")
if np.sum(np.isnan(moving.numpy())) > 0:
raise ValueError("moving image has NaNs - replace these")
if fixed.dimension != moving.dimension:
raise ValueError("Fixed and moving image dimensions are not the same.")
# ----------------------------
myiterations = aff_iterations
args = [fixed, moving, type_of_transform, outprefix]
myf_aff = "6x4x2x1" # old fixed params
mys_aff = "3x2x1x0" # old fixed params
if (
type(aff_shrink_factors) is int
or type(aff_smoothing_sigmas) is int
or type(aff_iterations) is int
):
if type(aff_smoothing_sigmas) is not int:
raise ValueError("aff_smoothing_sigmas should be a single integer.")
if type(aff_iterations) is not int:
raise ValueError("aff_iterations should be a single integer.")
if type(aff_shrink_factors) is not int:
raise ValueError("aff_shrink_factors should be a single integer.")
myf_aff = aff_shrink_factors
mys_aff = aff_smoothing_sigmas
myiterations = aff_iterations
if restrict_transformation is not None:
if type(restrict_transformation) is tuple:
restrict_transformationchar = "x".join([str(ri) for ri in restrict_transformation])
if type(aff_shrink_factors) is tuple:
myf_aff = "x".join([str(ri) for ri in aff_shrink_factors])
mys_aff = "x".join([str(ri) for ri in aff_smoothing_sigmas])
myiterations = "x".join([str(ri) for ri in aff_iterations])
if len(aff_iterations) != len(aff_smoothing_sigmas):
raise ValueError(
"aff_iterations length should equal aff_smoothing_sigmas length."
)
if len(aff_iterations) != len(aff_shrink_factors):
raise ValueError(
"aff_iterations length should equal aff_shrink_factors length."
)
if len(aff_shrink_factors) != len(aff_smoothing_sigmas):
raise ValueError(
"aff_shrink_factors length should equal aff_smoothing_sigmas length."
)
if type_of_transform == "AffineFast":
type_of_transform = "Affine"
myiterations = "2100x1200x0x0"
if type_of_transform == "BOLDAffine":
type_of_transform = "Affine"
myf_aff = "2x1"
mys_aff = "1x0"
myiterations = "100x20"
if type_of_transform == "QuickRigid":
type_of_transform = "Rigid"
myiterations = "20x20x0x0"
if type_of_transform == "DenseRigid":
type_of_transform = "Rigid"
aff_random_sampling_rate = 1.0
if type_of_transform == "BOLDRigid":
type_of_transform = "Rigid"
myf_aff = "2x1"
mys_aff = "1x0"
myiterations = "100x20"
if smoothing_in_mm:
mys_aff = mys_aff + 'mm'
mysyn = "SyN[%f,%f,%f]" % (grad_step, flow_sigma, total_sigma)
if type_of_transform == "Elastic":
mysyn = "GaussianDisplacementField[%f,%f,%f]" % (grad_step, flow_sigma, total_sigma)
itlen = len(reg_iterations) # NEED TO CHECK THIS
if itlen == 0:
smoothingsigmas = 0
shrinkfactors = 1
synits = reg_iterations
else:
smoothingsigmas = np.arange(0, itlen)[::-1].astype(
"float32"
) # NEED TO CHECK THIS
shrinkfactors = 2 ** smoothingsigmas
shrinkfactors = shrinkfactors.astype("int")
smoothingsigmas = "x".join([str(ss)[0] for ss in smoothingsigmas])
shrinkfactors = "x".join([str(ss) for ss in shrinkfactors])
synits = "x".join([str(ri) for ri in reg_iterations])
inpixeltype = fixed.pixeltype
tvTypes = [
"TV[1]",
"TV[2]",
"TV[3]",
"TV[4]",
"TV[5]",
"TV[6]",
"TV[7]",
"TV[8]",
]
allowable_tx = {
"SyNBold",
"SyNBoldAff",
"ElasticSyN",
"Elastic",
"SyN",
"SyNRA",
"SyNOnly",
"SyNAggro",
"SyNCC",
"TRSAA",
"SyNabp",
"SyNLessAggro",
"TV[1]",
"TV[2]",
"TV[3]",
"TV[4]",
"TV[5]",
"TV[6]",
"TV[7]",
"TV[8]",
"TVMSQ",
"TVMSQC",
"Rigid",
"Similarity",
"Translation",
"Affine",
"AffineFast",
"BOLDAffine",
"QuickRigid",
"DenseRigid",
"BOLDRigid"
}
ttexists = type_of_transform in allowable_tx
# Perform checking of antsRegistrationSyN transforms later
if not "antsRegistrationSyN" in type_of_transform and not ttexists:
raise ValueError(f'{type_of_transform} does not exist')
initx = initial_transform
if isinstance(initx, str):
initx = [initx]
# if isinstance(initx, ANTsTransform):
# tempTXfilename = tempfile( fileext = '.mat' )
# initx = invertAntsrTransform( initialTransform )
# initx = invertAntsrTransform( initx )
# writeAntsrTransform( initx, tempTXfilename )
# initx = tempTXfilename
moving = moving.clone("float")
fixed = fixed.clone("float")
# NOTE: this may be better for general purpose applications: TBD
# moving = utils.iMath( moving.clone("float"), "Normalize" )
# fixed = utils.iMath( fixed.clone("float"), "Normalize" )
warpedfixout = moving.clone()
warpedmovout = fixed.clone()
f = utils.get_pointer_string(fixed)
m = utils.get_pointer_string(moving)
wfo = utils.get_pointer_string(warpedfixout)
wmo = utils.get_pointer_string(warpedmovout)
if mask is not None:
mask_binary = mask != 0
f_mask_str = utils.get_pointer_string(mask_binary)
else:
f_mask_str = "NA"
if moving_mask is not None:
moving_mask_binary = moving_mask != 0
m_mask_str = utils.get_pointer_string(moving_mask_binary)
else:
m_mask_str = "NA"
maskopt = "[%s,%s]" % (f_mask_str, m_mask_str)
if mask_all_stages:
earlymaskopt = maskopt;
else:
earlymaskopt = "[NA,NA]"
if initx is None:
initx = ["[%s,%s,1]" % (f, m)]
# ------------------------------------------------------------
if type_of_transform == "SyNBold":
args = [
"-d",
str(fixed.dimension),
"-r"
] + initx + [
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Rigid[0.25]",
"-c",
"[1200x1200x100,1e-6,5]",
"-s",
"2x1x0",
"-f",
"4x2x1",
"-x",
earlymaskopt,
"-m",
"%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling),
"-t",
mysyn,
"-c",
"[%s,1e-7,8]" % synits,
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-u",
"1",
"-z",
"1",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
"-x",
maskopt
]
# ------------------------------------------------------------
elif type_of_transform == "SyNBoldAff":
args = [
"-d",
str(fixed.dimension),
"-r"
] + initx + [
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Rigid[0.25]",
"-c",
"[1200x1200x100,1e-6,5]",
"-s",
"2x1x0",
"-f",
"4x2x1",
"-x",
earlymaskopt,
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Affine[0.25]",
"-c",
"[200x20,1e-6,5]",
"-s",
"1x0",
"-f",
"2x1",
"-x",
earlymaskopt,
"-m",
"%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling),
"-t",
mysyn,
"-c",
"[%s,1e-7,8]" % (synits),
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-u",
"1",
"-z",
"1",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
"-x",
maskopt
]
# ------------------------------------------------------------
elif type_of_transform == "ElasticSyN":
args = [
"-d",
str(fixed.dimension),
"-r"
] + initx + [
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Affine[0.25]",
"-c",
"2100x1200x200x0",
"-s",
"3x2x1x0",
"-f",
"4x2x2x1",
"-x",
earlymaskopt,
"-m",
"%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling),
"-t",
mysyn,
"-c",
"[%s,1e-7,8]" % (synits),
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-u",
"1",
"-z",
"1",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
"-x",
maskopt
]
# ------------------------------------------------------------
elif type_of_transform == "SyN" or type_of_transform == "Elastic":
args = [
"-d",
str(fixed.dimension),
"-r"
] + initx + [
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Affine[0.25]",
"-c",
"2100x1200x1200x0",
"-s",
"3x2x1x0",
"-f",
"4x2x2x1",
"-x",
earlymaskopt,
"-m",
"%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling),
"-t",
mysyn,
"-c",
"[%s,1e-7,8]" % synits,
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-u",
"1",
"-z",
"1",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
"-x",
maskopt
]
# ------------------------------------------------------------
elif type_of_transform == "SyNRA":
args = [
"-d",
str(fixed.dimension),
"-r"
] + initx + [
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Rigid[0.25]",
"-c",
"2100x1200x1200x0",
"-s",
"3x2x1x0",
"-f",
"4x2x2x1",
"-x",
earlymaskopt,
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Affine[0.25]",
"-c",
"2100x1200x1200x0",
"-s",
"3x2x1x0",
"-f",
"4x2x2x1",
"-x",
earlymaskopt,
"-m",
"%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling),
"-t",
mysyn,
"-c",
"[%s,1e-7,8]" % synits,
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-u",
"1",
"-z",
"1",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
"-x",
maskopt
]
# ------------------------------------------------------------
elif type_of_transform == "SyNOnly":
args = [
"-d",
str(fixed.dimension),
"-r"
] + initx + [
"-m",
"%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling),
"-t",
mysyn,
"-c",
"[%s,1e-7,8]" % synits,
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-u",
"1",
"-z",
"1",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
]
if multivariate_extras is not None:
metrics = []
for kk in range(len(multivariate_extras)):
metrics.append("-m")
metricname = multivariate_extras[kk][0]
metricfixed = utils.get_pointer_string(
multivariate_extras[kk][1]
)
metricmov = utils.get_pointer_string(
multivariate_extras[kk][2]
)
metricWeight = multivariate_extras[kk][3]
metricSampling = multivariate_extras[kk][4]
metricString = "%s[%s,%s,%s,%s]" % (
metricname,
metricfixed,
metricmov,
metricWeight,
metricSampling,
)
metrics.append(metricString)
args = [
"-d",
str(fixed.dimension),
"-r"
] + initx + [
"-m",
"%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling),
]
args1 = [
"-t",
mysyn,
"-c",
"[%s,1e-7,8]" % synits,
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-u",
"1",
"-z",
"1",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
]
for kk in range(len(metrics)):
args.append(metrics[kk])
for kk in range(len(args1)):
args.append(args1[kk])
args.append("-x")
args.append(maskopt)
# ------------------------------------------------------------
elif type_of_transform == "SyNAggro":
args = [
"-d",
str(fixed.dimension),
"-r"
] + initx + [
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Affine[0.25]",
"-c",
"2100x1200x1200x100",
"-s",
"3x2x1x0",
"-f",
"4x2x2x1",
"-x",
earlymaskopt,
"-m",
"%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling),
"-t",
mysyn,
"-c",
"[%s,1e-7,8]" % synits,
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-u",
"1",
"-z",
"1",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
"-x",
maskopt
]
# ------------------------------------------------------------
elif type_of_transform == "SyNCC":
syn_metric = "CC"
syn_sampling = 4
synits = "2100x1200x1200x20"
smoothingsigmas = "3x2x1x0"
shrinkfactors = "4x3x2x1"
mysyn = "SyN[0.15,3,0]"
args = [
"-d",
str(fixed.dimension),
"-r"
] + initx + [
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Rigid[1]",
"-c",
"2100x1200x1200x0",
"-s",
"3x2x1x0",
"-f",
"4x4x2x1",
"-x",
earlymaskopt,
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Affine[1]",
"-c",
"1200x1200x100",
"-s",
"2x1x0",
"-f",
"4x2x1",
"-x",
earlymaskopt,
"-m",
"%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling),
"-t",
mysyn,
"-c",
"[%s,1e-7,8]" % synits,
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-u",
"1",
"-z",
"1",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
"-x",
maskopt
]
# ------------------------------------------------------------
elif type_of_transform == "TRSAA":
itlen = len(reg_iterations)
itlenlow = round(itlen / 2 + 0.0001)
dlen = itlen - itlenlow
_myconvlow = [2000] * itlenlow + [0] * dlen
myconvlow = "x".join([str(mc) for mc in _myconvlow])
myconvhi = "x".join([str(r) for r in reg_iterations])
myconvhi = "[%s,1.e-7,10]" % myconvhi
args = [
"-d",
str(fixed.dimension),
"-r"
] + initx + [
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Translation[1]",
"-c",
myconvlow,
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-x",
earlymaskopt,
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Rigid[1]",
"-c",
myconvlow,
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-x",
earlymaskopt,
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Similarity[1]",
"-c",
myconvlow,
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-x",
earlymaskopt,
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Affine[1]",
"-c",
myconvhi,
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-x",
earlymaskopt,
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Affine[1]",
"-c",
myconvhi,
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-u",
"1",
"-z",
"1",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
"-x",
maskopt
]
# ------------------------------------------------------------s
elif type_of_transform == "SyNabp":
args = [
"-d",
str(fixed.dimension),
"-r"
] + initx + [
"-m",
"mattes[%s,%s,1,32,regular,0.25]" % (f, m),
"-t",
"Rigid[0.1]",
"-c",
"1000x500x250x100",
"-s",
"4x2x1x0",
"-f",
"8x4x2x1",
"-x",
earlymaskopt,
"-m",
"mattes[%s,%s,1,32,regular,0.25]" % (f, m),
"-t",
"Affine[0.1]",
"-c",
"1000x500x250x100",
"-s",
"4x2x1x0",
"-f",
"8x4x2x1",
"-x",
earlymaskopt,
"-m",
"CC[%s,%s,0.5,4]" % (f, m),
"-t",
"SyN[0.1,3,0]",
"-c",
"50x10x0",
"-s",
"2x1x0",
"-f",
"4x2x1",
"-u",
"1",
"-z",
"1",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
"-x",
maskopt
]
# ------------------------------------------------------------
elif type_of_transform == "SyNLessAggro":
args = [
"-d",
str(fixed.dimension),
"-r"
] + initx + [
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"Affine[0.25]",
"-c",
"2100x1200x1200x100",
"-s",
"3x2x1x0",
"-f",
"4x2x2x1",
"-x",
earlymaskopt,
"-m",
"%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling),
"-t",
mysyn,
"-c",
"[%s,1e-7,8]" % synits,
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-u",
"1",
"-z",
"1",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
"-x",
maskopt
]
# ------------------------------------------------------------
elif type_of_transform in tvTypes:
if grad_step is None:
grad_step = 1.0
nTimePoints = type_of_transform.split("[")[1].split("]")[0]
tvtx = (
"TimeVaryingVelocityField["
+ str(grad_step)
+ ","
+ nTimePoints
+ ","
+ str(flow_sigma)
+ ",0.0,"
+ str(total_sigma)
+ ",0]"
)
args = [
"-d",
str(fixed.dimension),
"-r"
] + initx + [
"-m",
"%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling),
"-t",
tvtx,
"-c",
"[%s,1e-7,8]" % synits,
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-u",
"1",
"-z",
"0",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
"-x",
maskopt
]
elif type_of_transform == "TVMSQ":
if grad_step is None:
grad_step = 1.0
tvtx = "TimeVaryingVelocityField[%s, 4, 0.0,0.0, 0.5,0 ]" % str(
grad_step
)
args = [
"-d",
str(fixed.dimension),
# '-r', initx,
"-m",
"%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling),
"-t",
tvtx,
"-c",
"[%s,1e-7,8]" % synits,
"-s",
smoothingsigmas,
"-f",
shrinkfactors,
"-u",
"1",
"-z",
"0",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
"-x",
maskopt
]
# ------------------------------------------------------------
elif type_of_transform == "TVMSQC":
if grad_step is None:
grad_step = 2.0
tvtx = "TimeVaryingVelocityField[%s, 8, 1.0,0.0, 0.05,0 ]" % str(
grad_step
)
args = [
"-d",
str(fixed.dimension),
# '-r', initx,
"-m",
"demons[%s,%s,0.5,0]" % (f, m),
"-m",
"meansquares[%s,%s,1,0]" % (f, m),
"-t",
tvtx,
"-c",
"[1200x1200x100x20x0,0,5]",
"-s",
"8x6x4x2x1vox",
"-f",
"8x6x4x2x1",
"-u",
"1",
"-z",
"0",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
"-x",
maskopt
]
# ------------------------------------------------------------
elif (
(type_of_transform == "Rigid")
or (type_of_transform == "Similarity")
or (type_of_transform == "Translation")
or (type_of_transform == "Affine")
):
args = [
"-d",
str(fixed.dimension),
"-r"
] + initx + [
"-m",
"%s[%s,%s,1,%s,regular,%s]"
% (aff_metric, f, m, aff_sampling, aff_random_sampling_rate),
"-t",
"%s[0.25]" % type_of_transform,
"-c",
myiterations,
"-s",
mys_aff,
"-f",
myf_aff,
"-u",
"1",
"-z",
"1",
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
"-x",
maskopt
]
# ------------------------------------------------------------
elif "antsRegistrationSyN" in type_of_transform:
do_quick = False
if "Quick" in type_of_transform:
do_quick = True
subtype_of_transform = "s"
spline_distance = 26
metric_parameter = 4
if do_quick:
metric_parameter = 32
if "[" in type_of_transform and "]" in type_of_transform:
subtype_of_transform = type_of_transform.split("[")[1].split(
"]"
)[0]
if "," in subtype_of_transform:
subtype_of_transform_args = subtype_of_transform.split(",")
subtype_of_transform = subtype_of_transform_args[0]
if not ( subtype_of_transform == "b"
or subtype_of_transform == "br"
or subtype_of_transform == "bo"
or subtype_of_transform == "s"
or subtype_of_transform == "sr"
or subtype_of_transform == "so" ):
raise ValueError("Extra parameters are only valid for 's' or 'b' SyN transforms.")
metric_parameter = subtype_of_transform_args[1]
if len(subtype_of_transform_args) > 2:
spline_distance = subtype_of_transform_args[2]
do_repro = False
if "Repro" in type_of_transform:
do_repro = True
if do_quick == True:
rigid_convergence = "[1000x500x250x0,1e-6,10]"
else:
rigid_convergence = "[1000x500x250x100,1e-6,10]"
rigid_shrink_factors = "8x4x2x1"
rigid_smoothing_sigmas = "3x2x1x0vox"
if do_quick == True:
affine_convergence = "[1000x500x250x0,1e-6,10]"
else:
affine_convergence = "[1000x500x250x100,1e-6,10]"
affine_shrink_factors = "8x4x2x1"
affine_smoothing_sigmas = "3x2x1x0vox"
linear_metric="MI[%s,%s,1,32,Regular,0.25]"
if do_repro == True:
linear_metric="GC[%s,%s,1,1,Regular,0.25]"
if do_quick == True:
syn_convergence = "[100x70x50x0,1e-6,10]"
metric_parameter = 32
syn_metric = "MI[%s,%s,1,%s]" % (f, m, metric_parameter)
else:
metric_parameter = 2
syn_convergence = "[100x70x50x20,1e-6,10]"
syn_metric = "CC[%s,%s,1,%s]" % (f, m, metric_parameter)
syn_shrink_factors = "8x4x2x1"
syn_smoothing_sigmas = "3x2x1x0vox"
if do_quick == True and do_repro == True:
syn_convergence = "[100x70x50x0,1e-6,10]"
metric_parameter = 2
syn_metric = "CC[%s,%s,1,%s]" % (f, m, metric_parameter)
if random_seed is None and do_repro == True:
random_seed = str( 1 )
tx = "Rigid"
if subtype_of_transform == "t":
tx = "Translation"
rigid_stage = [
"--transform",
tx + "[0.1]",
"--metric",
linear_metric % (f, m),
"--convergence",
rigid_convergence,
"--shrink-factors",
rigid_shrink_factors,
"--smoothing-sigmas",
rigid_smoothing_sigmas,
]
affine_stage = [
"--transform",
"Affine[0.1]",
"--metric",
linear_metric % (f, m),
"--convergence",
affine_convergence,
"--shrink-factors",
affine_shrink_factors,
"--smoothing-sigmas",
affine_smoothing_sigmas,
]
if subtype_of_transform == "sr" or subtype_of_transform == "br":
if do_quick == True:
syn_convergence = "[50x0,1e-6,10]"
else:
syn_convergence = "[50x20,1e-6,10]"
syn_shrink_factors = "2x1"
syn_smoothing_sigmas = "1x0vox"
syn_stage = [
"--metric",
syn_metric,
]
if multivariate_extras is not None:
for kk in range(len(multivariate_extras)):
syn_stage.append("--metric")
metricname = multivariate_extras[kk][0]
metricfixed = utils.get_pointer_string(
multivariate_extras[kk][1]
)
metricmov = utils.get_pointer_string(
multivariate_extras[kk][2]
)
metricWeight = multivariate_extras[kk][3]
metricSampling = multivariate_extras[kk][4]
metricString = "%s[%s,%s,%s,%s]" % (
metricname,
metricfixed,
metricmov,
metricWeight,
metricSampling,
)
syn_stage.append(metricString)
syn_stage.append("--convergence")
syn_stage.append(syn_convergence)
syn_stage.append("--shrink-factors")
syn_stage.append(syn_shrink_factors)
syn_stage.append("--smoothing-sigmas")
syn_stage.append(syn_smoothing_sigmas)
if (
subtype_of_transform == "b"
or subtype_of_transform == "br"
or subtype_of_transform == "bo"
):
syn_stage.insert(0, "BSplineSyN[0.1," + str(spline_distance) + ",0,3]")
syn_stage.insert(0, "--transform")
if (
subtype_of_transform == "s"
or subtype_of_transform == "sr"
or subtype_of_transform == "so"
):
syn_stage.insert(0, "SyN[0.1,3,0]")
syn_stage.insert(0, "--transform")
args = [
"-d",
str(fixed.dimension),
"-r"
] + initx + [
"-o",
"[%s,%s,%s]" % (outprefix, wmo, wfo),
]
if subtype_of_transform == "r" or subtype_of_transform == "t":
args.append(rigid_stage)
if subtype_of_transform == "a":
args.append(rigid_stage)
args.append(affine_stage)
if subtype_of_transform == "b" or subtype_of_transform == "s":
args.append(rigid_stage)
args.append(affine_stage)
args.append(syn_stage)
if subtype_of_transform == "br" or subtype_of_transform == "sr":
args.append(rigid_stage)
args.append(syn_stage)
if subtype_of_transform == "bo" or subtype_of_transform == "so":
args.append(syn_stage)
args.append("-x")
args.append(maskopt)
args = list(
itertools.chain.from_iterable(
itertools.repeat(x, 1) if isinstance(x, str) else x
for x in args
)
)
# ------------------------------------------------------------
if random_seed is not None:
args.append("--random-seed")
args.append(random_seed)
if restrict_transformation is not None:
args.append("-g")
args.append(restrict_transformationchar)
args.append("--float")
args.append("1")
args.append("--write-composite-transform")
args.append(write_composite_transform * 1)
if verbose:
args.append("-v")
args.append("1")
processed_args = utils._int_antsProcessArguments(args)
libfn = utils.get_lib_fn("antsRegistration")
if verbose:
print("antsRegistration " + ' '.join(processed_args))
reg_exit = libfn(processed_args)
if (reg_exit != 0):
raise RuntimeError(f"Registration failed with error code {reg_exit}")
afffns = glob.glob(outprefix + "*" + "[0-9]GenericAffine.mat")
fwarpfns = glob.glob(outprefix + "*" + "[0-9]Warp.nii.gz")
iwarpfns = glob.glob(outprefix + "*" + "[0-9]InverseWarp.nii.gz")
vfieldfns = glob.glob(outprefix + "*" + "[0-9]VelocityField.nii.gz")
# print(afffns, fwarpfns, iwarpfns)
if len(afffns) == 0:
afffns = ""
if len(fwarpfns) == 0:
fwarpfns = ""
if len(iwarpfns) == 0:
iwarpfns = ""
if len(vfieldfns) == 0:
vfieldfns = ""
alltx = sorted(
set(glob.glob(outprefix + "*" + "[0-9]*"))
- set(glob.glob(outprefix + "*VelocityField*"))
)
findinv = np.where(
[re.search("[0-9]InverseWarp.nii.gz", ff) for ff in alltx]
)[0]
findfwd = np.where([re.search("[0-9]Warp.nii.gz", ff) for ff in alltx])[
0
]
if len(findinv) > 0:
fwdtransforms = list(
reversed(
[ff for idx, ff in enumerate(alltx) if idx != findinv[0]]
)
)
invtransforms = [
ff for idx, ff in enumerate(alltx) if idx != findfwd[0]
]
else:
fwdtransforms = list(reversed(alltx))
invtransforms = alltx
if write_composite_transform:
fwdtransforms = outprefix + "Composite.h5"
invtransforms = outprefix + "InverseComposite.h5"
if not vfieldfns:
return {
"warpedmovout": warpedmovout.clone(inpixeltype),
"warpedfixout": warpedfixout.clone(inpixeltype),
"fwdtransforms": fwdtransforms,
"invtransforms": invtransforms,
}
else:
return {
"warpedmovout": warpedmovout.clone(inpixeltype),
"warpedfixout": warpedfixout.clone(inpixeltype),
"fwdtransforms": fwdtransforms,
"invtransforms": invtransforms,
"velocityfield": vfieldfns,
}
[docs]def motion_correction(
image,
fixed=None,
type_of_transform="BOLDRigid",
mask=None,
fdOffset=50,
outprefix="",
verbose=False,
**kwargs
):
"""
Correct time-series data for motion.
ANTsR function: `antsrMotionCalculation`
Arguments
---------
image: antsImage, usually ND where D=4.
fixed: Fixed image to register all timepoints to. If not provided,
mean image is used.
type_of_transform : string
A linear or non-linear registration type. Mutual information metric and rigid transformation by default.
See ants registration for details.
mask: mask for image (ND-1). If not provided, estimated from data.
2023-02-05: a performance change - previously, we estimated a mask
when None is provided and would pass this to the registration. this
impairs performance if the mask estimate is bad. in such a case, we
prefer no mask at all. As such, we no longer pass the mask to the
registration when None is provided.
fdOffset: offset value to use in framewise displacement calculation
outprefix : string
output will be named with this prefix plus a numeric extension.
verbose: boolean
kwargs: keyword args
extra arguments - these extra arguments will control the details of registration that is performed. see ants registration for more.
Returns
-------
dict containing follow key/value pairs:
`motion_corrected`: Moving image warped to space of fixed image.
`motion_parameters`: transforms for each image in the time series.
`FD`: Framewise displacement generalized for arbitrary transformations.
Notes
-----
Control extra arguments via kwargs. see ants.registration for details.
Example
-------
>>> import ants
>>> fi = ants.image_read(ants.get_ants_data('ch2'))
>>> mytx = ants.motion_correction( fi )
"""
idim = image.dimension
ishape = image.shape
nTimePoints = ishape[idim - 1]
if fixed is None:
wt = 1.0 / nTimePoints
fixed = utils.slice_image(image, axis=idim - 1, idx=0) * 0
for k in range(nTimePoints):
temp = utils.slice_image(image, axis=idim - 1, idx=k)
fixed = fixed + utils.iMath(temp,"Normalize") * wt
if mask is None:
mask = utils.get_mask(fixed)
useMask=None
else:
useMask=mask
FD = np.zeros(nTimePoints)
motion_parameters = list()
motion_corrected = list()
centerOfMass = mask.get_center_of_mass()
npts = pow(2, idim - 1)
pointOffsets = np.zeros((npts, idim - 1))
myrad = np.ones(idim - 1).astype(int).tolist()
mask1vals = np.zeros(int(mask.sum()))
mask1vals[round(len(mask1vals) / 2)] = 1
mask1 = core.make_image(mask, mask1vals)
myoffsets = utils.get_neighborhood_in_mask(
mask1, mask1, radius=myrad, spatial_info=True
)["offsets"]
mycols = list("xy")
if idim - 1 == 3:
mycols = list("xyz")
useinds = list()
for k in range(myoffsets.shape[0]):
if abs(myoffsets[k, :]).sum() == (idim - 2):
useinds.append(k)
myoffsets[k, :] = myoffsets[k, :] * fdOffset / 2.0 + centerOfMass
fdpts = pd.DataFrame(data=myoffsets[useinds, :], columns=mycols)
if verbose:
print("Progress:")
counter = 0
for k in range(nTimePoints):
mycount = round(k / nTimePoints * 100)
if verbose and mycount == counter:
counter = counter + 10
print(mycount, end="%.", flush=True)
temp = utils.slice_image(image, axis=idim - 1, idx=k)
temp = utils.iMath(temp, "Normalize")
if temp.numpy().var() > 0:
if outprefix != "":
outprefixloc = outprefix + "_" + str.zfill( str(k), 5 ) + "_"
myreg = registration(
fixed, temp, type_of_transform=type_of_transform, mask=useMask,
outprefix=outprefixloc, **kwargs
)
else:
myreg = registration(
fixed, temp, type_of_transform=type_of_transform, mask=useMask, **kwargs
)
fdptsTxI = apply_transforms_to_points(
idim - 1, fdpts, myreg["fwdtransforms"]
)
if k > 0 and motion_parameters[k - 1] != "NA":
fdptsTxIminus1 = apply_transforms_to_points(
idim - 1, fdpts, motion_parameters[k - 1]
)
else:
fdptsTxIminus1 = fdptsTxI
# take the absolute value, then the mean across columns, then the sum
FD[k] = (fdptsTxIminus1 - fdptsTxI).abs().mean().sum()
motion_parameters.append(myreg["fwdtransforms"])
mywarped = apply_transforms( fixed,
utils.slice_image(image, axis=idim - 1, idx=k),
myreg["fwdtransforms"] )
motion_corrected.append(mywarped)
else:
motion_parameters.append("NA")
motion_corrected.append(temp)
if verbose:
print("Done")
return {
"motion_corrected": utils.list_to_ndimage(image, motion_corrected),
"motion_parameters": motion_parameters,
"FD": FD,
}