Skip to content

Instantly share code, notes, and snippets.

@suspectpart
Last active December 13, 2018 22:51
Show Gist options
  • Save suspectpart/98b0b5c4fcabe367d695d901c9d5cebd to your computer and use it in GitHub Desktop.
Save suspectpart/98b0b5c4fcabe367d695d901c9d5cebd to your computer and use it in GitHub Desktop.
Binary Search Tree
import unittest
from abc import ABC, abstractmethod
class DuplicateKeyException(Exception):
def __init__(self, key):
msg = f"Key {key} is already taken"
super(DuplicateKeyException, self).__init__(msg)
class DuplicatePolicy(ABC):
@abstractmethod
def handle(self, node, other):
pass
class Error(DuplicatePolicy):
def handle(self, node, other):
raise DuplicateKeyException(other.key)
class Overwrite(DuplicatePolicy):
def handle(self, node, other):
node.value = other.value
return False
class Buckets(DuplicatePolicy):
def handle(self, node, other):
if isinstance(node.value, list):
node.value.append(other.value)
else:
node.value = [node.value, other.value]
return True
class Plot:
def __init__(self, tree):
self._tree = tree
def mermaid(self):
plot = "graph TD;\n"
for node in self._tree.traverse():
plot += f"{node.key}-->{node.left.key}\n" if node.left else ""
plot += f"{node.key}-->{node.right.key}\n" if node.right else ""
return plot.strip()
class Node:
__slots__ = ["key", "value", "left", "right", "_policy"]
def __init__(self, key, value, policy=None):
self.key, self.value = key, value
self.left = self.right = None
self._policy = policy or Overwrite()
def insert(self, node):
if node == self:
return self._policy.handle(self, node)
if node < self:
self._insert_left(node)
else:
self._insert_right(node)
return True
def traverse(self):
yield self
if self.left:
yield from self.left.traverse()
if self.right:
yield from self.right.traverse()
def find(self, key):
if key > self.key:
return self.right.find(key) if self.right else None
if key < self.key:
return self.left.find(key) if self.left else None
return self.value
def _insert_left(self, node):
if self.left:
self.left.insert(node)
else:
self.left = node
def _insert_right(self, node):
if self.right:
self.right.insert(node)
else:
self.right = node
def __lt__(self, other):
return self.key < other.key
def __gt__(self, other):
return self.key > other.key
def __eq__(self, other):
return self.key == other.key
def __repr__(self):
return f"<Node {self.key}:{self.value}, " \
f"l={self.left.key if self.left else ''}, " \
f"r={self.right.key if self.right else ''}>"
class BinarySearchTree:
__slots__ = ["_root", "_size", "_policy"]
def __init__(self, policy=None):
self._root = None
self._size = 0
self._policy = policy
def find(self, key):
return self._root.find(key)
def traverse(self):
yield from self._root.traverse()
def insert(self, key, value):
node = Node(key, value, self._policy)
if self._root:
self._size += self._root.insert(node)
else:
self._root = node
self._size += 1
def __len__(self):
return self._size
class BinarySearchTreeTests(unittest.TestCase):
def test_init(self):
# sut
tree = BinarySearchTree()
# assert
self.assertEqual(len(tree), 0)
def test_get_single(self):
# arrange
key, value = 1, "hallo"
# sut
tree = BinarySearchTree()
# act
tree.insert(key, value)
# assert
self.assertEqual(len(tree), 1)
self.assertEqual(tree.find(key), value)
def test_get_many(self):
"""
5
/ \
3 6
/ \ \
2 4 7
\
8
"""
# arrange
five = 5
three = 3
six = 6
seven = 7
two = 2
eight = 8
four = 4
# sut
tree = BinarySearchTree()
# act
tree.insert(five, five)
tree.insert(three, three)
tree.insert(six, six)
tree.insert(seven, seven)
tree.insert(two, two)
tree.insert(eight, eight)
tree.insert(four, four)
# assert
self.assertEqual(len(tree), 7)
self.assertEqual(tree._root.left.key, three)
self.assertEqual(tree._root.left.left.key, two)
self.assertEqual(tree._root.left.right.key, four)
self.assertEqual(tree._root.right.key, six)
self.assertEqual(tree._root.right.right.key, seven)
self.assertEqual(tree._root.right.right.right.key, eight)
self.assertEqual(tree._root.right.right.right.value, eight)
def test_lookup(self):
# arrange
root = Node(5, "five")
root.left = Node(3, "three")
root.left.right = Node(4, "four")
root.right = Node(6, "six")
root.right.right = Node(7, "seven")
# system under test
tree = BinarySearchTree()
tree._root = root
# act + assert
self.assertEqual(tree.find(5), "five")
self.assertEqual(tree.find(3), "three")
self.assertEqual(tree.find(4), "four")
self.assertEqual(tree.find(6), "six")
self.assertEqual(tree.find(7), "seven")
def test_traverse(self):
# arrange
five = Node(5, "five")
three = Node(3, "three")
four = Node(4, "four")
six = Node(6, "six")
seven = Node(7, "seven")
root = five
root.left = three
root.left.right = four
root.right = six
root.right.right = seven
# system under test
tree = BinarySearchTree()
tree._root = root
# act
nodes = tree.traverse()
# assert
nodes_list = list(nodes)
self.assertEqual(len(nodes_list), 5)
self.assertIn(three, nodes_list)
self.assertIn(four, nodes_list)
self.assertIn(five, nodes_list)
self.assertIn(six, nodes_list)
self.assertIn(seven, nodes_list)
def test_overwrite_policy(self):
# system under test
tree = BinarySearchTree(policy=Overwrite())
# act
tree.insert(5, "hi")
tree.insert(5, "bye")
# assert
self.assertEqual(len(tree), 1)
self.assertEqual(tree.find(5), "bye")
def test_error_policy(self):
# system under test
tree = BinarySearchTree(policy=Error())
# act
tree.insert(5, "hi")
# assert
with self.assertRaises(DuplicateKeyException):
tree.insert(5, "bye")
def test_bucket_policy(self):
# system under test
tree = BinarySearchTree(policy=Buckets())
# act / assert
tree.insert(5, "hi")
self.assertEqual(tree.find(5), "hi")
tree.insert(5, "bye")
self.assertEqual(tree.find(5), ["hi", "bye"])
def test_default_policy(self):
"""
Defaults to Overwrites
"""
# system under test
tree = BinarySearchTree()
# act / assert
tree.insert(5, "hi")
self.assertEqual(tree.find(5), "hi")
tree.insert(5, "bye")
self.assertEqual(tree.find(5), "bye")
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment