Last active
October 13, 2020 15:43
-
-
Save maxpietsch/9aabbee5bac3634b7fc4aede25237880 to your computer and use it in GitHub Desktop.
generate screenshots for QC report
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# | |
# generate screenshots for QC report | |
# | |
# | |
# Author: Max Pietsch | |
# King's College London | |
# maximilian.pietsch@kcl.ac.uk | |
# | |
# __________ Initialisation __________ | |
import os | |
import mrtrix3 | |
import numpy as np | |
import matplotlib | |
from collections import defaultdict | |
B0THRESH=50 | |
_dtdict = {'Int8': '|i1', 'UInt8': '|u1', 'Int16': '=i2', 'UInt16': '=u2', 'Int16LE': '<i2', 'UInt16LE': '<u2', 'Int16BE': '>i2', 'UInt16BE': '>u2', 'Int32': '=i4', 'UInt32': '=u4', 'Int32LE': '<i4', 'UInt32LE': '<u4', 'Int32BE': '>i4', 'UInt32BE': '>u4', 'Float32': '=f4', 'Float32LE': '<f4', 'Float32BE': '>f4', 'Float64': '=f8', 'Float64LE': '<f8', 'Float64BE': '>f8', 'CFloat32': '=c8', 'CFloat32LE': '<c8', 'CFloat32BE': '>c8', 'CFloat64': '=c16', 'CFloat64LE': '<c16', 'CFloat64BE': '>c16'} | |
class Image(object): | |
''' | |
Adapted from https://github.com/dchristiaens/mrtrix3-pyio.git | |
Copyright (c) 2017 - Daan Christiaens (daan.christiaens@gmail.com) | |
''' | |
def __init__(self, filename): | |
self.data = None | |
self.vox = () | |
self.transform = np.eye(4) | |
self.grad = None | |
self.header = defaultdict(list) | |
with open(filename, 'rt', encoding='latin-1') as f: | |
fl = '' | |
tr_count = 0 | |
while fl != 'END': | |
fl = f.readline().strip() | |
if fl.startswith('dim'): | |
imsize = tuple(map(int, fl.split(':')[1].strip().split(','))) | |
elif fl.startswith('vox'): | |
self.vox = fl.split(':')[1].strip().split(',') | |
elif fl.startswith('layout'): | |
layout = fl.split(':')[1].strip().split(',') | |
elif fl.startswith('datatype'): | |
dtstr = fl.split(':')[1].strip() | |
dt = np.dtype(_dtdict.get(dtstr, 'u1')) | |
elif fl.startswith('file'): | |
offset = int(fl.split('.')[1].strip()) | |
elif fl.startswith('transform'): | |
self.transform[tr_count, :] = np.array(fl.split(':')[1].strip().split(','), dtype=float) | |
tr_count = tr_count + 1 | |
elif fl.startswith('dw_scheme'): | |
gbrow = np.array(fl.split(':')[1].strip().split(','), dtype=float) | |
if self.grad is None: | |
self.grad = gbrow | |
else: | |
self.grad = np.vstack([self.grad, gbrow]) | |
elif ':' in fl: | |
k = fl.split(':')[0] | |
self.header[k].append(fl[len(k) + 1:].strip()) | |
# read image data | |
with open(filename, 'rb') as f: | |
f.seek(offset, 0) | |
image = np.fromfile(file=f, dtype=dt) | |
if dtstr == 'Bit': | |
image = np.unpackbits(image) | |
s, o = self._layout_to_strides(layout, imsize, dt) | |
self.data = np.ndarray(shape=imsize, dtype=dt, buffer=image, strides=s, offset=o) | |
def _layout_to_strides(self, layout, size, dtype): | |
strides = [0 for l in layout] | |
stride, offset = int(dtype.itemsize), 0 | |
for dim in sorted(range(len(layout)), key=lambda k: int(layout[k][1:])): | |
if layout[dim][0] is '-': | |
strides[dim] = -stride | |
offset += (size[dim]-1) * stride | |
else: | |
strides[dim] = stride | |
stride *= size[dim] | |
return strides, offset | |
def amp2l2normspectrum(amp, bs, dir='.'): | |
from mrtrix3 import app, run | |
files = [] | |
for ib, b in enumerate(bs): | |
if ib ==0: | |
if float(bs[0]) > B0THRESH: | |
app.warn("amp2mssh expected b=0 as lowest b-value, got %s" % b) | |
else: | |
run.command('dwiextract ' + amp + ' -shell ' + b + | |
' - | mrmath - mean -axis 3 - | mrconvert - -axes 0:2,-1 ' + | |
dir + '/l2_spectrum_0.mif -datatype float32 -stride 0,1,2,3') | |
files.append(dir + '/l2_spectrum_0.mif') | |
continue | |
run.command('amp2sh '+amp+' -shell ' + b + ' - | sh2power -spectrum - - | mrcalc - -sqrt - | mrconvert - ' + dir+'/l2_spectrum_' + b + '.mif -datatype float32 -stride 0,1,2,3') | |
files.append(dir+'/l2_spectrum_' + b + '.mif') | |
return files | |
def plot(image_paths, orientations_slices, bs_ls, vmin_max=None, col_labels=None, figsize=(16, 8), dpi=140, | |
axes_pad=0.05, colbar=True, showmask=True, label_fontsize=8, | |
maskstyles={'colors': 'y', 'alpha': 0.5, 'levels': [0.5 + 1e-16, 1e16], 'linewidths':0.5, 'cmap': None}): | |
from mrtrix3 import app | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.axes_grid1 import AxesGrid | |
orientations = len(orientations_slices) | |
amps = len(image_paths) | |
cols = amps * orientations | |
mask = Image('mask.mif') | |
vox = mask.vox | |
mask = mask.data | |
if col_labels is None: | |
col_labels = [''] * cols | |
assert len(col_labels) == cols | |
images = [[] for _ in range(len(image_paths))] | |
row_labels = [] | |
rows = 0 | |
for icol, col in enumerate(image_paths): | |
for ib, (b, ls) in enumerate(bs_ls): | |
app.console(str([icol, col[ib]])) | |
images[icol] += [Image(col[ib]).data[..., l // 2] for l in ls] | |
# app.console(str([b, ls, col[ib]])) | |
if icol == 0: | |
row_labels += ['b=%s, l=%i' % (b, l) for l in ls] | |
rows += len(ls) | |
plt.close() | |
fig = plt.figure(figsize=figsize, dpi=dpi) | |
grid = AxesGrid(fig, | |
111, | |
nrows_ncols=(rows, cols), | |
axes_pad=axes_pad, | |
share_all=True, | |
label_mode="L", | |
cbar_location="right", | |
cbar_mode="edge" if colbar else None, | |
cbar_size="7%" if cols > 2 else "3%", | |
cbar_pad=("2%" if cols > 2 else "1%") if colbar else None) | |
if vmin_max is None: | |
vmin = [np.inf for _ in range(rows)] | |
vmax = [-np.inf for _ in range(rows)] | |
for iamp in range(amps): | |
for irow in range(rows): | |
lo, hi = np.percentile(images[iamp][irow][mask > 0.5], (1, 98)) | |
hi *= 1.1 | |
assert np.squeeze([lo]).ndim == 0, lo | |
assert np.squeeze([hi]).ndim == 0, hi | |
vmin[irow] = min(vmin[irow], lo) | |
vmax[irow] = max(vmax[irow], hi) | |
else: | |
vmin, vmax = vmin_max | |
assert len(vmin) == rows, (len(vmin), rows) | |
assert len(vmax) == rows, (len(vmax), rows) | |
for islices, slices in enumerate(orientations_slices): | |
for iamp in range(amps): | |
icol = iamp * orientations + islices | |
for irow in range(rows): | |
app.console(str([iamp, irow, icol])) | |
ims = [images[iamp][irow][sl] for sl in slices] | |
spacing = [[vox[i] for i in np.where(np.array(im.shape) > 1)[0]] for im in ims ] | |
is_axial = np.all([[[0, 1] == np.where(np.array(im.shape) > 1)[0].tolist()] for im in ims ]) | |
for sp in spacing[1:]: | |
if len(sp) != 2: | |
raise MrtrixError('slice must be 2D, got: ' + str(sp)) | |
if not np.allclose(np.array([float(s) for s in sp]), np.array([float(s) for s in spacing[0]])): | |
app.console('error:' + str(spacing)) | |
raise MrtrixError("multiple voxel spacings in one image not supported, split them across orientations") | |
cond_rot = lambda X: np.rot90(X) if is_axial else np.rot90(X) | |
ims = [cond_rot(np.squeeze(im)) for im in ims] | |
img = np.squeeze(np.concatenate(ims, axis=1)) | |
extent = [0, img.shape[1]*float(spacing[0][1]), 0, img.shape[0]*float(spacing[0][0])] # TODO untested | |
ax = grid[irow * cols + icol] | |
p = ax.imshow(img, origin='upper', cmap='gray', interpolation='none', vmin=vmin[irow], vmax=vmax[irow], extent=extent) | |
if showmask: | |
_ = ax.contour(np.concatenate([cond_rot(np.squeeze(mask[sl])) for sl in slices], axis=1), origin='upper', **maskstyles, extent=extent) | |
ax.set_yticks([]) | |
ax.set_xticks([]) | |
ax.grid(False) | |
if icol == cols - 1: | |
cbar = ax.cax.colorbar(p, drawedges=False) | |
if col_labels and irow == 0: | |
ax.set_title(col_labels[icol], fontdict={'fontsize': label_fontsize}) | |
if row_labels and icol == 0: | |
ax.set_ylabel(row_labels[irow], fontdict={'fontsize': label_fontsize}) | |
return fig, (vmin, vmax) | |
def usage(cmdline): | |
from mrtrix3 import app | |
cmdline.set_author('Max Pietsch (maximilian.pietsch@kcl.ac.uk)') | |
cmdline.set_synopsis('generate screenshots of l2 norm of shell-specific SH coefficients') | |
cmdline.add_description('part of dStripe') | |
cmdline.add_argument('amp', nargs='+', help='diffusion MRI series') | |
cmdline.add_argument('plotpath', help='output plot') | |
cmdline.add_argument('-mask', help='mask image') | |
cmdline.add_argument('-column_labels', help='comma separated list of column titles, one for each input dMRI image') | |
cmdline.add_argument('-orientation', help='sag, ax') | |
cmdline.add_argument('-croptomask', help='boolean') | |
cmdline.add_argument('-nslices', help='int') | |
cmdline.add_argument('-lmax', help='int') | |
cmdline.add_argument('-bmin_nonl0', help='float') | |
# app.add_dwgrad_import_options(cmdline) # TODO | |
def execute(): # pylint: disable=unused-variable | |
from mrtrix3 import MRtrixError | |
from mrtrix3 import app, image, path, run | |
import shutil, os | |
# ________ check inputs | |
def check_input_path(path): | |
if not os.path.isfile(path): | |
raise MRtrixError('input path not found: ' + path) | |
for amp in app.ARGS.amp: | |
check_input_path(amp) | |
dwi_sizes = None | |
for amp in app.ARGS.amp: | |
header = image.Header(amp) | |
if dwi_sizes is None: | |
dwi_sizes = header.size() | |
bs = image.mrinfo(amp, '-shell_bvalues').split() | |
else: | |
pass # todo check input | |
nshells = len(bs) | |
bsizes = image.mrinfo(amp, '-shell_sizes').split() | |
assert len(bs) > 1 and len(bs) == nshells, (bs, nshells) | |
bs_int = [int(round(float(b) / 500., 1) * 500) for b in bs] | |
assert len(set(bs_int)) == nshells, ("FIXME rounding issue?", bs, bs_int) | |
if float(bs[0]) > B0THRESH: | |
raise MRtrixError("expected inner shell close to b=0, got b=" + str(bs[0])) | |
app.console('b values: ' + ','.join(bs)) | |
app.console('b sizes: ' + ','.join(bsizes)) | |
app.make_scratch_dir() | |
app.goto_scratch_dir() | |
nslices = 5 | |
if app.ARGS.nslices is not None: | |
nslices = int(app.ARGS.nslices) | |
lmax = 2 | |
if app.ARGS.lmax is not None: | |
lmax = int(app.ARGS.lmax) | |
assert lmax % 2 == 0, lmax | |
bmin_nonl0 = 2000 | |
if app.ARGS.bmin_nonl0 is not None: | |
bmin_nonl0 = float(app.ARGS.bmin_nonl0) | |
mask = app.ARGS.mask | |
showmask = False | |
if mask is not None: | |
showmask=True | |
check_input_path(mask) | |
run.command('mrconvert ' + mask + ' mask.mif -stride 0,1,2 -datatype float32') | |
else: | |
run.command('dwi2mask ' + app.ARGS.amp[0] + ' - | mrconvert - -stride 0,1,2 -datatype float32 mask.mif') | |
if app.ARGS.croptomask: | |
amps = [] | |
for iamp, amp in enumerate(app.ARGS.amp): | |
run.command('mrgrid ' + amp + ' crop -mask mask.mif -axis 2 0,0 amp_%i.mif' % iamp) | |
amps.append('amp_%i.mif' % iamp) | |
app.ARGS.amp = amps | |
run.command('mrgrid mask.mif crop -mask mask.mif -axis 2 0,0 mask_cropped.mif') | |
run.function(os.remove, 'mask.mif') | |
run.function(shutil.move, 'mask_cropped.mif', 'mask.mif') | |
header = image.Header(amps[0]) | |
dwi_sizes = header.size() | |
orientation = app.ARGS.orientation | |
if orientation is None or orientation.startswith('sag'): | |
slices = [[(slice(sag, sag + 1), slice(None), slice(None)) for sag in | |
np.floor(dwi_sizes[0] * np.linspace(0.25, 0.75, nslices)).astype(np.int)]] | |
elif orientation.startswith('ax'): | |
slices = [[(slice(None), slice(None), slice(ax, ax + 1)) for ax in | |
np.floor(dwi_sizes[2] * np.linspace(0.25, 0.75, nslices)).astype(np.int)]] | |
elif orientation.startswith('cor'): | |
slices = [[(slice(None), slice(cor, cor + 1), slice(None)) for cor in | |
np.floor(dwi_sizes[1] * np.linspace(0.25, 0.75, nslices)).astype(np.int)]] | |
else: | |
raise MRtrixError("orientation not understood: " + str(orientation)) | |
app.console('slices: '+ str(slices)) | |
bs_ls = [(b, [0] if float(b) < bmin_nonl0 else [0, lmax]) for b in bs] # list(range(0, lmax+2, 2)) | |
column_labels = app.ARGS.column_labels | |
if column_labels: | |
column_labels = column_labels.split(',') | |
if len(column_labels) != len(app.ARGS.amp): | |
raise MRtrixError('one column label per input dMRI image required') | |
col_labels = [] | |
for lbl in column_labels: | |
col_labels += [lbl] + ['' for _ in range(len(slices) - 1)] | |
l2_norms = [] | |
for iamp, amp in enumerate(app.ARGS.amp): | |
os.makedirs(os.path.join(run.shared.get_scratch_dir(), str(iamp))) | |
l2_norms.append(amp2l2normspectrum(amp, bs, os.path.join(run.shared.get_scratch_dir(), str(iamp)))) | |
fig, vmin_max = plot(l2_norms, slices, bs_ls, showmask=showmask, col_labels=col_labels) | |
app.console('vmin_max:' + str(vmin_max)) | |
app.console('writing plot to ' + path.from_user(app.ARGS.plotpath)) | |
fig.savefig(path.from_user(app.ARGS.plotpath), dpi=200, bbox_inches='tight', pad_inches=0) | |
np.savetxt(path.from_user(app.ARGS.plotpath)+'.vmin_max', np.array(vmin_max)) | |
# Execute the script | |
mrtrix3.execute() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment