Last active
March 7, 2019 10:13
-
-
Save travishen/c4cc5797f8905b2a5b90f2545c374a26 to your computer and use it in GitHub Desktop.
Binary Search Tree(BST), Balancing BST(AVL Tree) Implements: insert, search, remove
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Binary Search Tree" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Node(object):\n", | |
" def __init__(self, data):\n", | |
" self._left, self._right = None, None\n", | |
" self.data = int(data)\n", | |
" \n", | |
" def __repr__(self):\n", | |
" return 'Node({})'.format(self.data)\n", | |
" \n", | |
" @property\n", | |
" def left(self):\n", | |
" return self._left\n", | |
" \n", | |
" @left.setter\n", | |
" def left(self, node):\n", | |
" self._left = node\n", | |
" \n", | |
" @property\n", | |
" def right(self):\n", | |
" return self._right\n", | |
" \n", | |
" @right.setter\n", | |
" def right(self, node):\n", | |
" self._right = node\n", | |
" \n", | |
"class BinarySearchTree(object): \n", | |
" def __init__(self, root=None):\n", | |
" self.root = root\n", | |
" self.search_mode = 'in_order'\n", | |
" \n", | |
" \n", | |
" # O(logN) time complexity if balanced, it could reduce to O(N)\n", | |
" def insert(self, data, **kwargs): \n", | |
" \"\"\"Insert from root\"\"\"\n", | |
" BinarySearchTree.insert_node(self.root, data, **kwargs)\n", | |
" \n", | |
" # O(logN) time complexity if balanced, it could reduce to O(N)\n", | |
" def remove(self, data): \n", | |
" \"\"\"Insert from root\"\"\"\n", | |
" BinarySearchTree.remove_node(self.root, data)\n", | |
" \n", | |
" @staticmethod\n", | |
" def insert_node(node, data, **kwargs):\n", | |
" node_consturctor = kwargs.get('node_constructor', None) or Node\n", | |
" if node:\n", | |
" if data < node.data:\n", | |
" if node.left is None:\n", | |
" node.left = node_consturctor(data)\n", | |
" else:\n", | |
" BinarySearchTree.insert_node(node.left, data, **kwargs)\n", | |
" elif data > node.data:\n", | |
" if node.right is None:\n", | |
" node.right = node_consturctor(data)\n", | |
" else:\n", | |
" BinarySearchTree.insert_node(node.right, data, **kwargs)\n", | |
" else:\n", | |
" node.data = data\n", | |
" return node\n", | |
" \n", | |
" @staticmethod\n", | |
" def remove_node(node, data): \n", | |
"\n", | |
" if not node:\n", | |
" return None\n", | |
" \n", | |
" if data < node.data:\n", | |
" node.left = BinarySearchTree.remove_node(node.left, data)\n", | |
" elif data > node.data:\n", | |
" node.right = BinarySearchTree.remove_node(node.right, data)\n", | |
" else:\n", | |
" if not (node.left and node.right): # leaf\n", | |
" del node\n", | |
" return None\n", | |
" if not node.left:\n", | |
" tmp = node.right\n", | |
" del node\n", | |
" return tmp\n", | |
" if not node.right:\n", | |
" tmp = node.left\n", | |
" del node\n", | |
" return tmp\n", | |
" predeccessor = BinarySearchTree.get_max_node(node.left)\n", | |
" node.data = predeccessor.data\n", | |
" node.left = BinarySearchTree.remove_node(node.left, predeccessor.data)\n", | |
" return node\n", | |
" \n", | |
" def get_min(self):\n", | |
" return self.get_min_node(self.root)\n", | |
" \n", | |
" @staticmethod\n", | |
" def get_min_node(node):\n", | |
" if node.left:\n", | |
" return BinarySearchTree.get_max_node(node.left)\n", | |
" return node\n", | |
" \n", | |
" def get_max(self):\n", | |
" return self.get_max_node(self.root)\n", | |
" \n", | |
" @staticmethod\n", | |
" def get_max_node(node):\n", | |
" if node.right:\n", | |
" return BinarySearchTree.get_max_node(node.right)\n", | |
" return node\n", | |
" \n", | |
" def search_decorator(func):\n", | |
" def interface(*args, **kwargs):\n", | |
" res = func(*args, **kwargs)\n", | |
" if isinstance(res, Node):\n", | |
" return res\n", | |
" elif 'data' in kwargs:\n", | |
" for node in res:\n", | |
" if node.data == kwargs['data']:\n", | |
" return node \n", | |
" return res\n", | |
" return interface\n", | |
" \n", | |
" @staticmethod\n", | |
" @search_decorator\n", | |
" def in_order(root, **kwargs):\n", | |
" \"\"\"left -> root -> right\"\"\"\n", | |
" f = BinarySearchTree.in_order\n", | |
" res = []\n", | |
" if root:\n", | |
" left = f(root.left, **kwargs)\n", | |
" if isinstance(left, Node):\n", | |
" return left\n", | |
" right = f(root.right, **kwargs)\n", | |
" if isinstance(right, Node):\n", | |
" return right\n", | |
" res = left + [root] + right\n", | |
" return res\n", | |
"\n", | |
" @staticmethod\n", | |
" @search_decorator\n", | |
" def pre_order(root, **kwargs):\n", | |
" \"\"\"root -> left -> right\"\"\"\n", | |
" f = BinarySearchTree.pre_order\n", | |
" res = []\n", | |
" if root:\n", | |
" left = f(root.left, **kwargs)\n", | |
" if isinstance(left, Node):\n", | |
" return left\n", | |
" right = f(root.right, **kwargs)\n", | |
" if isinstance(right, Node):\n", | |
" return right\n", | |
" res = [root] + left + right \n", | |
" return res\n", | |
"\n", | |
" @staticmethod\n", | |
" @search_decorator\n", | |
" def post_order(root, **kwargs):\n", | |
" \"\"\"root -> right -> root\"\"\"\n", | |
" f = BinarySearchTree.post_order\n", | |
" res = []\n", | |
" if root:\n", | |
" left = f(root.left, **kwargs)\n", | |
" if isinstance(left, Node):\n", | |
" return left\n", | |
" right = f(root.right, **kwargs)\n", | |
" if isinstance(right, Node):\n", | |
" return right\n", | |
" res = left + right + [root]\n", | |
" return res\n", | |
" \n", | |
" def traversal(self, \n", | |
" order:\"in_order|post_order|post_order\"=None,\n", | |
" data=None):\n", | |
" order = order or self.search_mode\n", | |
" if order == 'in_order':\n", | |
" return BinarySearchTree.in_order(self.root, data=data)\n", | |
" elif order == 'pre_order':\n", | |
" return BinarySearchTree.pre_order(self.root, data=data)\n", | |
" elif order == 'post_order':\n", | |
" return BinarySearchTree.post_order(self.root, data=data)\n", | |
" else:\n", | |
" raise NotImplementedError()\n", | |
" \n", | |
" def search(self, data, *args, **kwargs):\n", | |
" return self.traversal(*args, data=data, **kwargs)\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Node(20)" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt = BinarySearchTree(Node(10))\n", | |
"bt.insert(20)\n", | |
"bt.insert(15)\n", | |
"bt.insert(2)\n", | |
"bt.insert(66)\n", | |
"bt.insert(24)\n", | |
"bt.insert(17)\n", | |
"bt.insert(21)\n", | |
"bt.root.right" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[Node(2), Node(10), Node(15), Node(17), Node(20), Node(21), Node(24), Node(66)]" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt.traversal('in_order')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[Node(2), Node(17), Node(15), Node(21), Node(24), Node(66), Node(20), Node(10)]" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt.traversal('post_order')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[Node(10), Node(2), Node(20), Node(15), Node(17), Node(66), Node(24), Node(21)]" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt.traversal('pre_order')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Node(17)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt.search(data=17)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Node(66)" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt.get_max()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Node(2)" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt.get_min()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[Node(2), Node(10), Node(20), Node(21), Node(24), Node(66)]" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt.remove(15)\n", | |
"bt.traversal()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# AVL Tree" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class HNode(Node): \n", | |
" def __init__(self, *args, **kwargs):\n", | |
" super(HNode, self).__init__(*args, **kwargs)\n", | |
" self._height = 0\n", | |
" \n", | |
" def __repr__(self):\n", | |
" return 'HNode({})'.format(self.data)\n", | |
" \n", | |
" @property\n", | |
" def height(self):\n", | |
" return self._height\n", | |
" \n", | |
" def set_height(self): \n", | |
" if self.left is None and self.right is None:\n", | |
" self._height = 0\n", | |
" else:\n", | |
" self._height = max(self.left_height, self.right_height) + 1\n", | |
" return self._height\n", | |
"\n", | |
"\n", | |
" @Node.left.setter\n", | |
" def left(self, node):\n", | |
" self._left = node\n", | |
" self.set_height()\n", | |
" \n", | |
" @Node.right.setter\n", | |
" def right(self, node):\n", | |
" self._right = node\n", | |
" self.set_height()\n", | |
" \n", | |
" @property\n", | |
" def sub_diff(self):\n", | |
" return self.left_height - self.right_height \n", | |
" \n", | |
" @property\n", | |
" def left_height(self):\n", | |
" if self.left:\n", | |
" return self.left.height\n", | |
" return -1\n", | |
" \n", | |
" @property\n", | |
" def right_height(self):\n", | |
" if self.right:\n", | |
" return self.right.height\n", | |
" return -1\n", | |
" \n", | |
" @property\n", | |
" def is_balance(self):\n", | |
" return abs(self.sub_diff) <= 1 \n", | |
" \n", | |
" def balance(self, data):\n", | |
" \n", | |
" if self.sub_diff > 1:\n", | |
" if data < self.left.data: # left left heavy\n", | |
" return self.rotate('right')\n", | |
" if data > self.left.data: # left right heavy\n", | |
" self.left = self.left.rotate('left')\n", | |
" return self.rotate('right')\n", | |
" \n", | |
" if self.sub_diff < -1:\n", | |
" if data > self.right.data:\n", | |
" return self.rotate('left') # right right heavy\n", | |
" if data < self.right.data: # right left heavy\n", | |
" self.right = self.right.rotate('right')\n", | |
" return self.rotate('left')\n", | |
" \n", | |
" return self\n", | |
" \n", | |
" def rotate(self, to:\"left|right\"):\n", | |
" if to == 'right':\n", | |
" tmp = self.left\n", | |
" tmp_right = tmp.right\n", | |
" # update\n", | |
" tmp.right = self\n", | |
" self.left = tmp_right \n", | |
" print('Node {} right rotate to {}!'.format(self, tmp))\n", | |
" return tmp # return new root\n", | |
" if to == 'left':\n", | |
" tmp = self.right\n", | |
" tmp_left = tmp.left\n", | |
" # update\n", | |
" tmp.left = self\n", | |
" self.right = tmp_left\n", | |
" print('Node {} left rotate to {}!'.format(self, tmp))\n", | |
" return tmp # return new root\n", | |
" raise NotImplementedError()\n", | |
" \n", | |
"class AVLTree(BinarySearchTree): \n", | |
" def __init__(self, *args, **kwargs):\n", | |
" super(AVLTree, self).__init__(*args, **kwargs)\n", | |
" \n", | |
" def insert(self, data): \n", | |
" AVLTree.insert_node(self.root, data, tree=self) # pass self as keyword argument to update self.root\n", | |
" self.update_height()\n", | |
" \n", | |
" def remove(self, data):\n", | |
" AVLTree.remove_node(self.root, data, tree=self) # pass self as keyword argument to update self.root\n", | |
" self.update_height()\n", | |
" \n", | |
" def rotate_decorator(func):\n", | |
" def interface(*args, **kwargs):\n", | |
" node = func(*args, **kwargs)\n", | |
" \n", | |
" data = args[1]\n", | |
" tree = kwargs.get('tree')\n", | |
" \n", | |
" new_root = node.balance(data)\n", | |
" \n", | |
" if node == tree.root:\n", | |
" tree.root = new_root\n", | |
" \n", | |
" return interface\n", | |
" \n", | |
" def update_height(self):\n", | |
" for n in self.traversal(order='in_order'):\n", | |
" n.set_height()\n", | |
" \n", | |
" @property\n", | |
" def is_balance(self):\n", | |
" return self.root.is_balance\n", | |
" \n", | |
" @rotate_decorator\n", | |
" def insert_node(*args, **kwargs):\n", | |
" return BinarySearchTree.insert_node(*args, node_constructor=HNode, **kwargs)\n", | |
" \n", | |
" @rotate_decorator\n", | |
" def remove_node(*args, **kwargs):\n", | |
" return BinarySearchTree.remove_node(*args, **kwargs) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Node HNode(1) left rotate to HNode(2)!\n", | |
"Node HNode(2) left rotate to HNode(3)!\n" | |
] | |
} | |
], | |
"source": [ | |
"bt = AVLTree(HNode(1))\n", | |
"bt.insert(2)\n", | |
"bt.insert(3)\n", | |
"bt.insert(4)\n", | |
"bt.insert(5)\n", | |
"bt.insert(6)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"HNode(3)" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt.root" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"2" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt.root.height" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[HNode(1), HNode(2), HNode(3), HNode(4), HNode(5), HNode(6)]" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt.traversal()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt.is_balance" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Node HNode(5) right rotate to HNode(3)!\n", | |
"Node HNode(1) left rotate to HNode(3)!\n" | |
] | |
} | |
], | |
"source": [ | |
"bt = AVLTree(HNode(1))\n", | |
"bt.insert(5)\n", | |
"bt.insert(3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"HNode(3)" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt.root" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"HNode(1)" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt.root.left" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"HNode(5)" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bt.root.right" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment