Source code for ants.viz.plot

"""
Create a static 2D image of a 2D ANTsImage
or a tile of slices from a 3D ANTsImage

TODO:
- add `plot_multichannel` function for plotting multi-channel images
    - support for quivers as well
- add `plot_gif` function for making a gif/video or 2D slices across a 3D image
"""


__all__ = [
    "plot",
    "movie",
    "plot_hist",
    "plot_grid",
    "plot_ortho",
    "plot_ortho_double",
    "plot_ortho_stack",
    "plot_directory",
]

import fnmatch
import math
import os
import warnings

from matplotlib import gridspec
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import matplotlib.lines as mlines
import matplotlib.patches as patches
import matplotlib.mlab as mlab
import matplotlib.animation as animation
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


import numpy as np

from .. import registration as reg
from ..core import ants_image as iio
from ..core import ants_image_io as iio2
from ..core import ants_transform as tio
from ..core import ants_transform_io as tio2


[docs]def movie(image, filename=None, writer=None, fps=30): """ Create and save a movie - mp4, gif, etc - of the various 2D slices of a 3D ants image Try this: conda install -c conda-forge ffmpeg Example ------- >>> import ants >>> mni = ants.image_read(ants.get_data('mni')) >>> ants.movie(mni, filename='~/desktop/movie.mp4') """ image = image.pad_image() img_arr = image.numpy() minidx = max(0, np.where(image > 0)[0][0] - 5) maxidx = max(image.shape[0], np.where(image > 0)[0][-1] + 5) # Creare your figure and axes fig, ax = plt.subplots(1) im = ax.imshow( img_arr[minidx, :, :], animated=True, cmap="Greys_r", vmin=image.quantile(0.05), vmax=image.quantile(0.95), ) ax.axis("off") def init(): fig.axes("off") return (im,) def updatefig(frame): im.set_array(img_arr[frame, :, :]) return (im,) ani = animation.FuncAnimation( fig, updatefig, frames=np.arange(minidx, maxidx), # init_func=init, interval=50, blit=True, ) if writer is None: writer = animation.FFMpegWriter(fps=fps) if filename is not None: filename = os.path.expanduser(filename) ani.save(filename, writer=writer) else: plt.show()
[docs]def plot_hist( image, threshold=0.0, fit_line=False, normfreq=True, ## plot label arguments title=None, grid=True, xlabel=None, ylabel=None, ## other plot arguments facecolor="green", alpha=0.75, ): """ Plot a histogram from an ANTsImage Arguments --------- image : ANTsImage image from which histogram will be created """ img_arr = image.numpy().flatten() img_arr = img_arr[np.abs(img_arr) > threshold] if normfreq != False: normfreq = 1.0 if normfreq == True else normfreq n, bins, patches = plt.hist( img_arr, 50, facecolor=facecolor, alpha=alpha ) if fit_line: # add a 'best fit' line y = mlab.normpdf(bins, img_arr.mean(), img_arr.std()) l = plt.plot(bins, y, "r--", linewidth=1) if xlabel is not None: plt.xlabel(xlabel) if ylabel is not None: plt.ylabel(ylabel) if title is not None: plt.title(title) plt.grid(grid) plt.show()
[docs]def plot_grid( images, slices=None, axes=2, # general figure arguments figsize=1.0, rpad=0, cpad=0, vmin=None, vmax=None, colorbar=True, cmap="Greys_r", # title arguments title=None, tfontsize=20, title_dx=0, title_dy=0, # row arguments rlabels=None, rfontsize=14, rfontcolor="white", rfacecolor="black", # column arguments clabels=None, cfontsize=14, cfontcolor="white", cfacecolor="black", # save arguments filename=None, dpi=400, transparent=True, # other args **kwargs ): """ Plot a collection of images in an arbitrarily-defined grid Matplotlib named colors: https://matplotlib.org/examples/color/named_colors.html Arguments --------- images : list of ANTsImage types image(s) to plot. if one image, this image will be used for all grid locations. if multiple images, they should be arrange in a list the same shape as the `gridsize` argument. slices : integer or list of integers slice indices to plot if one integer, this slice index will be used for all images if multiple integers, they should be arranged in a list the same shape as the `gridsize` argument axes : integer or list of integers axis or axes along which to plot image slices if one integer, this axis will be used for all images if multiple integers, they should be arranged in a list the same shape as the `gridsize` argument Example ------- >>> import ants >>> import numpy as np >>> mni1 = ants.image_read(ants.get_data('mni')) >>> mni2 = mni1.smooth_image(1.) >>> mni3 = mni1.smooth_image(2.) >>> mni4 = mni1.smooth_image(3.) >>> images = np.asarray([[mni1, mni2], ... [mni3, mni4]]) >>> slices = np.asarray([[100, 100], ... [100, 100]]) >>> #axes = np.asarray([[2,2],[2,2]]) >>> # standard plotting >>> ants.plot_grid(images=images, slices=slices, title='2x2 Grid') >>> ants.plot_grid(images.reshape(1,4), slices.reshape(1,4), title='1x4 Grid') >>> ants.plot_grid(images.reshape(4,1), slices.reshape(4,1), title='4x1 Grid') >>> # Padding between rows and/or columns >>> ants.plot_grid(images, slices, cpad=0.02, title='Col Padding') >>> ants.plot_grid(images, slices, rpad=0.02, title='Row Padding') >>> ants.plot_grid(images, slices, rpad=0.02, cpad=0.02, title='Row and Col Padding') >>> # Adding plain row and/or column labels >>> ants.plot_grid(images, slices, title='Adding Row Labels', rlabels=['Row #1', 'Row #2']) >>> ants.plot_grid(images, slices, title='Adding Col Labels', clabels=['Col #1', 'Col #2']) >>> ants.plot_grid(images, slices, title='Row and Col Labels', rlabels=['Row 1', 'Row 2'], clabels=['Col 1', 'Col 2']) >>> # Making a publication-quality image >>> images = np.asarray([[mni1, mni2, mni2], ... [mni3, mni4, mni4]]) >>> slices = np.asarray([[100, 100, 100], ... [100, 100, 100]]) >>> axes = np.asarray([[0, 1, 2], [0, 1, 2]]) >>> ants.plot_grid(images, slices, axes, title='Publication Figures with ANTsPy', tfontsize=20, title_dy=0.03, title_dx=-0.04, rlabels=['Row 1', 'Row 2'], clabels=['Col 1', 'Col 2', 'Col 3'], rfontsize=16, cfontsize=16) """ def mirror_matrix(x): return x[::-1, :] def rotate270_matrix(x): return mirror_matrix(x.T) def rotate180_matrix(x): return x[::-1, ::-1] def rotate90_matrix(x): return mirror_matrix(x).T def flip_matrix(x): return mirror_matrix(rotate180_matrix(x)) def reorient_slice(x, axis): if axis != 1: x = rotate90_matrix(x) if axis == 1: x = rotate90_matrix(x) x = mirror_matrix(x) return x def slice_image(img, axis, idx): if axis == 0: return img[idx, :, :] elif axis == 1: return img[:, idx, :] elif axis == 2: return img[:, :, idx] elif axis == -1: return img[:, :, idx] elif axis == -2: return img[:, idx, :] elif axis == -3: return img[idx, :, :] else: raise ValueError("axis %i not valid" % axis) if isinstance(images, np.ndarray): images = images.tolist() if not isinstance(images, list): raise ValueError("images argument must be of type list") if not isinstance(images[0], list): images = [images] if isinstance(slices, int): one_slice = True if isinstance(slices, np.ndarray): slices = slices.tolist() if isinstance(slices, list): one_slice = False if not isinstance(slices[0], list): slices = [slices] nslicerow = len(slices) nslicecol = len(slices[0]) nrow = len(images) ncol = len(images[0]) if rlabels is None: rlabels = [None] * nrow if clabels is None: clabels = [None] * ncol if not one_slice: if (nrow != nslicerow) or (ncol != nslicecol): raise ValueError( "`images` arg shape (%i,%i) must equal `slices` arg shape (%i,%i)!" % (nrow, ncol, nslicerow, nslicecol) ) fig = plt.figure(figsize=((ncol + 1) * 2.5 * figsize, (nrow + 1) * 2.5 * figsize)) if title is not None: basex = 0.5 basey = 0.9 if clabels[0] is None else 0.95 fig.suptitle(title, fontsize=tfontsize, x=basex + title_dx, y=basey + title_dy) if (cpad > 0) and (rpad > 0): bothgridpad = max(cpad, rpad) cpad = 0 rpad = 0 else: bothgridpad = 0.0 gs = gridspec.GridSpec( nrow, ncol, wspace=bothgridpad, hspace=0.0, top=1.0 - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1) + cpad, left=0.5 / (ncol + 1) + rpad, right=1 - 0.5 / (ncol + 1), ) if isinstance(vmin, (int, float)): vmins = [vmin] * nrow elif vmin is None: vmins = [None] * nrow else: vmins = vmin if isinstance(vmax, (int, float)): vmaxs = [vmax] * nrow elif vmax is None: vmaxs = [None] * nrow else: vmaxs = vmax if isinstance(cmap, str): cmaps = [cmap] * nrow elif cmap is None: cmaps = [None] * nrow else: cmaps = cmap for rowidx, rvmin, rvmax, rcmap in zip(range(nrow), vmins, vmaxs, cmaps): for colidx in range(ncol): ax = plt.subplot(gs[rowidx, colidx]) if colidx == 0: if rlabels[rowidx] is not None: bottom, height = 0.25, 0.5 top = bottom + height # add label text ax.text( -0.07, 0.5 * (bottom + top), rlabels[rowidx], horizontalalignment="right", verticalalignment="center", rotation="vertical", transform=ax.transAxes, color=rfontcolor, fontsize=rfontsize, ) # add label background extra = 0.3 if rowidx == 0 else 0.0 rect = patches.Rectangle( (-0.3, 0), 0.3, 1.0 + extra, facecolor=rfacecolor, alpha=1.0, transform=ax.transAxes, clip_on=False, ) ax.add_patch(rect) if rowidx == 0: if clabels[colidx] is not None: bottom, height = 0.25, 0.5 left, width = 0.25, 0.5 right = left + width top = bottom + height ax.text( 0.5 * (left + right), 0.09 + top + bottom, clabels[colidx], horizontalalignment="center", verticalalignment="center", rotation="horizontal", transform=ax.transAxes, color=cfontcolor, fontsize=cfontsize, ) # add label background rect = patches.Rectangle( (0, 1.0), 1.0, 0.3, facecolor=cfacecolor, alpha=1.0, transform=ax.transAxes, clip_on=False, ) ax.add_patch(rect) tmpimg = images[rowidx][colidx] if isinstance(axes, int): tmpaxis = axes else: tmpaxis = axes[rowidx][colidx] sliceidx = slices[rowidx][colidx] if not one_slice else slices tmpslice = slice_image(tmpimg, tmpaxis, sliceidx) tmpslice = reorient_slice(tmpslice, tmpaxis) im = ax.imshow(tmpslice, cmap=rcmap, aspect="auto", vmin=rvmin, vmax=rvmax) ax.axis("off") # A colorbar solution with make_axes_locatable will not allow y-scaling of the colorbar. # from mpl_toolkits.axes_grid1 import make_axes_locatable # divider = make_axes_locatable(ax) # cax = divider.append_axes('right', size='5%', pad=0.05) if colorbar: axins = inset_axes(ax, width="5%", # width = 5% of parent_bbox width height="90%", # height : 50% loc='center left', bbox_to_anchor=(1.03, 0., 1, 1), bbox_transform=ax.transAxes, borderpad=0, ) fig.colorbar(im, cax=axins, orientation='vertical') if filename is not None: filename = os.path.expanduser(filename) plt.savefig(filename, dpi=dpi, transparent=transparent, bbox_inches="tight") plt.close(fig) else: plt.show()
[docs]def plot_ortho_stack( images, overlays=None, reorient=True, # xyz arguments xyz=None, xyz_lines=False, xyz_color="red", xyz_alpha=0.6, xyz_linewidth=2, xyz_pad=5, # base image arguments cmap="Greys_r", alpha=1, # overlay arguments overlay_cmap="jet", overlay_alpha=0.9, # background arguments black_bg=True, bg_thresh_quant=0.01, bg_val_quant=0.99, # scale/crop/domain arguments crop=False, scale=False, domain_image_map=None, # title arguments title=None, titlefontsize=24, title_dx=0, title_dy=0, # 4th panel text arguemnts text=None, textfontsize=24, textfontcolor="white", text_dx=0, text_dy=0, # save & size arguments filename=None, dpi=500, figsize=1.0, colpad=0, rowpad=0, transpose=False, transparent=True, orient_labels=True, ): """ Create a stack of orthographic plots with optional overlays. Use mask_image and/or threshold_image to preprocess images to be be overlaid and display the overlays in a given range. See the wiki examples. Example ------- >>> import ants >>> mni = ants.image_read(ants.get_data('mni')) >>> ch2 = ants.image_read(ants.get_data('ch2')) >>> ants.plot_ortho_stack([mni,mni,mni]) """ def mirror_matrix(x): return x[::-1, :] def rotate270_matrix(x): return mirror_matrix(x.T) def reorient_slice(x, axis): return rotate270_matrix(x) # need this hack because of a weird NaN warning from matplotlib with overlays warnings.simplefilter("ignore") n_images = len(images) # handle `image` argument for i in range(n_images): if isinstance(images[i], str): images[i] = iio2.image_read(images[i]) if not isinstance(images[i], iio.ANTsImage): raise ValueError("image argument must be an ANTsImage") if images[i].dimension != 3: raise ValueError("Input image must have 3 dimensions!") if overlays is None: overlays = [None] * n_images # handle `overlay` argument for i in range(n_images): if overlays[i] is not None: if isinstance(overlays[i], str): overlays[i] = iio2.image_read(overlays[i]) if not isinstance(overlays[i], iio.ANTsImage): raise ValueError("overlay argument must be an ANTsImage") if overlays[i].components > 1: raise ValueError("overlays[i] cannot have more than one voxel component") if overlays[i].dimension != 3: raise ValueError("Overlay image must have 3 dimensions!") if not iio.image_physical_space_consistency(images[i], overlays[i]): overlays[i] = reg.resample_image_to_target( overlays[i], images[i], interp_type="linear" ) for i in range(1, n_images): if not iio.image_physical_space_consistency(images[0], images[i]): images[i] = reg.resample_image_to_target( images[0], images[i], interp_type="linear" ) # reorient images if reorient != False: if reorient == True: reorient = "RPI" for i in range(n_images): images[i] = images[i].reorient_image2(reorient) if overlays[i] is not None: overlays[i] = overlays[i].reorient_image2(reorient) # handle `slices` argument if xyz is None: xyz = [int(s / 2) for s in images[0].shape] for i in range(3): if xyz[i] is None: xyz[i] = int(images[0].shape[i] / 2) # resample image if spacing is very unbalanced spacing = [s for i, s in enumerate(images[0].spacing)] if (max(spacing) / min(spacing)) > 3.0: new_spacing = (1, 1, 1) for i in range(n_images): images[i] = images[i].resample_image(tuple(new_spacing)) if overlays[i] is not None: overlays[i] = overlays[i].resample_image(tuple(new_spacing)) xyz = [ int(sl * (sold / snew)) for sl, sold, snew in zip(xyz, spacing, new_spacing) ] # potentially crop image if crop: for i in range(n_images): plotmask = images[i].get_mask(cleanup=0) if plotmask.max() == 0: plotmask += 1 images[i] = images[i].crop_image(plotmask) if overlays[i] is not None: overlays[i] = overlays[i].crop_image(plotmask) # pad images for i in range(n_images): if i == 0: images[i], lowpad, uppad = images[i].pad_image(return_padvals=True) else: images[i] = images[i].pad_image() if overlays[i] is not None: overlays[i] = overlays[i].pad_image() xyz = [v + l for v, l in zip(xyz, lowpad)] # handle `domain_image_map` argument if domain_image_map is not None: if isinstance(domain_image_map, iio.ANTsImage): tx = tio2.new_ants_transform( precision="float", transform_type="AffineTransform", dimension=3 ) for i in range(n_images): images[i] = tio.apply_ants_transform_to_image( tx, images[i], domain_image_map ) if overlays[i] is not None: overlays[i] = tio.apply_ants_transform_to_image( tx, overlays[i], domain_image_map, interpolation="linear" ) elif isinstance(domain_image_map, (list, tuple)): # expect an image and transformation if len(domain_image_map) != 2: raise ValueError("domain_image_map list or tuple must have length == 2") dimg = domain_image_map[0] if not isinstance(dimg, iio.ANTsImage): raise ValueError("domain_image_map first entry should be ANTsImage") tx = domain_image_map[1] for i in range(n_images): images[i] = reg.apply_transforms(dimg, images[i], transform_list=tx) if overlays[i] is not None: overlays[i] = reg.apply_transforms( dimg, overlays[i], transform_list=tx, interpolator="linear" ) # potentially find dynamic range if scale == True: vmins = [] vmaxs = [] for i in range(n_images): vmin, vmax = images[i].quantile((0.05, 0.95)) vmins.append(vmin) vmaxs.append(vmax) elif isinstance(scale, (list, tuple)): if len(scale) != 2: raise ValueError( "scale argument must be boolean or list/tuple with two values" ) vmins = [] vmaxs = [] for i in range(n_images): vmin, vmax = images[i].quantile(scale) vmins.append(vmin) vmaxs.append(vmax) else: vmin = None vmax = None if not transpose: nrow = n_images ncol = 3 else: nrow = 3 ncol = n_images fig = plt.figure(figsize=((ncol + 1) * 2.5 * figsize, (nrow + 1) * 2.5 * figsize)) if title is not None: basey = 0.93 basex = 0.5 fig.suptitle( title, fontsize=titlefontsize, color=textfontcolor, x=basex + title_dx, y=basey + title_dy ) if (colpad > 0) and (rowpad > 0): bothgridpad = max(colpad, rowpad) colpad = 0 rowpad = 0 else: bothgridpad = 0.0 gs = gridspec.GridSpec( nrow, ncol, wspace=bothgridpad, hspace=0.0, top=1.0 - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1) + colpad, left=0.5 / (ncol + 1) + rowpad, right=1 - 0.5 / (ncol + 1), ) # pad image to have isotropic array dimensions vminols=[] vmaxols=[] for i in range(n_images): images[i] = images[i].numpy() if overlays[i] is not None: vminols.append( overlays[i].min() ) vmaxols.append( overlays[i].max() ) overlays[i] = overlays[i].numpy() if overlays[i].dtype not in ["uint8", "uint32"]: overlays[i][np.abs(overlays[i]) == 0] = np.nan #################### #################### for i in range(n_images): yz_slice = reorient_slice(images[i][xyz[0], :, :], 0) if not transpose: ax = plt.subplot(gs[i, 0]) else: ax = plt.subplot(gs[0, i]) ax.imshow(yz_slice, cmap=cmap, vmin=vmin, vmax=vmax) if overlays[i] is not None: yz_overlay = reorient_slice(overlays[i][xyz[0], :, :], 0) ax.imshow(yz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminols[i], vmax=vmaxols[i]) if xyz_lines: # add lines l = mlines.Line2D( [yz_slice.shape[0] - xyz[1], yz_slice.shape[0] - xyz[1]], [xyz_pad, yz_slice.shape[0] - xyz_pad], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) l = mlines.Line2D( [xyz_pad, yz_slice.shape[1] - xyz_pad], [yz_slice.shape[1] - xyz[2], yz_slice.shape[1] - xyz[2]], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) if orient_labels: ax.text( 0.5, 0.98, "S", horizontalalignment="center", verticalalignment="top", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.5, 0.02, "I", horizontalalignment="center", verticalalignment="bottom", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.98, 0.5, "A", horizontalalignment="right", verticalalignment="center", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.02, 0.5, "P", horizontalalignment="left", verticalalignment="center", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.axis("off") #################### #################### xz_slice = reorient_slice(images[i][:, xyz[1], :], 1) if not transpose: ax = plt.subplot(gs[i, 1]) else: ax = plt.subplot(gs[1, i]) ax.imshow(xz_slice, cmap=cmap, vmin=vmin, vmax=vmax) if overlays[i] is not None: xz_overlay = reorient_slice(overlays[i][:, xyz[1], :], 1) ax.imshow(xz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminols[i], vmax=vmaxols[i]) if xyz_lines: # add lines l = mlines.Line2D( [xz_slice.shape[0] - xyz[0], xz_slice.shape[0] - xyz[0]], [xyz_pad, xz_slice.shape[0] - xyz_pad], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) l = mlines.Line2D( [xyz_pad, xz_slice.shape[1] - xyz_pad], [xz_slice.shape[1] - xyz[2], xz_slice.shape[1] - xyz[2]], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) if orient_labels: ax.text( 0.5, 0.98, "A", horizontalalignment="center", verticalalignment="top", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.5, 0.02, "P", horizontalalignment="center", verticalalignment="bottom", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.98, 0.5, "L", horizontalalignment="right", verticalalignment="center", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.02, 0.5, "R", horizontalalignment="left", verticalalignment="center", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.axis("off") #################### #################### xy_slice = reorient_slice(images[i][:, :, xyz[2]], 2) if not transpose: ax = plt.subplot(gs[i, 2]) else: ax = plt.subplot(gs[2, i]) ax.imshow(xy_slice, cmap=cmap, vmin=vmin, vmax=vmax) if overlays[i] is not None: xy_overlay = reorient_slice(overlays[i][:, :, xyz[2]], 2) ax.imshow(xy_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminols[i], vmax=vmaxols[i]) if xyz_lines: # add lines l = mlines.Line2D( [xy_slice.shape[0] - xyz[0], xy_slice.shape[0] - xyz[0]], [xyz_pad, xy_slice.shape[0] - xyz_pad], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) l = mlines.Line2D( [xyz_pad, xy_slice.shape[1] - xyz_pad], [xy_slice.shape[1] - xyz[1], xy_slice.shape[1] - xyz[1]], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) if orient_labels: ax.text( 0.5, 0.98, "A", horizontalalignment="center", verticalalignment="top", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.5, 0.02, "P", horizontalalignment="center", verticalalignment="bottom", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.98, 0.5, "L", horizontalalignment="right", verticalalignment="center", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.02, 0.5, "R", horizontalalignment="left", verticalalignment="center", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.axis("off") #################### #################### if filename is not None: plt.savefig(filename, dpi=dpi, transparent=transparent) plt.close(fig) else: plt.show() # turn warnings back to default warnings.simplefilter("default")
[docs]def plot_ortho_double( image, image2, overlay=None, overlay2=None, reorient=True, # xyz arguments xyz=None, xyz_lines=True, xyz_color="red", xyz_alpha=0.6, xyz_linewidth=2, xyz_pad=5, # base image arguments cmap="Greys_r", alpha=1, cmap2="Greys_r", alpha2=1, # overlay arguments overlay_cmap="jet", overlay_alpha=0.9, overlay_cmap2="jet", overlay_alpha2=0.9, # background arguments black_bg=True, bg_thresh_quant=0.01, bg_val_quant=0.99, # scale/crop/domain arguments crop=False, scale=False, crop2=False, scale2=True, domain_image_map=None, # title arguments title=None, titlefontsize=24, title_dx=0, title_dy=0, # 4th panel text arguemnts text=None, textfontsize=24, textfontcolor="white", text_dx=0, text_dy=0, # save & size arguments filename=None, dpi=500, figsize=1.0, flat=True, transpose=False, transparent=True, ): """ Create a pair of orthographic plots with overlays. Use mask_image and/or threshold_image to preprocess images to be be overlaid and display the overlays in a given range. See the wiki examples. Example ------- >>> import ants >>> mni = ants.image_read(ants.get_data('mni')) >>> ch2 = ants.image_read(ants.get_data('ch2')) >>> ants.plot_ortho_double(mni, ch2) """ def mirror_matrix(x): return x[::-1, :] def rotate270_matrix(x): return mirror_matrix(x.T) def reorient_slice(x, axis): return rotate270_matrix(x) # need this hack because of a weird NaN warning from matplotlib with overlays warnings.simplefilter("ignore") # handle `image` argument if isinstance(image, str): image = iio2.image_read(image) if not isinstance(image, iio.ANTsImage): raise ValueError("image argument must be an ANTsImage") if image.dimension != 3: raise ValueError("Input image must have 3 dimensions!") if isinstance(image2, str): image2 = iio2.image_read(image2) if not isinstance(image2, iio.ANTsImage): raise ValueError("image2 argument must be an ANTsImage") if image2.dimension != 3: raise ValueError("Input image2 must have 3 dimensions!") # handle `overlay` argument if overlay is not None: if isinstance(overlay, str): overlay = iio2.image_read(overlay) if not isinstance(overlay, iio.ANTsImage): raise ValueError("overlay argument must be an ANTsImage") if overlay.components > 1: raise ValueError("overlay cannot have more than one voxel component") if overlay.dimension != 3: raise ValueError("Overlay image must have 3 dimensions!") if not iio.image_physical_space_consistency(image, overlay): overlay = reg.resample_image_to_target(overlay, image, interp_type="linear") if overlay2 is not None: if isinstance(overlay2, str): overlay2 = iio2.image_read(overlay2) if not isinstance(overlay2, iio.ANTsImage): raise ValueError("overlay2 argument must be an ANTsImage") if overlay2.components > 1: raise ValueError("overlay2 cannot have more than one voxel component") if overlay2.dimension != 3: raise ValueError("Overlay2 image must have 3 dimensions!") if not iio.image_physical_space_consistency(image2, overlay2): overlay2 = reg.resample_image_to_target( overlay2, image2, interp_type="linear" ) if not iio.image_physical_space_consistency(image, image2): image2 = reg.resample_image_to_target(image2, image, interp_type="linear") if image.pixeltype not in {"float", "double"}: scale = False # turn off scaling if image is discrete if image2.pixeltype not in {"float", "double"}: scale2 = False # turn off scaling if image is discrete # reorient images if reorient != False: if reorient == True: reorient = "RPI" image = image.reorient_image2(reorient) image2 = image2.reorient_image2(reorient) if overlay is not None: overlay = overlay.reorient_image2(reorient) if overlay2 is not None: overlay2 = overlay2.reorient_image2(reorient) # handle `slices` argument if xyz is None: xyz = [int(s / 2) for s in image.shape] for i in range(3): if xyz[i] is None: xyz[i] = int(image.shape[i] / 2) # resample image if spacing is very unbalanced spacing = [s for i, s in enumerate(image.spacing)] if (max(spacing) / min(spacing)) > 3.0: new_spacing = (1, 1, 1) image = image.resample_image(tuple(new_spacing)) image2 = image2.resample_image_to_target(tuple(new_spacing)) if overlay is not None: overlay = overlay.resample_image(tuple(new_spacing)) if overlay2 is not None: overlay2 = overlay2.resample_image(tuple(new_spacing)) xyz = [ int(sl * (sold / snew)) for sl, sold, snew in zip(xyz, spacing, new_spacing) ] # pad images image, lowpad, uppad = image.pad_image(return_padvals=True) image2, lowpad2, uppad2 = image2.pad_image(return_padvals=True) xyz = [v + l for v, l in zip(xyz, lowpad)] if overlay is not None: overlay = overlay.pad_image() if overlay2 is not None: overlay2 = overlay2.pad_image() # handle `domain_image_map` argument if domain_image_map is not None: if isinstance(domain_image_map, iio.ANTsImage): tx = tio2.new_ants_transform( precision="float", transform_type="AffineTransform", dimension=image.dimension, ) image = tio.apply_ants_transform_to_image(tx, image, domain_image_map) image2 = tio.apply_ants_transform_to_image(tx, image2, domain_image_map) if overlay is not None: overlay = tio.apply_ants_transform_to_image( tx, overlay, domain_image_map, interpolation="linear" ) if overlay2 is not None: overlay2 = tio.apply_ants_transform_to_image( tx, overlay2, domain_image_map, interpolation="linear" ) elif isinstance(domain_image_map, (list, tuple)): # expect an image and transformation if len(domain_image_map) != 2: raise ValueError("domain_image_map list or tuple must have length == 2") dimg = domain_image_map[0] if not isinstance(dimg, iio.ANTsImage): raise ValueError("domain_image_map first entry should be ANTsImage") tx = domain_image_map[1] image = reg.apply_transforms(dimg, image, transform_list=tx) if overlay is not None: overlay = reg.apply_transforms( dimg, overlay, transform_list=tx, interpolator="linear" ) image2 = reg.apply_transforms(dimg, image2, transform_list=tx) if overlay2 is not None: overlay2 = reg.apply_transforms( dimg, overlay2, transform_list=tx, interpolator="linear" ) ## single-channel images ## if image.components == 1: # potentially crop image if crop: plotmask = image.get_mask(cleanup=0) if plotmask.max() == 0: plotmask += 1 image = image.crop_image(plotmask) if overlay is not None: overlay = overlay.crop_image(plotmask) if crop2: plotmask2 = image2.get_mask(cleanup=0) if plotmask2.max() == 0: plotmask2 += 1 image2 = image2.crop_image(plotmask2) if overlay2 is not None: overlay2 = overlay2.crop_image(plotmask2) # potentially find dynamic range if scale == True: vmin, vmax = image.quantile((0.05, 0.95)) elif isinstance(scale, (list, tuple)): if len(scale) != 2: raise ValueError( "scale argument must be boolean or list/tuple with two values" ) vmin, vmax = image.quantile(scale) else: vmin = None vmax = None if scale2 == True: vmin2, vmax2 = image2.quantile((0.05, 0.95)) elif isinstance(scale2, (list, tuple)): if len(scale2) != 2: raise ValueError( "scale2 argument must be boolean or list/tuple with two values" ) vmin2, vmax2 = image2.quantile(scale2) else: vmin2 = None vmax2 = None if not flat: nrow = 2 ncol = 4 else: if not transpose: nrow = 2 ncol = 3 else: nrow = 3 ncol = 2 fig = plt.figure( figsize=((ncol + 1) * 2.5 * figsize, (nrow + 1) * 2.5 * figsize) ) if title is not None: basey = 0.88 if not flat else 0.66 basex = 0.5 fig.suptitle( title, fontsize=titlefontsize, color=textfontcolor, x=basex + title_dx, y=basey + title_dy ) gs = gridspec.GridSpec( nrow, ncol, wspace=0.0, hspace=0.0, top=1.0 - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1), left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1), ) # pad image to have isotropic array dimensions image = image.numpy() if overlay is not None: overlay = overlay.numpy() if overlay.dtype not in ["uint8", "uint32"]: overlay[np.abs(overlay) == 0] = np.nan image2 = image2.numpy() if overlay2 is not None: overlay2 = overlay2.numpy() if overlay2.dtype not in ["uint8", "uint32"]: overlay2[np.abs(overlay2) == 0] = np.nan #################### #################### yz_slice = reorient_slice(image[xyz[0], :, :], 0) ax = plt.subplot(gs[0, 0]) ax.imshow(yz_slice, cmap=cmap, vmin=vmin, vmax=vmax) if overlay is not None: yz_overlay = reorient_slice(overlay[xyz[0], :, :], 0) ax.imshow(yz_overlay, alpha=overlay_alpha, cmap=overlay_cmap) if xyz_lines: # add lines l = mlines.Line2D( [yz_slice.shape[0] - xyz[1], yz_slice.shape[0] - xyz[1]], [xyz_pad, yz_slice.shape[0] - xyz_pad], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) l = mlines.Line2D( [xyz_pad, yz_slice.shape[1] - xyz_pad], [yz_slice.shape[1] - xyz[2], yz_slice.shape[1] - xyz[2]], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) ax.axis("off") ####### yz_slice2 = reorient_slice(image2[xyz[0], :, :], 0) if not flat: ax = plt.subplot(gs[0, 1]) else: if not transpose: ax = plt.subplot(gs[1, 0]) else: ax = plt.subplot(gs[0, 1]) ax.imshow(yz_slice2, cmap=cmap2, vmin=vmin2, vmax=vmax2) if overlay2 is not None: yz_overlay2 = reorient_slice(overlay2[xyz[0], :, :], 0) ax.imshow(yz_overlay2, alpha=overlay_alpha2, cmap=overlay_cmap2) if xyz_lines: # add lines l = mlines.Line2D( [yz_slice2.shape[0] - xyz[1], yz_slice2.shape[0] - xyz[1]], [xyz_pad, yz_slice2.shape[0] - xyz_pad], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) l = mlines.Line2D( [xyz_pad, yz_slice2.shape[1] - xyz_pad], [yz_slice2.shape[1] - xyz[2], yz_slice2.shape[1] - xyz[2]], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) ax.axis("off") #################### #################### xz_slice = reorient_slice(image[:, xyz[1], :], 1) if not flat: ax = plt.subplot(gs[0, 2]) else: if not transpose: ax = plt.subplot(gs[0, 1]) else: ax = plt.subplot(gs[1, 0]) ax.imshow(xz_slice, cmap=cmap, vmin=vmin, vmax=vmax) if overlay is not None: xz_overlay = reorient_slice(overlay[:, xyz[1], :], 1) ax.imshow(xz_overlay, alpha=overlay_alpha, cmap=overlay_cmap) if xyz_lines: # add lines l = mlines.Line2D( [xz_slice.shape[0] - xyz[0], xz_slice.shape[0] - xyz[0]], [xyz_pad, xz_slice.shape[0] - xyz_pad], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) l = mlines.Line2D( [xyz_pad, xz_slice.shape[1] - xyz_pad], [xz_slice.shape[1] - xyz[2], xz_slice.shape[1] - xyz[2]], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) ax.axis("off") ####### xz_slice2 = reorient_slice(image2[:, xyz[1], :], 1) if not flat: ax = plt.subplot(gs[0, 3]) else: ax = plt.subplot(gs[1, 1]) ax.imshow(xz_slice2, cmap=cmap2, vmin=vmin2, vmax=vmax2) if overlay is not None: xz_overlay2 = reorient_slice(overlay2[:, xyz[1], :], 1) ax.imshow(xz_overlay2, alpha=overlay_alpha2, cmap=overlay_cmap2) if xyz_lines: # add lines l = mlines.Line2D( [xz_slice2.shape[0] - xyz[0], xz_slice2.shape[0] - xyz[0]], [xyz_pad, xz_slice2.shape[0] - xyz_pad], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) l = mlines.Line2D( [xyz_pad, xz_slice2.shape[1] - xyz_pad], [xz_slice2.shape[1] - xyz[2], xz_slice2.shape[1] - xyz[2]], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) ax.axis("off") #################### #################### xy_slice = reorient_slice(image[:, :, xyz[2]], 2) if not flat: ax = plt.subplot(gs[1, 2]) else: if not transpose: ax = plt.subplot(gs[0, 2]) else: ax = plt.subplot(gs[2, 0]) ax.imshow(xy_slice, cmap=cmap, vmin=vmin, vmax=vmax) if overlay is not None: xy_overlay = reorient_slice(overlay[:, :, xyz[2]], 2) ax.imshow(xy_overlay, alpha=overlay_alpha, cmap=overlay_cmap) if xyz_lines: # add lines l = mlines.Line2D( [xy_slice.shape[0] - xyz[0], xy_slice.shape[0] - xyz[0]], [xyz_pad, xy_slice.shape[0] - xyz_pad], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) l = mlines.Line2D( [xyz_pad, xy_slice.shape[1] - xyz_pad], [xy_slice.shape[1] - xyz[1], xy_slice.shape[1] - xyz[1]], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) ax.axis("off") ####### xy_slice2 = reorient_slice(image2[:, :, xyz[2]], 2) if not flat: ax = plt.subplot(gs[1, 3]) else: if not transpose: ax = plt.subplot(gs[1, 2]) else: ax = plt.subplot(gs[2, 1]) ax.imshow(xy_slice2, cmap=cmap2, vmin=vmin2, vmax=vmax2) if overlay is not None: xy_overlay2 = reorient_slice(overlay2[:, :, xyz[2]], 2) ax.imshow(xy_overlay2, alpha=overlay_alpha2, cmap=overlay_cmap2) if xyz_lines: # add lines l = mlines.Line2D( [xy_slice2.shape[0] - xyz[0], xy_slice2.shape[0] - xyz[0]], [xyz_pad, xy_slice2.shape[0] - xyz_pad], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) l = mlines.Line2D( [xyz_pad, xy_slice2.shape[1] - xyz_pad], [xy_slice2.shape[1] - xyz[1], xy_slice2.shape[1] - xyz[1]], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) ax.axis("off") #################### #################### if not flat: # empty corner ax = plt.subplot(gs[1, :2]) if text is not None: # add text left, width = 0.25, 0.5 bottom, height = 0.25, 0.5 right = left + width top = bottom + height ax.text( 0.5 * (left + right) + text_dx, 0.5 * (bottom + top) + text_dy, text, horizontalalignment="center", verticalalignment="center", fontsize=textfontsize, color=textfontcolor, transform=ax.transAxes, ) # ax.text(0.5, 0.5) img_shape = list(image.shape[:-1]) img_shape[1] *= 2 ax.imshow(np.zeros(img_shape), cmap="Greys_r") ax.axis("off") ## multi-channel images ## elif image.components > 1: raise ValueError("Multi-channel images not currently supported!") if filename is not None: plt.savefig(filename, dpi=dpi, transparent=transparent) plt.close(fig) else: plt.show() # turn warnings back to default warnings.simplefilter("default")
[docs]def plot_ortho( image, overlay=None, reorient=True, blend=False, # xyz arguments xyz=None, xyz_lines=True, xyz_color="red", xyz_alpha=0.6, xyz_linewidth=2, xyz_pad=5, orient_labels=True, # base image arguments alpha=1, cmap="Greys_r", # overlay arguments overlay_cmap="jet", overlay_alpha=0.9, cbar=False, cbar_length=0.8, cbar_dx=0.0, cbar_vertical=True, # background arguments black_bg=True, bg_thresh_quant=0.01, bg_val_quant=0.99, # scale/crop/domain arguments crop=False, scale=False, domain_image_map=None, # title arguments title=None, titlefontsize=24, title_dx=0, title_dy=0, # 4th panel text arguemnts text=None, textfontsize=24, textfontcolor="white", text_dx=0, text_dy=0, # save & size arguments filename=None, dpi=500, figsize=1.0, flat=False, transparent=True, resample=False, ): """ Plot an orthographic view of a 3D image Use mask_image and/or threshold_image to preprocess images to be be overlaid and display the overlays in a given range. See the wiki examples. ANTsR function: N/A Arguments --------- image : ANTsImage image to plot overlay : ANTsImage image to overlay on base image slices : list or tuple of 3 integers slice indices along each axis to plot This can be absolute array indices (e.g. (80,100,120)), or this can be relative array indices (e.g. (0.4,0.5,0.6)). The default is to take the middle slice along each axis. xyz : list or tuple of 3 integers if given, solid lines will be drawn to converge at this coordinate. This is useful for pinpointing a specific location in the image. flat : boolean if true, the ortho image will be plot in one row if false, the ortho image will be a 2x2 grid with the bottom left corner blank cmap : string colormap to use for base image. See matplotlib. overlay_cmap : string colormap to use for overlay images, if applicable. See matplotlib. overlay_alpha : float level of transparency for any overlays. Smaller value means the overlay is more transparent. See matplotlib. cbar: boolean if true, a colorbar will be added to the plot cbar_length: float length of the colorbar relative to the image cbar_dx: float horizontal shift of the colorbar relative to the image cbar_vertical: boolean if true, the colorbar will be vertical, if false, it will be horizontal underneath the image axis : integer which axis to plot along if image is 3D black_bg : boolean if True, the background of the image(s) will be black. if False, the background of the image(s) will be determined by the values `bg_thresh_quant` and `bg_val_quant`. bg_thresh_quant : float if white_bg=True, the background will be determined by thresholding the image at the `bg_thresh` quantile value and setting the background intensity to the `bg_val` quantile value. This value should be in [0, 1] - somewhere around 0.01 is recommended. - equal to 1 will threshold the entire image - equal to 0 will threshold none of the image bg_val_quant : float if white_bg=True, the background will be determined by thresholding the image at the `bg_thresh` quantile value and setting the background intensity to the `bg_val` quantile value. This value should be in [0, 1] - equal to 1 is pure white - equal to 0 is pure black - somewhere in between is gray domain_image_map : ANTsImage this input ANTsImage or list of ANTsImage types contains a reference image `domain_image` and optional reference mapping named `domainMap`. If supplied, the image(s) to be plotted will be mapped to the domain image space before plotting - useful for non-standard image orientations. crop : boolean if true, the image(s) will be cropped to their bounding boxes, resulting in a potentially smaller image size. if false, the image(s) will not be cropped scale : boolean or 2-tuple if true, nothing will happen to intensities of image(s) and overlay(s) if false, dynamic range will be maximized when visualizing overlays if 2-tuple, the image will be dynamically scaled between these quantiles title : string add a title to the plot filename : string if given, the resulting image will be saved to this file dpi : integer determines resolution of image if saved to file. Higher values result in higher resolution images, but at a cost of having a larger file size resample : resample image in case of unbalanced spacing Example ------- >>> import ants >>> mni = ants.image_read(ants.get_data('mni')) >>> ants.plot_ortho(mni, xyz=(100,100,100)) >>> mni2 = mni.threshold_image(7000, mni.max()) >>> ants.plot_ortho(mni, overlay=mni2) >>> ants.plot_ortho(mni, overlay=mni2, flat=True) >>> ants.plot_ortho(mni, overlay=mni2, xyz=(110,110,110), xyz_lines=False, text='Lines Turned Off', textfontsize=22) >>> ants.plot_ortho(mni, mni2, xyz=(120,100,100), text=' Example \nOrtho Text', textfontsize=26, title='Example Ortho Title', titlefontsize=26) """ def mirror_matrix(x): return x[::-1, :] def rotate270_matrix(x): return mirror_matrix(x.T) def reorient_slice(x, axis): return rotate270_matrix(x) # need this hack because of a weird NaN warning from matplotlib with overlays warnings.simplefilter("ignore") # handle `image` argument if isinstance(image, str): image = iio2.image_read(image) if not isinstance(image, iio.ANTsImage): raise ValueError("image argument must be an ANTsImage") if image.dimension != 3: raise ValueError("Input image must have 3 dimensions!") # handle `overlay` argument if overlay is not None: vminol = overlay.min() vmaxol = overlay.max() if isinstance(overlay, str): overlay = iio2.image_read(overlay) if not isinstance(overlay, iio.ANTsImage): raise ValueError("overlay argument must be an ANTsImage") if overlay.components > 1: raise ValueError("overlay cannot have more than one voxel component") if overlay.dimension != 3: raise ValueError("Overlay image must have 3 dimensions!") if not iio.image_physical_space_consistency(image, overlay): overlay = reg.resample_image_to_target(overlay, image, interp_type="linear") if blend: if alpha == 1: alpha = 0.5 image = image * alpha + overlay * (1 - alpha) overlay = None alpha = 1.0 if image.pixeltype not in {"float", "double"}: scale = False # turn off scaling if image is discrete # reorient images if reorient != False: if reorient == True: reorient = "RPI" image = image.reorient_image2("RPI") if overlay is not None: overlay = overlay.reorient_image2("RPI") # handle `slices` argument if xyz is None: xyz = [int(s / 2) for s in image.shape] for i in range(3): if xyz[i] is None: xyz[i] = int(image.shape[i] / 2) # resample image if spacing is very unbalanced spacing = [s for i, s in enumerate(image.spacing)] if (max(spacing) / min(spacing)) > 3.0 and resample: new_spacing = (1, 1, 1) image = image.resample_image(tuple(new_spacing)) if overlay is not None: overlay = overlay.resample_image(tuple(new_spacing)) xyz = [ int(sl * (sold / snew)) for sl, sold, snew in zip(xyz, spacing, new_spacing) ] # potentially crop image if crop: plotmask = image.get_mask(cleanup=0) if plotmask.max() == 0: plotmask += 1 image = image.crop_image(plotmask) if overlay is not None: overlay = overlay.crop_image(plotmask) # pad images image, lowpad, uppad = image.pad_image(return_padvals=True) xyz = [v + l for v, l in zip(xyz, lowpad)] if overlay is not None: overlay = overlay.pad_image() # handle `domain_image_map` argument if domain_image_map is not None: if isinstance(domain_image_map, iio.ANTsImage): tx = tio2.new_ants_transform( precision="float", transform_type="AffineTransform", dimension=image.dimension, ) image = tio.apply_ants_transform_to_image(tx, image, domain_image_map) if overlay is not None: overlay = tio.apply_ants_transform_to_image( tx, overlay, domain_image_map, interpolation="linear" ) elif isinstance(domain_image_map, (list, tuple)): # expect an image and transformation if len(domain_image_map) != 2: raise ValueError("domain_image_map list or tuple must have length == 2") dimg = domain_image_map[0] if not isinstance(dimg, iio.ANTsImage): raise ValueError("domain_image_map first entry should be ANTsImage") tx = domain_image_map[1] image = reg.apply_transforms(dimg, image, transform_list=tx) if overlay is not None: overlay = reg.apply_transforms( dimg, overlay, transform_list=tx, interpolator="linear" ) ## single-channel images ## if image.components == 1: # potentially find dynamic range if scale == True: vmin, vmax = image.quantile((0.05, 0.95)) elif isinstance(scale, (list, tuple)): if len(scale) != 2: raise ValueError( "scale argument must be boolean or list/tuple with two values" ) vmin, vmax = image.quantile(scale) else: vmin = None vmax = None if not flat: nrow = 2 ncol = 2 else: nrow = 1 ncol = 3 fig = plt.figure(figsize=(9 * figsize, 9 * figsize)) if title is not None: basey = 0.88 if not flat else 0.66 basex = 0.5 fig.suptitle( title, fontsize=titlefontsize, color=textfontcolor, x=basex + title_dx, y=basey + title_dy ) gs = gridspec.GridSpec( nrow, ncol, wspace=0.0, hspace=0.0, top=1.0 - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1), left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1), ) # pad image to have isotropic array dimensions image = image.numpy() if overlay is not None: overlay = overlay.numpy() if overlay.dtype not in ["uint8", "uint32"]: overlay[np.abs(overlay) == 0] = np.nan yz_slice = reorient_slice(image[xyz[0], :, :], 0) ax = plt.subplot(gs[0, 0]) ax.imshow(yz_slice, cmap=cmap, vmin=vmin, vmax=vmax) if overlay is not None: yz_overlay = reorient_slice(overlay[xyz[0], :, :], 0) ax.imshow(yz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol ) if xyz_lines: # add lines l = mlines.Line2D( [yz_slice.shape[0] - xyz[1], yz_slice.shape[0] - xyz[1]], [xyz_pad, yz_slice.shape[0] - xyz_pad], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) l = mlines.Line2D( [xyz_pad, yz_slice.shape[1] - xyz_pad], [yz_slice.shape[1] - xyz[2], yz_slice.shape[1] - xyz[2]], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) if orient_labels: ax.text( 0.5, 0.98, "S", horizontalalignment="center", verticalalignment="top", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.5, 0.02, "I", horizontalalignment="center", verticalalignment="bottom", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.98, 0.5, "A", horizontalalignment="right", verticalalignment="center", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.02, 0.5, "P", horizontalalignment="left", verticalalignment="center", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.axis("off") xz_slice = reorient_slice(image[:, xyz[1], :], 1) ax = plt.subplot(gs[0, 1]) ax.imshow(xz_slice, cmap=cmap, vmin=vmin, vmax=vmax) if overlay is not None: xz_overlay = reorient_slice(overlay[:, xyz[1], :], 1) ax.imshow(xz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol ) if xyz_lines: # add lines l = mlines.Line2D( [xz_slice.shape[0] - xyz[0], xz_slice.shape[0] - xyz[0]], [xyz_pad, xz_slice.shape[0] - xyz_pad], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) l = mlines.Line2D( [xyz_pad, xz_slice.shape[1] - xyz_pad], [xz_slice.shape[1] - xyz[2], xz_slice.shape[1] - xyz[2]], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) if orient_labels: ax.text( 0.5, 0.98, "S", horizontalalignment="center", verticalalignment="top", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.5, 0.02, "I", horizontalalignment="center", verticalalignment="bottom", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.98, 0.5, "L", horizontalalignment="right", verticalalignment="center", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.02, 0.5, "R", horizontalalignment="left", verticalalignment="center", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.axis("off") xy_slice = reorient_slice(image[:, :, xyz[2]], 2) if not flat: ax = plt.subplot(gs[1, 1]) else: ax = plt.subplot(gs[0, 2]) im = ax.imshow(xy_slice, cmap=cmap, vmin=vmin, vmax=vmax) if overlay is not None: xy_overlay = reorient_slice(overlay[:, :, xyz[2]], 2) im = ax.imshow(xy_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol) if xyz_lines: # add lines l = mlines.Line2D( [xy_slice.shape[0] - xyz[0], xy_slice.shape[0] - xyz[0]], [xyz_pad, xy_slice.shape[0] - xyz_pad], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) l = mlines.Line2D( [xyz_pad, xy_slice.shape[1] - xyz_pad], [xy_slice.shape[1] - xyz[1], xy_slice.shape[1] - xyz[1]], color=xyz_color, alpha=xyz_alpha, linewidth=xyz_linewidth, ) ax.add_line(l) if orient_labels: ax.text( 0.5, 0.98, "A", horizontalalignment="center", verticalalignment="top", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.5, 0.02, "P", horizontalalignment="center", verticalalignment="bottom", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.98, 0.5, "L", horizontalalignment="right", verticalalignment="center", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.text( 0.02, 0.5, "R", horizontalalignment="left", verticalalignment="center", fontsize=20 * figsize, color=textfontcolor, transform=ax.transAxes, ) ax.axis("off") if not flat: # empty corner ax = plt.subplot(gs[1, 0]) if text is not None: # add text left, width = 0.25, 0.5 bottom, height = 0.25, 0.5 right = left + width top = bottom + height ax.text( 0.5 * (left + right) + text_dx, 0.5 * (bottom + top) + text_dy, text, horizontalalignment="center", verticalalignment="center", fontsize=textfontsize, color=textfontcolor, transform=ax.transAxes, ) # ax.text(0.5, 0.5) ax.imshow(np.zeros(image.shape[:-1]), cmap="Greys_r") ax.axis("off") if cbar: cbar_start = (1 - cbar_length) / 2 if cbar_vertical: cax = fig.add_axes([0.9 + cbar_dx, cbar_start, 0.03, cbar_length]) cbar_orient = "vertical" else: cax = fig.add_axes([cbar_start, 0.08 + cbar_dx, cbar_length, 0.03]) cbar_orient = "horizontal" fig.colorbar(im, cax=cax, orientation=cbar_orient) ## multi-channel images ## elif image.components > 1: raise ValueError("Multi-channel images not currently supported!") if filename is not None: plt.savefig(filename, dpi=dpi, transparent=transparent) plt.close(fig) else: plt.show() # turn warnings back to default warnings.simplefilter("default")
[docs]def plot( image, overlay=None, blend=False, alpha=1, cmap="Greys_r", overlay_cmap="turbo", overlay_alpha=0.9, vminol=None, vmaxol=None, cbar=False, cbar_length=0.8, cbar_dx=0.0, cbar_vertical=True, axis=0, nslices=12, slices=None, ncol=None, slice_buffer=None, black_bg=True, bg_thresh_quant=0.01, bg_val_quant=0.99, domain_image_map=None, crop=False, scale=False, reverse=False, title=None, title_fontsize=20, title_dx=0.0, title_dy=0.0, filename=None, dpi=500, figsize=1.5, reorient=True, resample=True, ): """ Plot an ANTsImage. Use mask_image and/or threshold_image to preprocess images to be be overlaid and display the overlays in a given range. See the wiki examples. By default, images will be reoriented to 'LAI' orientation before plotting. So, if axis == 0, the images will be ordered from the left side of the brain to the right side of the brain. If axis == 1, the images will be ordered from the anterior (front) of the brain to the posterior (back) of the brain. And if axis == 2, the images will be ordered from the inferior (bottom) of the brain to the superior (top) of the brain. ANTsR function: `plot.antsImage` Arguments --------- image : ANTsImage image to plot overlay : ANTsImage image to overlay on base image cmap : string colormap to use for base image. See matplotlib. overlay_cmap : string colormap to use for overlay images, if applicable. See matplotlib. overlay_alpha : float level of transparency for any overlays. Smaller value means the overlay is more transparent. See matplotlib. axis : integer which axis to plot along if image is 3D nslices : integer number of slices to plot if image is 3D slices : list or tuple of integers specific slice indices to plot if image is 3D. If given, this will override `nslices`. This can be absolute array indices (e.g. (80,100,120)), or this can be relative array indices (e.g. (0.4,0.5,0.6)) ncol : integer Number of columns to have on the plot if image is 3D. slice_buffer : integer how many slices to buffer when finding the non-zero slices of a 3D images. So, if slice_buffer = 10, then the first slice in a 3D image will be the first non-zero slice index plus 10 more slices. black_bg : boolean if True, the background of the image(s) will be black. if False, the background of the image(s) will be determined by the values `bg_thresh_quant` and `bg_val_quant`. bg_thresh_quant : float if white_bg=True, the background will be determined by thresholding the image at the `bg_thresh` quantile value and setting the background intensity to the `bg_val` quantile value. This value should be in [0, 1] - somewhere around 0.01 is recommended. - equal to 1 will threshold the entire image - equal to 0 will threshold none of the image bg_val_quant : float if white_bg=True, the background will be determined by thresholding the image at the `bg_thresh` quantile value and setting the background intensity to the `bg_val` quantile value. This value should be in [0, 1] - equal to 1 is pure white - equal to 0 is pure black - somewhere in between is gray domain_image_map : ANTsImage this input ANTsImage or list of ANTsImage types contains a reference image `domain_image` and optional reference mapping named `domainMap`. If supplied, the image(s) to be plotted will be mapped to the domain image space before plotting - useful for non-standard image orientations. crop : boolean if true, the image(s) will be cropped to their bounding boxes, resulting in a potentially smaller image size. if false, the image(s) will not be cropped scale : boolean or 2-tuple if true, nothing will happen to intensities of image(s) and overlay(s) if false, dynamic range will be maximized when visualizing overlays if 2-tuple, the image will be dynamically scaled between these quantiles reverse : boolean if true, the order in which the slices are plotted will be reversed. This is useful if you want to plot from the front of the brain first to the back of the brain, or vice-versa title : string add a title to the plot filename : string if given, the resulting image will be saved to this file dpi : integer determines resolution of image if saved to file. Higher values result in higher resolution images, but at a cost of having a larger file size resample : bool if true, resample image if spacing is very unbalanced. Example ------- >>> import ants >>> import numpy as np >>> img = ants.image_read(ants.get_data('r16')) >>> segs = img.kmeans_segmentation(k=3)['segmentation'] >>> ants.plot(img, segs*(segs==1), crop=True) >>> ants.plot(img, segs*(segs==1), crop=False) >>> mni = ants.image_read(ants.get_data('mni')) >>> segs = mni.kmeans_segmentation(k=3)['segmentation'] >>> ants.plot(mni, segs*(segs==1), crop=False) """ if (axis == "x") or (axis == "saggittal"): axis = 0 if (axis == "y") or (axis == "coronal"): axis = 1 if (axis == "z") or (axis == "axial"): axis = 2 def mirror_matrix(x): return x[::-1, :] def rotate270_matrix(x): return mirror_matrix(x.T) def rotate180_matrix(x): return x[::-1, ::-1] def rotate90_matrix(x): return x.T def flip_matrix(x): return mirror_matrix(rotate180_matrix(x)) def reorient_slice(x, axis): if axis != 2: x = rotate90_matrix(x) if axis == 2: x = rotate270_matrix(x) x = mirror_matrix(x) return x # need this hack because of a weird NaN warning from matplotlib with overlays warnings.simplefilter("ignore") # handle `image` argument if isinstance(image, str): image = iio2.image_read(image) if not isinstance(image, iio.ANTsImage): raise ValueError("image argument must be an ANTsImage") assert image.sum() > 0, "Image must be non-zero" if (image.pixeltype not in {"float", "double"}) or (image.is_rgb): scale = False # turn off scaling if image is discrete # handle `overlay` argument if overlay is not None: if vminol is None: vminol = overlay.min() if vmaxol is None: vmaxol = overlay.max() if isinstance(overlay, str): overlay = iio2.image_read(overlay) if not isinstance(overlay, iio.ANTsImage): raise ValueError("overlay argument must be an ANTsImage") if overlay.components > 1: raise ValueError("overlay cannot have more than one voxel component") if not iio.image_physical_space_consistency(image, overlay): overlay = reg.resample_image_to_target(overlay, image, interp_type="nearestNeighbor") if blend: if alpha == 1: alpha = 0.5 image = image * alpha + overlay * (1 - alpha) overlay = None alpha = 1.0 # handle `domain_image_map` argument if domain_image_map is not None: if isinstance(domain_image_map, iio.ANTsImage): tx = tio2.new_ants_transform( precision="float", transform_type="AffineTransform", dimension=image.dimension, ) image = tio.apply_ants_transform_to_image(tx, image, domain_image_map) if overlay is not None: overlay = tio.apply_ants_transform_to_image( tx, overlay, domain_image_map, interpolation="nearestNeighbor" ) elif isinstance(domain_image_map, (list, tuple)): # expect an image and transformation if len(domain_image_map) != 2: raise ValueError("domain_image_map list or tuple must have length == 2") dimg = domain_image_map[0] if not isinstance(dimg, iio.ANTsImage): raise ValueError("domain_image_map first entry should be ANTsImage") tx = domain_image_map[1] image = reg.apply_transforms(dimg, image, transform_list=tx) if overlay is not None: overlay = reg.apply_transforms( dimg, overlay, transform_list=tx, interpolator="linear" ) ## single-channel images ## if image.components == 1: # potentially crop image if crop: plotmask = image.get_mask(cleanup=0) if plotmask.max() == 0: plotmask += 1 image = image.crop_image(plotmask) if overlay is not None: overlay = overlay.crop_image(plotmask) # potentially find dynamic range if scale == True: vmin, vmax = image.quantile((0.05, 0.95)) elif isinstance(scale, (list, tuple)): if len(scale) != 2: raise ValueError( "scale argument must be boolean or list/tuple with two values" ) vmin, vmax = image.quantile(scale) else: vmin = None vmax = None # Plot 2D image if image.dimension == 2: img_arr = image.numpy() img_arr = rotate90_matrix(img_arr) if not black_bg: img_arr[img_arr < image.quantile(bg_thresh_quant)] = image.quantile( bg_val_quant ) if overlay is not None: ov_arr = overlay.numpy() ov_arr = rotate90_matrix(ov_arr) if ov_arr.dtype not in ["uint8", "uint32"]: ov_arr = np.ma.masked_where(ov_arr == 0, ov_arr) fig = plt.figure() if title is not None: fig.suptitle( title, fontsize=title_fontsize, x=0.5 + title_dx, y=0.95 + title_dy ) ax = plt.subplot(111) # plot main image im = ax.imshow(img_arr, cmap=cmap, alpha=alpha, vmin=vmin, vmax=vmax) if overlay is not None: im = ax.imshow(ov_arr, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol ) if cbar: cbar_orient = "vertical" if cbar_vertical else "horizontal" fig.colorbar(im, orientation=cbar_orient) plt.axis("off") # Plot 3D image elif image.dimension == 3: # resample image if spacing is very unbalanced spacing = [s for i, s in enumerate(image.spacing) if i != axis] was_resampled = False if (max(spacing) / min(spacing)) > 3.0 and resample: was_resampled = True new_spacing = (1, 1, 1) image = image.resample_image(tuple(new_spacing)) if overlay is not None: overlay = overlay.resample_image(tuple(new_spacing)) if reorient: image = image.reorient_image2("LAI") img_arr = image.numpy() # reorder dims so that chosen axis is first img_arr = np.rollaxis(img_arr, axis) if overlay is not None: if reorient: overlay = overlay.reorient_image2("LAI") ov_arr = overlay.numpy() if ov_arr.dtype not in ["uint8", "uint32"]: ov_arr = np.ma.masked_where(ov_arr == 0, ov_arr) ov_arr = np.rollaxis(ov_arr, axis) if slices is None: if not isinstance(slice_buffer, (list, tuple)): if slice_buffer is None: slice_buffer = ( int(img_arr.shape[1] * 0.1), int(img_arr.shape[2] * 0.1), ) else: slice_buffer = (slice_buffer, slice_buffer) nonzero = np.where(img_arr.sum(axis=(1, 2)) > 0.01)[0] min_idx = nonzero[0] + slice_buffer[0] max_idx = nonzero[-1] - slice_buffer[1] if min_idx > max_idx: temp = min_idx min_idx = max_idx max_idx = temp if max_idx > nonzero.max(): max_idx = nonzero.max() if min_idx < 0: min_idx = 0 slice_idxs = np.linspace(min_idx, max_idx, nslices).astype("int") if reverse: slice_idxs = np.array(list(reversed(slice_idxs))) else: if isinstance(slices, (int, float)): slices = [slices] # if all slices are less than 1, infer that they are relative slices if sum([s > 1 for s in slices]) == 0: slices = [int(s * img_arr.shape[0]) for s in slices] slice_idxs = slices nslices = len(slices) if was_resampled: # re-calculate slices to account for new image shape slice_idxs = np.unique( np.array( [ int(s * (image.shape[axis] / img_arr.shape[0])) for s in slice_idxs ] ) ) # only have one row if nslices <= 6 and user didnt specify ncol if ncol is None: if nslices <= 6: ncol = nslices else: ncol = int(round(math.sqrt(nslices))) # calculate grid size nrow = math.ceil(nslices / ncol) xdim = img_arr.shape[2] ydim = img_arr.shape[1] dim_ratio = ydim / xdim fig = plt.figure( figsize=((ncol + 1) * figsize * dim_ratio, (nrow + 1) * figsize) ) if title is not None: fig.suptitle( title, fontsize=title_fontsize, x=0.5 + title_dx, y=0.95 + title_dy ) gs = gridspec.GridSpec( nrow, ncol, wspace=0.0, hspace=0.0, top=1.0 - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1), left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1), ) slice_idx_idx = 0 for i in range(nrow): for j in range(ncol): if slice_idx_idx < len(slice_idxs): imslice = img_arr[slice_idxs[slice_idx_idx]] imslice = reorient_slice(imslice, axis) if not black_bg: imslice[ imslice < image.quantile(bg_thresh_quant) ] = image.quantile(bg_val_quant) else: imslice = np.zeros_like(img_arr[0]) imslice = reorient_slice(imslice, axis) ax = plt.subplot(gs[i, j]) im = ax.imshow(imslice, cmap=cmap, vmin=vmin, vmax=vmax) if overlay is not None: if slice_idx_idx < len(slice_idxs): ovslice = ov_arr[slice_idxs[slice_idx_idx]] ovslice = reorient_slice(ovslice, axis) im = ax.imshow( ovslice, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol ) ax.axis("off") slice_idx_idx += 1 if cbar: cbar_start = (1 - cbar_length) / 2 if cbar_vertical: cax = fig.add_axes([0.9 + cbar_dx, cbar_start, 0.03, cbar_length]) cbar_orient = "vertical" else: cax = fig.add_axes([cbar_start, 0.08 + cbar_dx, cbar_length, 0.03]) cbar_orient = "horizontal" fig.colorbar(im, cax=cax, orientation=cbar_orient) ## multi-channel images ## elif image.components > 1: if not image.is_rgb: if not image.components == 3: raise ValueError("Multi-component images only supported if they have 3 components") img_arr = image.numpy() img_arr = img_arr / img_arr.max() img_arr = np.stack( [rotate90_matrix(img_arr[:, :, i]) for i in range(3)], axis=-1 ) fig = plt.figure() ax = plt.subplot(111) # plot main image ax.imshow(img_arr, alpha=alpha) plt.axis("off") if filename is not None: filename = os.path.expanduser(filename) plt.savefig(filename, dpi=dpi, transparent=True, bbox_inches="tight") plt.close(fig) else: plt.show() # turn warnings back to default warnings.simplefilter("default")
[docs]def plot_directory( directory, recursive=False, regex="*", save_prefix="", save_suffix="", axis=None, **kwargs ): """ Create and save an ANTsPy plot for every image matching a given regular expression in a directory, optionally recursively. This is a good function for quick visualize exploration of all of images in a directory ANTsR function: N/A Arguments --------- directory : string directory in which to search for images and plot them recursive : boolean If true, this function will search through all directories under the given directory recursively to make plots. If false, this function will only create plots for images in the given directory regex : string regular expression used to filter out certain filenames or suffixes save_prefix : string sub-string that will be appended to the beginning of all saved plot filenames. Default is to add nothing. save_suffix : string sub-string that will be appended to the end of all saved plot filenames. Default is add nothing. kwargs : keyword arguments any additional arguments to pass onto the `ants.plot` function. e.g. overlay, alpha, cmap, etc. See `ants.plot` for more options. Example ------- >>> import ants >>> ants.plot_directory(directory='~/desktop/testdir', recursive=False, regex='*') """ def has_acceptable_suffix(fname): suffixes = {".nii.gz"} return sum([fname.endswith(sx) for sx in suffixes]) > 0 if directory.startswith("~"): directory = os.path.expanduser(directory) if not os.path.isdir(directory): raise ValueError("directory %s does not exist!" % directory) for root, dirnames, fnames in os.walk(directory): for fname in fnames: if fnmatch.fnmatch(fname, regex) and has_acceptable_suffix(fname): load_fname = os.path.join(root, fname) fname = fname.replace(".".join(fname.split(".")[1:]), "png") fname = fname.replace(".png", "%s.png" % save_suffix) fname = "%s%s" % (save_prefix, fname) save_fname = os.path.join(root, fname) img = iio2.image_read(load_fname) if axis is None: axis_range = [i for i in range(img.dimension)] else: axis_range = axis if isinstance(axis, (list, tuple)) else [axis] if img.dimension > 2: for axis_idx in axis_range: filename = save_fname.replace(".png", "_axis%i.png" % axis_idx) ncol = int(math.sqrt(img.shape[axis_idx])) plot( img, axis=axis_idx, nslices=img.shape[axis_idx], ncol=ncol, filename=filename, **kwargs ) else: filename = save_fname plot(img, filename=filename, **kwargs)