Skip to content

Instantly share code, notes, and snippets.

@tiqwab
Created October 30, 2016 08:25
Embed
What would you like to do?
# timing of visiting nodes
PRE = 1
IN = 0
POST = -1
'''
Implementation of splay tree
'''
class Node:
def __init__(self, x):
self.key = x
self.left = None
self.right = None
def rotate_right(self):
'''
rotate right based on this node and return new center
'''
node = self.left
self.left = node.right
node.right = self
return node
def rotate_left(self):
'''
rotate left base on this node and return new center
'''
node = self.right
self.right = node.left
node.left = self
return node
def traverse_pre(self):
'''
traverse nodes pre-order
'''
yield self
if self.left is not None:
for x in self.left.traverse_pre():
yield x
if self.right is not None:
for x in self.right.traverse_pre():
yield x
def traverse_in(self):
'''
traverse nodes in-order
'''
if self.left is not None:
for x in self.left.traverse_in():
yield x
yield self
if self.right is not None:
for x in self.right.traverse_in():
yield x
def traverse_through(self):
'''
yield a same node three times, pre-enter, in-enter, and post-enter
'''
yield self, PRE
if self.left is not None:
for node, timing in self.left.traverse_through():
yield node, timing
yield self, IN
if self.right is not None:
for node, timing in self.right.traverse_through():
yield node, timing
yield self, POST
class SplayTree:
def __init__(self):
self.root = None
def splay(self, root, x):
'''
splay the tree.
return node with key x if found, or node accessed last
'''
node = root
dummy = Node(None) # root of temporary tree
lnode = dummy # pointer to the most-left node of temporary tree
rnode = dummy # pointer to the most-right node of temporary tree
if node is None:
return None
while True:
if x > node.key:
if node.right is None:
break
if x > node.right.key: # zig-zig
## (1) zig
node = node.rotate_left()
## (2) zig
if node.right is not None:
rnode.right = node
node = node.right
rnode = rnode.right
rnode.right = None
elif x < node.right.key: # zig-zag
## (1) zig
rnode.right = node
node = node.right
rnode = rnode.right
rnode.right = None
## (2) zag
if node.left is not None:
lnode.left = node
node = node.left
lnode = lnode.left
lnode.left = None
else: # zig
rnode.right = node
node = node.right
rnode = rnode.right
rnode.right = None
elif x < node.key:
if node.left is None:
break
if x > node.left.key: # zig-zag
## (1) zig
lnode.left = node
node = node.left
lnode = lnode.left
lnode.left = None
## (2) zag
if node.right is not None:
rnode.right = node
node = node.right
rnode = rnode.right
rnode.right = None
elif x < node.left.key: # zig-zig
## (1) zig
node = node.rotate_right()
## (2) zig
if node.left is not None:
lnode.left = node
node = node.left
lnode = lnode.left
lnode.left = None
else: # zig
lnode.left = node
node = node.left
lnode = lnode.left
lnode.left = None
else:
# node is target
break
lnode.left = node.right
rnode.right = node.left
self.root = node
node.left = dummy.right
node.right = dummy.left
return node
def insert(self, x):
if self.root is None:
self.root = Node(x)
return self
node, isFound = self.search(x) # node should be root
if isFound:
return self
else:
new_node = Node(x)
if x < node.key:
new_node.right = node
new_node.left = node.left
node.left = None
else:
new_node.left = node
new_node.right = node.right
node.right = None
self.root = new_node
return self
def delete(self, x):
current_root, isFound = self.search(x)
if isFound:
if current_root.left is None:
self.root = current_root.right
elif current_root.right is None:
self.root = current_root.left
else:
next_root = self.splay(current_root.left, x)
next_root.right = current_root.right
self.root = next_root
return True
else:
return False
def search(self, x):
node = self.splay(self.root, x) # node should be root
if node is None:
return None, False
else:
return node, (x == node.key)
def find(self, x):
'''
alias of search
'''
return self.search(x)
def traverse_pre(self):
if self.root is not None:
for x in self.root.traverse_pre():
yield x
def traverse_in(self):
if self.root is not None:
for x in self.root.traverse_in():
yield x
def traverse_through(self):
if self.root is not None:
for node, timing in self.root.traverse_through():
yield node, timing
def count(self):
'''
count nodes in tree
'''
count = 0
for node in self.traverse_pre():
count += 1
return count
def max_level(self):
'''
return the level of tree
'''
level = maxLevel = 0
for node, timing in self.traverse_through():
if timing == PRE:
level += 1
elif timing == POST:
level -= 1
else:
if level > maxLevel:
maxLevel = level
return maxLevel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment