Skip to content

Instantly share code, notes, and snippets.

@dpatschke
Created April 9, 2018 02:49
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save dpatschke/f5793db4c1d9cf55b3d16b2fc25c63e3 to your computer and use it in GitHub Desktop.
Save dpatschke/f5793db4c1d9cf55b3d16b2fc25c63e3 to your computer and use it in GitHub Desktop.
Numba Ball Tree (ParallelAccelerator)

Numba Ball Tree (ParallelAccelerator)

This is a modified gist of Jake Vanderplas wonderful Numba Ball Tree code from the following gist. This gist basically adds the 'nopython' parameter in the jit decorators from the original gist and parallelizes the nearest neighbor query for each of the points.

Timings

With some of the new advances in numba and the modifications to the gist, I have been able to achieve the following timings on a 2013 MacBook Pro.

(py36) Davids-MBP:numba_test duck$ python ball_tree_numba.py
-------------------------------------------------------
Numba version: 0.37.0
-------------------------------------------------------
5 neighbors of 1000 points in 3 dimensions
random seed = 1108
results match: True True

sklearn build: 0.0003 sec
numba build  : 0.00022 sec

sklearn query: 0.00303 sec
numba query  : 0.000915 sec
import warnings
import numpy as np
from numba import jit as numba_jit
import numba
#----------------------------------------------------------------------
# Distance computations
@numba.jit(nopython=True)
def rdist(X1, i1, X2, i2):
d = 0
for k in range(X1.shape[1]):
tmp = (X1[i1, k] - X2[i2, k])
d += tmp * tmp
return d
@numba.jit(nopython=True)
def min_rdist(node_centroids, node_radius, i_node, X, j):
d = rdist(node_centroids, i_node, X, j)
return np.square(max(0, np.sqrt(d) - node_radius[i_node]))
#----------------------------------------------------------------------
# Heap for distances and neighbors
@numba.jit(nopython=True)
def heap_create(N, k):
distances = np.full((N, k), np.finfo(np.float64).max)
indices = np.zeros((N, k), dtype=np.int64)
return distances, indices
def heap_sort(distances, indices):
i = np.arange(len(distances), dtype=int)[:, None]
j = np.argsort(distances, 1)
return distances[i, j], indices[i, j]
@numba.jit(nopython=True)
def heap_push(row, val, i_val, distances, indices):
size = distances.shape[1]
# check if val should be in heap
if val > distances[row, 0]:
return
# insert val at position zero
distances[row, 0] = val
indices[row, 0] = i_val
#descend the heap, swapping values until the max heap criterion is met
i = 0
while True:
ic1 = 2 * i + 1
ic2 = ic1 + 1
if ic1 >= size:
break
elif ic2 >= size:
if distances[row, ic1] > val:
i_swap = ic1
else:
break
elif distances[row, ic1] >= distances[row, ic2]:
if val < distances[row, ic1]:
i_swap = ic1
else:
break
else:
if val < distances[row, ic2]:
i_swap = ic2
else:
break
distances[row, i] = distances[row, i_swap]
indices[row, i] = indices[row, i_swap]
i = i_swap
distances[row, i] = val
indices[row, i] = i_val
#----------------------------------------------------------------------
# Tools for building the tree
@numba.jit(nopython=True)
def _partition_indices(data, idx_array, idx_start, idx_end, split_index):
# Find the split dimension
n_features = data.shape[1]
split_dim = 0
max_spread = 0
for j in range(n_features):
max_val = -np.inf
min_val = np.inf
for i in range(idx_start, idx_end):
val = data[idx_array[i], j]
max_val = max(max_val, val)
min_val = min(min_val, val)
if max_val - min_val > max_spread:
max_spread = max_val - min_val
split_dim = j
# Partition using the split dimension
left = idx_start
right = idx_end - 1
while True:
midindex = left
for i in range(left, right):
d1 = data[idx_array[i], split_dim]
d2 = data[idx_array[right], split_dim]
if d1 < d2:
tmp = idx_array[i]
idx_array[i] = idx_array[midindex]
idx_array[midindex] = tmp
midindex += 1
tmp = idx_array[midindex]
idx_array[midindex] = idx_array[right]
idx_array[right] = tmp
if midindex == split_index:
break
elif midindex < split_index:
left = midindex + 1
else:
right = midindex - 1
@numba.jit(nopython=True)
def _recursive_build(i_node, idx_start, idx_end,
data, node_centroids, node_radius, idx_array,
node_idx_start, node_idx_end, node_is_leaf,
n_nodes, leaf_size):
# determine Node centroid
for j in range(data.shape[1]):
node_centroids[i_node, j] = 0
for i in range(idx_start, idx_end):
node_centroids[i_node, j] += data[idx_array[i], j]
node_centroids[i_node, j] /= (idx_end - idx_start)
# determine Node radius
sq_radius = 0.0
for i in range(idx_start, idx_end):
sq_dist = rdist(node_centroids, i_node, data, idx_array[i])
if sq_dist > sq_radius:
sq_radius = sq_dist
# set node properties
node_radius[i_node] = np.sqrt(sq_radius)
node_idx_start[i_node] = idx_start
node_idx_end[i_node] = idx_end
i_child = 2 * i_node + 1
# recursively create subnodes
if i_child >= n_nodes:
node_is_leaf[i_node] = True
if idx_end - idx_start > 2 * leaf_size:
# this shouldn't happen if our memory allocation is correct.
# We'll proactively prevent memory errors, but raise a
# warning saying we're doing so.
#warnings.warn("Internal: memory layout is flawed: "
# "not enough nodes allocated")
pass
elif idx_end - idx_start < 2:
# again, this shouldn't happen if our memory allocation is correct.
#warnings.warn("Internal: memory layout is flawed: "
# "too many nodes allocated")
node_is_leaf[i_node] = True
else:
# split node and recursively construct child nodes.
node_is_leaf[i_node] = False
n_mid = int((idx_end + idx_start) // 2)
_partition_indices(data, idx_array, idx_start, idx_end, n_mid)
_recursive_build(i_child, idx_start, n_mid,
data, node_centroids, node_radius, idx_array,
node_idx_start, node_idx_end, node_is_leaf,
n_nodes, leaf_size)
_recursive_build(i_child + 1, n_mid, idx_end,
data, node_centroids, node_radius, idx_array,
node_idx_start, node_idx_end, node_is_leaf,
n_nodes, leaf_size)
#----------------------------------------------------------------------
# Tools for querying the tree
@numba.jit(nopython=True)
def _query_recursive(i_node, X, i_pt, heap_distances, heap_indices, sq_dist_LB,
data, idx_array, node_centroids, node_radius,
node_is_leaf, node_idx_start, node_idx_end):
#------------------------------------------------------------
# Case 1: query point is outside node radius:
# trim it from the query
if sq_dist_LB > heap_distances[i_pt, 0]:
pass
#------------------------------------------------------------
# Case 2: this is a leaf node. Update set of nearby points
elif node_is_leaf[i_node]:
for i in range(node_idx_start[i_node],
node_idx_end[i_node]):
dist_pt = rdist(data, idx_array[i], X, i_pt)
if dist_pt < heap_distances[i_pt, 0]:
heap_push(i_pt, dist_pt, idx_array[i],
heap_distances, heap_indices)
#------------------------------------------------------------
# Case 3: Node is not a leaf. Recursively query subnodes
# starting with the closest
else:
i1 = 2 * i_node + 1
i2 = i1 + 1
sq_dist_LB_1 = min_rdist(node_centroids,
node_radius,
i1, X, i_pt)
sq_dist_LB_2 = min_rdist(node_centroids,
node_radius,
i2, X, i_pt)
# recursively query subnodes
if sq_dist_LB_1 <= sq_dist_LB_2:
_query_recursive(i1, X, i_pt, heap_distances,
heap_indices, sq_dist_LB_1,
data, idx_array, node_centroids, node_radius,
node_is_leaf, node_idx_start, node_idx_end)
_query_recursive(i2, X, i_pt, heap_distances,
heap_indices, sq_dist_LB_2,
data, idx_array, node_centroids, node_radius,
node_is_leaf, node_idx_start, node_idx_end)
else:
_query_recursive(i2, X, i_pt, heap_distances,
heap_indices, sq_dist_LB_2,
data, idx_array, node_centroids, node_radius,
node_is_leaf, node_idx_start, node_idx_end)
_query_recursive(i1, X, i_pt, heap_distances,
heap_indices, sq_dist_LB_1,
data, idx_array, node_centroids, node_radius,
node_is_leaf, node_idx_start, node_idx_end)
@numba.jit(nopython=True, parallel=True)
def _query_parallel(i_node, X, heap_distances, heap_indices,
data, idx_array, node_centroids, node_radius,
node_is_leaf, node_idx_start, node_idx_end):
for i_pt in numba.prange(X.shape[0]):
sq_dist_LB = min_rdist(node_centroids, node_radius, i_node, X, i_pt)
_query_recursive(i_node, X, i_pt, heap_distances, heap_indices, sq_dist_LB,
data, idx_array, node_centroids, node_radius, node_is_leaf,
node_idx_start, node_idx_end)
#----------------------------------------------------------------------
# The Ball Tree object
class BallTree(object):
def __init__(self, data, leaf_size=40):
self.data = data
self.leaf_size = leaf_size
# validate data
if self.data.size == 0:
raise ValueError("X is an empty array")
if leaf_size < 1:
raise ValueError("leaf_size must be greater than or equal to 1")
self.n_samples = self.data.shape[0]
self.n_features = self.data.shape[1]
# determine number of levels in the tree, and from this
# the number of nodes in the tree. This results in leaf nodes
# with numbers of points betweeen leaf_size and 2 * leaf_size
self.n_levels = 1 + np.log2(max(1, ((self.n_samples - 1)
// self.leaf_size)))
self.n_nodes = int(2 ** self.n_levels) - 1
# allocate arrays for storage
self.idx_array = np.arange(self.n_samples, dtype=int)
self.node_radius = np.zeros(self.n_nodes, dtype=float)
self.node_idx_start = np.zeros(self.n_nodes, dtype=int)
self.node_idx_end = np.zeros(self.n_nodes, dtype=int)
self.node_is_leaf = np.zeros(self.n_nodes, dtype=int)
self.node_centroids = np.zeros((self.n_nodes, self.n_features),
dtype=float)
# Allocate tree-specific data from TreeBase
_recursive_build(0, 0, self.n_samples,
self.data, self.node_centroids,
self.node_radius, self.idx_array,
self.node_idx_start, self.node_idx_end,
self.node_is_leaf, self.n_nodes, self.leaf_size)
def query(self, X, k=1, sort_results=True):
X = np.asarray(X, dtype=float)
if X.shape[-1] != self.n_features:
raise ValueError("query data dimension must "
"match training data dimension")
if self.data.shape[0] < k:
raise ValueError("k must be less than or equal "
"to the number of training points")
# flatten X, and save original shape information
Xshape = X.shape
X = X.reshape((-1, self.data.shape[1]))
# initialize heap for neighbors
heap_distances, heap_indices = heap_create(X.shape[0], k)
#for i in range(X.shape[0]):
# sq_dist_LB = min_rdist(self.node_centroids,
# self.node_radius,
# 0, X, i)
# _query_recursive(0, X, i, heap_distances, heap_indices, sq_dist_LB,
# self.data, self.idx_array, self.node_centroids,
# self.node_radius, self.node_is_leaf,
# self.node_idx_start, self.node_idx_end)
_query_parallel(0, X, heap_distances, heap_indices,
self.data, self.idx_array, self.node_centroids, self.node_radius,
self.node_is_leaf, self.node_idx_start, self.node_idx_end)
distances, indices = heap_sort(heap_distances, heap_indices)
distances = np.sqrt(distances)
# deflatten results
return (distances.reshape(Xshape[:-1] + (k,)),
indices.reshape(Xshape[:-1] + (k,)))
#----------------------------------------------------------------------
# Testing function
def test_tree(N=1000, D=3, K=5, LS=40):
from time import time
from sklearn.neighbors import BallTree as skBallTree
print("-------------------------------------------------------")
print("Numba version: " + numba.__version__)
rseed = np.random.randint(10000)
print("-------------------------------------------------------")
print("{0} neighbors of {1} points in {2} dimensions".format(K, N, D))
print("random seed = {0}".format(rseed))
np.random.seed(rseed)
X = np.random.random((N, D))
# pre-run to jit compile the code
BallTree(X, leaf_size=LS).query(X, K)
t0 = time()
bt1 = skBallTree(X, leaf_size=LS)
t1 = time()
dist1, ind1 = bt1.query(X, K)
t2 = time()
bt2 = BallTree(X, leaf_size=LS)
t3 = time()
dist2, ind2 = bt2.query(X, K)
t4 = time()
print("results match: {0} {1}".format(np.allclose(dist1, dist2),
np.allclose(ind1, ind2)))
print("")
print("sklearn build: {0:.3g} sec".format(t1 - t0))
print("numba build : {0:.3g} sec".format(t3 - t2))
print("")
print("sklearn query: {0:.3g} sec".format(t2 - t1))
print("numba query : {0:.3g} sec".format(t4 - t3))
if __name__ == '__main__':
test_tree()
@dylanlee
Copy link

dylanlee commented Nov 10, 2019

Hey, this code looks like just what I need. Many thanks! Just want to let you know that there seems to be a minor bug with the initial call to "_recursive_build". When I use the arguments listed here and change the value for N to 1003 (or various other sizes of N) in the test function it kept crashing. When I changed the 2nd to last argument to 'self.n_nodes-1' instead of 'self.n_nodes' it appears to work.

@tgbnhy
Copy link

tgbnhy commented Jul 14, 2020

Hi, your code is super fast! Great job! But there is a problem when I applied your code to index multiple small datasets at the same time, it reports that "Process finished with exit code 134 (interrupted by signal 6: SIGABRT)" after indexing 30 datasets, do you know why? Thanks very much!

@dpatschke
Copy link
Author

"Process finished with exit code 134 (interrupted by signal 6: SIGABRT)" after indexing 30 datasets, do you know why?

I'm glad the gist is working for you ... well .. for the most part :-). I haven't really used this code in awhile so I'm not sure what could be causing the problem. A quick stackoverflow search of the exit code you provided appears to point to a potential out of memory issue on your system. I can't say that I've ever done memory profiling on this code but it is possible that the additional objects that are being created are consuming a good chunk of memory. It may be best to delete the BallTree objects once you are finished querying them and see if that helps. Apologies for not having a more definitive response.

@tgbnhy
Copy link

tgbnhy commented Jul 14, 2020

Wow, thanks for your quick response! I did what you suggested by using "del bt2", but it still reported the same error. Below is my code:

==========================
def build_index_ball_tree(li, LS):
"""
li stores a set of datasets
:return:
"""
counter = 0
for xx in li:
if len(xx) < 20:
continue
bt2 = ball_tree_numba.BallTree(xx, leaf_size=LS)
counter+= 1
print("counter: " + str(counter))
del bt2

    return index_list

==========================
Thanks again!

@dpatschke
Copy link
Author

Hmmmm ... not quite sure. Just so I'm understanding the code above (due to the formatting) ... you are looking to run this only on datasets that that have 20 or fewer rows? What is the dimensionality of each of these datasets as well as the leaf size you are attempting to use?

@tgbnhy
Copy link

tgbnhy commented Jul 14, 2020

Thanks, I just randomly generated multiple datasets using "X = np.random.random((N, D))" and it works well. So I guess it is because of my dataset format. Just let you know that I am trying to index spatial datasets like this:

[[ 40.74221047 -73.9326508 ]
[ 40.73738217 -73.81427516]
[ 40.76983905 -73.96429129]
...
[ 40.87176915 -73.80558426]
[ 40.67450249 -73.94389451]
[ 40.69479017 -73.9924112 ]]

Do I need to normalize them? Thanks for your time!

@tgbnhy
Copy link

tgbnhy commented Jul 15, 2020

Screen Shot 2020-07-14 at 9 27 52 PM

Hi, I think the main problem is that we cannot create indexes for multiple datasets with different N, please check the above figure.
When I comment on line 377, it works fine. You can also test on your side. Thanks!

for i in range(100):
    N = np.random.randint(1000)
    X = np.random.random((N, D))
    bt2 = BallTree(X, leaf_size=LS)

@dpatschke
Copy link
Author

ahhh yes ... you are right. Good find.

class BallTree(object):
    def __init__(self, data, leaf_size=40):

The class init does have the data hard-coded in there. I suppose you could create multiple instances of the class .... one for each size of your data. It's not the most elegant solution but it would probably work.

@tgbnhy
Copy link

tgbnhy commented Jul 15, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment