Skip to content

Instantly share code, notes, and snippets.

@dgobbi
Last active March 31, 2023 10:08
Show Gist options
  • Save dgobbi/bddbbabb1a9c86e8d8373752859dac9f to your computer and use it in GitHub Desktop.
Save dgobbi/bddbbabb1a9c86e8d8373752859dac9f to your computer and use it in GitHub Desktop.
Generate many nifti files corresponding to transformations of one nifti file.
"""
Program vtk_augment_nifti.py
Sep 25, 2018
David Gobbi
dgobbi@ucalgary.ca
"""
import vtk
import sys
import os.path
import argparse
import math
brief = "Apply multiple transformations to a NIfTI file."
helptext = """
This program will take in image as input and will generate multiple
output files, each with a different transformation applied.
You must supply an existing output directory for the files to be
written to. Each output file will be named "<file>_NNNN.nii.gz"
where <file> is the name of the input, and "NNNN" is four digits.
You can specifify the number of distinct transformations as
follows:
-R N (will produce 3*NxNxN transforms, for the 3 angles)
-S M (will produce M transforms, since only 1 scale param is used)
-T L (will produce LxL transforms, since 2 translation params are used)
The total number of outputs is therefore 3*N*N*N*M*L*L.
Translations are only done in the row and column directions, not the slice
direction. The largest applied translation will be 15% of the image size.
The largest rotation that will be applied is 45 degrees. There will not
be any outputs that correspond to a rotation of zero degrees, unless you
specify N=0.
The scale factor will be a minimum of 1.0 and a maximum of 1.5. No scales
less than 1.0 are used since shrinking an image can lead to aliasing
artifacts.
Note that there are to ways to specify input images: -i and -l.
For "-i", the image is assumed to be greyscale and linear interpolation
is used. For "-l", the image is assumed to be binary or indexed, and
nearest-neighbor interpolation is used.
"""
def linscale(scale_range, n):
"""Generate a linear scale of 'n' values within the given range.
"""
if n <= 1:
return [0.5*(scale_range[0] + scale_range[1])]
d = float(n - 1)
return [(n-i-1)/d*scale_range[0] + i/d*scale_range[1] for i in range(n)]
def combine(*scales):
"""Create a grid from the input scales (inefficiently).
"""
if len(scales) == 0:
return []
result = [ [y] for y in scales[0] ]
for scale in scales[1:]:
newres = []
for y in scale:
for x in result:
newres.append(x + [y])
result = newres
return result
def generate_rotations(angle_range, n):
"""Generate approximately regularly spaced rotations by using
cube geometry instead of sphere geometry.
"""
if n <= 0:
return [ [ 0.0, 1.0, 0.0, 0.0 ] ]
# use half a cube to generate axes of rotation
# (because a cube is the easiest platonic solid to work from)
# here are the angles (except the first, assumed to be zero)
scale0 = linscale(angle_range, n + 1)[1:]
# the scale along one edge of a cube
scale1 = linscale([-1.0, 1.0], n)
# rotations corresponding to axes intersecting 3 of the cube faces
result = combine(scale0, [1.0], scale1, scale1) # cube face 1
result += combine(scale0, scale1, [1.0], scale1) # cube face 2
result += combine(scale0, scale1, scale1, [1.0]) # cube face 3
# normalize the axis of rotation
for w in result:
r = math.sqrt(w[1]**2 + w[2]**2 + w[3]**2)
w[1] /= r
w[2] /= r
w[3] /= r
return result
def generate_scales(scale_range, n):
"""Use a linear scale for scale parameter.
"""
return linscale(scale_range, n)
def generate_translations(trans_ranges, n):
"""Translate in all three directions.
"""
scales = [linscale(rng, n) for rng in trans_ranges]
while len(scales) < 3:
scales.append([1.0])
return combine(*scales)
def build_transform(rotation, scale, translation, center):
"""Build a vtkTransform from a set of parameters.
"""
transform = vtk.vtkTransform()
transform.PostMultiply()
transform.Translate([-x for x in center])
transform.RotateWXYZ(*rotation)
transform.Scale(scale, scale, scale)
transform.Translate(center)
transform.Translate(translation)
return transform
def process_one_output(output_file, header, sform, qform, image, transform,
is_label):
"""Write one transformed image file.
"""
# slow, high-quality interpolator
#interpolator = vtk.vtkImageSincInterpolator()
#interpolator.SetWindowFunctionToBlackman()
# resample the image through a transform
reslice = vtk.vtkImageReslice()
reslice.SetNumberOfThreads(1)
reslice.SetInputData(image)
if not is_label:
reslice.SetInterpolationModeToLinear()
reslice.TransformInputSamplingOff()
reslice.SetResliceTransform(transform.GetInverse())
reslice.Update()
# write the image with the same sform, qform, and header as the input
writer = vtk.vtkNIFTIImageWriter()
writer.SetInputData(reslice.GetOutput())
writer.SetFileName(output_file)
writer.SetNIFTIHeader(header)
writer.SetSFormMatrix(sform)
writer.SetQFormMatrix(qform)
writer.Write()
def process_one_input(input_file, args, is_label):
"""Augment one input image file.
"""
# read the input file
reader = vtk.vtkNIFTIImageReader()
reader.SetFileName(input_file)
reader.Update()
# get the header, the sform, the qform, and the image data
header = reader.GetNIFTIHeader()
sform = reader.GetSFormMatrix()
qform = reader.GetQFormMatrix()
image = reader.GetOutput()
center = image.GetCenter()
spacing = image.GetSpacing()
shape = image.GetDimensions()
f = 0.15 # translation is a faction of the image size
trans_ranges = [ [-a*b*f, a*b*f] for a,b in zip(shape, spacing)]
rotation_range = [0.0, 45.0]
scale_range = [1.0, 1.5]
# only translate in x, y (assume the algorithms that use this data
# will be processing it slice-by-slice, not volumetrically)
trans_ranges = trans_ranges[0:-1]
# generate all of the transformational parameters
params = combine(generate_rotations(rotation_range, args.rotations),
generate_scales(scale_range, args.scales),
generate_translations(trans_ranges, args.translations))
if not args.silent:
sys.stdout.write("Generating %d outputs per input " % len(params))
sys.stdout.write("(each \".\" is one output).\n")
sys.stdout.flush()
# get prefix from input file
base = os.path.basename(input_file)
ext = '.nii.gz'
for e in ['.nii.gz', '.nii']:
if base.endswith(e):
base = base[0:-len(e)]
ext = e
break
# go though all the parameters
counter = 0
for (rotation, scale, translation) in params:
if not args.silent:
sys.stdout.write(".")
sys.stdout.flush()
# build the transform from the parameters
transform = build_transform(rotation, scale, translation, center)
# create one output
counter += 1
basename = base + ('_%04d' % counter) + ext
output_file = os.path.join(args.output, basename)
process_one_output(output_file, header, sform, qform, image, transform,
is_label)
if not args.silent:
if counter % 50 == 0:
sys.stdout.write("\n")
if not args.silent:
if counter % 50 != 0:
sys.stdout.write("\n")
sys.stdout.flush()
def main(argv):
"""The main program.
"""
# parse the command line
parser = argparse.ArgumentParser(
prog=argv[0],
formatter_class=argparse.RawDescriptionHelpFormatter,
description=brief, epilog=helptext)
parser.add_argument('-i', '--input', required=False,
help="Input greyscale image.")
parser.add_argument('-l', '--label', required=False,
help="Input label image.")
parser.add_argument('-o', '--output', required=True,
help="Output directory.")
parser.add_argument('-R', '--rotations', type=int, default=3,
help="Steps per rotation degree of freedom.")
parser.add_argument('-S', '--scales', type=int, default=2,
help="Steps per scale degree of freedom.")
parser.add_argument('-T', '--translations', type=int, default=1,
help="Steps per translation degree of freedom.")
parser.add_argument('-s', '--silent', action='count',
help="Do not print progress information.")
args = parser.parse_args(argv[1:])
# validate some parameters
args.rotations = max(0, args.rotations)
args.scales = max(1, args.scales)
args.translations = max(1, args.translations)
# augment all the input files
if args.input:
process_one_input(args.input, args, is_label=False)
if args.label:
process_one_input(args.label, args, is_label=True)
if __name__ == '__main__':
main(sys.argv)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment