Skip to content

Instantly share code, notes, and snippets.

@ofx
Created March 17, 2015 12:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ofx/f3b28f5765b5a7d8edc3 to your computer and use it in GitHub Desktop.
Save ofx/f3b28f5765b5a7d8edc3 to your computer and use it in GitHub Desktop.
Build a kd-tree for 2d points, with two consecutive x-levels and one y-level.
#!/usr/bin/python
import sys
import numpy
import Queue
import numpy as np
import random
import matplotlib.pyplot as plt
# Generate set of points
S = [
[11, 0],
[15, 4],
[10, 13],
[5, 12],
[4, 19],
[0, 17],
[8, 30],
[10, 23],
[9, 17],
[6, 27],
[8, 23],
[11, 8]
]
X, Y = zip(*S)
# Render all points
plt.plot(X, Y, 'ro')
maxX = max(X)
maxY = max(Y)
print maxX
class Node:
l = 0
left = None
right = None
level = 0
def __init__(self, l, left, right, level):
self.l = l
self.left = left
self.right = right
self.level = level
class Leaf:
v = 0
level = 0
def __init__(self, v, level):
self.v = v
self.level = level
def median(lst):
return numpy.median(numpy.array(lst))
def buildKdTree(P, depth):
if len(P) == 1:
return Leaf(P[0], depth)
else:
# Split the set in X and Y coordinates
X, Y = zip(*P)
# Y level
if depth % 3 == 0:
# Calculate the median
l = median(Y)
# Split the set of points
P1 = [p for p in P if p[1] <= l]
P2 = [p for p in P if p[1] > l]
# Render a line
plt.plot([0, maxX], [l, l])
# X level
else:
# Calculate the median
l = median(X)
# Split the set of points
P1 = [p for p in P if p[0] <= l]
P2 = [p for p in P if p[0] > l]
# Render a line
plt.plot([l, l], [0, maxY])
# Construct a left and right child
left = buildKdTree(P1, depth + 1)
right = buildKdTree(P2, depth + 1)
node = Node(l, left, right, depth)
return node
def displayTree(tree):
# Create a queue
queue = Queue.Queue()
# Add the root
queue.put(tree)
# Keep track of the last level
lastLevel = 1
# Display
while not queue.empty():
# Retrieve the current
current = queue.get()
# Check if we should add a new line
if current.level > lastLevel:
sys.stdout.write("\n")
lastLevel += 1
# Display data of current
if isinstance(current, Node):
sys.stdout.write('N(%f)' % current.l)
# Recurse
if current.left != None:
queue.put(current.left)
if current.right != None:
queue.put(current.right)
elif isinstance(current, Leaf):
sys.stdout.write('L(%i,%i)' % (current.v[0], current.v[1]))
# Build the 2d tree
tree = buildKdTree(S, 1)
# Display the tree
displayTree(tree)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment