Skip to content

Instantly share code, notes, and snippets.

@irachex

irachex/Treap.py

Created Oct 20, 2012
Embed
What would you like to do?
Treap (tree + heap) in Python
import random
class TreapNode(object):
def __init__(self, key, data):
self.key = key
self.ran = random.random()
self.size = 1
self.cnt = 1
self.data = data
self.left = None
self.right = None
def left_rotate(self):
a = self
b = a.right
a.right = b.left
b.left = a
a = b
b = a.left
b.size = b.left_size() + b.right_size() + b.cnt
a.size = a.left_size() + a.right_size() + a.cnt
return a
def right_rotate(self):
a = self
b = a.left
a.left = b.right
b.right = a
a = b
b = a.right
b.size = b.left_size() + b.right_size() + b.cnt
a.size = a.left_size() + a.right_size() + a.cnt
return a
def left_size(self):
return 0 if self.left is None else self.left.size
def right_size(self):
return 0 if self.right is None else self.right.size
def __repr__(self):
return '<node key:%s ran:%f size:%d left:%s right:%s>' % (str(self.key), self.ran, self.size, str(self.left), str(self.right))
class Treap(object):
def __init__(self):
self.root = None
def _insert(self, node, key, data=None):
if node is None:
node = TreapNode(key, data)
return node
node.size += 1
if key < node.key:
node.left = self._insert(node.left, key, data)
if node.left.ran < node.ran:
node = node.right_rotate()
elif key >= node.key:
node.right = self._insert(node.right, key, data)
if node.right.ran < node.ran:
node = node.left_rotate()
#else:
# node.cnt += 1
return node
def insert(self, key, data=None):
self.root = self._insert(self.root, key, data)
def _find(self, node, key):
if node == None:
return None
if node.key == key:
return node
if key < node.key:
return self._find(node.left, key)
else:
return self._find(node.right, key)
def find(self, key):
return self._find(self.root, key)
def _delete(self, node, key):
if node is None:
return False
if node.key == key:
if node.left is None and node.right is None:
return None
elif node.left is None:
return node.right
elif node.right is None:
return node.left
else:
if node.left.ran < node.right.ran:
node = node.right_rotate()
node.right = self._delete(node.right, key)
else:
node = node.left_rotate()
node.left = self._delete(node.left, key)
elif key < node.key:
node.left = self._delete(node.left, key)
else:
node.right = self._delete(node.right, key)
node.size = node.left_size() + node.right_size() + node.cnt
return node
def delete(self, key):
if self.find(key) is None: return False
self.root = self._delete(self.root, key)
return True
def _find_kth(self, node, k):
if node is None: return None
if k <= node.left_size():
return self._find_kth(node.left, k)
if k > node.left_size() + node.cnt:
return self._find_kth(node.right, k - node.left_size() - node.cnt)
return node
def find_kth(self, k):
if k <=0 or k > self.size():
return None
return self._find_kth(self.root, k)
def size(self):
return 0 if self.root is None else self.root.size
def median(self):
s = self.size()
if s == 0: return 0
result = 0
if s % 2 == 1:
result = self.find_kth(s / 2 + 1).key
else:
result = (self.find_kth(s / 2).key + self.find_kth(s / 2 + 1).key) / 2.0
if result == int(result): result = int(result)
return result
def _traverse(self, node):
if node == None: return
self._traverse(node.left)
print node.key,
self._traverse(node.right)
def traverse(self):
self._traverse(self.root)
def __repr__(self):
return str(self.root)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.