This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Given a binary tree, this routine computes new coordinates for visualization purposes | |
# (reduces overlap between nodes). | |
# ix (paix) should be np.arrays that have integer indices for nodes (parent nodes). | |
# paix of root should be -1. | |
# depth is a np.array with integer value of the depth, and | |
# coords is 2d initial coordinates of each node. | |
# It works as follows, as a physics simulation: | |
# - y coordinates are fixed by looking at the depth (not free). | |
# - Every node is a charged particle that pushes nearby particles away (at the same depth). | |
# - Connection between a parent and children is like a string: parent is trying to pull its children. | |
# - This system is simulated until it reaches equilibrium: Forces turn into acceleration, which turns | |
# into velocity, which turns into translation. | |
# - When the system is in equilibrium (i.e. velocities are small enough) it usually gives a decent tree. | |
# Sometimes particles can move too fast and pass each other, which can give weird trees with parents | |
# on the right and children on the left. To avoid that, the order with respect to each depth value is | |
# explicitly preserved. | |
def compute_coords(ix, paix, depth, coords): | |
rangex = 100. | |
cpush = 2e-1 | |
cspring = 1e-4 | |
N = np.shape(ix)[0] | |
acc = np.zeros((N)) | |
vel = np.zeros((N)) | |
# will fix the order per depth, i don't want any inversions | |
# when particles are moving too fast | |
sorted_per_depth = {} | |
for i in range(N): | |
d = depth[i] | |
if d not in sorted_per_depth: | |
sorted_per_depth[d] = [i] | |
else: | |
l = sorted_per_depth[d] | |
for j in range(len(l)): | |
if (j+1) == len(l): | |
sorted_per_depth[d].append(i) | |
break | |
if coords[l[j],0] > coords[i,0]: | |
sorted_per_depth[d].insert(j,i) | |
break | |
print 'Starting to compute coords...' | |
while True: | |
#compute push forces | |
for i in range(N): | |
for j in range(N): | |
if i != j: | |
if depth[i] == depth[j] and abs(coords[i,0]-coords[j,0]) < rangex: | |
assert coords[i,0] != coords[j,0] | |
k = -1. if coords[j,0] > coords[i,0] else 1. | |
acc[i] += k * min(cpush / (coords[i,0]-coords[j,0])**2, 200) | |
#compute pull forces | |
for i in range(N): | |
if paix[i] != -1: | |
j = paix[i] | |
dx = abs(coords[i,0]-coords[j,0]) | |
assert dx > 0 | |
k = 1. if coords[j,0] > coords[i,0] else -1. | |
acc[i] += k * cspring * (dx ** 2) # pow2 might be too powerful? | |
acc[j] -= k * cspring * (dx ** 2) | |
#update velocities | |
vel += acc | |
vel = np.maximum(vel, -200) | |
vel = np.minimum(vel, 200) | |
sys.stdout.write('%.4f' % (np.max(vel)) + ' ' + '%.4f' % (np.max(acc)) + ' \r') | |
sys.stdout.flush() | |
#translate | |
coords[:,0] += vel | |
vel[:] *= 0.9 #lose velocity to friction | |
acc[:] = 0. | |
# if there are inversions because the particles were too fast, fix | |
# depthwise left-to-right order should be preserved | |
for d, l in sorted_per_depth.items(): | |
#print l | |
inversion = True | |
if len(l) == 1: | |
continue | |
while inversion: | |
inversion = False | |
for i in range(len(l)-1): | |
if (coords[l[i],0] > coords[l[i+1],0]): | |
inversion = True | |
coords[l[i],0],coords[l[i+1],0] = coords[l[i+1],0],coords[l[i],0] | |
vel[l[i]] *= -0.1 | |
vel[l[i+1]] *= -0.1 # as if they collided and bounced back | |
break | |
#test for stopping condition | |
if np.max(vel) <= 1e-2: | |
break | |
print '' | |
print 'Done.' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment