Last active
May 7, 2018 05:14
-
-
Save robotsorcerer/97256009137a88314a23890ef87c4b4f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import copy | |
import time | |
import random | |
import h5py | |
import h5sparse | |
import logging | |
import numpy as np | |
from scipy.sparse import vstack | |
import matplotlib as mpl | |
mpl.use('qt5agg') | |
# mpl.use('agg') | |
import matplotlib.pyplot as plt | |
# plt.switch_backend('agg') | |
logger = logging.getLogger(__name__) | |
# directories where the masks and dijs are | |
# e.g /path_to_tdrive/Users/.. .. /Lekan/f_dijs | |
# /path_to_tdrive/Users/.. .. /Lekan/f_masks | |
from config import f_masks_dir, f_dijs_dir, save_dir | |
def get_mask(mask, f_masks_dir=None, disp=False): | |
#global f_masks_dir | |
# used to check fields in hdf5 files | |
def see(n): | |
print(n) | |
def exp(x, axis): | |
return np.expand_dims(x, axis=axis) | |
with h5py.File('{}/{}.h5'.format(f_masks_dir, mask), 'r') as mf: | |
if disp: | |
mf.visit(see) | |
masks = mf['oar_ptvs'] | |
logger.debug(' loading patient {}'.format(mask)) | |
keys = [k for k, v in masks.items()] | |
if "structs" in keys: # augmented data | |
struct = masks['structs'].value | |
bundle = dict(struct = struct) | |
else: | |
flu_grp = mf['fluence_grp'] | |
fluence = flu_grp['fluence'].value | |
organ_names = [] | |
for k, _ in masks.items(): | |
if k != "dose": | |
organ_names.append(k) | |
print(organ_names) if disp else None | |
bladder = masks['bladder'].value | |
body = masks['body'].value | |
dose_mask = masks['dose'].value | |
fem_head_lt = masks['fem_head_lt'].value if 'fem_head_lt' in organ_names else masks['fem head lt'].value | |
fem_head_rt = masks['fem_head_rt'].value if 'fem_head_rt' in organ_names else masks['fem head rt'].value | |
ptv = masks['ptv'].value | |
rectum = masks['rectum'].value | |
struct = np.concatenate((exp(bladder, -1), exp(fem_head_lt, -1), \ | |
exp(fem_head_rt, -1), exp(body, -1), | |
exp(ptv, -1), exp(rectum, -1)), axis=-1) | |
struct = struct.transpose(-1, -2, 0, 1) # Organ x Slice x H x W | |
bundle = dict( bladder = bladder, dose_mask = dose_mask, | |
fem_head_lt = fem_head_lt, fem_head_rt = fem_head_rt, | |
ptv = ptv, rectum = rectum, | |
organ_names = organ_names, struct = struct, | |
fluence = fluence, body = body) | |
return bundle | |
class BundleType(object): | |
""" | |
This class bundles many fields, similar to a record or a mutable | |
namedtuple. | |
""" | |
def __init__(self, variables): | |
for var, val in variables.items(): | |
object.__setattr__(self, var, val) | |
# Freeze fields so new ones cannot be set. | |
def __setattr__(self, key, value): | |
if not hasattr(self, key): | |
raise AttributeError("%r has no attribute %s" % (self, key)) | |
object.__setattr__(self, key, value) | |
state = BundleType(OptimState) | |
class NumpyBFGS(object): | |
"""docstring for TorchBFGS""" | |
def __init__(self, case=None, dijs_dir=None, bundle=None): | |
super(NumpyBFGS, self).__init__() | |
self.Dij = None | |
self.case = case | |
self.beams = None | |
self.f_dijs_dir = dijs_dir | |
self.fontdict = {'fontsize':14, 'fontweight':'bold'} | |
max_dose = dict(bladder = 79.0, fem_head_lt = 48.0, fem_head_rt = 48.0, | |
ptv = 100.0, rectum = 79.0, body=76.0) | |
ptv = bundle['ptv'] | |
body = bundle['body'] | |
rectum = bundle['rectum'] | |
bladder = bundle['bladder'] | |
dose_mask = bundle['dose_mask'] | |
fem_head_lt = bundle['fem_head_lt'] | |
fem_head_rt = bundle['fem_head_rt'] | |
rx_dose = np.zeros_like(dose_mask) | |
rx_dose[np.where(ptv)] = max_dose['ptv'] | |
rx_dose[np.where(body)] = max_dose['body'] | |
rx_dose[np.where(rectum)] = max_dose['rectum'] | |
rx_dose[np.where(bladder)] = max_dose['bladder'] | |
rx_dose[np.where(fem_head_lt)] = max_dose['fem_head_lt'] | |
rx_dose[np.where(fem_head_rt)] = max_dose['fem_head_rt'] | |
overPenalty = np.ones_like(dose_mask) | |
overPenalty[np.where(ptv)] = max_dose['ptv'] | |
overPenalty[np.where(body)] = max_dose['body'] | |
overPenalty[np.where(rectum)] = max_dose['rectum'] | |
overPenalty[np.where(bladder)] = max_dose['bladder'] | |
overPenalty[np.where(fem_head_lt)] = max_dose['fem_head_lt'] | |
overPenalty[np.where(fem_head_rt)] = max_dose['fem_head_rt'] | |
underPenalty = np.zeros_like(dose_mask) | |
underPenalty[np.where(ptv)] = 100.0 | |
self.obj_hist = list() | |
self.dose = None | |
self.bundle = bundle | |
self.overPenalty = overPenalty[dose_mask].flatten() | |
self.underPenalty = underPenalty[dose_mask].flatten() | |
self.target_dose = rx_dose[dose_mask].flatten() | |
self.options = {'disp': True, | |
'maxiter': 1500, #Change this to say 15000 during testing | |
'maxfun': 10000, # max # of obj function evals; # change to 10000 during testing | |
'maxls': 20, # max line search param | |
'maxfev': None} | |
self.state = None #BundleType(OptimState) | |
def get_dij(self, beams, disp=None): | |
if not isinstance(beams[0], str): | |
beams = ['{:0>3}'.format(x) for x in beams] # so we can do a proper lookup of Dijs | |
if '360' in beams: | |
beams[beams.index('360')] = '359.9' # fix lookup error in h5 dictionary | |
self.beams = beams | |
# returns voxels x bixels for dij and sij | |
with h5sparse.File('{}/{}.hdf5'.format(self.f_dijs_dir, self.case)) as df: | |
Dij, Sij = None, None | |
for beam in beams: | |
# print('dij for beam ', beam) | |
dij = df['Dij/{}'.format(beam)].value | |
#sij = df['/Sij/{}'.format(beam)].value | |
Dij = vstack((Dij, dij)) | |
print('dij: ', dij.shape) if disp else None | |
#Sij = vstack((Sij, sij)) | |
del dij#, sij # large objects | |
self.Dij = Dij.T | |
del Dij | |
def obj_eval(self, x): | |
# evaluate initial f(x) and df/dx | |
lambder = 1e-10 | |
# TODO: Fix nonnegativity of x | |
self.dose = self.Dij.dot(x) | |
# truncate dose to 0 wherever it's negative | |
oDose = self.dose - self.target_dose | |
uDose = self.dose - self.target_dose | |
f = float(0.5 * self.overPenalty.dot(oDose.clip(0) ** 2) + | |
0.5 * self.underPenalty.dot(uDose.clip(-1e10, 0) ** 2) + | |
lambder* np.linalg.norm(x, ord=1) | |
) | |
g_temp = np.multiply(self.overPenalty, oDose.clip(0)) + \ | |
np.multiply(self.underPenalty, uDose.clip(-1e10, 0)) + \ | |
lambder | |
g = self.Dij.T.dot(g_temp) | |
return f, g | |
def optimize(self, beams, dvh=False): | |
global state | |
# state = self.state | |
maxIter = state.maxIter or 20 | |
maxEval = state.maxEval or maxIter#*1.25 | |
tolFun = state.tolFun or 1e-5 | |
tolX = state.tolX or 1e-9 | |
nCorrection = state.nCorrection or 100 | |
lineSearch = state.lineSearch | |
lineSearchOpts = state.lineSearchOptions | |
learningRate = state.learningRate or 1 | |
isverbose = state.verbose or True | |
state.funcEval = state.funcEval or 0 | |
state.nIter = state.nIter or 0 | |
if isverbose: | |
logger.debug('<optim.lbfgs with numpy on cpu > ...') | |
# evaluate initial f(x) and df/dx | |
tic = time.time() | |
self.get_dij(beams) | |
logger.debug(' Dij from {} took {}' | |
.format(self.f_dijs_dir.rsplit()[-1], time.time() - tic)) | |
self.beams = beams | |
x = np.zeros((self.Dij.shape[-1]), dtype=np.float32) | |
f,g = self.obj_eval(x) | |
#logger.debug('obj_val, {} '.format(f)) | |
self.obj_hist.append(f) | |
currentFuncEval = 1 | |
state.funcEval += 1 | |
p = g.shape[0] | |
# check optimality of initial point | |
state.tmp1 = np.zeros_like(g, dtype=np.float32) #g.new(g.size()).zero_() | |
tmp1 = state.tmp1 | |
tmp1 = abs(copy.deepcopy(g)) #tmp1.copy_(g).abs() | |
if tmp1.sum() <= tolFun: | |
# optimality condition below tolFun | |
logger.debug('optimality condition below tolFun') | |
return x, self.obj_hist | |
if not state.dir_bufs: | |
#reusable buffers for y's and s's, and their histories | |
logger.debug('<creating recyclable direction/step/history buffers>') | |
state.dir_bufs = np.split(np.zeros((nCorrection+1, p), dtype=np.float32), nCorrection+1) #np.zeros((nCorrection+1, 1, p), dtype=np.float32) | |
state.stp_bufs = np.split(np.zeros((nCorrection+1, p), dtype=np.float32), nCorrection+1) #np.zeros((nCorrection+1, 1, p), dtype=np.float32) | |
for i in range(len(state.dir_bufs)): | |
state.dir_bufs[i] = state.dir_bufs[i].squeeze() | |
state.stp_bufs[i] = state.stp_bufs[i].squeeze() | |
# variables cached in state (for tracing) | |
d = state.d | |
t = state.t | |
old_dirs = state.old_dirs | |
old_stps = state.old_stps | |
Hdiag = state.Hdiag | |
g_old = np.zeros_like(g) #state.g_old | |
f_old = state.f_old | |
#optimize for a max of maxIter iterations | |
nIter = 0 | |
while nIter < maxIter: | |
#keep track of nb of iterations | |
nIter = nIter + 1 | |
state.nIter = state.nIter + 1 | |
logger.debug('current lbfgs iterate {} out of {}'.format(state.nIter, maxIter)) | |
#----------------------------------------------------------- | |
#- compute gradient descent direction | |
#----------------------------------------------------------- | |
if state.nIter == 1: | |
d = g * -1 # -g | |
old_dirs = [] | |
old_stps = [] | |
Hdiag = 1 | |
else: | |
#do lbfgs update (update memory) | |
y = state.dir_bufs.pop(0) # pop | |
s = state.stp_bufs.pop(0) | |
y = g - g_old | |
s = np.multiply(d, t) | |
ys = y.dot(s) # y*s | |
if ys > 1e-10: | |
# updating memory | |
logger.debug('<updating lbfgs-memory>') | |
if len(old_dirs) == nCorrection: | |
#shift history by one (limited-memory) | |
removed1 = old_dirs.pop(0) | |
removed2 = old_stps.pop(0) | |
state.dir_bufs.append(removed1) | |
state.stp_bufs.append(removed2) | |
# store new direction/step | |
old_dirs.append(s) | |
old_stps.append(y) | |
# update scale of initial Hessian approximation | |
Hdiag = np.divide(ys, y.dot(y.T)) | |
else: | |
# put y and s back into the buffer pool | |
state.dir_bufs.append(y) | |
state.stp_bufs.append(s) | |
# compute the approximate (L-BFGS) inverse Hessian | |
# multiplied by the gradient | |
k = len(old_dirs) | |
#print('len lbfgs buffers: ', k) | |
# need to be accessed element-by-element, so don't re-type tensor: | |
state.ro = np.zeros((nCorrection),dtype=np.float32) | |
ro = state.ro | |
for i in range(k): | |
ro[i] = 1 / np.dot(old_stps[i], old_dirs[i].T) | |
# iteration in L-BFGS loop collapsed to use just one buffer | |
q = tmp1 # reuse tmp1 for the q buffer | |
# need to be accessed element-by-element, so don't re-type tensor: | |
state.al = np.zeros((nCorrection)) | |
al = state.al | |
q = np.multiply(g, -1) | |
# print('old_stps: ', len(old_stps)) | |
# print('old_dirs: ', len(old_dirs)) | |
# print('ro: ', len(ro)) | |
for i in range(k-1,1,-1) : | |
al[i] = old_dirs[i].dot(q) * ro[i] | |
q = -al[i] + old_stps[i] | |
# multiply by initial Hessian | |
r = d # share the same buffer, since we don't need the old d | |
# print('q: ', q) | |
# print('Hdiag: ', Hdiag) | |
r = q * Hdiag # q[1] * Hdiag | |
for i in range(k): | |
be_i = old_stps[i].dot(r) * ro[i] | |
r = al[i] - be_i + old_dirs[i] | |
# final direction is in r/d (same object) | |
g_old = g | |
f_old = f | |
#----------------------------------------------------------- | |
#- compute step length | |
#----------------------------------------------------------- | |
#- directional derivative | |
gtd = g.dot(d) # g * d | |
#- check that progress can be made along that direction | |
if gtd > -tolX: | |
logger.debug('gtd > -tolX; breaking') | |
break | |
# reset initial guess for step size | |
if state.nIter == 1: | |
tmp1 = abs(g) | |
# tmp1.copy_(g).abs() | |
t = min(1,1/tmp1.sum()) * learningRate | |
else: | |
t = learningRate | |
# optional line search: user function | |
lsFuncEval = 0 | |
if lineSearch: # and type(lineSearch) == 'function' then | |
# perform line search, using user function | |
f,g,x,t,lsFuncEval = lineSearch(self.obj_eval,x,t,d,f,g,gtd,lineSearchOpts) | |
self.obj_hist.append(f) | |
else: | |
# no line search, simply move with fixed-step | |
x = t + d | |
if nIter != maxIter: | |
# re-evaluate function only if not in last iteration | |
# the reason we do this: in a stochastic setting, | |
# no use to re-evaluate that function here | |
f,g = self.obj_eval(x) | |
lsFuncEval = 1 | |
self.obj_hist.append(f) | |
# update func eval | |
currentFuncEval += lsFuncEval | |
state.funcEval += lsFuncEval | |
#----------------------------------------------------------- | |
#- check conditions | |
#----------------------------------------------------------- | |
if nIter == maxIter: | |
# no use to run tests | |
logger.debug('reached max number of iterations') | |
break | |
if currentFuncEval >= maxEval: | |
# max nb of function evals | |
logger.debug('max nb of function evals') | |
break | |
tmp1 = abs(g) | |
if tmp1.sum() <= tolFun: | |
# check optimality | |
logger.debug('optimality condition below tolFun') | |
break | |
tmp1 = abs(d * t) | |
# print('tmp1: ', tmp1) | |
# tmp1.copy_(d).mul(t).abs() | |
if tmp1.sum() <= tolX: | |
# step size below tolX | |
logger.debug('step size below tolX') | |
break | |
if abs(f-f_old) < tolX: | |
# function value changing less than tolX | |
logger.debug('function value changing less than tolX') | |
break | |
state.nIter = 0 # reset it | |
# save state | |
state.old_dirs = old_dirs | |
state.old_stps = old_stps | |
state.Hdiag = Hdiag | |
state.g_old = g_old | |
state.f_old = f_old | |
state.t = t | |
state.d = d | |
self.dose = self.Dij.dot(x) | |
self.plot_dvh() if dvh else None | |
# return optimal x, and history of f(x) | |
return x, self.obj_hist#, currentFuncEval | |
def plot_dvh(self): | |
dose_mask=self.bundle['dose_mask'] | |
dosearray = np.zeros_like(dose_mask, dtype=np.float32) | |
dosearray[np.nonzero(dose_mask)]= self.dose | |
slc = dosearray.shape[-1]//2 | |
plt.close('all') | |
f, (ax) = plt.subplots(2, 3, figsize=(16, 12)) | |
ax[0, 0].imshow(dosearray[:,:,slc-6], cmap='jet') | |
ax[0, 0].set_title('Case {} and slice {}\n{} '.format(self.case, slc-6, self.beams), fontdict=self.fontdict) | |
ax[0, 0].xaxis.set_tick_params(labelsize=14) | |
ax[0, 0].yaxis.set_tick_params(labelsize=14) | |
ax[0, 1].imshow(dosearray[:,:,slc-4], cmap='jet') | |
ax[0, 1].set_title('Case {}; and slice {}\n{}'.format(self.case, slc-4, self.beams), fontdict=self.fontdict) | |
ax[0, 1].xaxis.set_tick_params(labelsize=14) | |
ax[0, 1].yaxis.set_tick_params(labelsize=14) | |
xmin, xmax = ax[0, 1].get_xlim() | |
ymin, ymax = ax[0, 1].get_ylim() | |
ax[0, 1].text(xmin+5, ymax-50, 'Cost: %.4f'%(-self.obj_hist[-1]), | |
self.fontdict, backgroundcolor='pink', animated=True, color='k', rasterized=True, | |
size='x-large', style ='oblique', ma='center') | |
ax[0, 2].imshow(dosearray[:,:,slc-1], cmap='jet') | |
ax[0, 2].set_title('Case {}; and slice {}\n{}'.format(self.case, slc-1, self.beams), fontdict=self.fontdict) | |
ax[0, 2].xaxis.set_tick_params(labelsize=14) | |
ax[0, 2].yaxis.set_tick_params(labelsize=14) | |
ax[1, 0].imshow(dosearray[:,:,slc], cmap='jet') | |
ax[1, 0].set_title('Case {}; and slice {}\n{}'.format(self.case, slc, self.beams), fontdict=self.fontdict) | |
ax[1, 0].xaxis.set_tick_params(labelsize=14) | |
ax[1, 0].yaxis.set_tick_params(labelsize=14) | |
ax[1, 1].imshow(dosearray[:,:,slc+2], cmap='jet') | |
ax[1, 1].set_title('Case {}; and slice {}\n{}'.format(self.case, slc+2, self.beams), fontdict=self.fontdict) | |
ax[1, 1].xaxis.set_tick_params(labelsize=14) | |
ax[1, 1].yaxis.set_tick_params(labelsize=14) | |
ax[1, 2].imshow(dosearray[:,:,slc+4], cmap='jet') | |
ax[1, 2].set_title('Case {}; and slice {}\n{}'.format(self.case, slc+4, self.beams), fontdict=self.fontdict) | |
ax[1, 2].xaxis.set_tick_params(labelsize=14) | |
ax[1, 2].yaxis.set_tick_params(labelsize=14) | |
plt.tight_layout() | |
lbfgs_dir = '{}/lbfgs/dose/case_{}'.format(save_dir, self.case) | |
os.makedirs(lbfgs_dir) if not os.path.exists(lbfgs_dir) else None | |
f.savefig('{}/dose_{}_{}.png'.format(lbfgs_dir, slc, random.randint(0, 99)), fontdict=self.fontdict) | |
masks = ['fem_head_rt', 'body', 'bladder', 'rectum', 'ptv', 'fem_head_lt'] | |
f2, ax2 = plt.subplots(1, 1, figsize=(10, 7)) | |
lbfgs_dir = '{}/lbfgs/dvh/case_{}'.format(save_dir, self.case) | |
os.makedirs(lbfgs_dir) if not os.path.exists(lbfgs_dir) else None | |
for sId in masks: | |
mask_idx = np.where(self.bundle[sId]) | |
tot_vox = self.bundle[sId].sum() | |
hist,bins = np.histogram(dosearray[mask_idx].flatten(),bins=500,range=(0,dosearray.max())) | |
temp = (100.-hist.cumsum()*100.0/tot_vox) | |
# temp[0] = 100 if temp[0] < 100 else temp[0] | |
ax2.plot(bins[:-1],temp,label=sId, linewidth=2) | |
ax2.legend(fancybox=True, framealpha=0.5, bbox_to_anchor=(0.6, 0.95), | |
borderaxespad=0., scatterpoints=1, ncol=1, | |
loc=2, fontsize=16) | |
ax2.set_title("LBFGS DVH | Case " + self.case, fontdict=self.fontdict) | |
ax2.set_xlabel('Fractional Dose [Gy]', fontdict=self.fontdict) | |
ax2.set_ylabel('Fractional Volume [%]', fontdict=self.fontdict) | |
f2.savefig('{}/dvh_{}_{}.png'.format(lbfgs_dir, slc, random.randint(0, 99)), fontdict=self.fontdict) | |
def main(case=None, f_masks_dir=None, f_dijs_dir=None): | |
#----------------------------------------------------------------------# | |
# Initializations | |
#----------------------------------------------------------------------# | |
bundle = get_mask(case, f_masks_dir, f_dijs_dir) | |
lbfgs_obj = NumpyBFGS(case=case, dijs_dir=f_dijs_dir, bundle=bundle, | |
dvh=True) | |
OptimState = { | |
'maxIter': 1500, | |
'maxEval': 10000, | |
'tolFun': 1e-4, #Termination tolerance on the first-order optimality | |
'tolX': 1e-9, #Termination tol on progress in terms of func/param changes | |
'lineSearch': None, #A line search function | |
'lineSearchOptions': None, #A line search function | |
'learningRate': 0.01, | |
'nCorrection': 100, | |
'verbose': True, | |
'tmp1': np.array(()), | |
'funcEval': None, | |
'nIter': 0, | |
'dir_bufs': [], | |
'stp_bufs': [], | |
'd': None, | |
't': None, | |
'al': None, | |
'old_dirs': [], | |
'old_stps': [], | |
'Hdiag': 1, | |
'g_old': None, | |
'ro': None, #np.zeros((nCorrection),dtype=cp.float32)), | |
'f_old': None, | |
'nIter': None, | |
} | |
state = BundleType(OptimState) | |
lambder = 1.0 | |
#----------------------------------------------------------------------# | |
beams = np.array([10, 60, 150, 260, 330]) | |
x, fhist, currentf = lbfgs_obj.optimize(state, beams) | |
#----------------------------------------------------------------------# | |
print('optimized x: ', x) | |
print('fhist: ', fhist) | |
print('current: ', currentf) | |
if __name__ == '__main__': | |
tikko =time.time() | |
main(case=args.case, f_masks_dir='/home/lex/tdrive/f_masks', | |
f_dijs_dir='/home/lex/tdrive/f_dijs') | |
tikka =time.time() | |
print('total time taken ', tikka - tikko) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Check out the fluence after just two iterations!