Skip to content

Instantly share code, notes, and snippets.

@jakevdp
Last active September 30, 2023 13:25
Show Gist options
  • Star 15 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save jakevdp/5216193 to your computer and use it in GitHub Desktop.
Save jakevdp/5216193 to your computer and use it in GitHub Desktop.
Numba Ball Tree example

Numba Ball Tree

This is a quick attempt at writing a ball tree for nearest neighbor searches using numba. I've included a pure python version, and a version with numba jit decorators. Because class support in numba is not yet complete, all the code is factored out to stand-alone functions in the numba version. The resulting code produced by numba is about ~10 times slower than the cython ball tree in scikit-learn. My guess is that part of this stems from lack of inlining in numba, while the rest is due to some sort of overhead that's not obvious to me.

Timings

These are the results on my Macbook Pro, running Python 3.4:

jakesmac $ python ball_tree_python.py 
-------------------------------------------------------
5 neighbors of 1000 points in 3 dimensions
random seed = 9742
results match: True True

sklearn build: 0.00033 sec
python build  : 0.053 sec

sklearn query: 0.004 sec
python query  : 1 sec


jakesmac $ python ball_tree_numba.py 
-------------------------------------------------------
5 neighbors of 1000 points in 3 dimensions
random seed = 2772
results match: True True

sklearn build: 0.0003 sec
numba build  : 0.00045 sec

sklearn query: 0.0041 sec
numba query  : 0.041 sec
import warnings
import numpy as np
class FakeJit(object):
def __call__(self, *args, **kwargs):
if kwargs:
if args:
raise ValueError()
else:
return self
else:
return args[0]
from numba import jit as numba_jit
#numba_jit = FakeJit()
#----------------------------------------------------------------------
# Distance computations
@numba_jit
def rdist(X1, i1, X2, i2):
d = 0
for k in range(X1.shape[1]):
tmp = (X1[i1, k] - X2[i2, k])
d += tmp * tmp
return d
@numba_jit
def min_rdist(node_centroids, node_radius, i_node, X, j):
d = rdist(node_centroids, i_node, X, j)
return max(0, np.sqrt(d) - node_radius[i_node]) ** 2
#----------------------------------------------------------------------
# Heap for distances and neighbors
def heap_create(N, k):
distances = np.full((N, k), np.inf, dtype=float)
indices = np.zeros((N, k), dtype=int)
return distances, indices
def heap_sort(distances, indices):
i = np.arange(len(distances), dtype=int)[:, None]
j = np.argsort(distances, 1)
return distances[i, j], indices[i, j]
@numba_jit
def heap_push(row, val, i_val, distances, indices):
size = distances.shape[1]
# check if val should be in heap
if val > distances[row, 0]:
return
# insert val at position zero
distances[row, 0] = val
indices[row, 0] = i_val
#descend the heap, swapping values until the max heap criterion is met
i = 0
while True:
ic1 = 2 * i + 1
ic2 = ic1 + 1
if ic1 >= size:
break
elif ic2 >= size:
if distances[row, ic1] > val:
i_swap = ic1
else:
break
elif distances[row, ic1] >= distances[row, ic2]:
if val < distances[row, ic1]:
i_swap = ic1
else:
break
else:
if val < distances[row, ic2]:
i_swap = ic2
else:
break
distances[row, i] = distances[row, i_swap]
indices[row, i] = indices[row, i_swap]
i = i_swap
distances[row, i] = val
indices[row, i] = i_val
#----------------------------------------------------------------------
# Tools for building the tree
@numba_jit
def _partition_indices(data, idx_array, idx_start, idx_end, split_index):
# Find the split dimension
n_features = data.shape[1]
split_dim = 0
max_spread = 0
for j in range(n_features):
max_val = -np.inf
min_val = np.inf
for i in range(idx_start, idx_end):
val = data[idx_array[i], j]
max_val = max(max_val, val)
min_val = min(min_val, val)
if max_val - min_val > max_spread:
max_spread = max_val - min_val
split_dim = j
# Partition using the split dimension
left = idx_start
right = idx_end - 1
while True:
midindex = left
for i in range(left, right):
d1 = data[idx_array[i], split_dim]
d2 = data[idx_array[right], split_dim]
if d1 < d2:
tmp = idx_array[i]
idx_array[i] = idx_array[midindex]
idx_array[midindex] = tmp
midindex += 1
tmp = idx_array[midindex]
idx_array[midindex] = idx_array[right]
idx_array[right] = tmp
if midindex == split_index:
break
elif midindex < split_index:
left = midindex + 1
else:
right = midindex - 1
@numba_jit
def _recursive_build(i_node, idx_start, idx_end,
data, node_centroids, node_radius, idx_array,
node_idx_start, node_idx_end, node_is_leaf,
n_nodes, leaf_size):
# determine Node centroid
for j in range(data.shape[1]):
node_centroids[i_node, j] = 0
for i in range(idx_start, idx_end):
node_centroids[i_node, j] += data[idx_array[i], j]
node_centroids[i_node, j] /= (idx_end - idx_start)
# determine Node radius
sq_radius = 0.0
for i in range(idx_start, idx_end):
sq_dist = rdist(node_centroids, i_node, data, idx_array[i])
if sq_dist > sq_radius:
sq_radius = sq_dist
# set node properties
node_radius[i_node] = np.sqrt(sq_radius)
node_idx_start[i_node] = idx_start
node_idx_end[i_node] = idx_end
i_child = 2 * i_node + 1
# recursively create subnodes
if i_child >= n_nodes:
node_is_leaf[i_node] = True
if idx_end - idx_start > 2 * leaf_size:
# this shouldn't happen if our memory allocation is correct.
# We'll proactively prevent memory errors, but raise a
# warning saying we're doing so.
warnings.warn("Internal: memory layout is flawed: "
"not enough nodes allocated")
pass
elif idx_end - idx_start < 2:
# again, this shouldn't happen if our memory allocation is correct.
warnings.warn("Internal: memory layout is flawed: "
"too many nodes allocated")
node_is_leaf[i_node] = True
else:
# split node and recursively construct child nodes.
node_is_leaf[i_node] = False
n_mid = int((idx_end + idx_start) // 2)
_partition_indices(data, idx_array, idx_start, idx_end, n_mid)
_recursive_build(i_child, idx_start, n_mid,
data, node_centroids, node_radius, idx_array,
node_idx_start, node_idx_end, node_is_leaf,
n_nodes, leaf_size)
_recursive_build(i_child + 1, n_mid, idx_end,
data, node_centroids, node_radius, idx_array,
node_idx_start, node_idx_end, node_is_leaf,
n_nodes, leaf_size)
#----------------------------------------------------------------------
# Tools for querying the tree
@numba_jit
def _query_recursive(i_node, X, i_pt, heap_distances, heap_indices, sq_dist_LB,
data, idx_array, node_centroids, node_radius,
node_is_leaf, node_idx_start, node_idx_end):
#------------------------------------------------------------
# Case 1: query point is outside node radius:
# trim it from the query
if sq_dist_LB > heap_distances[i_pt, 0]:
pass
#------------------------------------------------------------
# Case 2: this is a leaf node. Update set of nearby points
elif node_is_leaf[i_node]:
for i in range(node_idx_start[i_node],
node_idx_end[i_node]):
dist_pt = rdist(data, idx_array[i], X, i_pt)
if dist_pt < heap_distances[i_pt, 0]:
heap_push(i_pt, dist_pt, idx_array[i],
heap_distances, heap_indices)
#------------------------------------------------------------
# Case 3: Node is not a leaf. Recursively query subnodes
# starting with the closest
else:
i1 = 2 * i_node + 1
i2 = i1 + 1
sq_dist_LB_1 = min_rdist(node_centroids,
node_radius,
i1, X, i_pt)
sq_dist_LB_2 = min_rdist(node_centroids,
node_radius,
i2, X, i_pt)
# recursively query subnodes
if sq_dist_LB_1 <= sq_dist_LB_2:
_query_recursive(i1, X, i_pt, heap_distances,
heap_indices, sq_dist_LB_1,
data, idx_array, node_centroids, node_radius,
node_is_leaf, node_idx_start, node_idx_end)
_query_recursive(i2, X, i_pt, heap_distances,
heap_indices, sq_dist_LB_2,
data, idx_array, node_centroids, node_radius,
node_is_leaf, node_idx_start, node_idx_end)
else:
_query_recursive(i2, X, i_pt, heap_distances,
heap_indices, sq_dist_LB_2,
data, idx_array, node_centroids, node_radius,
node_is_leaf, node_idx_start, node_idx_end)
_query_recursive(i1, X, i_pt, heap_distances,
heap_indices, sq_dist_LB_1,
data, idx_array, node_centroids, node_radius,
node_is_leaf, node_idx_start, node_idx_end)
#----------------------------------------------------------------------
# The Ball Tree object
class BallTree(object):
def __init__(self, data, leaf_size=40):
self.data = data
self.leaf_size = leaf_size
# validate data
if self.data.size == 0:
raise ValueError("X is an empty array")
if leaf_size < 1:
raise ValueError("leaf_size must be greater than or equal to 1")
self.n_samples = self.data.shape[0]
self.n_features = self.data.shape[1]
# determine number of levels in the tree, and from this
# the number of nodes in the tree. This results in leaf nodes
# with numbers of points betweeen leaf_size and 2 * leaf_size
self.n_levels = 1 + np.log2(max(1, ((self.n_samples - 1)
// self.leaf_size)))
self.n_nodes = int(2 ** self.n_levels) - 1
# allocate arrays for storage
self.idx_array = np.arange(self.n_samples, dtype=int)
self.node_radius = np.zeros(self.n_nodes, dtype=float)
self.node_idx_start = np.zeros(self.n_nodes, dtype=int)
self.node_idx_end = np.zeros(self.n_nodes, dtype=int)
self.node_is_leaf = np.zeros(self.n_nodes, dtype=int)
self.node_centroids = np.zeros((self.n_nodes, self.n_features),
dtype=float)
# Allocate tree-specific data from TreeBase
_recursive_build(0, 0, self.n_samples,
self.data, self.node_centroids,
self.node_radius, self.idx_array,
self.node_idx_start, self.node_idx_end,
self.node_is_leaf, self.n_nodes, self.leaf_size)
def query(self, X, k=1, sort_results=True):
X = np.asarray(X, dtype=float)
if X.shape[-1] != self.n_features:
raise ValueError("query data dimension must "
"match training data dimension")
if self.data.shape[0] < k:
raise ValueError("k must be less than or equal "
"to the number of training points")
# flatten X, and save original shape information
Xshape = X.shape
X = X.reshape((-1, self.data.shape[1]))
# initialize heap for neighbors
heap_distances, heap_indices = heap_create(X.shape[0], k)
for i in range(X.shape[0]):
sq_dist_LB = min_rdist(self.node_centroids,
self.node_radius,
0, X, i)
_query_recursive(0, X, i, heap_distances, heap_indices, sq_dist_LB,
self.data, self.idx_array, self.node_centroids,
self.node_radius, self.node_is_leaf,
self.node_idx_start, self.node_idx_end)
distances, indices = heap_sort(heap_distances, heap_indices)
distances = np.sqrt(distances)
# deflatten results
return (distances.reshape(Xshape[:-1] + (k,)),
indices.reshape(Xshape[:-1] + (k,)))
#----------------------------------------------------------------------
# Testing function
def test_tree(N=1000, D=3, K=5, LS=40):
from time import time
from sklearn.neighbors import BallTree as skBallTree
rseed = np.random.randint(10000)
print("-------------------------------------------------------")
print("{0} neighbors of {1} points in {2} dimensions".format(K, N, D))
print("random seed = {0}".format(rseed))
np.random.seed(rseed)
X = np.random.random((N, D))
# pre-run to jit compile the code
BallTree(X, leaf_size=LS).query(X, K)
t0 = time()
bt1 = skBallTree(X, leaf_size=LS)
t1 = time()
dist1, ind1 = bt1.query(X, K)
t2 = time()
bt2 = BallTree(X, leaf_size=LS)
t3 = time()
dist2, ind2 = bt2.query(X, K)
t4 = time()
print("results match: {0} {1}".format(np.allclose(dist1, dist2),
np.allclose(ind1, ind2)))
print("")
print("sklearn build: {0:.2g} sec".format(t1 - t0))
print("numba build : {0:.2g} sec".format(t3 - t2))
print("")
print("sklearn query: {0:.2g} sec".format(t2 - t1))
print("numba query : {0:.2g} sec".format(t4 - t3))
if __name__ == '__main__':
test_tree()
from __future__ import division, print_function
import numpy as np
class BallTree(object):
def __init__(self, data, leaf_size=40):
self.data = data
self.leaf_size = leaf_size
# validate data
if self.data.size == 0:
raise ValueError("X is an empty array")
if leaf_size < 1:
raise ValueError("leaf_size must be greater than or equal to 1")
self.n_samples = self.data.shape[0]
self.n_features = self.data.shape[1]
# determine number of levels in the tree, and from this
# the number of nodes in the tree. This results in leaf nodes
# with numbers of points betweeen leaf_size and 2 * leaf_size
self.n_levels = 1 + np.log2(max(1, ((self.n_samples - 1)
// self.leaf_size)))
self.n_nodes = int(2 ** self.n_levels) - 1
# allocate arrays for storage
self.idx_array = np.arange(self.n_samples, dtype=int)
self.node_radius = np.zeros(self.n_nodes, dtype=float)
self.node_idx_start = np.zeros(self.n_nodes, dtype=int)
self.node_idx_end = np.zeros(self.n_nodes, dtype=int)
self.node_is_leaf = np.zeros(self.n_nodes, dtype=int)
self.node_centroids = np.zeros((self.n_nodes, self.n_features),
dtype=float)
# Allocate tree-specific data from TreeBase
self._recursive_build(0, 0, self.n_samples)
def _recursive_build(self, i_node, idx_start, idx_end):
# initialize node data
self.init_node(i_node, idx_start, idx_end)
if 2 * i_node + 1 >= self.n_nodes:
self.node_is_leaf[i_node] = True
if idx_end - idx_start > 2 * self.leaf_size:
# this shouldn't happen if our memory allocation is correct
# we'll proactively prevent memory errors, but raise a
# warning saying we're doing so.
import warnings
warnings.warn("Internal: memory layout is flawed: "
"not enough nodes allocated")
elif idx_end - idx_start < 2:
# again, this shouldn't happen if our memory allocation
# is correct. Raise a warning.
import warnings
warnings.warn("Internal: memory layout is flawed: "
"too many nodes allocated")
self.node_is_leaf[i_node] = True
else:
# split node and recursively construct child nodes.
self.node_is_leaf[i_node] = False
n_mid = int((idx_end + idx_start) // 2)
_partition_indices(self.data, self.idx_array,
idx_start, idx_end, n_mid)
self._recursive_build(2 * i_node + 1, idx_start, n_mid)
self._recursive_build(2 * i_node + 2, n_mid, idx_end)
def init_node(self, i_node, idx_start, idx_end):
# determine Node centroid
for j in range(self.n_features):
self.node_centroids[i_node, j] = 0
for i in range(idx_start, idx_end):
self.node_centroids[i_node, j] += self.data[self.idx_array[i],
j]
self.node_centroids[i_node, j] /= (idx_end - idx_start)
# determine Node radius
sq_radius = 0
for i in range(idx_start, idx_end):
sq_dist = self.rdist(self.node_centroids, i_node,
self.data, self.idx_array[i])
sq_radius = max(sq_radius, sq_dist)
self.node_radius[i_node] = np.sqrt(sq_radius)
self.node_idx_start[i_node] = idx_start
self.node_idx_end[i_node] = idx_end
nbrhd = self.data[self.idx_array[idx_start:idx_end]]
def rdist(self, X1, i1, X2, i2):
d = 0
for k in range(self.n_features):
tmp = (X1[i1, k] - X2[i2, k])
d += tmp * tmp
return d
def min_rdist(self, i_node, X, j):
d = self.rdist(self.node_centroids, i_node, X, j)
return max(0, np.sqrt(d) - self.node_radius[i_node]) ** 2
def query(self, X, k=1, sort_results=True):
X = np.asarray(X, dtype=float)
if X.shape[-1] != self.n_features:
raise ValueError("query data dimension must "
"match training data dimension")
if self.data.shape[0] < k:
raise ValueError("k must be less than or equal "
"to the number of training points")
# flatten X, and save original shape information
Xshape = X.shape
X = X.reshape((-1, self.data.shape[1]))
# initialize heap for neighbors
heap = NeighborsHeap(X.shape[0], k)
for i in range(X.shape[0]):
sq_dist_LB = self.min_rdist(0, X, i)
self._query_recursive(0, X, i, heap, sq_dist_LB)
distances, indices = heap.get_arrays(sort=sort_results)
distances = np.sqrt(distances)
# deflatten results
return (distances.reshape(Xshape[:-1] + (k,)),
indices.reshape(Xshape[:-1] + (k,)))
def _query_recursive(self, i_node, X, i_pt, heap, sq_dist_LB):
#------------------------------------------------------------
# Case 1: query point is outside node radius:
# trim it from the query
if sq_dist_LB > heap.largest(i_pt):
pass
#------------------------------------------------------------
# Case 2: this is a leaf node. Update set of nearby points
elif self.node_is_leaf[i_node]:
for i in range(self.node_idx_start[i_node],
self.node_idx_end[i_node]):
dist_pt = self.rdist(self.data, self.idx_array[i], X, i_pt)
if dist_pt < heap.largest(i_pt):
heap.push(i_pt, dist_pt, self.idx_array[i])
#------------------------------------------------------------
# Case 3: Node is not a leaf. Recursively query subnodes
# starting with the closest
else:
i1 = 2 * i_node + 1
i2 = i1 + 1
sq_dist_LB_1 = self.min_rdist(i1, X, i_pt)
sq_dist_LB_2 = self.min_rdist(i2, X, i_pt)
# recursively query subnodes
if sq_dist_LB_1 <= sq_dist_LB_2:
self._query_recursive(i1, X, i_pt, heap, sq_dist_LB_1)
self._query_recursive(i2, X, i_pt, heap, sq_dist_LB_2)
else:
self._query_recursive(i2, X, i_pt, heap, sq_dist_LB_2)
self._query_recursive(i1, X, i_pt, heap, sq_dist_LB_1)
def _partition_indices(data, idx_array, idx_start, idx_end, split_index):
# Find the split dimension
n_features = data.shape[1]
split_dim = 0
max_spread = 0
for j in range(n_features):
max_val = -np.inf
min_val = np.inf
for i in range(idx_start, idx_end):
val = data[idx_array[i], j]
max_val = max(max_val, val)
min_val = min(min_val, val)
if max_val - min_val > max_spread:
max_spread = max_val - min_val
split_dim = j
# Partition using the split dimension
left = idx_start
right = idx_end - 1
while True:
midindex = left
for i in range(left, right):
d1 = data[idx_array[i], split_dim]
d2 = data[idx_array[right], split_dim]
if d1 < d2:
tmp = idx_array[i]
idx_array[i] = idx_array[midindex]
idx_array[midindex] = tmp
midindex += 1
tmp = idx_array[midindex]
idx_array[midindex] = idx_array[right]
idx_array[right] = tmp
if midindex == split_index:
break
elif midindex < split_index:
left = midindex + 1
else:
right = midindex - 1
class NeighborsHeap:
def __init__(self, n_pts, n_nbrs):
self.distances = np.zeros((n_pts, n_nbrs), dtype=float) + np.inf
self.indices = np.zeros((n_pts, n_nbrs), dtype=int)
def get_arrays(self, sort=True):
if sort:
i = np.arange(len(self.distances), dtype=int)[:, None]
j = np.argsort(self.distances, 1)
return self.distances[i, j], self.indices[i, j]
else:
return self.distances, self.indices
def largest(self, row):
return self.distances[row, 0]
def push(self, row, val, i_val):
size = self.distances.shape[1]
# check if val should be in heap
if val > self.distances[row, 0]:
return
# insert val at position zero
self.distances[row, 0] = val
self.indices[row, 0] = i_val
#descend the heap, swapping values until the max heap criterion is met
i = 0
while True:
ic1 = 2 * i + 1
ic2 = ic1 + 1
if ic1 >= size:
break
elif ic2 >= size:
if self.distances[row, ic1] > val:
i_swap = ic1
else:
break
elif self.distances[row, ic1] >= self.distances[row, ic2]:
if val < self.distances[row, ic1]:
i_swap = ic1
else:
break
else:
if val < self.distances[row, ic2]:
i_swap = ic2
else:
break
self.distances[row, i] = self.distances[row, i_swap]
self.indices[row, i] = self.indices[row, i_swap]
i = i_swap
self.distances[row, i] = val
self.indices[row, i] = i_val
def test_tree(N=1000, D=3, K=5, LS=40):
from time import time
from sklearn.neighbors import BallTree as skBallTree
rseed = np.random.randint(10000)
print("-------------------------------------------------------")
print("{0} neighbors of {1} points in {2} dimensions".format(K, N, D))
print("random seed = {0}".format(rseed))
np.random.seed(rseed)
X = np.random.random((N, D))
t0 = time()
bt1 = skBallTree(X, leaf_size=LS)
t1 = time()
dist1, ind1 = bt1.query(X, K)
t2 = time()
bt2 = BallTree(X, leaf_size=LS)
t3 = time()
dist2, ind2 = bt2.query(X, K)
t4 = time()
print("results match: {0} {1}".format(np.allclose(dist1, dist2),
np.allclose(ind1, ind2)))
print("")
print("sklearn build: {0:.2g} sec".format(t1 - t0))
print("python build : {0:.2g} sec".format(t3 - t2))
print("")
print("sklearn query: {0:.2g} sec".format(t2 - t1))
print("python query : {0:.2g} sec".format(t4 - t3))
if __name__ == '__main__':
test_tree()
@dpatschke
Copy link

Hi Jake,

I realize this gist is almost 2 years old now, but I was looking at implementing the Ball Tree algorithm using numba and found this wonderful piece of code. At the time, you mentioned that numba was about 10x slower than the cython code in scikit-learn. A lot has happened in the last two years, with numba. It looks like inlining has improved for recursion purposes, but more importantly, the ParallelAccelerator functionality now allows you to get free multi-threading by modifying lines 316-323 in ball_tree_numba.py. As a result, here are the results of your code (w/ my modifications/enhancements) on my 2013 MacBook Pro w/ Python 3.6:

(py36) Davids-MBP:numba_test duck$ python ball_tree_numba.py
-------------------------------------------------------
Numba version: 0.37.0
-------------------------------------------------------
5 neighbors of 1000 points in 3 dimensions
random seed = 1108
results match: True True

sklearn build: 0.0003 sec
numba build  : 0.00022 sec

sklearn query: 0.00303 sec
numba query  : 0.000915 sec

New function to take advantage of parallelized query:

@numba.jit(nopython=True, parallel=True)
def _query_parallel(i_node, X, heap_distances, heap_indices,
                     data, idx_array, node_centroids, node_radius,
                     node_is_leaf, node_idx_start, node_idx_end):
    for i_pt in numba.prange(X.shape[0]):
        sq_dist_LB = min_rdist(node_centroids, node_radius, i_node, X, i_pt)
        _query_recursive(i_node, X, i_pt, heap_distances, heap_indices, sq_dist_LB,
                         data, idx_array, node_centroids, node_radius, node_is_leaf,
                         node_idx_start, node_idx_end)

Please let me know of a good way to provide you my updated script and I will be happy to do so. I tried to attach it to this comment but it looks like code file attachments are not supported.

Thanks again for providing this awesome gist!!

David

@epifanio
Copy link

epifanio commented Apr 8, 2018

@dpatschke I'm also in a need of a fast point from points distance. Google landed me here, do you have a working example of the parallel implementation? Perhaps a gist linked here will be a great start! Thanks.

@dpatschke
Copy link

@epifanio Here is a link to a gist with my code. Hope this helps!

@dylanlee
Copy link

dylanlee commented Nov 10, 2019

Hey, this code looks like just what I need. Many thanks! Just want to let you know that there seems to be a minor bug with the initial call to "_recursive_build". When I use the arguments listed here and change the value for N to 1003 (or various other sizes of N) in the test function it kept crashing. When I changed the 2nd to last argument to 'self.n_nodes-1' instead of 'self.n_nodes' it appears to work (at least according to this test).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment