Skip to content

Instantly share code, notes, and snippets.

@schocco
Created February 8, 2015 16:14
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save schocco/1d519e6554fbb0066c5a to your computer and use it in GitHub Desktop.
Save schocco/1d519e6554fbb0066c5a to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# coding=utf-8
import sys, getopt
def update_progress(progress, barLength = 12, status = ""):
'display a progress bar in the console and update in-place'
if isinstance(progress, int):
progress = float(progress)
if not isinstance(progress, float):
progress = 0
raise ValueError("progress bar must be float\n")
if progress >= 1:
progress = 1
status = "Done...\n"
block = int(round(barLength*progress))
text = "\rProgress: [{0}] {1:.1f}% {2}".format( "#"*block + "-"*(barLength-block), progress*100, status)
sys.stdout.write(text)
sys.stdout.flush()
def fill_p(nodes):
'''
The probabilities table can be calculated completely with the initial p values.
'''
update_progress(0, status="pre-filling probabilities table...")
n = len(nodes)
probabilities = [x[:] for x in [[0]*(n)]*(n)]
# fill 1st diagonal
for i in range(n):
probabilities[i][i] = nodes[i]
for i in range(1,n):
for j in range(n-i):
l = i+j
a = probabilities[j][j]
b = probabilities[j+1][l]
probabilities[j][l] = a + b
return probabilities
def obst(nodes, mode="qubic"):
'''
Uses Knuth's dynamic algorithm to calculate the table of roots and weights of all subtrees.
Basic O(n^3) form of algorithm is:
obst(i,j) = min_from_i_to_j{obst(i, r-1) + obst(r + 1,j) + sum_k_to_i(p[k])}
obst(i,j) = min_from_i_to_j{obst(i, r-1) + obst(r + 1,j) + probabilities[i][j]}
Instead of i to j, it is sufficient to calculate r[i,j-1]<=k<=r[i+1,j] which makes
the complexity O(n^2)
'''
# start with empty table
n = len(nodes)
roots = [x[:] for x in [[None]*(n)]*(n)]
weights = [x[:] for x in [[None]*(n)]*(n)]
probabilities = fill_p(nodes)
# fill 1st diagonal
for i in range(n):
roots[i][i] = i
weights[i][i] = nodes[i]
# calculate values for other cells
permil = n / 1000
for i in range(1,n):
if(i % permil == 0):
update_progress(float(i/n))
for j in range(n-i):
l = i+j
summ = None
winning_root = None
if mode == "qubic":
rs = range(j, l+1)
else:
rs = range(roots[j][l-1], roots[j+1][l]+1)
for r in rs: # +1 because it needs to be inclusive
a = (j, r-1) # index for left obst
aval = 0 # value of left obst
b = (r+1, l) # index for right obst
bval = 0 # value of right obst
if a[1] >= 0 and a[0] <= a[1]:
aval = weights[a[0]][a[1]]
if b[0] <= b[1]:
bval = weights[b[0]][b[1]]
sum_r = aval + bval
if summ is None or sum_r < summ:
summ = sum_r
winning_root = r
roots[j][l] = winning_root
weights[j][l] = summ + probabilities[j][l]
update_progress(1.0)
return weights,roots
def read_nodes_from_file(path):
'reads comma separated numbers from text file'
nums = open(path, "r").read()
return tuple([int(n) for n in nums.strip().split(",")])
def main(argv=[]):
inputfile = ""
helptext = 'obst.py -i <inputfile>'
try:
opts, args = getopt.getopt(argv, "hi:", ["ifile="])
except getopt.GetoptError:
print(helptext)
sys.exit(2)
for opt, arg in opts:
if opt == '-h':
print(helptext)
sys.exit()
elif opt in ("-i", "--ifile"):
inputfile = arg
try:
nodes = read_nodes_from_file(path=inputfile)
except FileNotFoundError as err:
print(err)
print(helptext)
sys.exit(2)
print("calculating optimal binary search tree with n=%d" % len(nodes))
weights,roots = obst(nodes, mode="quadratic")
root = roots[0][len(roots)-1]
weight = weights[0][len(weights)-1]
for row in weights:
for char in row:
print(char, end="\t")
print("")
print("\nRoot node is {0} (p({0})={1}, p_max = {2}), weighted inner path length is {3}".format(root+1, nodes[root], max(nodes), weight))
if __name__ == '__main__':
if len(sys.argv) > 1:
main(sys.argv[1:])
else:
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment