import os
import copy
import time
import random
import h5py
import h5sparse
import logging
import numpy as np
from scipy.sparse import vstack
import matplotlib as mpl
# mpl.use('agg')
import matplotlib.pyplot as plt
# plt.switch_backend('agg')
logger = logging.getLogger(__name__)
# directories where the masks and dijs are
# e.g /path_to_tdrive/Users/.. .. /Lekan/f_dijs
# /path_to_tdrive/Users/.. .. /Lekan/f_masks
from config import f_masks_dir, f_dijs_dir, save_dir
def get_mask(mask, f_masks_dir=None, disp=False):
#global f_masks_dir
# used to check fields in hdf5 files
def see(n):
def exp(x, axis):
return np.expand_dims(x, axis=axis)
with h5py.File('{}/{}.h5'.format(f_masks_dir, mask), 'r') as mf:
if disp:
masks = mf['oar_ptvs']
logger.debug(' loading patient {}'.format(mask))
keys = [k for k, v in masks.items()]
if "structs" in keys: # augmented data
struct = masks['structs'].value
bundle = dict(struct = struct)
flu_grp = mf['fluence_grp']
fluence = flu_grp['fluence'].value
organ_names = []
for k, _ in masks.items():
if k != "dose":
print(organ_names) if disp else None
bladder = masks['bladder'].value
body = masks['body'].value
dose_mask = masks['dose'].value
fem_head_lt = masks['fem_head_lt'].value if 'fem_head_lt' in organ_names else masks['fem head lt'].value
fem_head_rt = masks['fem_head_rt'].value if 'fem_head_rt' in organ_names else masks['fem head rt'].value
ptv = masks['ptv'].value
rectum = masks['rectum'].value
struct = np.concatenate((exp(bladder, -1), exp(fem_head_lt, -1), \
exp(fem_head_rt, -1), exp(body, -1),
exp(ptv, -1), exp(rectum, -1)), axis=-1)
struct = struct.transpose(-1, -2, 0, 1) # Organ x Slice x H x W
bundle = dict( bladder = bladder, dose_mask = dose_mask,
fem_head_lt = fem_head_lt, fem_head_rt = fem_head_rt,
ptv = ptv, rectum = rectum,
organ_names = organ_names, struct = struct,
fluence = fluence, body = body)
return bundle
class BundleType(object):
This class bundles many fields, similar to a record or a mutable
def __init__(self, variables):
for var, val in variables.items():
object.__setattr__(self, var, val)
# Freeze fields so new ones cannot be set.
def __setattr__(self, key, value):
if not hasattr(self, key):
raise AttributeError("%r has no attribute %s" % (self, key))
object.__setattr__(self, key, value)
state = BundleType(OptimState)
class NumpyBFGS(object):
"""docstring for TorchBFGS"""
def __init__(self, case=None, dijs_dir=None, bundle=None):
super(NumpyBFGS, self).__init__()
self.Dij = None = case
self.beams = None
self.f_dijs_dir = dijs_dir
self.fontdict = {'fontsize':14, 'fontweight':'bold'}
max_dose = dict(bladder = 79.0, fem_head_lt = 48.0, fem_head_rt = 48.0,
ptv = 100.0, rectum = 79.0, body=76.0)
ptv = bundle['ptv']
body = bundle['body']
rectum = bundle['rectum']
bladder = bundle['bladder']
dose_mask = bundle['dose_mask']
fem_head_lt = bundle['fem_head_lt']
fem_head_rt = bundle['fem_head_rt']
rx_dose = np.zeros_like(dose_mask)
rx_dose[np.where(ptv)] = max_dose['ptv']
rx_dose[np.where(body)] = max_dose['body']
rx_dose[np.where(rectum)] = max_dose['rectum']
rx_dose[np.where(bladder)] = max_dose['bladder']
rx_dose[np.where(fem_head_lt)] = max_dose['fem_head_lt']
rx_dose[np.where(fem_head_rt)] = max_dose['fem_head_rt']
overPenalty = np.ones_like(dose_mask)
overPenalty[np.where(ptv)] = max_dose['ptv']
overPenalty[np.where(body)] = max_dose['body']
overPenalty[np.where(rectum)] = max_dose['rectum']
overPenalty[np.where(bladder)] = max_dose['bladder']
overPenalty[np.where(fem_head_lt)] = max_dose['fem_head_lt']
overPenalty[np.where(fem_head_rt)] = max_dose['fem_head_rt']
underPenalty = np.zeros_like(dose_mask)
underPenalty[np.where(ptv)] = 100.0
self.obj_hist = list()
self.dose = None
self.bundle = bundle
self.overPenalty = overPenalty[dose_mask].flatten()
self.underPenalty = underPenalty[dose_mask].flatten()
self.target_dose = rx_dose[dose_mask].flatten()
self.options = {'disp': True,
'maxiter': 1500, #Change this to say 15000 during testing
'maxfun': 10000, # max # of obj function evals; # change to 10000 during testing
'maxls': 20, # max line search param
'maxfev': None}
self.state = None #BundleType(OptimState)
def get_dij(self, beams, disp=None):
if not isinstance(beams[0], str):
beams = ['{:0>3}'.format(x) for x in beams] # so we can do a proper lookup of Dijs
if '360' in beams:
beams[beams.index('360')] = '359.9' # fix lookup error in h5 dictionary
self.beams = beams
# returns voxels x bixels for dij and sij
with h5sparse.File('{}/{}.hdf5'.format(self.f_dijs_dir, as df:
Dij, Sij = None, None
for beam in beams:
# print('dij for beam ', beam)
dij = df['Dij/{}'.format(beam)].value
#sij = df['/Sij/{}'.format(beam)].value
Dij = vstack((Dij, dij))
print('dij: ', dij.shape) if disp else None
#Sij = vstack((Sij, sij))
del dij#, sij # large objects
self.Dij = Dij.T
del Dij
def obj_eval(self, x):
# evaluate initial f(x) and df/dx
lambder = 1e-10
# TODO: Fix nonnegativity of x
self.dose =
# truncate dose to 0 wherever it's negative
oDose = self.dose - self.target_dose
uDose = self.dose - self.target_dose
f = float(0.5 * ** 2) +
0.5 *, 0) ** 2) +
lambder* np.linalg.norm(x, ord=1)
g_temp = np.multiply(self.overPenalty, oDose.clip(0)) + \
np.multiply(self.underPenalty, uDose.clip(-1e10, 0)) + \
g =
return f, g
def optimize(self, beams, dvh=False):
global state
# state = self.state
maxIter = state.maxIter or 20
maxEval = state.maxEval or maxIter#*1.25
tolFun = state.tolFun or 1e-5
tolX = state.tolX or 1e-9
nCorrection = state.nCorrection or 100
lineSearch = state.lineSearch
lineSearchOpts = state.lineSearchOptions
learningRate = state.learningRate or 1
isverbose = state.verbose or True
state.funcEval = state.funcEval or 0
state.nIter = state.nIter or 0
if isverbose:
logger.debug('<optim.lbfgs with numpy on cpu > ...')
# evaluate initial f(x) and df/dx
tic = time.time()
logger.debug(' Dij from {} took {}'
.format(self.f_dijs_dir.rsplit()[-1], time.time() - tic))
self.beams = beams
x = np.zeros((self.Dij.shape[-1]), dtype=np.float32)
f,g = self.obj_eval(x)
#logger.debug('obj_val, {} '.format(f))
currentFuncEval = 1
state.funcEval += 1
p = g.shape[0]
# check optimality of initial point
state.tmp1 = np.zeros_like(g, dtype=np.float32)
tmp1 = state.tmp1
tmp1 = abs(copy.deepcopy(g)) #tmp1.copy_(g).abs()
if tmp1.sum() <= tolFun:
# optimality condition below tolFun
logger.debug('optimality condition below tolFun')
return x, self.obj_hist
if not state.dir_bufs:
#reusable buffers for y's and s's, and their histories
logger.debug('<creating recyclable direction/step/history buffers>')
state.dir_bufs = np.split(np.zeros((nCorrection+1, p), dtype=np.float32), nCorrection+1) #np.zeros((nCorrection+1, 1, p), dtype=np.float32)
state.stp_bufs = np.split(np.zeros((nCorrection+1, p), dtype=np.float32), nCorrection+1) #np.zeros((nCorrection+1, 1, p), dtype=np.float32)
for i in range(len(state.dir_bufs)):
state.dir_bufs[i] = state.dir_bufs[i].squeeze()
state.stp_bufs[i] = state.stp_bufs[i].squeeze()
# variables cached in state (for tracing)
d = state.d
t = state.t
old_dirs = state.old_dirs
old_stps = state.old_stps
Hdiag = state.Hdiag
g_old = np.zeros_like(g) #state.g_old
f_old = state.f_old
#optimize for a max of maxIter iterations
nIter = 0
while nIter < maxIter:
#keep track of nb of iterations
nIter = nIter + 1
state.nIter = state.nIter + 1
logger.debug('current lbfgs iterate {} out of {}'.format(state.nIter, maxIter))
#- compute gradient descent direction
if state.nIter == 1:
d = g * -1 # -g
old_dirs = []
old_stps = []
Hdiag = 1
#do lbfgs update (update memory)
y = state.dir_bufs.pop(0) # pop
s = state.stp_bufs.pop(0)
y = g - g_old
s = np.multiply(d, t)
ys = # y*s
if ys > 1e-10:
# updating memory
logger.debug('<updating lbfgs-memory>')
if len(old_dirs) == nCorrection:
#shift history by one (limited-memory)
removed1 = old_dirs.pop(0)
removed2 = old_stps.pop(0)
# store new direction/step
# update scale of initial Hessian approximation
Hdiag = np.divide(ys,
# put y and s back into the buffer pool
# compute the approximate (L-BFGS) inverse Hessian
# multiplied by the gradient
k = len(old_dirs)
#print('len lbfgs buffers: ', k)
# need to be accessed element-by-element, so don't re-type tensor: = np.zeros((nCorrection),dtype=np.float32)
ro =
for i in range(k):
ro[i] = 1 /[i], old_dirs[i].T)
# iteration in L-BFGS loop collapsed to use just one buffer
q = tmp1 # reuse tmp1 for the q buffer
# need to be accessed element-by-element, so don't re-type tensor: = np.zeros((nCorrection))
al =
q = np.multiply(g, -1)
# print('old_stps: ', len(old_stps))
# print('old_dirs: ', len(old_dirs))
# print('ro: ', len(ro))
for i in range(k-1,1,-1) :
al[i] = old_dirs[i].dot(q) * ro[i]
q = -al[i] + old_stps[i]
# multiply by initial Hessian
r = d # share the same buffer, since we don't need the old d
# print('q: ', q)
# print('Hdiag: ', Hdiag)
r = q * Hdiag # q[1] * Hdiag
for i in range(k):
be_i = old_stps[i].dot(r) * ro[i]
r = al[i] - be_i + old_dirs[i]
# final direction is in r/d (same object)
g_old = g
f_old = f
#- compute step length
#- directional derivative
gtd = # g * d
#- check that progress can be made along that direction
if gtd > -tolX:
logger.debug('gtd > -tolX; breaking')
# reset initial guess for step size
if state.nIter == 1:
tmp1 = abs(g)
# tmp1.copy_(g).abs()
t = min(1,1/tmp1.sum()) * learningRate
t = learningRate
# optional line search: user function
lsFuncEval = 0
if lineSearch: # and type(lineSearch) == 'function' then
# perform line search, using user function
f,g,x,t,lsFuncEval = lineSearch(self.obj_eval,x,t,d,f,g,gtd,lineSearchOpts)
# no line search, simply move with fixed-step
x = t + d
if nIter != maxIter:
# re-evaluate function only if not in last iteration
# the reason we do this: in a stochastic setting,
# no use to re-evaluate that function here
f,g = self.obj_eval(x)
lsFuncEval = 1
# update func eval
currentFuncEval += lsFuncEval
state.funcEval += lsFuncEval
#- check conditions
if nIter == maxIter:
# no use to run tests
logger.debug('reached max number of iterations')
if currentFuncEval >= maxEval:
# max nb of function evals
logger.debug('max nb of function evals')
tmp1 = abs(g)
if tmp1.sum() <= tolFun:
# check optimality
logger.debug('optimality condition below tolFun')
tmp1 = abs(d * t)
# print('tmp1: ', tmp1)
# tmp1.copy_(d).mul(t).abs()
if tmp1.sum() <= tolX:
# step size below tolX
logger.debug('step size below tolX')
if abs(f-f_old) < tolX:
# function value changing less than tolX
logger.debug('function value changing less than tolX')
state.nIter = 0 # reset it
# save state
state.old_dirs = old_dirs
state.old_stps = old_stps
state.Hdiag = Hdiag
state.g_old = g_old
state.f_old = f_old
state.t = t
state.d = d
self.dose =
self.plot_dvh() if dvh else None
# return optimal x, and history of f(x)
return x, self.obj_hist#, currentFuncEval
def plot_dvh(self):
dosearray = np.zeros_like(dose_mask, dtype=np.float32)
dosearray[np.nonzero(dose_mask)]= self.dose
slc = dosearray.shape[-1]//2
f, (ax) = plt.subplots(2, 3, figsize=(16, 12))
ax[0, 0].imshow(dosearray[:,:,slc-6], cmap='jet')
ax[0, 0].set_title('Case {} and slice {}\n{} '.format(, slc-6, self.beams), fontdict=self.fontdict)
ax[0, 0].xaxis.set_tick_params(labelsize=14)
ax[0, 0].yaxis.set_tick_params(labelsize=14)
ax[0, 1].imshow(dosearray[:,:,slc-4], cmap='jet')
ax[0, 1].set_title('Case {}; and slice {}\n{}'.format(, slc-4, self.beams), fontdict=self.fontdict)
ax[0, 1].xaxis.set_tick_params(labelsize=14)
ax[0, 1].yaxis.set_tick_params(labelsize=14)
xmin, xmax = ax[0, 1].get_xlim()
ymin, ymax = ax[0, 1].get_ylim()
ax[0, 1].text(xmin+5, ymax-50, 'Cost: %.4f'%(-self.obj_hist[-1]),
self.fontdict, backgroundcolor='pink', animated=True, color='k', rasterized=True,
size='x-large', style ='oblique', ma='center')
ax[0, 2].imshow(dosearray[:,:,slc-1], cmap='jet')
ax[0, 2].set_title('Case {}; and slice {}\n{}'.format(, slc-1, self.beams), fontdict=self.fontdict)
ax[0, 2].xaxis.set_tick_params(labelsize=14)
ax[0, 2].yaxis.set_tick_params(labelsize=14)
ax[1, 0].imshow(dosearray[:,:,slc], cmap='jet')
ax[1, 0].set_title('Case {}; and slice {}\n{}'.format(, slc, self.beams), fontdict=self.fontdict)
ax[1, 0].xaxis.set_tick_params(labelsize=14)
ax[1, 0].yaxis.set_tick_params(labelsize=14)
ax[1, 1].imshow(dosearray[:,:,slc+2], cmap='jet')
ax[1, 1].set_title('Case {}; and slice {}\n{}'.format(, slc+2, self.beams), fontdict=self.fontdict)
ax[1, 1].xaxis.set_tick_params(labelsize=14)
ax[1, 1].yaxis.set_tick_params(labelsize=14)
ax[1, 2].imshow(dosearray[:,:,slc+4], cmap='jet')
ax[1, 2].set_title('Case {}; and slice {}\n{}'.format(, slc+4, self.beams), fontdict=self.fontdict)
ax[1, 2].xaxis.set_tick_params(labelsize=14)
ax[1, 2].yaxis.set_tick_params(labelsize=14)
lbfgs_dir = '{}/lbfgs/dose/case_{}'.format(save_dir,
os.makedirs(lbfgs_dir) if not os.path.exists(lbfgs_dir) else None
f.savefig('{}/dose_{}_{}.png'.format(lbfgs_dir, slc, random.randint(0, 99)), fontdict=self.fontdict)
masks = ['fem_head_rt', 'body', 'bladder', 'rectum', 'ptv', 'fem_head_lt']
f2, ax2 = plt.subplots(1, 1, figsize=(10, 7))
lbfgs_dir = '{}/lbfgs/dvh/case_{}'.format(save_dir,
os.makedirs(lbfgs_dir) if not os.path.exists(lbfgs_dir) else None
for sId in masks:
mask_idx = np.where(self.bundle[sId])
tot_vox = self.bundle[sId].sum()
hist,bins = np.histogram(dosearray[mask_idx].flatten(),bins=500,range=(0,dosearray.max()))
temp = (100.-hist.cumsum()*100.0/tot_vox)
# temp[0] = 100 if temp[0] < 100 else temp[0]
ax2.plot(bins[:-1],temp,label=sId, linewidth=2)
ax2.legend(fancybox=True, framealpha=0.5, bbox_to_anchor=(0.6, 0.95),
borderaxespad=0., scatterpoints=1, ncol=1,
loc=2, fontsize=16)
ax2.set_title("LBFGS DVH | Case " +, fontdict=self.fontdict)
ax2.set_xlabel('Fractional Dose [Gy]', fontdict=self.fontdict)
ax2.set_ylabel('Fractional Volume [%]', fontdict=self.fontdict)
f2.savefig('{}/dvh_{}_{}.png'.format(lbfgs_dir, slc, random.randint(0, 99)), fontdict=self.fontdict)
def main(case=None, f_masks_dir=None, f_dijs_dir=None):
# Initializations
bundle = get_mask(case, f_masks_dir, f_dijs_dir)
lbfgs_obj = NumpyBFGS(case=case, dijs_dir=f_dijs_dir, bundle=bundle,
OptimState = {
'maxIter': 1500,
'maxEval': 10000,
'tolFun': 1e-4, #Termination tolerance on the first-order optimality
'tolX': 1e-9, #Termination tol on progress in terms of func/param changes
'lineSearch': None, #A line search function
'lineSearchOptions': None, #A line search function
'learningRate': 0.01,
'nCorrection': 100,
'verbose': True,
'tmp1': np.array(()),
'funcEval': None,
'nIter': 0,
'dir_bufs': [],
'stp_bufs': [],
'd': None,
't': None,
'al': None,
'old_dirs': [],
'old_stps': [],
'Hdiag': 1,
'g_old': None,
'ro': None, #np.zeros((nCorrection),dtype=cp.float32)),
'f_old': None,
'nIter': None,
state = BundleType(OptimState)
lambder = 1.0
beams = np.array([10, 60, 150, 260, 330])
x, fhist, currentf = lbfgs_obj.optimize(state, beams)
print('optimized x: ', x)
print('fhist: ', fhist)
print('current: ', currentf)
if __name__ == '__main__':
tikko =time.time()
main(, f_masks_dir='/home/lex/tdrive/f_masks',
tikka =time.time()
print('total time taken ', tikka - tikko)
