Last active
December 13, 2018 22:51
-
-
Save suspectpart/98b0b5c4fcabe367d695d901c9d5cebd to your computer and use it in GitHub Desktop.
Binary Search Tree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
from abc import ABC, abstractmethod | |
class DuplicateKeyException(Exception): | |
def __init__(self, key): | |
msg = f"Key {key} is already taken" | |
super(DuplicateKeyException, self).__init__(msg) | |
class DuplicatePolicy(ABC): | |
@abstractmethod | |
def handle(self, node, other): | |
pass | |
class Error(DuplicatePolicy): | |
def handle(self, node, other): | |
raise DuplicateKeyException(other.key) | |
class Overwrite(DuplicatePolicy): | |
def handle(self, node, other): | |
node.value = other.value | |
return False | |
class Buckets(DuplicatePolicy): | |
def handle(self, node, other): | |
if isinstance(node.value, list): | |
node.value.append(other.value) | |
else: | |
node.value = [node.value, other.value] | |
return True | |
class Plot: | |
def __init__(self, tree): | |
self._tree = tree | |
def mermaid(self): | |
plot = "graph TD;\n" | |
for node in self._tree.traverse(): | |
plot += f"{node.key}-->{node.left.key}\n" if node.left else "" | |
plot += f"{node.key}-->{node.right.key}\n" if node.right else "" | |
return plot.strip() | |
class Node: | |
__slots__ = ["key", "value", "left", "right", "_policy"] | |
def __init__(self, key, value, policy=None): | |
self.key, self.value = key, value | |
self.left = self.right = None | |
self._policy = policy or Overwrite() | |
def insert(self, node): | |
if node == self: | |
return self._policy.handle(self, node) | |
if node < self: | |
self._insert_left(node) | |
else: | |
self._insert_right(node) | |
return True | |
def traverse(self): | |
yield self | |
if self.left: | |
yield from self.left.traverse() | |
if self.right: | |
yield from self.right.traverse() | |
def find(self, key): | |
if key > self.key: | |
return self.right.find(key) if self.right else None | |
if key < self.key: | |
return self.left.find(key) if self.left else None | |
return self.value | |
def _insert_left(self, node): | |
if self.left: | |
self.left.insert(node) | |
else: | |
self.left = node | |
def _insert_right(self, node): | |
if self.right: | |
self.right.insert(node) | |
else: | |
self.right = node | |
def __lt__(self, other): | |
return self.key < other.key | |
def __gt__(self, other): | |
return self.key > other.key | |
def __eq__(self, other): | |
return self.key == other.key | |
def __repr__(self): | |
return f"<Node {self.key}:{self.value}, " \ | |
f"l={self.left.key if self.left else ''}, " \ | |
f"r={self.right.key if self.right else ''}>" | |
class BinarySearchTree: | |
__slots__ = ["_root", "_size", "_policy"] | |
def __init__(self, policy=None): | |
self._root = None | |
self._size = 0 | |
self._policy = policy | |
def find(self, key): | |
return self._root.find(key) | |
def traverse(self): | |
yield from self._root.traverse() | |
def insert(self, key, value): | |
node = Node(key, value, self._policy) | |
if self._root: | |
self._size += self._root.insert(node) | |
else: | |
self._root = node | |
self._size += 1 | |
def __len__(self): | |
return self._size | |
class BinarySearchTreeTests(unittest.TestCase): | |
def test_init(self): | |
# sut | |
tree = BinarySearchTree() | |
# assert | |
self.assertEqual(len(tree), 0) | |
def test_get_single(self): | |
# arrange | |
key, value = 1, "hallo" | |
# sut | |
tree = BinarySearchTree() | |
# act | |
tree.insert(key, value) | |
# assert | |
self.assertEqual(len(tree), 1) | |
self.assertEqual(tree.find(key), value) | |
def test_get_many(self): | |
""" | |
5 | |
/ \ | |
3 6 | |
/ \ \ | |
2 4 7 | |
\ | |
8 | |
""" | |
# arrange | |
five = 5 | |
three = 3 | |
six = 6 | |
seven = 7 | |
two = 2 | |
eight = 8 | |
four = 4 | |
# sut | |
tree = BinarySearchTree() | |
# act | |
tree.insert(five, five) | |
tree.insert(three, three) | |
tree.insert(six, six) | |
tree.insert(seven, seven) | |
tree.insert(two, two) | |
tree.insert(eight, eight) | |
tree.insert(four, four) | |
# assert | |
self.assertEqual(len(tree), 7) | |
self.assertEqual(tree._root.left.key, three) | |
self.assertEqual(tree._root.left.left.key, two) | |
self.assertEqual(tree._root.left.right.key, four) | |
self.assertEqual(tree._root.right.key, six) | |
self.assertEqual(tree._root.right.right.key, seven) | |
self.assertEqual(tree._root.right.right.right.key, eight) | |
self.assertEqual(tree._root.right.right.right.value, eight) | |
def test_lookup(self): | |
# arrange | |
root = Node(5, "five") | |
root.left = Node(3, "three") | |
root.left.right = Node(4, "four") | |
root.right = Node(6, "six") | |
root.right.right = Node(7, "seven") | |
# system under test | |
tree = BinarySearchTree() | |
tree._root = root | |
# act + assert | |
self.assertEqual(tree.find(5), "five") | |
self.assertEqual(tree.find(3), "three") | |
self.assertEqual(tree.find(4), "four") | |
self.assertEqual(tree.find(6), "six") | |
self.assertEqual(tree.find(7), "seven") | |
def test_traverse(self): | |
# arrange | |
five = Node(5, "five") | |
three = Node(3, "three") | |
four = Node(4, "four") | |
six = Node(6, "six") | |
seven = Node(7, "seven") | |
root = five | |
root.left = three | |
root.left.right = four | |
root.right = six | |
root.right.right = seven | |
# system under test | |
tree = BinarySearchTree() | |
tree._root = root | |
# act | |
nodes = tree.traverse() | |
# assert | |
nodes_list = list(nodes) | |
self.assertEqual(len(nodes_list), 5) | |
self.assertIn(three, nodes_list) | |
self.assertIn(four, nodes_list) | |
self.assertIn(five, nodes_list) | |
self.assertIn(six, nodes_list) | |
self.assertIn(seven, nodes_list) | |
def test_overwrite_policy(self): | |
# system under test | |
tree = BinarySearchTree(policy=Overwrite()) | |
# act | |
tree.insert(5, "hi") | |
tree.insert(5, "bye") | |
# assert | |
self.assertEqual(len(tree), 1) | |
self.assertEqual(tree.find(5), "bye") | |
def test_error_policy(self): | |
# system under test | |
tree = BinarySearchTree(policy=Error()) | |
# act | |
tree.insert(5, "hi") | |
# assert | |
with self.assertRaises(DuplicateKeyException): | |
tree.insert(5, "bye") | |
def test_bucket_policy(self): | |
# system under test | |
tree = BinarySearchTree(policy=Buckets()) | |
# act / assert | |
tree.insert(5, "hi") | |
self.assertEqual(tree.find(5), "hi") | |
tree.insert(5, "bye") | |
self.assertEqual(tree.find(5), ["hi", "bye"]) | |
def test_default_policy(self): | |
""" | |
Defaults to Overwrites | |
""" | |
# system under test | |
tree = BinarySearchTree() | |
# act / assert | |
tree.insert(5, "hi") | |
self.assertEqual(tree.find(5), "hi") | |
tree.insert(5, "bye") | |
self.assertEqual(tree.find(5), "bye") | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment