Skip to content

Instantly share code, notes, and snippets.

@sang4lv
Last active January 23, 2017 11:19
Show Gist options
  • Save sang4lv/b1974d71fea6ffeb379f7f2eeeb306d2 to your computer and use it in GitHub Desktop.
Save sang4lv/b1974d71fea6ffeb379f7f2eeeb306d2 to your computer and use it in GitHub Desktop.
#Uses python3
import sys
import math
class Node:
def __init__(self, coords):
self.left = None
self.right = None
self.coords = coords
class KDTree:
def __init__(self, k):
self.k = k
self.root = None
def insert(self, coords):
node = Node(coords)
if self.root is None:
self.root = node
return
dim = 0
curr = self.root
while True:
if node.coords[dim] <= curr.coords[dim]:
if curr.left is None:
curr.left = node
break
curr = curr.left
else:
if curr.right is None:
curr.right = node
break
curr = curr.right
dim = (dim + 1) % self.k
def show(self, curr=None):
curr = curr or self.root
print curr.coords
if curr.left is not None:
print "Left subtree:"
self.show(curr.left)
elif curr.right is not None:
print "Right subtree:"
self.show(curr.right)
def get_sum_square(self, pair):
return (pair[1] - pair[0])**2
def get_axes_square_sum(self, coord1, coord2):
return reduce(lambda total, axis: total + self.get_sum_square(axis), zip(coord1, coord2), 0)
def get_minimum_distance(self, coords, node, dim, best, visited):
hyper_distance = self.get_sum_square([node.coords[dim], coords[dim]])
if hyper_distance <= best:
dim = (dim - 1) % self.k
best = min(self.get_axes_square_sum(node.coords, coords), best)
next_node = None
if node.left and node.left not in visited:
next_node = node.left
elif node.right and node.right not in visited:
next_node = node.right
if next_node:
visited.append(next_node)
return self.get_minimum_distance(coords, next_node, dim, best, visited)
return best
def findNN(self, coords):
dim = 0
curr = self.root
visited = []
breadcrumb = [self.root]
# find the leaf where coords should be inserted
while True:
if coords[dim] <= curr.coords[dim]:
if curr.left is None:
break
curr = curr.left
else:
if curr.right is None:
break
curr = curr.right
dim = (dim + 1) % self.k
visited.append(curr)
breadcrumb.append(curr)
# assume last node is the best
curr = breadcrumb.pop()
best_distance = self.get_axes_square_sum(curr.coords, coords)
dim = (dim - 1) % self.k
# backtrace current path in hopes of finding NN
while len(breadcrumb):
curr = breadcrumb.pop()
best_distance = self.get_minimum_distance(coords, curr, dim, best_distance, visited)
dim = (dim - 1) % self.k
# bail early
if best_distance == 0.0:
return best_distance
return best_distance
def minimum_distance(x, y):
tree = KDTree(2)
for pair in zip(x, y):
tree.insert(list(pair))
result = tree.findNN([6, 7])
print math.sqrt(result)
x = [7, 1, 4, 7, 8, 5]
y = [7, 100, 8, 7, 6, 101]
minimum_distance(x, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment