Skip to content

Instantly share code, notes, and snippets.

Last active March 31, 2023 10:08
Show Gist options
  • Save dgobbi/bddbbabb1a9c86e8d8373752859dac9f to your computer and use it in GitHub Desktop.
Save dgobbi/bddbbabb1a9c86e8d8373752859dac9f to your computer and use it in GitHub Desktop.
Generate many nifti files corresponding to transformations of one nifti file.
Sep 25, 2018
David Gobbi
import vtk
import sys
import os.path
import argparse
import math
brief = "Apply multiple transformations to a NIfTI file."
helptext = """
This program will take in image as input and will generate multiple
output files, each with a different transformation applied.
You must supply an existing output directory for the files to be
written to. Each output file will be named "<file>_NNNN.nii.gz"
where <file> is the name of the input, and "NNNN" is four digits.
You can specifify the number of distinct transformations as
-R N (will produce 3*NxNxN transforms, for the 3 angles)
-S M (will produce M transforms, since only 1 scale param is used)
-T L (will produce LxL transforms, since 2 translation params are used)
The total number of outputs is therefore 3*N*N*N*M*L*L.
Translations are only done in the row and column directions, not the slice
direction. The largest applied translation will be 15% of the image size.
The largest rotation that will be applied is 45 degrees. There will not
be any outputs that correspond to a rotation of zero degrees, unless you
specify N=0.
The scale factor will be a minimum of 1.0 and a maximum of 1.5. No scales
less than 1.0 are used since shrinking an image can lead to aliasing
Note that there are to ways to specify input images: -i and -l.
For "-i", the image is assumed to be greyscale and linear interpolation
is used. For "-l", the image is assumed to be binary or indexed, and
nearest-neighbor interpolation is used.
def linscale(scale_range, n):
"""Generate a linear scale of 'n' values within the given range.
if n <= 1:
return [0.5*(scale_range[0] + scale_range[1])]
d = float(n - 1)
return [(n-i-1)/d*scale_range[0] + i/d*scale_range[1] for i in range(n)]
def combine(*scales):
"""Create a grid from the input scales (inefficiently).
if len(scales) == 0:
return []
result = [ [y] for y in scales[0] ]
for scale in scales[1:]:
newres = []
for y in scale:
for x in result:
newres.append(x + [y])
result = newres
return result
def generate_rotations(angle_range, n):
"""Generate approximately regularly spaced rotations by using
cube geometry instead of sphere geometry.
if n <= 0:
return [ [ 0.0, 1.0, 0.0, 0.0 ] ]
# use half a cube to generate axes of rotation
# (because a cube is the easiest platonic solid to work from)
# here are the angles (except the first, assumed to be zero)
scale0 = linscale(angle_range, n + 1)[1:]
# the scale along one edge of a cube
scale1 = linscale([-1.0, 1.0], n)
# rotations corresponding to axes intersecting 3 of the cube faces
result = combine(scale0, [1.0], scale1, scale1) # cube face 1
result += combine(scale0, scale1, [1.0], scale1) # cube face 2
result += combine(scale0, scale1, scale1, [1.0]) # cube face 3
# normalize the axis of rotation
for w in result:
r = math.sqrt(w[1]**2 + w[2]**2 + w[3]**2)
w[1] /= r
w[2] /= r
w[3] /= r
return result
def generate_scales(scale_range, n):
"""Use a linear scale for scale parameter.
return linscale(scale_range, n)
def generate_translations(trans_ranges, n):
"""Translate in all three directions.
scales = [linscale(rng, n) for rng in trans_ranges]
while len(scales) < 3:
return combine(*scales)
def build_transform(rotation, scale, translation, center):
"""Build a vtkTransform from a set of parameters.
transform = vtk.vtkTransform()
transform.Translate([-x for x in center])
transform.Scale(scale, scale, scale)
return transform
def process_one_output(output_file, header, sform, qform, image, transform,
"""Write one transformed image file.
# slow, high-quality interpolator
#interpolator = vtk.vtkImageSincInterpolator()
# resample the image through a transform
reslice = vtk.vtkImageReslice()
if not is_label:
# write the image with the same sform, qform, and header as the input
writer = vtk.vtkNIFTIImageWriter()
def process_one_input(input_file, args, is_label):
"""Augment one input image file.
# read the input file
reader = vtk.vtkNIFTIImageReader()
# get the header, the sform, the qform, and the image data
header = reader.GetNIFTIHeader()
sform = reader.GetSFormMatrix()
qform = reader.GetQFormMatrix()
image = reader.GetOutput()
center = image.GetCenter()
spacing = image.GetSpacing()
shape = image.GetDimensions()
f = 0.15 # translation is a faction of the image size
trans_ranges = [ [-a*b*f, a*b*f] for a,b in zip(shape, spacing)]
rotation_range = [0.0, 45.0]
scale_range = [1.0, 1.5]
# only translate in x, y (assume the algorithms that use this data
# will be processing it slice-by-slice, not volumetrically)
trans_ranges = trans_ranges[0:-1]
# generate all of the transformational parameters
params = combine(generate_rotations(rotation_range, args.rotations),
generate_scales(scale_range, args.scales),
generate_translations(trans_ranges, args.translations))
if not args.silent:
sys.stdout.write("Generating %d outputs per input " % len(params))
sys.stdout.write("(each \".\" is one output).\n")
# get prefix from input file
base = os.path.basename(input_file)
ext = '.nii.gz'
for e in ['.nii.gz', '.nii']:
if base.endswith(e):
base = base[0:-len(e)]
ext = e
# go though all the parameters
counter = 0
for (rotation, scale, translation) in params:
if not args.silent:
# build the transform from the parameters
transform = build_transform(rotation, scale, translation, center)
# create one output
counter += 1
basename = base + ('_%04d' % counter) + ext
output_file = os.path.join(args.output, basename)
process_one_output(output_file, header, sform, qform, image, transform,
if not args.silent:
if counter % 50 == 0:
if not args.silent:
if counter % 50 != 0:
def main(argv):
"""The main program.
# parse the command line
parser = argparse.ArgumentParser(
description=brief, epilog=helptext)
parser.add_argument('-i', '--input', required=False,
help="Input greyscale image.")
parser.add_argument('-l', '--label', required=False,
help="Input label image.")
parser.add_argument('-o', '--output', required=True,
help="Output directory.")
parser.add_argument('-R', '--rotations', type=int, default=3,
help="Steps per rotation degree of freedom.")
parser.add_argument('-S', '--scales', type=int, default=2,
help="Steps per scale degree of freedom.")
parser.add_argument('-T', '--translations', type=int, default=1,
help="Steps per translation degree of freedom.")
parser.add_argument('-s', '--silent', action='count',
help="Do not print progress information.")
args = parser.parse_args(argv[1:])
# validate some parameters
args.rotations = max(0, args.rotations)
args.scales = max(1, args.scales)
args.translations = max(1, args.translations)
# augment all the input files
if args.input:
process_one_input(args.input, args, is_label=False)
if args.label:
process_one_input(args.label, args, is_label=True)
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment