Created
April 4, 2013 16:27
-
-
Save jo32/5311867 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class KdTree(object): | |
# the KdTree must know the domains (numerical) and ensure the | |
# consistence of domains of all the points in the tree | |
def __init__(self, domains, root=None): | |
self.domains = domains | |
self.root = root | |
if not root is None: | |
for domain in domains: | |
setattr(self, domain, getattr(root.point, domain)) | |
def __len__(self): | |
return self.size() | |
# return the size of the tree | |
def size(self): | |
return self.countNoOfChildren(self.root) | |
# get the countNoOfChildren of the tree | |
def countNoOfChildren(self, root): | |
if root is None: | |
return 0 | |
number = 1 | |
number += self.countNoOfChildren(root.leftChild) | |
number += self.countNoOfChildren(root.rightChild) | |
return number | |
# find the closet leaf | |
def findCloestLeaf(self, point): | |
return self._findCloestLeaf(point, self.root) | |
# find the closet leaf | |
def _findCloestLeaf(self, point, root, depth=0): | |
rootAttr = getattr(root.point, self.domains[depth % len(self.domains)]) | |
pointAttr = getattr(point, self.domains[depth % len(self.domains)]) | |
if root.leftChild is None and root.rightChild is None: | |
return root | |
if pointAttr <= rootAttr: | |
if root.leftChild is None: | |
return root | |
return self._findCloestLeaf(point, root.leftChild, depth=depth + 1) | |
else: | |
if root.rightChild is None: | |
return root | |
return self._findCloestLeaf(point, root.leftChild, depth=depth + 1) | |
# add a point to the tree by ivoking '_insertPoint' recursively | |
def addPoint(self, point): | |
for domain in self.domains: | |
if not hasattr(point, domain): | |
print "point doesn't have required domains" | |
raise | |
if self.root is None: | |
self.root = Node(self.domains[0], point) | |
for domain in self.domains: | |
setattr(self, domain, getattr(point, domain)) | |
return | |
self._insertPoint(self.root, point, 0) | |
# the split domain of the point is by choosing a domain in domain | |
# cyclically | |
def _insertPoint(self, root, point, depth): | |
rootAttr = getattr(root.point, self.domains[depth % len(self.domains)]) | |
pointAttr = getattr(point, self.domains[depth % len(self.domains)]) | |
if pointAttr <= rootAttr: | |
if root.leftChild is None: | |
root.leftChild = Node(self.domains[(depth + 1) % len(self.domains)], point) | |
return | |
self._insertPoint(root.leftChild, point, depth + 1) | |
else: | |
if root.rightChild is None: | |
root.rightChild = Node(self.domains[(depth + 1) % len(self.domains)], point) | |
return | |
self._insertPoint(root.rightChild, point, depth + 1) | |
# do the breadth first search of the tree | |
def traverse(self): | |
traversalQueue = [] | |
traversalQueue.append(self.root) | |
while len(traversalQueue) > 0: | |
node = traversalQueue[0] | |
yield node | |
traversalQueue.pop(0) | |
if not node.leftChild is None: | |
traversalQueue.append(node.leftChild) | |
if not node.rightChild is None: | |
traversalQueue.append(node.rightChild) | |
# return string of the tree accroing to the 'traverse' | |
# function | |
def __str__(self): | |
generator = self.traverse() | |
points = [node.domain + ": " + str(node.point) for node in generator] | |
return ', '.join(points) | |
class Node(object): | |
def __init__(self, domain, point): | |
self.domain = domain | |
self.point = point | |
self.leftChild = None | |
self.rightChild = None | |
def __str__(self): | |
return "%s, domain: %s" % (str(self.point), self.domain) | |
def __repr__(self): | |
return self.__str__() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment