Skip to content

Instantly share code, notes, and snippets.

@robotsorcerer
Last active May 7, 2018 05:14
Show Gist options
  • Save robotsorcerer/97256009137a88314a23890ef87c4b4f to your computer and use it in GitHub Desktop.
Save robotsorcerer/97256009137a88314a23890ef87c4b4f to your computer and use it in GitHub Desktop.
import os
import copy
import time
import random
import h5py
import h5sparse
import logging
import numpy as np
from scipy.sparse import vstack
import matplotlib as mpl
mpl.use('qt5agg')
# mpl.use('agg')
import matplotlib.pyplot as plt
# plt.switch_backend('agg')
logger = logging.getLogger(__name__)
# directories where the masks and dijs are
# e.g /path_to_tdrive/Users/.. .. /Lekan/f_dijs
# /path_to_tdrive/Users/.. .. /Lekan/f_masks
from config import f_masks_dir, f_dijs_dir, save_dir
def get_mask(mask, f_masks_dir=None, disp=False):
#global f_masks_dir
# used to check fields in hdf5 files
def see(n):
print(n)
def exp(x, axis):
return np.expand_dims(x, axis=axis)
with h5py.File('{}/{}.h5'.format(f_masks_dir, mask), 'r') as mf:
if disp:
mf.visit(see)
masks = mf['oar_ptvs']
logger.debug(' loading patient {}'.format(mask))
keys = [k for k, v in masks.items()]
if "structs" in keys: # augmented data
struct = masks['structs'].value
bundle = dict(struct = struct)
else:
flu_grp = mf['fluence_grp']
fluence = flu_grp['fluence'].value
organ_names = []
for k, _ in masks.items():
if k != "dose":
organ_names.append(k)
print(organ_names) if disp else None
bladder = masks['bladder'].value
body = masks['body'].value
dose_mask = masks['dose'].value
fem_head_lt = masks['fem_head_lt'].value if 'fem_head_lt' in organ_names else masks['fem head lt'].value
fem_head_rt = masks['fem_head_rt'].value if 'fem_head_rt' in organ_names else masks['fem head rt'].value
ptv = masks['ptv'].value
rectum = masks['rectum'].value
struct = np.concatenate((exp(bladder, -1), exp(fem_head_lt, -1), \
exp(fem_head_rt, -1), exp(body, -1),
exp(ptv, -1), exp(rectum, -1)), axis=-1)
struct = struct.transpose(-1, -2, 0, 1) # Organ x Slice x H x W
bundle = dict( bladder = bladder, dose_mask = dose_mask,
fem_head_lt = fem_head_lt, fem_head_rt = fem_head_rt,
ptv = ptv, rectum = rectum,
organ_names = organ_names, struct = struct,
fluence = fluence, body = body)
return bundle
class BundleType(object):
"""
This class bundles many fields, similar to a record or a mutable
namedtuple.
"""
def __init__(self, variables):
for var, val in variables.items():
object.__setattr__(self, var, val)
# Freeze fields so new ones cannot be set.
def __setattr__(self, key, value):
if not hasattr(self, key):
raise AttributeError("%r has no attribute %s" % (self, key))
object.__setattr__(self, key, value)
state = BundleType(OptimState)
class NumpyBFGS(object):
"""docstring for TorchBFGS"""
def __init__(self, case=None, dijs_dir=None, bundle=None):
super(NumpyBFGS, self).__init__()
self.Dij = None
self.case = case
self.beams = None
self.f_dijs_dir = dijs_dir
self.fontdict = {'fontsize':14, 'fontweight':'bold'}
max_dose = dict(bladder = 79.0, fem_head_lt = 48.0, fem_head_rt = 48.0,
ptv = 100.0, rectum = 79.0, body=76.0)
ptv = bundle['ptv']
body = bundle['body']
rectum = bundle['rectum']
bladder = bundle['bladder']
dose_mask = bundle['dose_mask']
fem_head_lt = bundle['fem_head_lt']
fem_head_rt = bundle['fem_head_rt']
rx_dose = np.zeros_like(dose_mask)
rx_dose[np.where(ptv)] = max_dose['ptv']
rx_dose[np.where(body)] = max_dose['body']
rx_dose[np.where(rectum)] = max_dose['rectum']
rx_dose[np.where(bladder)] = max_dose['bladder']
rx_dose[np.where(fem_head_lt)] = max_dose['fem_head_lt']
rx_dose[np.where(fem_head_rt)] = max_dose['fem_head_rt']
overPenalty = np.ones_like(dose_mask)
overPenalty[np.where(ptv)] = max_dose['ptv']
overPenalty[np.where(body)] = max_dose['body']
overPenalty[np.where(rectum)] = max_dose['rectum']
overPenalty[np.where(bladder)] = max_dose['bladder']
overPenalty[np.where(fem_head_lt)] = max_dose['fem_head_lt']
overPenalty[np.where(fem_head_rt)] = max_dose['fem_head_rt']
underPenalty = np.zeros_like(dose_mask)
underPenalty[np.where(ptv)] = 100.0
self.obj_hist = list()
self.dose = None
self.bundle = bundle
self.overPenalty = overPenalty[dose_mask].flatten()
self.underPenalty = underPenalty[dose_mask].flatten()
self.target_dose = rx_dose[dose_mask].flatten()
self.options = {'disp': True,
'maxiter': 1500, #Change this to say 15000 during testing
'maxfun': 10000, # max # of obj function evals; # change to 10000 during testing
'maxls': 20, # max line search param
'maxfev': None}
self.state = None #BundleType(OptimState)
def get_dij(self, beams, disp=None):
if not isinstance(beams[0], str):
beams = ['{:0>3}'.format(x) for x in beams] # so we can do a proper lookup of Dijs
if '360' in beams:
beams[beams.index('360')] = '359.9' # fix lookup error in h5 dictionary
self.beams = beams
# returns voxels x bixels for dij and sij
with h5sparse.File('{}/{}.hdf5'.format(self.f_dijs_dir, self.case)) as df:
Dij, Sij = None, None
for beam in beams:
# print('dij for beam ', beam)
dij = df['Dij/{}'.format(beam)].value
#sij = df['/Sij/{}'.format(beam)].value
Dij = vstack((Dij, dij))
print('dij: ', dij.shape) if disp else None
#Sij = vstack((Sij, sij))
del dij#, sij # large objects
self.Dij = Dij.T
del Dij
def obj_eval(self, x):
# evaluate initial f(x) and df/dx
lambder = 1e-10
# TODO: Fix nonnegativity of x
self.dose = self.Dij.dot(x)
# truncate dose to 0 wherever it's negative
oDose = self.dose - self.target_dose
uDose = self.dose - self.target_dose
f = float(0.5 * self.overPenalty.dot(oDose.clip(0) ** 2) +
0.5 * self.underPenalty.dot(uDose.clip(-1e10, 0) ** 2) +
lambder* np.linalg.norm(x, ord=1)
)
g_temp = np.multiply(self.overPenalty, oDose.clip(0)) + \
np.multiply(self.underPenalty, uDose.clip(-1e10, 0)) + \
lambder
g = self.Dij.T.dot(g_temp)
return f, g
def optimize(self, beams, dvh=False):
global state
# state = self.state
maxIter = state.maxIter or 20
maxEval = state.maxEval or maxIter#*1.25
tolFun = state.tolFun or 1e-5
tolX = state.tolX or 1e-9
nCorrection = state.nCorrection or 100
lineSearch = state.lineSearch
lineSearchOpts = state.lineSearchOptions
learningRate = state.learningRate or 1
isverbose = state.verbose or True
state.funcEval = state.funcEval or 0
state.nIter = state.nIter or 0
if isverbose:
logger.debug('<optim.lbfgs with numpy on cpu > ...')
# evaluate initial f(x) and df/dx
tic = time.time()
self.get_dij(beams)
logger.debug(' Dij from {} took {}'
.format(self.f_dijs_dir.rsplit()[-1], time.time() - tic))
self.beams = beams
x = np.zeros((self.Dij.shape[-1]), dtype=np.float32)
f,g = self.obj_eval(x)
#logger.debug('obj_val, {} '.format(f))
self.obj_hist.append(f)
currentFuncEval = 1
state.funcEval += 1
p = g.shape[0]
# check optimality of initial point
state.tmp1 = np.zeros_like(g, dtype=np.float32) #g.new(g.size()).zero_()
tmp1 = state.tmp1
tmp1 = abs(copy.deepcopy(g)) #tmp1.copy_(g).abs()
if tmp1.sum() <= tolFun:
# optimality condition below tolFun
logger.debug('optimality condition below tolFun')
return x, self.obj_hist
if not state.dir_bufs:
#reusable buffers for y's and s's, and their histories
logger.debug('<creating recyclable direction/step/history buffers>')
state.dir_bufs = np.split(np.zeros((nCorrection+1, p), dtype=np.float32), nCorrection+1) #np.zeros((nCorrection+1, 1, p), dtype=np.float32)
state.stp_bufs = np.split(np.zeros((nCorrection+1, p), dtype=np.float32), nCorrection+1) #np.zeros((nCorrection+1, 1, p), dtype=np.float32)
for i in range(len(state.dir_bufs)):
state.dir_bufs[i] = state.dir_bufs[i].squeeze()
state.stp_bufs[i] = state.stp_bufs[i].squeeze()
# variables cached in state (for tracing)
d = state.d
t = state.t
old_dirs = state.old_dirs
old_stps = state.old_stps
Hdiag = state.Hdiag
g_old = np.zeros_like(g) #state.g_old
f_old = state.f_old
#optimize for a max of maxIter iterations
nIter = 0
while nIter < maxIter:
#keep track of nb of iterations
nIter = nIter + 1
state.nIter = state.nIter + 1
logger.debug('current lbfgs iterate {} out of {}'.format(state.nIter, maxIter))
#-----------------------------------------------------------
#- compute gradient descent direction
#-----------------------------------------------------------
if state.nIter == 1:
d = g * -1 # -g
old_dirs = []
old_stps = []
Hdiag = 1
else:
#do lbfgs update (update memory)
y = state.dir_bufs.pop(0) # pop
s = state.stp_bufs.pop(0)
y = g - g_old
s = np.multiply(d, t)
ys = y.dot(s) # y*s
if ys > 1e-10:
# updating memory
logger.debug('<updating lbfgs-memory>')
if len(old_dirs) == nCorrection:
#shift history by one (limited-memory)
removed1 = old_dirs.pop(0)
removed2 = old_stps.pop(0)
state.dir_bufs.append(removed1)
state.stp_bufs.append(removed2)
# store new direction/step
old_dirs.append(s)
old_stps.append(y)
# update scale of initial Hessian approximation
Hdiag = np.divide(ys, y.dot(y.T))
else:
# put y and s back into the buffer pool
state.dir_bufs.append(y)
state.stp_bufs.append(s)
# compute the approximate (L-BFGS) inverse Hessian
# multiplied by the gradient
k = len(old_dirs)
#print('len lbfgs buffers: ', k)
# need to be accessed element-by-element, so don't re-type tensor:
state.ro = np.zeros((nCorrection),dtype=np.float32)
ro = state.ro
for i in range(k):
ro[i] = 1 / np.dot(old_stps[i], old_dirs[i].T)
# iteration in L-BFGS loop collapsed to use just one buffer
q = tmp1 # reuse tmp1 for the q buffer
# need to be accessed element-by-element, so don't re-type tensor:
state.al = np.zeros((nCorrection))
al = state.al
q = np.multiply(g, -1)
# print('old_stps: ', len(old_stps))
# print('old_dirs: ', len(old_dirs))
# print('ro: ', len(ro))
for i in range(k-1,1,-1) :
al[i] = old_dirs[i].dot(q) * ro[i]
q = -al[i] + old_stps[i]
# multiply by initial Hessian
r = d # share the same buffer, since we don't need the old d
# print('q: ', q)
# print('Hdiag: ', Hdiag)
r = q * Hdiag # q[1] * Hdiag
for i in range(k):
be_i = old_stps[i].dot(r) * ro[i]
r = al[i] - be_i + old_dirs[i]
# final direction is in r/d (same object)
g_old = g
f_old = f
#-----------------------------------------------------------
#- compute step length
#-----------------------------------------------------------
#- directional derivative
gtd = g.dot(d) # g * d
#- check that progress can be made along that direction
if gtd > -tolX:
logger.debug('gtd > -tolX; breaking')
break
# reset initial guess for step size
if state.nIter == 1:
tmp1 = abs(g)
# tmp1.copy_(g).abs()
t = min(1,1/tmp1.sum()) * learningRate
else:
t = learningRate
# optional line search: user function
lsFuncEval = 0
if lineSearch: # and type(lineSearch) == 'function' then
# perform line search, using user function
f,g,x,t,lsFuncEval = lineSearch(self.obj_eval,x,t,d,f,g,gtd,lineSearchOpts)
self.obj_hist.append(f)
else:
# no line search, simply move with fixed-step
x = t + d
if nIter != maxIter:
# re-evaluate function only if not in last iteration
# the reason we do this: in a stochastic setting,
# no use to re-evaluate that function here
f,g = self.obj_eval(x)
lsFuncEval = 1
self.obj_hist.append(f)
# update func eval
currentFuncEval += lsFuncEval
state.funcEval += lsFuncEval
#-----------------------------------------------------------
#- check conditions
#-----------------------------------------------------------
if nIter == maxIter:
# no use to run tests
logger.debug('reached max number of iterations')
break
if currentFuncEval >= maxEval:
# max nb of function evals
logger.debug('max nb of function evals')
break
tmp1 = abs(g)
if tmp1.sum() <= tolFun:
# check optimality
logger.debug('optimality condition below tolFun')
break
tmp1 = abs(d * t)
# print('tmp1: ', tmp1)
# tmp1.copy_(d).mul(t).abs()
if tmp1.sum() <= tolX:
# step size below tolX
logger.debug('step size below tolX')
break
if abs(f-f_old) < tolX:
# function value changing less than tolX
logger.debug('function value changing less than tolX')
break
state.nIter = 0 # reset it
# save state
state.old_dirs = old_dirs
state.old_stps = old_stps
state.Hdiag = Hdiag
state.g_old = g_old
state.f_old = f_old
state.t = t
state.d = d
self.dose = self.Dij.dot(x)
self.plot_dvh() if dvh else None
# return optimal x, and history of f(x)
return x, self.obj_hist#, currentFuncEval
def plot_dvh(self):
dose_mask=self.bundle['dose_mask']
dosearray = np.zeros_like(dose_mask, dtype=np.float32)
dosearray[np.nonzero(dose_mask)]= self.dose
slc = dosearray.shape[-1]//2
plt.close('all')
f, (ax) = plt.subplots(2, 3, figsize=(16, 12))
ax[0, 0].imshow(dosearray[:,:,slc-6], cmap='jet')
ax[0, 0].set_title('Case {} and slice {}\n{} '.format(self.case, slc-6, self.beams), fontdict=self.fontdict)
ax[0, 0].xaxis.set_tick_params(labelsize=14)
ax[0, 0].yaxis.set_tick_params(labelsize=14)
ax[0, 1].imshow(dosearray[:,:,slc-4], cmap='jet')
ax[0, 1].set_title('Case {}; and slice {}\n{}'.format(self.case, slc-4, self.beams), fontdict=self.fontdict)
ax[0, 1].xaxis.set_tick_params(labelsize=14)
ax[0, 1].yaxis.set_tick_params(labelsize=14)
xmin, xmax = ax[0, 1].get_xlim()
ymin, ymax = ax[0, 1].get_ylim()
ax[0, 1].text(xmin+5, ymax-50, 'Cost: %.4f'%(-self.obj_hist[-1]),
self.fontdict, backgroundcolor='pink', animated=True, color='k', rasterized=True,
size='x-large', style ='oblique', ma='center')
ax[0, 2].imshow(dosearray[:,:,slc-1], cmap='jet')
ax[0, 2].set_title('Case {}; and slice {}\n{}'.format(self.case, slc-1, self.beams), fontdict=self.fontdict)
ax[0, 2].xaxis.set_tick_params(labelsize=14)
ax[0, 2].yaxis.set_tick_params(labelsize=14)
ax[1, 0].imshow(dosearray[:,:,slc], cmap='jet')
ax[1, 0].set_title('Case {}; and slice {}\n{}'.format(self.case, slc, self.beams), fontdict=self.fontdict)
ax[1, 0].xaxis.set_tick_params(labelsize=14)
ax[1, 0].yaxis.set_tick_params(labelsize=14)
ax[1, 1].imshow(dosearray[:,:,slc+2], cmap='jet')
ax[1, 1].set_title('Case {}; and slice {}\n{}'.format(self.case, slc+2, self.beams), fontdict=self.fontdict)
ax[1, 1].xaxis.set_tick_params(labelsize=14)
ax[1, 1].yaxis.set_tick_params(labelsize=14)
ax[1, 2].imshow(dosearray[:,:,slc+4], cmap='jet')
ax[1, 2].set_title('Case {}; and slice {}\n{}'.format(self.case, slc+4, self.beams), fontdict=self.fontdict)
ax[1, 2].xaxis.set_tick_params(labelsize=14)
ax[1, 2].yaxis.set_tick_params(labelsize=14)
plt.tight_layout()
lbfgs_dir = '{}/lbfgs/dose/case_{}'.format(save_dir, self.case)
os.makedirs(lbfgs_dir) if not os.path.exists(lbfgs_dir) else None
f.savefig('{}/dose_{}_{}.png'.format(lbfgs_dir, slc, random.randint(0, 99)), fontdict=self.fontdict)
masks = ['fem_head_rt', 'body', 'bladder', 'rectum', 'ptv', 'fem_head_lt']
f2, ax2 = plt.subplots(1, 1, figsize=(10, 7))
lbfgs_dir = '{}/lbfgs/dvh/case_{}'.format(save_dir, self.case)
os.makedirs(lbfgs_dir) if not os.path.exists(lbfgs_dir) else None
for sId in masks:
mask_idx = np.where(self.bundle[sId])
tot_vox = self.bundle[sId].sum()
hist,bins = np.histogram(dosearray[mask_idx].flatten(),bins=500,range=(0,dosearray.max()))
temp = (100.-hist.cumsum()*100.0/tot_vox)
# temp[0] = 100 if temp[0] < 100 else temp[0]
ax2.plot(bins[:-1],temp,label=sId, linewidth=2)
ax2.legend(fancybox=True, framealpha=0.5, bbox_to_anchor=(0.6, 0.95),
borderaxespad=0., scatterpoints=1, ncol=1,
loc=2, fontsize=16)
ax2.set_title("LBFGS DVH | Case " + self.case, fontdict=self.fontdict)
ax2.set_xlabel('Fractional Dose [Gy]', fontdict=self.fontdict)
ax2.set_ylabel('Fractional Volume [%]', fontdict=self.fontdict)
f2.savefig('{}/dvh_{}_{}.png'.format(lbfgs_dir, slc, random.randint(0, 99)), fontdict=self.fontdict)
def main(case=None, f_masks_dir=None, f_dijs_dir=None):
#----------------------------------------------------------------------#
# Initializations
#----------------------------------------------------------------------#
bundle = get_mask(case, f_masks_dir, f_dijs_dir)
lbfgs_obj = NumpyBFGS(case=case, dijs_dir=f_dijs_dir, bundle=bundle,
dvh=True)
OptimState = {
'maxIter': 1500,
'maxEval': 10000,
'tolFun': 1e-4, #Termination tolerance on the first-order optimality
'tolX': 1e-9, #Termination tol on progress in terms of func/param changes
'lineSearch': None, #A line search function
'lineSearchOptions': None, #A line search function
'learningRate': 0.01,
'nCorrection': 100,
'verbose': True,
'tmp1': np.array(()),
'funcEval': None,
'nIter': 0,
'dir_bufs': [],
'stp_bufs': [],
'd': None,
't': None,
'al': None,
'old_dirs': [],
'old_stps': [],
'Hdiag': 1,
'g_old': None,
'ro': None, #np.zeros((nCorrection),dtype=cp.float32)),
'f_old': None,
'nIter': None,
}
state = BundleType(OptimState)
lambder = 1.0
#----------------------------------------------------------------------#
beams = np.array([10, 60, 150, 260, 330])
x, fhist, currentf = lbfgs_obj.optimize(state, beams)
#----------------------------------------------------------------------#
print('optimized x: ', x)
print('fhist: ', fhist)
print('current: ', currentf)
if __name__ == '__main__':
tikko =time.time()
main(case=args.case, f_masks_dir='/home/lex/tdrive/f_masks',
f_dijs_dir='/home/lex/tdrive/f_dijs')
tikka =time.time()
print('total time taken ', tikka - tikko)
@robotsorcerer
Copy link
Author

robotsorcerer commented May 7, 2018

Check out the fluence after just two iterations!

figure_1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment