Skip to content

Instantly share code, notes, and snippets.

@jaantollander
Created March 6, 2017 15:51
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jaantollander/442ba7249d53ad3ff263def5645725b9 to your computer and use it in GitHub Desktop.
Save jaantollander/442ba7249d53ad3ff263def5645725b9 to your computer and use it in GitHub Desktop.
Quadtree implementation with Python and Numba
import numpy as np
import numba
from numba import deferred_type, optional, f8
node_type = deferred_type()
spec = (
("value", optional(f8[:, :])),
("center", f8[:]),
("length", f8),
("nw", optional(node_type)),
("ne", optional(node_type)),
("sw", optional(node_type)),
("se", optional(node_type)),
)
@numba.jitclass(spec)
class QuadTreeNode(object):
"""
Node of an quad-tree (Q-tree).
NW | NE
---+----
SW | SE
#) Bounding Square
#) Find Center
#) Divide the domain into quadrants (NE, NW, SW, SE)
#) Divide the remaining points into their quadrants
.. [#] https://en.wikipedia.org/wiki/Quadtree
.. [#] http://arborjs.org/docs/barnes-hut
"""
def __init__(self, center, length):
self.value = None
self.center = center
self.length = length
self.nw = None
self.ne = None
self.sw = None
self.se = None
def set_value(self, value):
self.value = value
def set_nodes(self):
l = self.length / 2
d = self.length / 4
self.nw = QuadTreeNode(self.center + np.array((-d, d)), l)
self.ne = QuadTreeNode(self.center + np.array((d, d)), l)
self.sw = QuadTreeNode(self.center + np.array((-d, -d)), l)
self.se = QuadTreeNode(self.center + np.array((d, -d)), l)
return l
node_type.define(QuadTreeNode.class_type.instance_type)
@numba.jit(nopython=True, nogil=True)
def partition(points, center):
# TODO: optimize?
# Count how many values go to each quadrant
count = np.zeros(4, dtype=np.int64) # nw, ne, sw, se
for i in range(len(points)):
p = points[i]
c0 = p[0] >= center[0] # x
c1 = p[1] >= center[1] # y
if (not c0) and c1: # nw
count[0] += 1
elif c0 and c1: # ne
count[1] += 1
elif (not c0) and (not c1): # sw
count[2] += 1
else: # se
count[3] += 1
# Maybe inplace would be faster?
nw = np.empty(shape=(count[0], 2))
ne = np.empty(shape=(count[1], 2))
sw = np.empty(shape=(count[2], 2))
se = np.empty(shape=(count[3], 2))
for i in range(len(points)):
p = points[i]
c0 = p[0] >= center[0] # x
c1 = p[1] >= center[1] # y
if (not c0) and c1: # nw
count[0] -= 1
nw[count[0]] = p
elif c0 and c1: # ne
count[1] -= 1
ne[count[1]] = p
elif (not c0) and (not c1): # sw
count[2] -= 1
sw[count[2]] = p
else: # se
count[3] -= 1
se[count[3]] = p
return nw, ne, sw, se
@numba.jit(nopython=True)
def add_nodes(node, points, threshold):
if len(points) >= 2:
# Divide points into their quadrants
nw, ne, sw, se = partition(points, node.center)
l = node.set_nodes()
if l < threshold:
node.nw.set_value(nw)
node.ne.set_value(ne)
node.sw.set_value(sw)
node.se.set_value(se)
else:
add_nodes(node.nw, nw, threshold)
add_nodes(node.ne, ne, threshold)
add_nodes(node.sw, sw, threshold)
add_nodes(node.se, se, threshold)
elif len(points) == 1:
node.set_value(points)
else:
pass
@numba.jit(nopython=True)
def bounding_square(points):
lengths = np.array(((points[:, 0].max() - points[:, 0].min()),
(points[:, 1].max() - points[:, 1].min())))
return lengths / 2, lengths.max()
@numba.jit(nopython=True)
def barnes_hut(points, threshold):
center, length = bounding_square(points)
tree = QuadTreeNode(center, length)
add_nodes(tree, points, threshold)
return tree
@HerveGlz
Copy link

HerveGlz commented Feb 8, 2020

Hi Sir,
Thanks a lot for this great python script ! Good idea to use Numba with this space partitionning script !
May I ask you to explain us a bit more how to use it ? I do not succeed in make it works unfortunately...
I done something a bit similar with the scipy package thanks to its spatial.cKDTree function. But I tried to add Numba on top of it so that I can speed up my processing step but without success since Numba doesn't seem to accept the KDTree function from scipy.

So I will be so glad to receive some help from you on your solution !

Thanks in advance,
Warm regards,
Hervé

@jaantollander
Copy link
Author

This code snippets seems to only construct the quadtree. Initially, I created this code for fixed-radius near neighbors search, but I ended up using Cell lists instead, therefore, the code is somewhat unfinished. Also, I'm not certain you would get any speedup using Numba over scipy kd-tree.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment