Created
October 30, 2016 08:25
-
-
Save tiqwab/dbad2ba205ff74f750d037194918d442 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# timing of visiting nodes | |
PRE = 1 | |
IN = 0 | |
POST = -1 | |
''' | |
Implementation of splay tree | |
''' | |
class Node: | |
def __init__(self, x): | |
self.key = x | |
self.left = None | |
self.right = None | |
def rotate_right(self): | |
''' | |
rotate right based on this node and return new center | |
''' | |
node = self.left | |
self.left = node.right | |
node.right = self | |
return node | |
def rotate_left(self): | |
''' | |
rotate left base on this node and return new center | |
''' | |
node = self.right | |
self.right = node.left | |
node.left = self | |
return node | |
def traverse_pre(self): | |
''' | |
traverse nodes pre-order | |
''' | |
yield self | |
if self.left is not None: | |
for x in self.left.traverse_pre(): | |
yield x | |
if self.right is not None: | |
for x in self.right.traverse_pre(): | |
yield x | |
def traverse_in(self): | |
''' | |
traverse nodes in-order | |
''' | |
if self.left is not None: | |
for x in self.left.traverse_in(): | |
yield x | |
yield self | |
if self.right is not None: | |
for x in self.right.traverse_in(): | |
yield x | |
def traverse_through(self): | |
''' | |
yield a same node three times, pre-enter, in-enter, and post-enter | |
''' | |
yield self, PRE | |
if self.left is not None: | |
for node, timing in self.left.traverse_through(): | |
yield node, timing | |
yield self, IN | |
if self.right is not None: | |
for node, timing in self.right.traverse_through(): | |
yield node, timing | |
yield self, POST | |
class SplayTree: | |
def __init__(self): | |
self.root = None | |
def splay(self, root, x): | |
''' | |
splay the tree. | |
return node with key x if found, or node accessed last | |
''' | |
node = root | |
dummy = Node(None) # root of temporary tree | |
lnode = dummy # pointer to the most-left node of temporary tree | |
rnode = dummy # pointer to the most-right node of temporary tree | |
if node is None: | |
return None | |
while True: | |
if x > node.key: | |
if node.right is None: | |
break | |
if x > node.right.key: # zig-zig | |
## (1) zig | |
node = node.rotate_left() | |
## (2) zig | |
if node.right is not None: | |
rnode.right = node | |
node = node.right | |
rnode = rnode.right | |
rnode.right = None | |
elif x < node.right.key: # zig-zag | |
## (1) zig | |
rnode.right = node | |
node = node.right | |
rnode = rnode.right | |
rnode.right = None | |
## (2) zag | |
if node.left is not None: | |
lnode.left = node | |
node = node.left | |
lnode = lnode.left | |
lnode.left = None | |
else: # zig | |
rnode.right = node | |
node = node.right | |
rnode = rnode.right | |
rnode.right = None | |
elif x < node.key: | |
if node.left is None: | |
break | |
if x > node.left.key: # zig-zag | |
## (1) zig | |
lnode.left = node | |
node = node.left | |
lnode = lnode.left | |
lnode.left = None | |
## (2) zag | |
if node.right is not None: | |
rnode.right = node | |
node = node.right | |
rnode = rnode.right | |
rnode.right = None | |
elif x < node.left.key: # zig-zig | |
## (1) zig | |
node = node.rotate_right() | |
## (2) zig | |
if node.left is not None: | |
lnode.left = node | |
node = node.left | |
lnode = lnode.left | |
lnode.left = None | |
else: # zig | |
lnode.left = node | |
node = node.left | |
lnode = lnode.left | |
lnode.left = None | |
else: | |
# node is target | |
break | |
lnode.left = node.right | |
rnode.right = node.left | |
self.root = node | |
node.left = dummy.right | |
node.right = dummy.left | |
return node | |
def insert(self, x): | |
if self.root is None: | |
self.root = Node(x) | |
return self | |
node, isFound = self.search(x) # node should be root | |
if isFound: | |
return self | |
else: | |
new_node = Node(x) | |
if x < node.key: | |
new_node.right = node | |
new_node.left = node.left | |
node.left = None | |
else: | |
new_node.left = node | |
new_node.right = node.right | |
node.right = None | |
self.root = new_node | |
return self | |
def delete(self, x): | |
current_root, isFound = self.search(x) | |
if isFound: | |
if current_root.left is None: | |
self.root = current_root.right | |
elif current_root.right is None: | |
self.root = current_root.left | |
else: | |
next_root = self.splay(current_root.left, x) | |
next_root.right = current_root.right | |
self.root = next_root | |
return True | |
else: | |
return False | |
def search(self, x): | |
node = self.splay(self.root, x) # node should be root | |
if node is None: | |
return None, False | |
else: | |
return node, (x == node.key) | |
def find(self, x): | |
''' | |
alias of search | |
''' | |
return self.search(x) | |
def traverse_pre(self): | |
if self.root is not None: | |
for x in self.root.traverse_pre(): | |
yield x | |
def traverse_in(self): | |
if self.root is not None: | |
for x in self.root.traverse_in(): | |
yield x | |
def traverse_through(self): | |
if self.root is not None: | |
for node, timing in self.root.traverse_through(): | |
yield node, timing | |
def count(self): | |
''' | |
count nodes in tree | |
''' | |
count = 0 | |
for node in self.traverse_pre(): | |
count += 1 | |
return count | |
def max_level(self): | |
''' | |
return the level of tree | |
''' | |
level = maxLevel = 0 | |
for node, timing in self.traverse_through(): | |
if timing == PRE: | |
level += 1 | |
elif timing == POST: | |
level -= 1 | |
else: | |
if level > maxLevel: | |
maxLevel = level | |
return maxLevel |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment