Skip to content

Instantly share code, notes, and snippets.

@travishen
Last active March 7, 2019 10:13
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save travishen/c4cc5797f8905b2a5b90f2545c374a26 to your computer and use it in GitHub Desktop.
Save travishen/c4cc5797f8905b2a5b90f2545c374a26 to your computer and use it in GitHub Desktop.
Binary Search Tree(BST), Balancing BST(AVL Tree) Implements: insert, search, remove
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Binary Search Tree"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"class Node(object):\n",
" def __init__(self, data):\n",
" self._left, self._right = None, None\n",
" self.data = int(data)\n",
" \n",
" def __repr__(self):\n",
" return 'Node({})'.format(self.data)\n",
" \n",
" @property\n",
" def left(self):\n",
" return self._left\n",
" \n",
" @left.setter\n",
" def left(self, node):\n",
" self._left = node\n",
" \n",
" @property\n",
" def right(self):\n",
" return self._right\n",
" \n",
" @right.setter\n",
" def right(self, node):\n",
" self._right = node\n",
" \n",
"class BinarySearchTree(object): \n",
" def __init__(self, root=None):\n",
" self.root = root\n",
" self.search_mode = 'in_order'\n",
" \n",
" \n",
" # O(logN) time complexity if balanced, it could reduce to O(N)\n",
" def insert(self, data, **kwargs): \n",
" \"\"\"Insert from root\"\"\"\n",
" BinarySearchTree.insert_node(self.root, data, **kwargs)\n",
" \n",
" # O(logN) time complexity if balanced, it could reduce to O(N)\n",
" def remove(self, data): \n",
" \"\"\"Insert from root\"\"\"\n",
" BinarySearchTree.remove_node(self.root, data)\n",
" \n",
" @staticmethod\n",
" def insert_node(node, data, **kwargs):\n",
" node_consturctor = kwargs.get('node_constructor', None) or Node\n",
" if node:\n",
" if data < node.data:\n",
" if node.left is None:\n",
" node.left = node_consturctor(data)\n",
" else:\n",
" BinarySearchTree.insert_node(node.left, data, **kwargs)\n",
" elif data > node.data:\n",
" if node.right is None:\n",
" node.right = node_consturctor(data)\n",
" else:\n",
" BinarySearchTree.insert_node(node.right, data, **kwargs)\n",
" else:\n",
" node.data = data\n",
" return node\n",
" \n",
" @staticmethod\n",
" def remove_node(node, data): \n",
"\n",
" if not node:\n",
" return None\n",
" \n",
" if data < node.data:\n",
" node.left = BinarySearchTree.remove_node(node.left, data)\n",
" elif data > node.data:\n",
" node.right = BinarySearchTree.remove_node(node.right, data)\n",
" else:\n",
" if not (node.left and node.right): # leaf\n",
" del node\n",
" return None\n",
" if not node.left:\n",
" tmp = node.right\n",
" del node\n",
" return tmp\n",
" if not node.right:\n",
" tmp = node.left\n",
" del node\n",
" return tmp\n",
" predeccessor = BinarySearchTree.get_max_node(node.left)\n",
" node.data = predeccessor.data\n",
" node.left = BinarySearchTree.remove_node(node.left, predeccessor.data)\n",
" return node\n",
" \n",
" def get_min(self):\n",
" return self.get_min_node(self.root)\n",
" \n",
" @staticmethod\n",
" def get_min_node(node):\n",
" if node.left:\n",
" return BinarySearchTree.get_max_node(node.left)\n",
" return node\n",
" \n",
" def get_max(self):\n",
" return self.get_max_node(self.root)\n",
" \n",
" @staticmethod\n",
" def get_max_node(node):\n",
" if node.right:\n",
" return BinarySearchTree.get_max_node(node.right)\n",
" return node\n",
" \n",
" def search_decorator(func):\n",
" def interface(*args, **kwargs):\n",
" res = func(*args, **kwargs)\n",
" if isinstance(res, Node):\n",
" return res\n",
" elif 'data' in kwargs:\n",
" for node in res:\n",
" if node.data == kwargs['data']:\n",
" return node \n",
" return res\n",
" return interface\n",
" \n",
" @staticmethod\n",
" @search_decorator\n",
" def in_order(root, **kwargs):\n",
" \"\"\"left -> root -> right\"\"\"\n",
" f = BinarySearchTree.in_order\n",
" res = []\n",
" if root:\n",
" left = f(root.left, **kwargs)\n",
" if isinstance(left, Node):\n",
" return left\n",
" right = f(root.right, **kwargs)\n",
" if isinstance(right, Node):\n",
" return right\n",
" res = left + [root] + right\n",
" return res\n",
"\n",
" @staticmethod\n",
" @search_decorator\n",
" def pre_order(root, **kwargs):\n",
" \"\"\"root -> left -> right\"\"\"\n",
" f = BinarySearchTree.pre_order\n",
" res = []\n",
" if root:\n",
" left = f(root.left, **kwargs)\n",
" if isinstance(left, Node):\n",
" return left\n",
" right = f(root.right, **kwargs)\n",
" if isinstance(right, Node):\n",
" return right\n",
" res = [root] + left + right \n",
" return res\n",
"\n",
" @staticmethod\n",
" @search_decorator\n",
" def post_order(root, **kwargs):\n",
" \"\"\"root -> right -> root\"\"\"\n",
" f = BinarySearchTree.post_order\n",
" res = []\n",
" if root:\n",
" left = f(root.left, **kwargs)\n",
" if isinstance(left, Node):\n",
" return left\n",
" right = f(root.right, **kwargs)\n",
" if isinstance(right, Node):\n",
" return right\n",
" res = left + right + [root]\n",
" return res\n",
" \n",
" def traversal(self, \n",
" order:\"in_order|post_order|post_order\"=None,\n",
" data=None):\n",
" order = order or self.search_mode\n",
" if order == 'in_order':\n",
" return BinarySearchTree.in_order(self.root, data=data)\n",
" elif order == 'pre_order':\n",
" return BinarySearchTree.pre_order(self.root, data=data)\n",
" elif order == 'post_order':\n",
" return BinarySearchTree.post_order(self.root, data=data)\n",
" else:\n",
" raise NotImplementedError()\n",
" \n",
" def search(self, data, *args, **kwargs):\n",
" return self.traversal(*args, data=data, **kwargs)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Node(20)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bt = BinarySearchTree(Node(10))\n",
"bt.insert(20)\n",
"bt.insert(15)\n",
"bt.insert(2)\n",
"bt.insert(66)\n",
"bt.insert(24)\n",
"bt.insert(17)\n",
"bt.insert(21)\n",
"bt.root.right"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Node(2), Node(10), Node(15), Node(17), Node(20), Node(21), Node(24), Node(66)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bt.traversal('in_order')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Node(2), Node(17), Node(15), Node(21), Node(24), Node(66), Node(20), Node(10)]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bt.traversal('post_order')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Node(10), Node(2), Node(20), Node(15), Node(17), Node(66), Node(24), Node(21)]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bt.traversal('pre_order')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Node(17)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bt.search(data=17)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Node(66)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bt.get_max()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Node(2)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bt.get_min()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Node(2), Node(10), Node(20), Node(21), Node(24), Node(66)]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bt.remove(15)\n",
"bt.traversal()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# AVL Tree"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class HNode(Node): \n",
" def __init__(self, *args, **kwargs):\n",
" super(HNode, self).__init__(*args, **kwargs)\n",
" self._height = 0\n",
" \n",
" def __repr__(self):\n",
" return 'HNode({})'.format(self.data)\n",
" \n",
" @property\n",
" def height(self):\n",
" return self._height\n",
" \n",
" def set_height(self): \n",
" if self.left is None and self.right is None:\n",
" self._height = 0\n",
" else:\n",
" self._height = max(self.left_height, self.right_height) + 1\n",
" return self._height\n",
"\n",
"\n",
" @Node.left.setter\n",
" def left(self, node):\n",
" self._left = node\n",
" self.set_height()\n",
" \n",
" @Node.right.setter\n",
" def right(self, node):\n",
" self._right = node\n",
" self.set_height()\n",
" \n",
" @property\n",
" def sub_diff(self):\n",
" return self.left_height - self.right_height \n",
" \n",
" @property\n",
" def left_height(self):\n",
" if self.left:\n",
" return self.left.height\n",
" return -1\n",
" \n",
" @property\n",
" def right_height(self):\n",
" if self.right:\n",
" return self.right.height\n",
" return -1\n",
" \n",
" @property\n",
" def is_balance(self):\n",
" return abs(self.sub_diff) <= 1 \n",
" \n",
" def balance(self, data):\n",
" \n",
" if self.sub_diff > 1:\n",
" if data < self.left.data: # left left heavy\n",
" return self.rotate('right')\n",
" if data > self.left.data: # left right heavy\n",
" self.left = self.left.rotate('left')\n",
" return self.rotate('right')\n",
" \n",
" if self.sub_diff < -1:\n",
" if data > self.right.data:\n",
" return self.rotate('left') # right right heavy\n",
" if data < self.right.data: # right left heavy\n",
" self.right = self.right.rotate('right')\n",
" return self.rotate('left')\n",
" \n",
" return self\n",
" \n",
" def rotate(self, to:\"left|right\"):\n",
" if to == 'right':\n",
" tmp = self.left\n",
" tmp_right = tmp.right\n",
" # update\n",
" tmp.right = self\n",
" self.left = tmp_right \n",
" print('Node {} right rotate to {}!'.format(self, tmp))\n",
" return tmp # return new root\n",
" if to == 'left':\n",
" tmp = self.right\n",
" tmp_left = tmp.left\n",
" # update\n",
" tmp.left = self\n",
" self.right = tmp_left\n",
" print('Node {} left rotate to {}!'.format(self, tmp))\n",
" return tmp # return new root\n",
" raise NotImplementedError()\n",
" \n",
"class AVLTree(BinarySearchTree): \n",
" def __init__(self, *args, **kwargs):\n",
" super(AVLTree, self).__init__(*args, **kwargs)\n",
" \n",
" def insert(self, data): \n",
" AVLTree.insert_node(self.root, data, tree=self) # pass self as keyword argument to update self.root\n",
" self.update_height()\n",
" \n",
" def remove(self, data):\n",
" AVLTree.remove_node(self.root, data, tree=self) # pass self as keyword argument to update self.root\n",
" self.update_height()\n",
" \n",
" def rotate_decorator(func):\n",
" def interface(*args, **kwargs):\n",
" node = func(*args, **kwargs)\n",
" \n",
" data = args[1]\n",
" tree = kwargs.get('tree')\n",
" \n",
" new_root = node.balance(data)\n",
" \n",
" if node == tree.root:\n",
" tree.root = new_root\n",
" \n",
" return interface\n",
" \n",
" def update_height(self):\n",
" for n in self.traversal(order='in_order'):\n",
" n.set_height()\n",
" \n",
" @property\n",
" def is_balance(self):\n",
" return self.root.is_balance\n",
" \n",
" @rotate_decorator\n",
" def insert_node(*args, **kwargs):\n",
" return BinarySearchTree.insert_node(*args, node_constructor=HNode, **kwargs)\n",
" \n",
" @rotate_decorator\n",
" def remove_node(*args, **kwargs):\n",
" return BinarySearchTree.remove_node(*args, **kwargs) "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Node HNode(1) left rotate to HNode(2)!\n",
"Node HNode(2) left rotate to HNode(3)!\n"
]
}
],
"source": [
"bt = AVLTree(HNode(1))\n",
"bt.insert(2)\n",
"bt.insert(3)\n",
"bt.insert(4)\n",
"bt.insert(5)\n",
"bt.insert(6)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"HNode(3)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bt.root"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bt.root.height"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[HNode(1), HNode(2), HNode(3), HNode(4), HNode(5), HNode(6)]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bt.traversal()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bt.is_balance"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Node HNode(5) right rotate to HNode(3)!\n",
"Node HNode(1) left rotate to HNode(3)!\n"
]
}
],
"source": [
"bt = AVLTree(HNode(1))\n",
"bt.insert(5)\n",
"bt.insert(3)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"HNode(3)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bt.root"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"HNode(1)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bt.root.left"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"HNode(5)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bt.root.right"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment