Skip to content

Instantly share code, notes, and snippets.

@Lestropie
Last active April 7, 2023 16:26
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save Lestropie/16342f43e93db042d2693d4191802b65 to your computer and use it in GitHub Desktop.
Save Lestropie/16342f43e93db042d2693d4191802b65 to your computer and use it in GitHub Desktop.
MRtrix3 stand-alone Python script for generating connectivity fingerprints within a ROI
#!/usr/bin/env python
# Works using MRtrix version 3.0.0 or later
# For linking against MRtrix3 Python API:
# https://mrtrix.readthedocs.io/en/3.0.0/tips_and_tricks/external_modules.html
# Typical pre-processing required before running this script:
# 1. FOD: Standard DWI pre-processing
# 2. 5TT: e.g. 5ttgen fsl
# 3. Parcellation:
# 3.1. Run FreeSurfer
# 3.2. labelconvert - Map large indices provided in FreeSurfer image output to indices that increase from 1
# 3.3. labelsgmfix - To improve parcellation of sub-cortical grey matter structures
# 4. Mask:
# 4.1. FSL FIRST (run_first_all) - segment structure(s) of interest only
# 4.2. meshconvert - Convert .vtk mesh file from FIRST convention to realspace coordinates
# 4.2. mesh2pve - Partial volume fractions for each voxel in a template image that lie within the mesh
# 4.3. mrthreshold -abs 0.5 - Convert partial volume fraction image to a mask
import json, math, os
from mrtrix3 import COMMAND_HISTORY_STRING, MRtrixError
from mrtrix3 import app, image, matrix, path, run
def usage(cmdline): #pylint: disable=unused-variable
cmdline.set_author('Robert E. Smith (robert.smith@florey.edu.au)')
cmdline.set_synopsis('Generate, for every voxel in a mask, a connectivity fingerprint to a parcellation image')
cmdline.add_argument('input_fod', help='The input FODs')
cmdline.add_argument('input_5tt', help='The input 5TT image for ACT')
cmdline.add_argument('input_parc', help='The input parcellation image')
cmdline.add_argument('input_mask', help='The mask of the structure to be parcellated')
cmdline.add_argument('output_voxels', help='The list of voxels for which fingerprints were successfully generated')
cmdline.add_argument('output_data', help='The connectivity fingerprint of each successfully-processed voxel')
cmdline.add_argument('-tckgen_options', help='Options to pass to the tckgen command (remember to wrap in quotation marks)')
cmdline.add_argument('-grad_image', help='Export an image representing spatial gradients in connectivity fingerprints')
def execute(): #pylint: disable=unused-variable
stderr = run.command(['5ttcheck', path.from_user(app.ARGS.input_5tt, False)]).stderr
if '[WARNING]' in stderr:
app.warn('Generated image does not perfectly conform to 5TT format:')
for line in stderr.splitlines():
app.warn(line)
if len(image.Header(path.from_user(app.ARGS.input_fod, False)).size()) != 4:
raise MRtrixError('Input FOD image must be a 4D image')
app.check_output_path(path.from_user(app.ARGS.output_voxels, False))
app.check_output_path(path.from_user(app.ARGS.output_data, False))
if app.ARGS.grad_image:
app.check_output_path(path.from_user(app.ARGS.grad_image, False))
tckgen_options = '-backtrack -crop_at_gmwmi -select 5k -seeds 10M'
if app.ARGS.tckgen_options:
tckgen_options = app.ARGS.tckgen_options
# Need to extract the target number of streamlines from tckgen_options
# Beware that this may include a postfix multiplier
tckgen_options_split = tckgen_options.split(' ')
if not '-select' in tckgen_options_split:
raise MRtrixError('Contents of -tckgen_options must at least include the -select option')
target_count = tckgen_options_split[tckgen_options_split.index('-select')+1]
if target_count[-1].isalpha():
multiplier = 1
if target_count[-1].lower() == 'k':
multiplier = 1000
elif target_count[-1].lower() == 'm':
multiplier = 1000000
elif target_count[-1].lower() == 'b':
multiplier = 1000000000
else:
raise MRtrixError('Could not convert -select field \'' + target_count + '\' to an integer')
target_count = int(target_count[:-1]) * multiplier
else:
target_count = int(target_count)
app.make_scratch_dir()
# Only the mask image is copied to the scratch directory for convenience; the others can be used in-place
run.command('mrconvert ' + path.from_user(app.ARGS.input_mask) + ' ' + path.to_scratch('mask.mif') + ' -datatype bit')
app.goto_scratch_dir()
mask_size = image.Header('mask.mif').size()
if not (len(mask_size) == 3 or (len(mask_size) == 4 and mask_size[3] == 1)):
raise MRtrixError('Input mask image must be a 3D image')
run.command('maskdump mask.mif input_voxels.txt')
input_voxels = [ _.split() for _ in open('input_voxels.txt', 'r').read().splitlines() if _ and not _.lstrip().startswith('#') ]
# Images of:
# - The number of streamline attempts from each voxel
# - The number of streamlines generated from each voxel
# - The number of streamlines successfully reaching a target node from each voxel
# This now needs to be done at the tracking step instead:
# some voxels may abort early
run.command('mrthreshold mask.mif empty_mask.mif -abs 1.5')
run.command('mrconvert empty_mask.mif num_attempts.mif -datatype uint32')
run.command('mrconvert empty_mask.mif num_streamlines.mif -datatype uint32')
run.command('mrconvert empty_mask.mif num_assigned.mif -datatype uint32')
def track_counts(filepath):
count = None
total_count = None
tckinfo_output = run.command('tckinfo ' + filepath).stdout
for line in tckinfo_output.splitlines():
key_value = [ entry.strip() for entry in line.split(':') ]
if len(key_value) != 2:
continue
if key_value[0] == 'count':
count = int(key_value[1])
elif key_value[0] == 'total_count':
total_count = int(key_value[1])
return (count, total_count)
progress = app.ProgressBar('Tracking: 0 of ' + str(len(input_voxels)) + ' voxels processed', len(input_voxels))
voxel_counter = 0
failure_counter = 0
output_voxels = [ ]
output_data = [ ]
for voxel in input_voxels:
seed_path = 'seed_' + voxel[0] + '_' + voxel[1] + '_' + voxel[2] + '.mif'
tracks_path = 'tracks_' + voxel[0] + '_' + voxel[1] + '_' + voxel[2] + '.tck'
connectome_path = 'connectome_' + voxel[0] + '_' + voxel[1] + '_' + voxel[2] + '.csv'
run.command('mrconvert mask.mif ' + seed_path + ' -coord 0 ' + voxel[0] + ' -coord 1 ' + voxel[1] + ' -coord 2 ' + voxel[2])
run.command('tckgen ' + path.from_user(app.ARGS.input_fod) + ' ' + tracks_path + ' -act ' + path.from_user(app.ARGS.input_5tt) + ' -seed_image ' + seed_path + ' -seed_unidirectional -config TckgenEarlyExit true' + tckgen_options)
# Capture the number of tracks that needed to be generated;
# see if there's any useful contrast
# (actually, get the ratio of generated vs. accepted)
# Will lose this ability if changing to a fixed-number-of-streamlines-per-voxel seeding mechanism...
# Reject voxel if the requested number of streamlines was not generated
if os.path.isfile(tracks_path):
counts = track_counts(tracks_path)
run.command('mredit num_streamlines.mif -voxel ' + ','.join(voxel) + ' ' + str(counts[0]))
run.command('mredit num_attempts.mif -voxel ' + ','.join(voxel) + ' ' + str(counts[1]))
if counts[0] == target_count:
run.command('tck2connectome ' + tracks_path + ' ' + path.from_user(app.ARGS.input_parc) + ' ' + connectome_path + ' -vector -keep_unassigned')
connectome = matrix.load_vector(connectome_path)
run.command('mredit num_assigned.mif -voxel ' + ','.join(voxel) + ' ' + str(sum(connectome[1:])))
output_voxels.append([int(i) for i in voxel])
output_data.append(connectome)
else:
failure_counter += 1
app.debug('Tracking aborted for voxel ' + ','.join(voxel) + ' (' + str(counts[0]) + ' of ' + str(counts[1]) + ' streamlines accepted)')
# Never keep, even with -nocleanup
# Do however need to check their existence, in case we're using -continue
if os.path.isfile(seed_path):
os.remove(seed_path)
if os.path.isfile(tracks_path):
os.remove(tracks_path)
voxel_counter += 1
progress.increment('Tracking: ' + str(voxel_counter) + ' of ' + str(len(input_voxels)) + ' voxels processed')
progress.done()
if failure_counter:
app.warn(str(failure_counter) + ' of ' + str(len(input_voxels)) + ' voxels not successfully tracked')
matrix.save_matrix(path.from_user(app.ARGS.output_voxels, False), output_voxels, fmt='%i', force=app.FORCE_OVERWRITE)
matrix.save_matrix(path.from_user(app.ARGS.output_data, False), output_data, fmt='%i', force=app.FORCE_OVERWRITE)
if not app.ARGS.grad_image:
return
# Rather than going straight to clustering, generate an image of 'connectivity gradient'
# Difference in connectivity profiles between adjacent voxels in 3 dimensions
# Get the voxel sizes of the mask image; these must be used to scale the gradient calculations
spacings = image.Header('mask.mif').spacing()
# Construct the list of possible voxel offsets, remembering that the opposite offset will also be tested
offsets = [ [0,0,1], [0,1,0], [1,0,0], [0,1,1], [1,0,1], [1, 1, 0], [0,1,-1], [1,0,-1], [1,-1,0], [1,1,1], [1,1,-1], [1,-1,1], [-1,1,1] ]
offset_length_multipliers = [ (1.0/length) for length in [ math.sqrt(float(o[0]*o[0]+o[1]*o[1]+o[2]*o[2])) for o in offsets ] ]
unit_offsets = [ [ float(f)*m for f in o ] for o, m in zip(offsets, offset_length_multipliers) ]
# Want to know the length in mm of each of these offsets, such that the output gradient is in units of (CosSim.mm^-1)
scale_factors = [ (1.0/length) for length in [ math.sqrt(math.pow(spacings[0]*o[0],2.0)+math.pow(spacings[1]*o[1],2.0)+math.pow(spacings[2]*o[2],2.0)) for o in offsets ] ]
# First, generate a bogus file here: That will allow the -continue option to be used
# Actually, instead, delay creation of the output image to this point
# Output image must now be 4D with 3 volumes
# Scratch that: _13_ volumes
# Need to embed the gradient directions within the image header
# Write them to a file also
direction_string = ''
with open('directions.txt', 'w') as dirfile:
for direction in unit_offsets:
text = ','.join(str(d) for d in direction) + '\n'
dirfile.write (text)
direction_string += text
direction_string = direction_string[:-1] # Remove the trailing newline character
run.command('mrcat ' + ' '.join( ['empty_mask.mif']*13) + ' -axis 3 - | mrconvert - result.mif -stride 0,0,0,1 -datatype float32 -set_property directions \"' + direction_string + '\"')
# Store all connectivity vectors in memory
# Load them here rather than as they are generated so that the RAM isn't used up unnecessarily -
# also just to make sure that they're loaded correctly if -continue is used
connectomes = { }
progress = app.ProgressBar('Loading connectomes into memory', len(input_voxels))
for voxel in input_voxels:
connectome_path = 'connectome_' + voxel[0] + '_' + voxel[1] + '_' + voxel[2] + '.csv'
if os.path.exists(connectome_path):
connectomes[tuple(voxel)] = matrix.load_vector(connectome_path, dtype=int)
progress.increment()
progress.done()
def cos_sim(one, two): # Cosine Similarity
if not isinstance(one, list) or not isinstance(two, list):
raise MRtrixError('Internal error: cos_sim() function intended to work on lists')
if not isinstance(one[0], int) or not isinstance(two[0], int):
raise MRtrixError('Internal error: cos_sim() function intended to work on lists of integers')
if len(one) != len(two):
raise MRtrixError('Internal error: Mismatched connectome vector lengths')
# The following check shouldn't really be required anymore, since voxels in which tracking
# fails are omitted from the analysis, but let's leave it here anyway
one_norm = math.sqrt(float(sum(i*i for i in one[1:])))
two_norm = math.sqrt(float(sum(i*i for i in two[1:])))
if not one_norm or not two_norm:
return 1.0
one_normalised = [ float(j) / one_norm for j in one[1:] ]
two_normalised = [ float(j) / two_norm for j in two[1:] ]
return sum( [ one_normalised[_] * two_normalised[_] for _ in range(0, len(one_normalised)) ] )
# Is it possible to get something that's directional?
# Maybe, rather than just the three axes, test diagonals as well
# Total of 13 directions, so can't get an lmax=4 fit, but can use dixel overlay plot, and
# rotate directions based on rigid transformation
# - Could potentially get lmax=4 with spatial regularisation...
# Maybe then do an lmax=2 fit and get peak direction / components in scanner XYZ?
# Also: Scale appropriately depending on (anisotropic) voxel size
# TODO This will run exceptionally slowly if the target is on a network file system
# Alternatively, could concatenate the whole lot into a single mredit call?
# For each voxel, calculate 'gradient' in connectivity profile in 13 directions
# One or both of the adjacent voxels in any particular direction may be absent from the mask - handle appropriately
progress = app.ProgressBar('Gradient calculation: 0 of ' + str(len(input_voxels)) + ' voxels processed', len(input_voxels))
voxel_counter = 0
for voxel in input_voxels:
if tuple(voxel) in connectomes: # Some voxels may have failed tracking
# TODO Could at least concatenate the 13 values for a single voxel into a single mredit call
for index, (offset, scale_factor) in enumerate(zip(offsets, scale_factors)):
vneg = [ str(int(voxel[axis]) - offset[axis]) for axis in range(0,3) ]
vpos = [ str(int(voxel[axis]) + offset[axis]) for axis in range(0,3) ]
grad = 0.0
if vneg in input_voxels:
grad += (1.0 - cos_sim(connectomes[tuple(voxel)], connectomes[tuple(vneg)])) * scale_factor
if vpos in input_voxels:
grad += (1.0 - cos_sim(connectomes[tuple(voxel)], connectomes[tuple(vpos)])) * scale_factor
run.command('mredit result.mif -voxel ' + ','.join(voxel) + ',' + str(index) + ' ' + str(grad))
voxel_counter += 1
progress.increment('Gradient calculation: ' + str(voxel_counter) + ' of ' + str(len(input_voxels)) + ' voxels processed')
progress.done()
app.var(unit_offsets)
output_keyval = { 'command_history': COMMAND_HISTORY_STRING,
'directions': unit_offsets }
app.var(output_keyval)
with open('output.json', 'w') as jsonfile:
json.dump(output_keyval, jsonfile)
run.command('mrconvert result.mif ' + path.from_user(app.ARGS.grad_image), mrconvert_keyval='output.json', force=app.FORCE_OVERWRITE)
# Execute the script
import mrtrix3
mrtrix3.execute() #pylint: disable=no-member
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment