Skip to content

Instantly share code, notes, and snippets.

@maxpietsch
Last active October 13, 2020 15:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maxpietsch/9aabbee5bac3634b7fc4aede25237880 to your computer and use it in GitHub Desktop.
Save maxpietsch/9aabbee5bac3634b7fc4aede25237880 to your computer and use it in GitHub Desktop.
generate screenshots for QC report
#!/usr/bin/env python
#
# generate screenshots for QC report
#
#
# Author: Max Pietsch
# King's College London
# maximilian.pietsch@kcl.ac.uk
#
# __________ Initialisation __________
import os
import mrtrix3
import numpy as np
import matplotlib
from collections import defaultdict
B0THRESH=50
_dtdict = {'Int8': '|i1', 'UInt8': '|u1', 'Int16': '=i2', 'UInt16': '=u2', 'Int16LE': '<i2', 'UInt16LE': '<u2', 'Int16BE': '>i2', 'UInt16BE': '>u2', 'Int32': '=i4', 'UInt32': '=u4', 'Int32LE': '<i4', 'UInt32LE': '<u4', 'Int32BE': '>i4', 'UInt32BE': '>u4', 'Float32': '=f4', 'Float32LE': '<f4', 'Float32BE': '>f4', 'Float64': '=f8', 'Float64LE': '<f8', 'Float64BE': '>f8', 'CFloat32': '=c8', 'CFloat32LE': '<c8', 'CFloat32BE': '>c8', 'CFloat64': '=c16', 'CFloat64LE': '<c16', 'CFloat64BE': '>c16'}
class Image(object):
'''
Adapted from https://github.com/dchristiaens/mrtrix3-pyio.git
Copyright (c) 2017 - Daan Christiaens (daan.christiaens@gmail.com)
'''
def __init__(self, filename):
self.data = None
self.vox = ()
self.transform = np.eye(4)
self.grad = None
self.header = defaultdict(list)
with open(filename, 'rt', encoding='latin-1') as f:
fl = ''
tr_count = 0
while fl != 'END':
fl = f.readline().strip()
if fl.startswith('dim'):
imsize = tuple(map(int, fl.split(':')[1].strip().split(',')))
elif fl.startswith('vox'):
self.vox = fl.split(':')[1].strip().split(',')
elif fl.startswith('layout'):
layout = fl.split(':')[1].strip().split(',')
elif fl.startswith('datatype'):
dtstr = fl.split(':')[1].strip()
dt = np.dtype(_dtdict.get(dtstr, 'u1'))
elif fl.startswith('file'):
offset = int(fl.split('.')[1].strip())
elif fl.startswith('transform'):
self.transform[tr_count, :] = np.array(fl.split(':')[1].strip().split(','), dtype=float)
tr_count = tr_count + 1
elif fl.startswith('dw_scheme'):
gbrow = np.array(fl.split(':')[1].strip().split(','), dtype=float)
if self.grad is None:
self.grad = gbrow
else:
self.grad = np.vstack([self.grad, gbrow])
elif ':' in fl:
k = fl.split(':')[0]
self.header[k].append(fl[len(k) + 1:].strip())
# read image data
with open(filename, 'rb') as f:
f.seek(offset, 0)
image = np.fromfile(file=f, dtype=dt)
if dtstr == 'Bit':
image = np.unpackbits(image)
s, o = self._layout_to_strides(layout, imsize, dt)
self.data = np.ndarray(shape=imsize, dtype=dt, buffer=image, strides=s, offset=o)
def _layout_to_strides(self, layout, size, dtype):
strides = [0 for l in layout]
stride, offset = int(dtype.itemsize), 0
for dim in sorted(range(len(layout)), key=lambda k: int(layout[k][1:])):
if layout[dim][0] is '-':
strides[dim] = -stride
offset += (size[dim]-1) * stride
else:
strides[dim] = stride
stride *= size[dim]
return strides, offset
def amp2l2normspectrum(amp, bs, dir='.'):
from mrtrix3 import app, run
files = []
for ib, b in enumerate(bs):
if ib ==0:
if float(bs[0]) > B0THRESH:
app.warn("amp2mssh expected b=0 as lowest b-value, got %s" % b)
else:
run.command('dwiextract ' + amp + ' -shell ' + b +
' - | mrmath - mean -axis 3 - | mrconvert - -axes 0:2,-1 ' +
dir + '/l2_spectrum_0.mif -datatype float32 -stride 0,1,2,3')
files.append(dir + '/l2_spectrum_0.mif')
continue
run.command('amp2sh '+amp+' -shell ' + b + ' - | sh2power -spectrum - - | mrcalc - -sqrt - | mrconvert - ' + dir+'/l2_spectrum_' + b + '.mif -datatype float32 -stride 0,1,2,3')
files.append(dir+'/l2_spectrum_' + b + '.mif')
return files
def plot(image_paths, orientations_slices, bs_ls, vmin_max=None, col_labels=None, figsize=(16, 8), dpi=140,
axes_pad=0.05, colbar=True, showmask=True, label_fontsize=8,
maskstyles={'colors': 'y', 'alpha': 0.5, 'levels': [0.5 + 1e-16, 1e16], 'linewidths':0.5, 'cmap': None}):
from mrtrix3 import app
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import AxesGrid
orientations = len(orientations_slices)
amps = len(image_paths)
cols = amps * orientations
mask = Image('mask.mif')
vox = mask.vox
mask = mask.data
if col_labels is None:
col_labels = [''] * cols
assert len(col_labels) == cols
images = [[] for _ in range(len(image_paths))]
row_labels = []
rows = 0
for icol, col in enumerate(image_paths):
for ib, (b, ls) in enumerate(bs_ls):
app.console(str([icol, col[ib]]))
images[icol] += [Image(col[ib]).data[..., l // 2] for l in ls]
# app.console(str([b, ls, col[ib]]))
if icol == 0:
row_labels += ['b=%s, l=%i' % (b, l) for l in ls]
rows += len(ls)
plt.close()
fig = plt.figure(figsize=figsize, dpi=dpi)
grid = AxesGrid(fig,
111,
nrows_ncols=(rows, cols),
axes_pad=axes_pad,
share_all=True,
label_mode="L",
cbar_location="right",
cbar_mode="edge" if colbar else None,
cbar_size="7%" if cols > 2 else "3%",
cbar_pad=("2%" if cols > 2 else "1%") if colbar else None)
if vmin_max is None:
vmin = [np.inf for _ in range(rows)]
vmax = [-np.inf for _ in range(rows)]
for iamp in range(amps):
for irow in range(rows):
lo, hi = np.percentile(images[iamp][irow][mask > 0.5], (1, 98))
hi *= 1.1
assert np.squeeze([lo]).ndim == 0, lo
assert np.squeeze([hi]).ndim == 0, hi
vmin[irow] = min(vmin[irow], lo)
vmax[irow] = max(vmax[irow], hi)
else:
vmin, vmax = vmin_max
assert len(vmin) == rows, (len(vmin), rows)
assert len(vmax) == rows, (len(vmax), rows)
for islices, slices in enumerate(orientations_slices):
for iamp in range(amps):
icol = iamp * orientations + islices
for irow in range(rows):
app.console(str([iamp, irow, icol]))
ims = [images[iamp][irow][sl] for sl in slices]
spacing = [[vox[i] for i in np.where(np.array(im.shape) > 1)[0]] for im in ims ]
is_axial = np.all([[[0, 1] == np.where(np.array(im.shape) > 1)[0].tolist()] for im in ims ])
for sp in spacing[1:]:
if len(sp) != 2:
raise MrtrixError('slice must be 2D, got: ' + str(sp))
if not np.allclose(np.array([float(s) for s in sp]), np.array([float(s) for s in spacing[0]])):
app.console('error:' + str(spacing))
raise MrtrixError("multiple voxel spacings in one image not supported, split them across orientations")
cond_rot = lambda X: np.rot90(X) if is_axial else np.rot90(X)
ims = [cond_rot(np.squeeze(im)) for im in ims]
img = np.squeeze(np.concatenate(ims, axis=1))
extent = [0, img.shape[1]*float(spacing[0][1]), 0, img.shape[0]*float(spacing[0][0])] # TODO untested
ax = grid[irow * cols + icol]
p = ax.imshow(img, origin='upper', cmap='gray', interpolation='none', vmin=vmin[irow], vmax=vmax[irow], extent=extent)
if showmask:
_ = ax.contour(np.concatenate([cond_rot(np.squeeze(mask[sl])) for sl in slices], axis=1), origin='upper', **maskstyles, extent=extent)
ax.set_yticks([])
ax.set_xticks([])
ax.grid(False)
if icol == cols - 1:
cbar = ax.cax.colorbar(p, drawedges=False)
if col_labels and irow == 0:
ax.set_title(col_labels[icol], fontdict={'fontsize': label_fontsize})
if row_labels and icol == 0:
ax.set_ylabel(row_labels[irow], fontdict={'fontsize': label_fontsize})
return fig, (vmin, vmax)
def usage(cmdline):
from mrtrix3 import app
cmdline.set_author('Max Pietsch (maximilian.pietsch@kcl.ac.uk)')
cmdline.set_synopsis('generate screenshots of l2 norm of shell-specific SH coefficients')
cmdline.add_description('part of dStripe')
cmdline.add_argument('amp', nargs='+', help='diffusion MRI series')
cmdline.add_argument('plotpath', help='output plot')
cmdline.add_argument('-mask', help='mask image')
cmdline.add_argument('-column_labels', help='comma separated list of column titles, one for each input dMRI image')
cmdline.add_argument('-orientation', help='sag, ax')
cmdline.add_argument('-croptomask', help='boolean')
cmdline.add_argument('-nslices', help='int')
cmdline.add_argument('-lmax', help='int')
cmdline.add_argument('-bmin_nonl0', help='float')
# app.add_dwgrad_import_options(cmdline) # TODO
def execute(): # pylint: disable=unused-variable
from mrtrix3 import MRtrixError
from mrtrix3 import app, image, path, run
import shutil, os
# ________ check inputs
def check_input_path(path):
if not os.path.isfile(path):
raise MRtrixError('input path not found: ' + path)
for amp in app.ARGS.amp:
check_input_path(amp)
dwi_sizes = None
for amp in app.ARGS.amp:
header = image.Header(amp)
if dwi_sizes is None:
dwi_sizes = header.size()
bs = image.mrinfo(amp, '-shell_bvalues').split()
else:
pass # todo check input
nshells = len(bs)
bsizes = image.mrinfo(amp, '-shell_sizes').split()
assert len(bs) > 1 and len(bs) == nshells, (bs, nshells)
bs_int = [int(round(float(b) / 500., 1) * 500) for b in bs]
assert len(set(bs_int)) == nshells, ("FIXME rounding issue?", bs, bs_int)
if float(bs[0]) > B0THRESH:
raise MRtrixError("expected inner shell close to b=0, got b=" + str(bs[0]))
app.console('b values: ' + ','.join(bs))
app.console('b sizes: ' + ','.join(bsizes))
app.make_scratch_dir()
app.goto_scratch_dir()
nslices = 5
if app.ARGS.nslices is not None:
nslices = int(app.ARGS.nslices)
lmax = 2
if app.ARGS.lmax is not None:
lmax = int(app.ARGS.lmax)
assert lmax % 2 == 0, lmax
bmin_nonl0 = 2000
if app.ARGS.bmin_nonl0 is not None:
bmin_nonl0 = float(app.ARGS.bmin_nonl0)
mask = app.ARGS.mask
showmask = False
if mask is not None:
showmask=True
check_input_path(mask)
run.command('mrconvert ' + mask + ' mask.mif -stride 0,1,2 -datatype float32')
else:
run.command('dwi2mask ' + app.ARGS.amp[0] + ' - | mrconvert - -stride 0,1,2 -datatype float32 mask.mif')
if app.ARGS.croptomask:
amps = []
for iamp, amp in enumerate(app.ARGS.amp):
run.command('mrgrid ' + amp + ' crop -mask mask.mif -axis 2 0,0 amp_%i.mif' % iamp)
amps.append('amp_%i.mif' % iamp)
app.ARGS.amp = amps
run.command('mrgrid mask.mif crop -mask mask.mif -axis 2 0,0 mask_cropped.mif')
run.function(os.remove, 'mask.mif')
run.function(shutil.move, 'mask_cropped.mif', 'mask.mif')
header = image.Header(amps[0])
dwi_sizes = header.size()
orientation = app.ARGS.orientation
if orientation is None or orientation.startswith('sag'):
slices = [[(slice(sag, sag + 1), slice(None), slice(None)) for sag in
np.floor(dwi_sizes[0] * np.linspace(0.25, 0.75, nslices)).astype(np.int)]]
elif orientation.startswith('ax'):
slices = [[(slice(None), slice(None), slice(ax, ax + 1)) for ax in
np.floor(dwi_sizes[2] * np.linspace(0.25, 0.75, nslices)).astype(np.int)]]
elif orientation.startswith('cor'):
slices = [[(slice(None), slice(cor, cor + 1), slice(None)) for cor in
np.floor(dwi_sizes[1] * np.linspace(0.25, 0.75, nslices)).astype(np.int)]]
else:
raise MRtrixError("orientation not understood: " + str(orientation))
app.console('slices: '+ str(slices))
bs_ls = [(b, [0] if float(b) < bmin_nonl0 else [0, lmax]) for b in bs] # list(range(0, lmax+2, 2))
column_labels = app.ARGS.column_labels
if column_labels:
column_labels = column_labels.split(',')
if len(column_labels) != len(app.ARGS.amp):
raise MRtrixError('one column label per input dMRI image required')
col_labels = []
for lbl in column_labels:
col_labels += [lbl] + ['' for _ in range(len(slices) - 1)]
l2_norms = []
for iamp, amp in enumerate(app.ARGS.amp):
os.makedirs(os.path.join(run.shared.get_scratch_dir(), str(iamp)))
l2_norms.append(amp2l2normspectrum(amp, bs, os.path.join(run.shared.get_scratch_dir(), str(iamp))))
fig, vmin_max = plot(l2_norms, slices, bs_ls, showmask=showmask, col_labels=col_labels)
app.console('vmin_max:' + str(vmin_max))
app.console('writing plot to ' + path.from_user(app.ARGS.plotpath))
fig.savefig(path.from_user(app.ARGS.plotpath), dpi=200, bbox_inches='tight', pad_inches=0)
np.savetxt(path.from_user(app.ARGS.plotpath)+'.vmin_max', np.array(vmin_max))
# Execute the script
mrtrix3.execute()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment