Last active
April 7, 2023 16:26
-
-
Save Lestropie/16342f43e93db042d2693d4191802b65 to your computer and use it in GitHub Desktop.
MRtrix3 stand-alone Python script for generating connectivity fingerprints within a ROI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# Works using MRtrix version 3.0.0 or later | |
# For linking against MRtrix3 Python API: | |
# https://mrtrix.readthedocs.io/en/3.0.0/tips_and_tricks/external_modules.html | |
# Typical pre-processing required before running this script: | |
# 1. FOD: Standard DWI pre-processing | |
# 2. 5TT: e.g. 5ttgen fsl | |
# 3. Parcellation: | |
# 3.1. Run FreeSurfer | |
# 3.2. labelconvert - Map large indices provided in FreeSurfer image output to indices that increase from 1 | |
# 3.3. labelsgmfix - To improve parcellation of sub-cortical grey matter structures | |
# 4. Mask: | |
# 4.1. FSL FIRST (run_first_all) - segment structure(s) of interest only | |
# 4.2. meshconvert - Convert .vtk mesh file from FIRST convention to realspace coordinates | |
# 4.2. mesh2pve - Partial volume fractions for each voxel in a template image that lie within the mesh | |
# 4.3. mrthreshold -abs 0.5 - Convert partial volume fraction image to a mask | |
import json, math, os | |
from mrtrix3 import COMMAND_HISTORY_STRING, MRtrixError | |
from mrtrix3 import app, image, matrix, path, run | |
def usage(cmdline): #pylint: disable=unused-variable | |
cmdline.set_author('Robert E. Smith (robert.smith@florey.edu.au)') | |
cmdline.set_synopsis('Generate, for every voxel in a mask, a connectivity fingerprint to a parcellation image') | |
cmdline.add_argument('input_fod', help='The input FODs') | |
cmdline.add_argument('input_5tt', help='The input 5TT image for ACT') | |
cmdline.add_argument('input_parc', help='The input parcellation image') | |
cmdline.add_argument('input_mask', help='The mask of the structure to be parcellated') | |
cmdline.add_argument('output_voxels', help='The list of voxels for which fingerprints were successfully generated') | |
cmdline.add_argument('output_data', help='The connectivity fingerprint of each successfully-processed voxel') | |
cmdline.add_argument('-tckgen_options', help='Options to pass to the tckgen command (remember to wrap in quotation marks)') | |
cmdline.add_argument('-grad_image', help='Export an image representing spatial gradients in connectivity fingerprints') | |
def execute(): #pylint: disable=unused-variable | |
stderr = run.command(['5ttcheck', path.from_user(app.ARGS.input_5tt, False)]).stderr | |
if '[WARNING]' in stderr: | |
app.warn('Generated image does not perfectly conform to 5TT format:') | |
for line in stderr.splitlines(): | |
app.warn(line) | |
if len(image.Header(path.from_user(app.ARGS.input_fod, False)).size()) != 4: | |
raise MRtrixError('Input FOD image must be a 4D image') | |
app.check_output_path(path.from_user(app.ARGS.output_voxels, False)) | |
app.check_output_path(path.from_user(app.ARGS.output_data, False)) | |
if app.ARGS.grad_image: | |
app.check_output_path(path.from_user(app.ARGS.grad_image, False)) | |
tckgen_options = '-backtrack -crop_at_gmwmi -select 5k -seeds 10M' | |
if app.ARGS.tckgen_options: | |
tckgen_options = app.ARGS.tckgen_options | |
# Need to extract the target number of streamlines from tckgen_options | |
# Beware that this may include a postfix multiplier | |
tckgen_options_split = tckgen_options.split(' ') | |
if not '-select' in tckgen_options_split: | |
raise MRtrixError('Contents of -tckgen_options must at least include the -select option') | |
target_count = tckgen_options_split[tckgen_options_split.index('-select')+1] | |
if target_count[-1].isalpha(): | |
multiplier = 1 | |
if target_count[-1].lower() == 'k': | |
multiplier = 1000 | |
elif target_count[-1].lower() == 'm': | |
multiplier = 1000000 | |
elif target_count[-1].lower() == 'b': | |
multiplier = 1000000000 | |
else: | |
raise MRtrixError('Could not convert -select field \'' + target_count + '\' to an integer') | |
target_count = int(target_count[:-1]) * multiplier | |
else: | |
target_count = int(target_count) | |
app.make_scratch_dir() | |
# Only the mask image is copied to the scratch directory for convenience; the others can be used in-place | |
run.command('mrconvert ' + path.from_user(app.ARGS.input_mask) + ' ' + path.to_scratch('mask.mif') + ' -datatype bit') | |
app.goto_scratch_dir() | |
mask_size = image.Header('mask.mif').size() | |
if not (len(mask_size) == 3 or (len(mask_size) == 4 and mask_size[3] == 1)): | |
raise MRtrixError('Input mask image must be a 3D image') | |
run.command('maskdump mask.mif input_voxels.txt') | |
input_voxels = [ _.split() for _ in open('input_voxels.txt', 'r').read().splitlines() if _ and not _.lstrip().startswith('#') ] | |
# Images of: | |
# - The number of streamline attempts from each voxel | |
# - The number of streamlines generated from each voxel | |
# - The number of streamlines successfully reaching a target node from each voxel | |
# This now needs to be done at the tracking step instead: | |
# some voxels may abort early | |
run.command('mrthreshold mask.mif empty_mask.mif -abs 1.5') | |
run.command('mrconvert empty_mask.mif num_attempts.mif -datatype uint32') | |
run.command('mrconvert empty_mask.mif num_streamlines.mif -datatype uint32') | |
run.command('mrconvert empty_mask.mif num_assigned.mif -datatype uint32') | |
def track_counts(filepath): | |
count = None | |
total_count = None | |
tckinfo_output = run.command('tckinfo ' + filepath).stdout | |
for line in tckinfo_output.splitlines(): | |
key_value = [ entry.strip() for entry in line.split(':') ] | |
if len(key_value) != 2: | |
continue | |
if key_value[0] == 'count': | |
count = int(key_value[1]) | |
elif key_value[0] == 'total_count': | |
total_count = int(key_value[1]) | |
return (count, total_count) | |
progress = app.ProgressBar('Tracking: 0 of ' + str(len(input_voxels)) + ' voxels processed', len(input_voxels)) | |
voxel_counter = 0 | |
failure_counter = 0 | |
output_voxels = [ ] | |
output_data = [ ] | |
for voxel in input_voxels: | |
seed_path = 'seed_' + voxel[0] + '_' + voxel[1] + '_' + voxel[2] + '.mif' | |
tracks_path = 'tracks_' + voxel[0] + '_' + voxel[1] + '_' + voxel[2] + '.tck' | |
connectome_path = 'connectome_' + voxel[0] + '_' + voxel[1] + '_' + voxel[2] + '.csv' | |
run.command('mrconvert mask.mif ' + seed_path + ' -coord 0 ' + voxel[0] + ' -coord 1 ' + voxel[1] + ' -coord 2 ' + voxel[2]) | |
run.command('tckgen ' + path.from_user(app.ARGS.input_fod) + ' ' + tracks_path + ' -act ' + path.from_user(app.ARGS.input_5tt) + ' -seed_image ' + seed_path + ' -seed_unidirectional -config TckgenEarlyExit true' + tckgen_options) | |
# Capture the number of tracks that needed to be generated; | |
# see if there's any useful contrast | |
# (actually, get the ratio of generated vs. accepted) | |
# Will lose this ability if changing to a fixed-number-of-streamlines-per-voxel seeding mechanism... | |
# Reject voxel if the requested number of streamlines was not generated | |
if os.path.isfile(tracks_path): | |
counts = track_counts(tracks_path) | |
run.command('mredit num_streamlines.mif -voxel ' + ','.join(voxel) + ' ' + str(counts[0])) | |
run.command('mredit num_attempts.mif -voxel ' + ','.join(voxel) + ' ' + str(counts[1])) | |
if counts[0] == target_count: | |
run.command('tck2connectome ' + tracks_path + ' ' + path.from_user(app.ARGS.input_parc) + ' ' + connectome_path + ' -vector -keep_unassigned') | |
connectome = matrix.load_vector(connectome_path) | |
run.command('mredit num_assigned.mif -voxel ' + ','.join(voxel) + ' ' + str(sum(connectome[1:]))) | |
output_voxels.append([int(i) for i in voxel]) | |
output_data.append(connectome) | |
else: | |
failure_counter += 1 | |
app.debug('Tracking aborted for voxel ' + ','.join(voxel) + ' (' + str(counts[0]) + ' of ' + str(counts[1]) + ' streamlines accepted)') | |
# Never keep, even with -nocleanup | |
# Do however need to check their existence, in case we're using -continue | |
if os.path.isfile(seed_path): | |
os.remove(seed_path) | |
if os.path.isfile(tracks_path): | |
os.remove(tracks_path) | |
voxel_counter += 1 | |
progress.increment('Tracking: ' + str(voxel_counter) + ' of ' + str(len(input_voxels)) + ' voxels processed') | |
progress.done() | |
if failure_counter: | |
app.warn(str(failure_counter) + ' of ' + str(len(input_voxels)) + ' voxels not successfully tracked') | |
matrix.save_matrix(path.from_user(app.ARGS.output_voxels, False), output_voxels, fmt='%i', force=app.FORCE_OVERWRITE) | |
matrix.save_matrix(path.from_user(app.ARGS.output_data, False), output_data, fmt='%i', force=app.FORCE_OVERWRITE) | |
if not app.ARGS.grad_image: | |
return | |
# Rather than going straight to clustering, generate an image of 'connectivity gradient' | |
# Difference in connectivity profiles between adjacent voxels in 3 dimensions | |
# Get the voxel sizes of the mask image; these must be used to scale the gradient calculations | |
spacings = image.Header('mask.mif').spacing() | |
# Construct the list of possible voxel offsets, remembering that the opposite offset will also be tested | |
offsets = [ [0,0,1], [0,1,0], [1,0,0], [0,1,1], [1,0,1], [1, 1, 0], [0,1,-1], [1,0,-1], [1,-1,0], [1,1,1], [1,1,-1], [1,-1,1], [-1,1,1] ] | |
offset_length_multipliers = [ (1.0/length) for length in [ math.sqrt(float(o[0]*o[0]+o[1]*o[1]+o[2]*o[2])) for o in offsets ] ] | |
unit_offsets = [ [ float(f)*m for f in o ] for o, m in zip(offsets, offset_length_multipliers) ] | |
# Want to know the length in mm of each of these offsets, such that the output gradient is in units of (CosSim.mm^-1) | |
scale_factors = [ (1.0/length) for length in [ math.sqrt(math.pow(spacings[0]*o[0],2.0)+math.pow(spacings[1]*o[1],2.0)+math.pow(spacings[2]*o[2],2.0)) for o in offsets ] ] | |
# First, generate a bogus file here: That will allow the -continue option to be used | |
# Actually, instead, delay creation of the output image to this point | |
# Output image must now be 4D with 3 volumes | |
# Scratch that: _13_ volumes | |
# Need to embed the gradient directions within the image header | |
# Write them to a file also | |
direction_string = '' | |
with open('directions.txt', 'w') as dirfile: | |
for direction in unit_offsets: | |
text = ','.join(str(d) for d in direction) + '\n' | |
dirfile.write (text) | |
direction_string += text | |
direction_string = direction_string[:-1] # Remove the trailing newline character | |
run.command('mrcat ' + ' '.join( ['empty_mask.mif']*13) + ' -axis 3 - | mrconvert - result.mif -stride 0,0,0,1 -datatype float32 -set_property directions \"' + direction_string + '\"') | |
# Store all connectivity vectors in memory | |
# Load them here rather than as they are generated so that the RAM isn't used up unnecessarily - | |
# also just to make sure that they're loaded correctly if -continue is used | |
connectomes = { } | |
progress = app.ProgressBar('Loading connectomes into memory', len(input_voxels)) | |
for voxel in input_voxels: | |
connectome_path = 'connectome_' + voxel[0] + '_' + voxel[1] + '_' + voxel[2] + '.csv' | |
if os.path.exists(connectome_path): | |
connectomes[tuple(voxel)] = matrix.load_vector(connectome_path, dtype=int) | |
progress.increment() | |
progress.done() | |
def cos_sim(one, two): # Cosine Similarity | |
if not isinstance(one, list) or not isinstance(two, list): | |
raise MRtrixError('Internal error: cos_sim() function intended to work on lists') | |
if not isinstance(one[0], int) or not isinstance(two[0], int): | |
raise MRtrixError('Internal error: cos_sim() function intended to work on lists of integers') | |
if len(one) != len(two): | |
raise MRtrixError('Internal error: Mismatched connectome vector lengths') | |
# The following check shouldn't really be required anymore, since voxels in which tracking | |
# fails are omitted from the analysis, but let's leave it here anyway | |
one_norm = math.sqrt(float(sum(i*i for i in one[1:]))) | |
two_norm = math.sqrt(float(sum(i*i for i in two[1:]))) | |
if not one_norm or not two_norm: | |
return 1.0 | |
one_normalised = [ float(j) / one_norm for j in one[1:] ] | |
two_normalised = [ float(j) / two_norm for j in two[1:] ] | |
return sum( [ one_normalised[_] * two_normalised[_] for _ in range(0, len(one_normalised)) ] ) | |
# Is it possible to get something that's directional? | |
# Maybe, rather than just the three axes, test diagonals as well | |
# Total of 13 directions, so can't get an lmax=4 fit, but can use dixel overlay plot, and | |
# rotate directions based on rigid transformation | |
# - Could potentially get lmax=4 with spatial regularisation... | |
# Maybe then do an lmax=2 fit and get peak direction / components in scanner XYZ? | |
# Also: Scale appropriately depending on (anisotropic) voxel size | |
# TODO This will run exceptionally slowly if the target is on a network file system | |
# Alternatively, could concatenate the whole lot into a single mredit call? | |
# For each voxel, calculate 'gradient' in connectivity profile in 13 directions | |
# One or both of the adjacent voxels in any particular direction may be absent from the mask - handle appropriately | |
progress = app.ProgressBar('Gradient calculation: 0 of ' + str(len(input_voxels)) + ' voxels processed', len(input_voxels)) | |
voxel_counter = 0 | |
for voxel in input_voxels: | |
if tuple(voxel) in connectomes: # Some voxels may have failed tracking | |
# TODO Could at least concatenate the 13 values for a single voxel into a single mredit call | |
for index, (offset, scale_factor) in enumerate(zip(offsets, scale_factors)): | |
vneg = [ str(int(voxel[axis]) - offset[axis]) for axis in range(0,3) ] | |
vpos = [ str(int(voxel[axis]) + offset[axis]) for axis in range(0,3) ] | |
grad = 0.0 | |
if vneg in input_voxels: | |
grad += (1.0 - cos_sim(connectomes[tuple(voxel)], connectomes[tuple(vneg)])) * scale_factor | |
if vpos in input_voxels: | |
grad += (1.0 - cos_sim(connectomes[tuple(voxel)], connectomes[tuple(vpos)])) * scale_factor | |
run.command('mredit result.mif -voxel ' + ','.join(voxel) + ',' + str(index) + ' ' + str(grad)) | |
voxel_counter += 1 | |
progress.increment('Gradient calculation: ' + str(voxel_counter) + ' of ' + str(len(input_voxels)) + ' voxels processed') | |
progress.done() | |
app.var(unit_offsets) | |
output_keyval = { 'command_history': COMMAND_HISTORY_STRING, | |
'directions': unit_offsets } | |
app.var(output_keyval) | |
with open('output.json', 'w') as jsonfile: | |
json.dump(output_keyval, jsonfile) | |
run.command('mrconvert result.mif ' + path.from_user(app.ARGS.grad_image), mrconvert_keyval='output.json', force=app.FORCE_OVERWRITE) | |
# Execute the script | |
import mrtrix3 | |
mrtrix3.execute() #pylint: disable=no-member |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment