Skip to content

Instantly share code, notes, and snippets.

@jo32
Created April 4, 2013 16:27
Show Gist options
  • Save jo32/5311867 to your computer and use it in GitHub Desktop.
Save jo32/5311867 to your computer and use it in GitHub Desktop.
class KdTree(object):
# the KdTree must know the domains (numerical) and ensure the
# consistence of domains of all the points in the tree
def __init__(self, domains, root=None):
self.domains = domains
self.root = root
if not root is None:
for domain in domains:
setattr(self, domain, getattr(root.point, domain))
def __len__(self):
return self.size()
# return the size of the tree
def size(self):
return self.countNoOfChildren(self.root)
# get the countNoOfChildren of the tree
def countNoOfChildren(self, root):
if root is None:
return 0
number = 1
number += self.countNoOfChildren(root.leftChild)
number += self.countNoOfChildren(root.rightChild)
return number
# find the closet leaf
def findCloestLeaf(self, point):
return self._findCloestLeaf(point, self.root)
# find the closet leaf
def _findCloestLeaf(self, point, root, depth=0):
rootAttr = getattr(root.point, self.domains[depth % len(self.domains)])
pointAttr = getattr(point, self.domains[depth % len(self.domains)])
if root.leftChild is None and root.rightChild is None:
return root
if pointAttr <= rootAttr:
if root.leftChild is None:
return root
return self._findCloestLeaf(point, root.leftChild, depth=depth + 1)
else:
if root.rightChild is None:
return root
return self._findCloestLeaf(point, root.leftChild, depth=depth + 1)
# add a point to the tree by ivoking '_insertPoint' recursively
def addPoint(self, point):
for domain in self.domains:
if not hasattr(point, domain):
print "point doesn't have required domains"
raise
if self.root is None:
self.root = Node(self.domains[0], point)
for domain in self.domains:
setattr(self, domain, getattr(point, domain))
return
self._insertPoint(self.root, point, 0)
# the split domain of the point is by choosing a domain in domain
# cyclically
def _insertPoint(self, root, point, depth):
rootAttr = getattr(root.point, self.domains[depth % len(self.domains)])
pointAttr = getattr(point, self.domains[depth % len(self.domains)])
if pointAttr <= rootAttr:
if root.leftChild is None:
root.leftChild = Node(self.domains[(depth + 1) % len(self.domains)], point)
return
self._insertPoint(root.leftChild, point, depth + 1)
else:
if root.rightChild is None:
root.rightChild = Node(self.domains[(depth + 1) % len(self.domains)], point)
return
self._insertPoint(root.rightChild, point, depth + 1)
# do the breadth first search of the tree
def traverse(self):
traversalQueue = []
traversalQueue.append(self.root)
while len(traversalQueue) > 0:
node = traversalQueue[0]
yield node
traversalQueue.pop(0)
if not node.leftChild is None:
traversalQueue.append(node.leftChild)
if not node.rightChild is None:
traversalQueue.append(node.rightChild)
# return string of the tree accroing to the 'traverse'
# function
def __str__(self):
generator = self.traverse()
points = [node.domain + ": " + str(node.point) for node in generator]
return ', '.join(points)
class Node(object):
def __init__(self, domain, point):
self.domain = domain
self.point = point
self.leftChild = None
self.rightChild = None
def __str__(self):
return "%s, domain: %s" % (str(self.point), self.domain)
def __repr__(self):
return self.__str__()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment