Skip to content

Instantly share code, notes, and snippets.

@lhuemill
Created February 16, 2017 20:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lhuemill/ab2660081ba0689700aae3887efa70fe to your computer and use it in GitHub Desktop.
Save lhuemill/ab2660081ba0689700aae3887efa70fe to your computer and use it in GitHub Desktop.
Ordered Dictionary
"""An ordered/sequential mapping container.
This modules provides the implementation of a sequence mapping container,
where the entries are kept in an order based on the key of each entry.
The initial implementation maintains the entries through the use of a
binary-search-tree. Beyond a minimal implementation of a binary-search-tree,
this implementation also provides:
+ Semi-Balanced Tree
A red-black tree is used, so that even in the worst case
the distance to the lowest leaf is no more than two times
the highest leaf. Although the tree is typically not fully
balanced, the cost of maintaining the semi-balanced tree is
significantly lower than a fully-balanced tree.
+ Calculated Parent Pointer
To save space, the nodes that comprise the binary-search-tree
do not contain a direct reference to its parent node. Instead,
the location of a parent node is calculated on demand, usually by
traversing down from the root node.
+ Augmented Value
Able to optionally enable a mechanism where each entry has an
augmented value, which must be > 0. When enabled there is
an efficient mechanism for finding the first entry where the
sum of the augmented value of all earlier entries is less than
or equal to a specified value.
Classes:
OrderedDict
Basic Mapping Container Functions:
odict = OrderedDict() # Create a new empty ordered dictionary.
odict[key] = value # Add or update value of entry specified by key.
del odict[key] # Remove entry specified by key.
val = odict[key] # Obtain value of entry for specified key.
odict.keys() # Iterator over keys.
odict.values() # Iterator over each entries value.
odict.items() # Iterator over each entries (key, value) tuple.
Sequential Access Functions:
first() # Return (key, value) tuple of entry with
# lowest key.
last() # Return (key, value) tuple of entry for
# highest key.
next(key) # Return (key, value) tuple of next entry
# after entry for specified key.
prev(key) # Return (key, value) tuple of previous entry
# before entry for specified key.
Augmented Value:
By default an OrderedDict is created with support for entry augmented
values disabled. Once enabled, by calling aug_enable(), each entry
supports an augmented value, which must be > 0.0. By default each
entry has an augmented value of 1. A non-default augmented value
is provided by having the value portion of an entry extended with
a method named augmented_val(). A default value of 1 is used when
the value portion of an entry does not contain a function named
augmented_value(). Note that when all entries use the default
augmented value of 1, aug_find(n) finds the entry at rank n.
Augmented values of both integral and float are supported. An
ordered dictionary may even have a mixture of entries where some
entries have an integral augmented value and others have an augmented
value of type float. Note: when any of the augmented values are of
type float, it is the responsibility of the caller on insertion
of each new entry or when aug_enable() is called that the sum of the
augmented values in any sequential subset of entries is < the sum of
all those values plus the smallest augmented value of any entry. In
genral it is best for all the entries to use either the default augmented
value or an augmented value of type int. Use of augmented values of type
float is discouraged, in that it requires the caller to use a range of
augmented values that assure the uniqueness of augmented ranges, even
when rounding errors are possibly accumalated due to tree rebalance
operations.
Augmented Value Functions:
aug_enable() # Enable support for augmented values.
aug_disable() # Disable support for augmented values.
aug_enabled() # Query as to whether support for augmented
# values is currently enabled.
aug_sum_before(key) # Returns the sum of all the augmented values
# before the entry specified by key.
aug_find(amount) # Returns the (key, value) tuple of the first
# entry where the augmented sum of all entries
# before that entry is >= amount. A ValueError
# is raised if amount is < 0 or >= the sum of
# the augmented values in all the entries.
"""
from __future__ import print_function
from enum import Enum
# Marker for a temporary diagnostic statement
def diag(): pass
class OrderedDict(object):
"""
OrderedDict() -> new empty Ordered Dictionary.
"""
def __init__(self):
"""OrderedDict Constructor"""
self._root = None
self._aug_enabled = False
def __str__(self):
"""d.__str__() <==> str(d)"""
rv = 'OrderedDict:\n'
rv += ' _augment_enable: %s\n' % self._aug_enabled
if self._root:
rv += (' _root:\n%s' % self._root.dump_subtree(self._root, 4))
else:
rv += ' _root: %s' % self._root
return(rv)
def __getitem__(self, key):
"""d.__getitem__(key) <==> d[key]
Finds and returns the value of the entry specified by key.
Raises:
KeyError - Entry for the specified key does not exist.
"""
node = self._root.find(key) if self._root else None
if not node:
raise KeyError('key of %s not found' % key)
return node.val
def __setitem__(self, key, val):
"""d.__setitem__(key, val) <==> d[key]=val
Creates a new entry with the given key and value, when an
entry with the given key doesn't already exist. When an
entry of the given key already exists, the value portion of
that entry is updated with the value given by val.
Raises:
ValueError - Augmented value support is enabled and
the new value has an invalid augmented value
(e.g. augmented value <= 0).
"""
red = OrderedDict.__Node.NodeColor.red
black = OrderedDict.__Node.NodeColor.black
# Does a node already exists for the given key?
if self._root:
node = self._root.find(key)
if node:
if self._aug_enabled:
# orig_aug_amt = node.aug_amt
node.aug_amt = OrderedDict._val_aug_amt(node.val)
node.aug_total_adjust(self._root, -node.aug_amt)
node.val = val
if self._aug_enabled:
node.aug_amt = OrderedDict._val_aug_amt(node.val)
node.aug_total_adjust(self._root,
OrderedDict._val_aug_amt(node.val))
node.aug_subtotal = node.aug_amt
node.aug_subtotal += (node.left.aug_subtotal
if node.left else 0)
node.aug_subtotal += (node.right.aug_subtotal
if node.right else 0)
return
# Create new node.
node = OrderedDict.__Node(key, val)
if self._aug_enabled:
node.aug_amt = OrderedDict._val_aug_amt(node.val)
node.aug_subtotal = node.aug_amt
# Add node to the leaf location specified by its key.
if not self._root:
node.color = black
self._root = node
return
else:
parent = self._root
while True:
assert parent.key != key
if key < parent.key:
if not parent.left:
break
else:
parent = parent.left
else:
if not parent.right:
break
else:
parent = parent.right
if key < parent.key:
parent.left = node
else:
parent.right = node
if self._aug_enabled:
node.aug_total_adjust(self._root, node.aug_amt)
node.color = red
# As needed re-balance tree.
while parent and parent.color == red:
grandparent = node.grandparent(self._root)
# A grandparent should exist, because there is a
# red parent and at a minimum there should be a
# black root node above the parent.
assert grandparent, ('Unexpected no grandparent')
if parent is grandparent.left:
uncle = grandparent.right
if uncle and uncle.color == red:
# Case 1: Both parent and uncle are red.
assert parent.color == red
parent.color = black
uncle.color = black
grandparent.color = red
node = grandparent
parent = node.parent(self._root)
else:
if node == parent.right:
node = parent
y = node.left_rotate()
parent = y
grandparent.left = y
parent.color = black
grandparent.color = red
tmp = grandparent.parent(self._root)
y = grandparent.right_rotate()
if tmp:
if grandparent.key < tmp.key:
tmp.left = y
else:
tmp.right = y
else:
self._root = y
else:
uncle = grandparent.left
if uncle and uncle.color == red:
# Case 1: Both parent and uncle are red.
assert parent.color == red
parent.color = black
uncle.color = black
grandparent.color = red
node = grandparent
parent = node.parent(self._root)
else:
if node == parent.left:
node = parent
y = node.right_rotate()
parent = y
grandparent.right = y
parent.color = black
grandparent.color = red
tmp = grandparent.parent(self._root)
y = grandparent.left_rotate()
if tmp:
if grandparent.key < tmp.key:
tmp.left = y
else:
tmp.right = y
else:
self._root = y
parent = node.parent(self._root)
if not parent or not parent.parent(self._root):
break
self._root.color = black
return
def __delitem__(self, key):
"""d.__delitem__(y) <==> del d[y]
Deletes the entry with the specified key.
Raises:
KeyError - Entry for the specified key does not exist.
"""
red = OrderedDict.__Node.NodeColor.red
black = OrderedDict.__Node.NodeColor.black
node = self._root.find(key)
if not node:
raise KeyError('key of %s not found' % key)
if not node.left or not node.right:
y = node
else:
y = node.prev(self._root)
x = y.left if y.left else y.right
y_parent = y.parent(self._root)
x_parent = y_parent
if self._aug_enabled:
node.aug_total_adjust(self._root, -node.aug_amt)
if not y_parent:
self._root = x
else:
if y is y_parent.left:
y_parent.left = x
else:
y_parent.right = x
if self._aug_enabled:
y_parent.aug_subtotal = y_parent.aug_amt
y_parent.aug_subtotal += (y_parent.left.aug_subtotal
if y_parent.left else 0)
y_parent.aug_subtotal += (y_parent.right.aug_subtotal
if y_parent.right else 0)
if y is not node:
node.key = y.key
node.val = y.val
if self._aug_enabled:
# y_parent.aug_total_adjust(self._root, -y.aug_amt)
if y_parent is not node:
parents = y_parent.parents(self._root)
while parents:
entry = parents.pop()
if entry is node:
break
entry.aug_subtotal -= y.aug_amt
node.aug_amt = y.aug_amt
node.aug_subtotal = node.aug_amt
node.aug_subtotal += node.left.aug_subtotal if node.left else 0
node.aug_subtotal += (node.right.aug_subtotal
if node.right else 0)
y_parent.aug_subtotal = y_parent.aug_amt
y_parent.aug_subtotal += (y_parent.left.aug_subtotal
if y_parent.left else 0)
y_parent.aug_subtotal += (y_parent.right.aug_subtotal
if y_parent.right else 0)
# If needed rebalance tree
if y.color == black:
while ((x is not self._root) and (x == None or x.color == black)):
if x is x_parent.left:
w = x_parent.right
if w.color == red:
w.color = black
x_parent.color = red
tmp1 = x_parent.parent(self._root)
tmp2 = x_parent.left_rotate()
if tmp1:
if tmp1.key > tmp2.key:
tmp1.left = tmp2
else:
tmp1.right = tmp2
else:
self._root = tmp2
w = x_parent.right
if ((w.right == None or w.right.color == black)
and (w.left == None or w.left.color == black)):
w.color = red
x = x_parent
x_parent = x.parent(self._root)
else:
if w.right == None or w.right.color == black:
w.left.color = black
w.color = red
tmp1 = w.parent(self._root)
tmp2 = w.right_rotate()
if tmp1:
if tmp1.key > tmp2.key:
tmp1.left = tmp2
else:
tmp1.right = tmp2
else:
self._root = tmp2
w = x_parent.right
w.color = x_parent.color
x_parent.color = black
w.right.color = black
tmp1 = x_parent.parent(self._root)
tmp2 = x_parent.left_rotate()
if tmp1:
if tmp1.key > tmp2.key:
tmp1.left = tmp2
else:
tmp1.right = tmp2
else:
self._root = tmp2
x = self._root
x_parent = None
else:
w = x_parent.left
if w.color == red:
w.color = black
x_parent.color = red
tmp1 = x_parent.parent(self._root)
tmp2 = x_parent.right_rotate()
if tmp1:
if tmp1.key > tmp2.key:
tmp1.left = tmp2
else:
tmp1.right = tmp2
else:
self._root = tmp2
w = x_parent.left
if ((w.right == None or w.right.color == black)
and (w.left == None or w.left.color == black)):
w.color = red
x = x_parent
x_parent = x.parent(self._root)
else:
if w.left == None or w.left.color == black:
w.right.color = black
w.color = red
tmp1 = w.parent(self._root)
tmp2 = w.left_rotate()
if tmp1:
if tmp1.key > tmp2.key:
tmp1.left = tmp2
else:
tmp1.right = tmp2
else:
self._root = tmp2
w = x_parent.left
w.color = x_parent.color
x_parent.color = black
w.left.color = black
tmp1 = x_parent.parent(self._root)
tmp2 = x_parent.right_rotate()
if tmp1:
if tmp1.key > tmp2.key:
tmp1.left = tmp2
else:
tmp1.right = tmp2
else:
self._root = tmp2
x = self._root
x_parent = None
if x:
x.color = black
return
def keys(self):
"""d.keys() -> list of d\'s keys
Keys are provided in the order determined by the <
operator of the entries keys.
"""
if self._root == None:
return
current = self._root.leftmost()
yield current.key
current = current.next(self._root)
while current != None:
yield current.key
current = current.next(self._root)
return
def values(self):
"""d.values() -> list of d\'s values
Values are provided from entries in an an order that is
determined by the < operator on the key portion of each entry.
"""
if self._root == None:
return
current = self._root.leftmost()
yield current.val
current = current.next(self._root)
while current != None:
yield current.val
current = current.next(self._root)
return
def items(self):
"""d.items() -> list of d\'s (key, value) tuples
The (key, value) tuples are returned in an order that is determined
by the < operator on the key portion of the entries.
"""
if self._root == None:
return
current = self._root.leftmost()
yield (current.key, current.val)
current = current.next(self._root)
while current != None:
yield (current.key, current.val)
current = current.next(self._root)
return
def find(self, key):
"""Find entry with specified key. Returns corresponding
(key, value) tuple of entry specified by key.
Raises:
KeyError - Entry for the specified key does not exist.
"""
node = self._root.find(key)
if not node:
raise KeyError('key of %s not found' % key)
return (node.key, node.val)
def first(self):
"""Returns (key, val) tuple of first entry (entry with lowest key).
Raises:
RuntimeError - First entry does not exist, because the
dictionary is empty.
"""
if not self._root:
raise RuntimeError('Empty Dictionary')
node = self._root
while node.left:
node = node.left
return (node.key, node.val)
def last(self):
"""Returns (key, val) tuple of last entry (entry with highest key).
Raises:
RuntimeError - Last entry does not exist, because the
dictionary is empty.
"""
if not self._root:
raise RuntimeError('Empty Dictionary')
node = self._root
while node.right:
node = node.right
return (node.key, node.val)
def next(self, key):
"""
Returns (key, val) tuple of next entry after the entry specified
by key.
Raises:
KeyError - An entry for the specified key does not exist.
RuntimeError - Next entry does not exist, because the entry
specified by key is currently the last entry.
"""
node = self._root.find(key) if self._root else None
if not node:
raise KeyError('key of %s not found' % key)
node = node.next(self._root)
if node:
return (node.key, node.val)
else:
raise RuntimeError('Key specifies last entry')
def prev(self, key):
"""
Returns (key, val) tuple of entry immediately before the
entry specified by key.
Raises:
KeyError - An entry for the specified key does not exist.
RuntimeError - Previous entry does not exist, because the entry
specified by key is currently the first entry.
"""
node = self._root.find(key) if self._root else None
if not node:
raise KeyError('key of %s not found' % key)
node = node.prev(self._root)
if node:
return (node.key, node.val)
else:
raise RuntimeError('Key specifies first entry')
def aug_enable(self):
"""Enables support for augmented values. On enable the augmented
value is obtained and stored from the value portion of each of the
existing entries.
Raises:
RuntimeError - if already enabled.
ValueError - if any of the existing entries produce an
invalid augmented value (e.g. augmented value
that is <= 0).
Time Complexity:
O(n) - where n is the current number of entries.
"""
if self._aug_enabled:
raise RuntimeError('Augmented value already enabled.')
if self._root:
self._root.aug_recalculate()
self._aug_enabled = True
def aug_disable(self):
"""Disables support for augmented values. On disable the augmented
value stored in each entry, plus any supporting values, are deleted.
Raises:
RuntimeError - Support for augmented values is already disabled.
Time Complexity:
O(n) - where n is the current number of entries.
"""
if not self._aug_enabled:
raise RuntimeError('Augmented value already disabled.')
if self._root:
self._root.aug_info_delete()
self._aug_enabled = False
def aug_enabled(self):
"""Returns True if support for augmented values is currently enabled,
else returns false.
"""
return self._aug_enabled
def aug_sum_before(self, key):
"""Returns sum of augmented values of all entries before the
entry specified by key.
Raises:
KeyError - Entry for the specified key does not exist.
RuntimeError - Support for augmented values is currently disabled.
"""
if not self._aug_enabled:
raise RuntimeError('Augmented value not currently enabled')
node = self._root.find(key) if self._root else None
if not node:
raise KeyError('key of %s not found' % key)
parents = node.parents(self._root)
r = node.left.aug_subtotal if node.left else 0
y = node
while y:
parent = parents.pop() if parents else None
if parent and y == parent.right:
r += (parent.left.aug_subtotal + parent.aug_amt
if parent.left else parent.aug_amt)
y = parent
return r
def aug_find(self, amt):
"""Returns the (key, value) tuple of the first entry where
the sum of the augmented values in all the entries before
that entry is >= amt.
Raises:
RuntimeError - Support for augmented values is currently disabled.
ValueError - amt is <= 0
ValueError - amt is >= sum of augmented values of all the entries.
"""
if not self._aug_enabled:
raise RuntimeError('Augmented value not currently enabled')
if amt < 0.0:
raise ValueError('Invalid aug_find amount of: %s' % amt)
node = self._root
if not node:
raise ValueError('All aug_find amount are invalid for an empty '
'dictionary')
val = self._root.aug_find(amt)
if val == (None, None):
raise ValueError('Amount of %s >= sum of augmented values '
'of all entries.' % amount)
return val
def diag_validate(self, diag_str=''):
"""Validates the overall state of the ordered dictionary.
Although an error should never be found, when an error is
detected an assertion failure is produced with a string
describing the error, plus the contents of diag_str. Callers
typically either use the default empty diag_str or provide
a diag_str whoes contents may be useful in determining the
sequence that got the ordered dictionary into an invalid state.
"""
# Nothing to validate when there are no nodes.
if not self._root:
return
# Use passed in diagnostic string or create one from scratch.
# This string is only used by error messages.
if len(diag_str):
diag_str = '\n' + diag_str.rstrip() + '\n'
else:
diag_str += 'actual:\n'
diag_str += OrderedDict._str_indent(str(self), 2)
# Validate root node is colored black.
assert self._root.color == OrderedDict.__Node.NodeColor.black, (
'Unexpected root node color\n'
' root_node_color: %s\n'
' expected: %s\n'
' %s'
% (self._root.color, OrderedDict.__Node.NodeColor.black,
diag_str))
# Recursively validate each of the nodes.
self.__node_validate(self._root, diag_str)
# Validate black height is consistent from the root node.
# If consistent from root node, then it is also consistent
# in all subtrees.
black_height = self._root.black_height()
assert black_height != None, (
'Inconsistent black height\n'
' %s' % diag_str)
# -------- Implementation Private --------
@staticmethod
def _val_aug_amt(val):
if hasattr(val, 'augmented_val'):
return val.augmented_val()
else:
return 1
@staticmethod
def _str_indent(s, indent):
return('%s%s' % (' ' * indent, s.replace('\n', '\n' + ' ' * indent)))
class __Node:
class NodeColor(Enum):
red = 0
black = 1
def __init__(self, key, val):
self.key = key
self.val = val
self.left = None
self.right = None
self.color = None
def __str__(self):
rv = 'id: %#x\n' % id(self)
rv += 'key: %s val: %s\n' % (
(self.key.__str__(), self.val.__str__()))
rv += 'color: %s\n' % self.color
if hasattr(self, 'aug_amt'):
rv += 'aug_amt: %s\n' % self.aug_amt
if hasattr(self, 'aug_subtotal'):
rv += 'aug_subtotal: %s\n' % self.aug_subtotal
rv += 'left: %s\n' % (hex(id(self.left))
if self.left else self.left)
rv += 'right: %s' % (hex(id(self.right))
if self.right else self.right)
return rv
# @staticmethod
# def val_aug_amt(val):
# if hasattr(val, 'augmented_val'):
# return val.augmented_val()
# else:
# return 1
def parent(self, root):
if self is root:
return None
parent = root
while True:
if self.key < parent.key:
assert parent.left != None
if parent.left.key == self.key:
return parent
parent = parent.left
else:
assert parent.right != None
if (parent.right.key == self.key):
return parent
parent = parent.right
def parents(self, root):
assert root != None
rv = []
parent = root
while parent.key != self.key:
if (self.key < parent.key):
assert parent.left != None
rv.append(parent)
parent = parent.left
else:
assert parent.right != None
rv.append(parent)
parent = parent.right
return rv
def grandparent(self, root):
parent = self.parent(root)
if parent == None:
return None
return parent.parent(root)
def leftmost(self):
node = self
while (node.left):
node = node.left
return node
def last(self):
node = self
while (node.right):
node = node.right
return node
def next(self, root):
node = self
if node.right:
node = node.right
if node.left:
node = node.leftmost()
else:
parents = node.parents(root)
while parents:
parent = parents.pop()
if parent.left is node:
node = parent
break
node = parent
else:
node = None
return node
def prev(self, root):
node = self
if node.left:
node = node.left
if node.right:
node = node.last()
else:
parents = node.parents(root)
while parents:
parent = parents.pop()
if parent.right is node:
node = parent
break
node = parent
else:
node = None
return node
def find(self, key):
node = self
while node:
if (node.key == key):
return node
if (key < node.key):
node = node.left
else:
node = node.right
return None
def aug_sum_before(self, root):
total = (self.aug_subtotal
- (self.right.aug_subtotal if self.right else 0))
parents = self.parents(root)
prev_parent = self
while parents:
parent = parents.pop()
if parent.right is prev_parent:
total += parent.aug_subtotal - prev_parent.aug_subtotal
prev_parent = parent
return total
def aug_recalculate(self):
self.aug_amt = OrderedDict._val_aug_amt(self.val)
if self.left:
self.left.aug_recalculate()
if self.right:
self.right.aug_recalculate()
self.aug_subtotal = self.aug_amt
self.aug_subtotal += self.left.aug_subtotal if self.left else 0
self.aug_subtotal += self.right.aug_subtotal if self.right else 0
def aug_delete_info(self):
if self.left:
self.left.aug_delete_info()
if self.right:
self.right.aug_delete_info()
# +----+ +---+
# |self| | y |
# +----+ +---+
# / \ / \
# / \ ----> / \
# a +---+ +----+ c
# | y | |self|
# +---+ +----+
# / \ / \
# / \ / \
# b c a b
#
# Returns: y
#
def left_rotate(self):
y = self.right
self.right = y.left
y.left = self
if hasattr(y, 'aug_amt'):
self.aug_subtotal = self.aug_amt
self.aug_subtotal += (self.left.aug_subtotal
if self.left else 0)
self.aug_subtotal += (self.right.aug_subtotal
if self.right else 0)
y.aug_subtotal = y.aug_amt
y.aug_subtotal += y.left.aug_subtotal if y.left else 0
y.aug_subtotal += y.right.aug_subtotal if y.right else 0
return y
# +----+ +---+
# |self| | y |
# +----+ +---+
# / \ / \
# / \ ----> / \
# +---+ c a +----+
# | y | |self|
# +---+ +----+
# / \ / \
# / \ / \
# a b b c
#
# Returns: y
#
def right_rotate(self):
y = self.left
self.left = y.right
y.right = self
if hasattr(self, 'aug_amt'):
self.aug_subtotal = self.aug_amt
self.aug_subtotal += (self.left.aug_subtotal
if self.left else 0)
self.aug_subtotal += (self.right.aug_subtotal
if self.right else 0)
y.aug_subtotal = y.aug_amt
y.aug_subtotal += y.left.aug_subtotal if y.left else 0
y.aug_subtotal += y.right.aug_subtotal if y.right else 0
return y
def aug_total_adjust(self, root, amt):
node = self
initial_aug_total = node.aug_sum_before(root)
prev = node.prev(root)
prev_node_aug_total = prev.aug_sum_before(root) if prev else 0
assert -amt <= (initial_aug_total - prev_node_aug_total), (
'amt: %s prev_node_aug_total: %s initial_aug_total: %s'
% (amt, prev_node_aug_total, initial_aug_total))
parents = self.parents(root)
while parents:
parent = parents.pop()
parent.aug_subtotal += amt
node = parent
def aug_find(self, amt):
r = self.left.aug_subtotal if self.left else 0
if (amt >= r) and (amt < r + self.aug_amt):
return (self.key, self.val)
elif amt < r:
return self.left.aug_find(amt)
else:
return (self.right.aug_find(amt - (r + self.aug_amt))
if self.right else (None, None))
def black_height(self):
red = self.NodeColor.red
black = self.NodeColor.black
# If both children are leaves return 1, becauses leaves are
# always considered black, plus 1 more if node is also black.
if self.left == None and self.right == None:
return 1 + (1 if self.color == black else 0)
black_height_left = (self.left.black_height()
if self.left else 1)
black_height_right = (self.right.black_height()
if self.right else 1)
# Return None if left and right black heights are not
# the same, as an indication of inconsistent black
# height from this node.
if black_height_left != black_height_right:
return None
# Return black height of either child plus 1 more if this
# node is also black.
return black_height_left + (1 if self.color == black else 0)
def dump_subtree(self, root, indent):
node = self
prefix = ' ' * indent
rv = '%s%s\n' % (prefix, node.__str__().replace('\n', '\n'
+ prefix))
parent = node.parent(root)
if parent:
rv += '%sparent: %#x\n' % (prefix, id(parent))
else:
rv += '%sparent: %s\n' % (prefix, parent)
if node.left:
rv += '%sleft:\n%s\n' % (prefix,
node.left.dump_subtree(root, indent + 2))
else:
rv += '%sleft: %s\n' % (prefix, node.left)
if node.right:
rv += '%sright:\n%s\n' % (prefix,
node.right.dump_subtree(root, indent + 2))
else:
rv += '%sright: %s\n' % (prefix, node.right)
rv = rv.rstrip()
return rv
def __node_validate(self, node, diag_str):
red = OrderedDict.__Node.NodeColor.red
black = OrderedDict.__Node.NodeColor.black
if node.color == red:
if node.left:
assert node.left.color == black, (
'node.key: %s node.color: %s\n'
'node.left.key: %s node.left.color: %s\n'
'%s'
% (node.key, node.color, node.left.key,
node.left.color, diag_str))
if node.right:
assert node.right.color == black, (
'node.key: %s node.color: %s\n'
'node.right.key: %s node.right.color: %s\n'
'%s'
% (node.key, node.color, node.right.key,
node.right.color, diag_str))
if node.left:
assert node.left.key < node.key, ('node.key: %s\n%s'
% (node.key, diag_str))
self.__node_validate(node.left, diag_str)
if node.right:
assert node.key < node.right.key, ('node.key: %s\n%s'
% (node.key, diag_str))
self.__node_validate(node.right, diag_str)
# -------- Tests --------
if __name__ == '__main__':
import itertools
import math
import numbers
import numpy
import random
import string
import sys
import time
import unittest
from collections import namedtuple
# Make assertRaisesRegex available in Python versions 2.7 through 3.1
# by renaming assertRaisesRegex to assertRaisesRegexp.
assert bool(sys.version_info.major > 2
or ((sys.version_info.major == 2)
and (sys.version_info.minor >= 7))), (
'Need at least Python version 2.7, which introduced\n'
'unittest.TestCase.assertRaisesRegexp\n'
'sys.version_info.major: %s sys.version_info.minor: %s'
% (sys.version_info.major, sys.version_info.minor))
if ((sys.version_info.major < 3)
or (sys.version_info.major == 3 and sys.version_info.minor <= 1)):
unittest.TestCase.assertRaisesRegex = (
unittest.TestCase.assertRaisesRegexp)
class TestSupport(object):
# Type that describes a single expected value.
ExpectedEntry = namedtuple('EpectedEntry', 'key val aug_amt')
# Extended integral and float types, with an augmented value
# determined from the subclass value. These types are mostly
# used in various tests to create a value with a non-default
# augmented amount. For example, AugIntMod100(1234) creates
# an integral value of 1234 and an augmented amount of 34.
class AugIntDiv10(int):
def augmented_val(self):
return int(self / 10 if self / 10 != 0 else 3)
class AugIntMod100(int):
def augmented_val(self):
if self >= 0:
return int(self % 100 if self % 100 != 0 else 12)
else:
return int(-self % 100 if -self % 100 != 0 else 12)
class AugFloatDiv100(float):
def augmented_val(self):
return float(self / 100 if self / 100 > 0.0 else 12.34)
@staticmethod
def expected_key_aug_start(key, expected):
expected_start = 0
for entry in expected:
if entry.key < key:
expected_start += entry.aug_amt
return expected_start
@staticmethod
def expected_aug_total(expected):
aug_total = 0
for entry in expected:
aug_total += entry.aug_amt
return aug_total
@staticmethod
def check_expected(actual, expected, diag_str=''):
# Only validate augmented values if actual currently
# has augmented values enabled.
validate_augmented = True if actual.aug_enabled() else False
# Produce a diagnostic string to be printed in cases
# where an unexpected condition is detected.
if len(diag_str):
diag_str = '\n' + diag_str.rstrip() + '\n'
diag_str += 'expected:\n'
for entry in expected:
diag_str += (' key: %s val: %s aug_amt: %s aug_start: %s\n'
% (entry.key, entry.val, entry.aug_amt,
TestSupport.expected_key_aug_start(
entry.key, expected)))
if validate_augmented:
aug_total = TestSupport.expected_aug_total(expected)
diag_str += 'aug_total: %s\n' % aug_total
diag_str += 'actual:\n'
diag_str += OrderedDict._str_indent(actual.__str__(), 2)
# Validate internal state is valid.
actual.diag_validate(diag_str)
# Check that actual correctly contains each of the
# expected entries.
for entry in expected:
key, val = actual.find(entry.key)
assert key == entry.key, ('Unexpected key of: %s\n'
'expected key of: %s\n%s' % (key, entry.key, diag_str))
assert val == entry.val, ('Unexpected val of: %s for key: %s\n'
'expected val of: %s\n%s'
% (val, entry.key, entry.val, diag_str))
if validate_augmented:
aug_start = actual.aug_sum_before(entry.key)
expected_aug_start = TestSupport.expected_key_aug_start(
entry.key, expected)
if isinstance(expected_aug_start, numbers.Integral):
assert aug_start == expected_aug_start, (
'Unexpected aug_start of: %s for key: %s\n'
'expected aug_start of: %s\n%s'
% (aug_start, entry.key, expected_aug_start,
diag_str))
else:
assert numpy.isclose(aug_start, expected_aug_start), (
'Unexpected aug_start of: %s for key: %s\n'
'expected aug_start of: %s\n%s'
% (aug_start, entry.key, expected_aug_start,
diag_str))
# Validate that expected key is at each aug value
expected_sorted_keys = sorted([entry.key for entry in expected])
if validate_augmented:
if isinstance(aug_total, numbers.Integral):
# All augmented values are integral
expected_aug_start = 0
for expected_key in expected_sorted_keys:
expected_entry = [val for val in expected if
val.key == expected_key][0]
# Validate key obtained at start of range.
aug_try = expected_aug_start
key = actual.aug_find(aug_try)[0]
assert key == expected_key, ('Unexpected key of: %s '
'at aug_try: %s\n'
'expected key of: %s\n%s'
% (key, aug_try, expected_key, diag_str))
# Validate key obtained at aug_start + 1.
if expected_entry.aug_amt >= 2:
aug_try = expected_aug_start + 1
key = actual.aug_find(aug_try)[0]
assert key == expected_key, ('Unexpected key of: '
'%s at aug_try: %s\n'
'expected key of: %s\n%s'
% (key, aug_try, expected_key, diag_str))
# Validate key obtained at aug_end.
if expected_entry.aug_amt >= 3:
aug_try = (expected_aug_start
+ expected_entry.aug_amt - 1)
key = actual.aug_find(aug_try)[0]
assert key == expected_key, ('Unexpected key of: '
'%s at aug_try: %s\n'
'expected key of: %s\n%s'
% (key, aug_try, expected_key, diag_str))
expected_aug_start += expected_entry.aug_amt
else:
# One or more augmented values are of type float
previous_key = None
expected_aug_start = 0.0
for expected_key in expected_sorted_keys:
expected_entry = [val for val in expected if
val.key == expected_key][0]
# Validate key obtained at start of range.
# Due to floating point rounding, key from previous
# entry is also valid.
aug_try = expected_aug_start
key = actual.aug_find(aug_try)[0]
expected_keys = [expected_key]
if previous_key != None:
expected_keys.append(previous_key)
assert key in expected_keys, (
'Unexpected key of: %s '
'at aug_try: %s\n'
'expected key of: %s\n%s'
% (key, aug_try, expected_keys, diag_str))
# Validate key obtained at 10% of range.
aug_try = (expected_aug_start
+ expected_entry.aug_amt * 0.1)
key = actual.aug_find(aug_try)[0]
assert key == expected_key, (
'Unexpected key of: %s '
'at aug_try: %s\n'
'expected key of: %s\n%s'
% (key, aug_try, expected_key, diag_str))
previous_key = expected_key
expected_aug_start += float(expected_entry.aug_amt)
# Validate results of first, last, next, and prev.
expected_sorted_keys = sorted([entry.key for entry in expected])
expected_key = (expected_sorted_keys[0]
if len(expected_sorted_keys) > 0 else None)
try:
(key, val) = actual.first()
except RuntimeError:
key = None
assert key == expected_key, ('Unexpected first key\n'
'key: %s expected_key: %s\n%s'
% (key, expected_key, diag_str))
expected_key = (expected_sorted_keys[-1:][0]
if len(expected_sorted_keys) > 0 else None)
try:
(key, val) = actual.last()
except RuntimeError:
key = None
assert key == expected_key, ('Unexpected last key\n'
'key: %s expected_key: %s\n%s'
% (key, expected_key, diag_str))
for i in range(len(expected_sorted_keys)):
expected_next = (expected_sorted_keys[i + 1] if
i < (len(expected_sorted_keys) - 1) else None)
try:
key_next = actual.next(expected_sorted_keys[i])[0]
except RuntimeError:
key_next = None
assert key_next == expected_next, ('Unexpected next key\n'
'i: %s expected_sorted_keys[i]: %s\n'
'key_next: %s\n'
'expected_next: %s\n'
'%s'
% (i, expected_sorted_keys[i], key_next, expected_next,
diag_str))
expected_prev = (expected_sorted_keys[i - 1] if
i > 0 else None)
try:
key_prev = actual.prev(expected_sorted_keys[i])[0]
except RuntimeError:
key_prev = None
assert key_prev == expected_prev, ('Unexpected prev key\n'
'i: %s expected_sorted_keys[i]: %s\n'
'key_prev: %s\n'
'expected_prev: %s\n'
'%s'
% (i, expected_sorted_keys[i], key_prev, expected_prev,
diag_str))
# Validate OrderedDict.keys(), OrderedDict.values(),
# and OrderedDict.items()
expected_dict = {x.key: x.val for x in expected}
expected_keys = expected_sorted_keys
expected_values = [expected_dict[x] for x in expected_keys]
expected_items = [(x, expected_dict[x]) for x in expected_keys]
keys_list = list(actual.keys())
assert keys_list == expected_keys, ('Unexpected result from '
'actual.keys()\n'
' actual.keys(): %s\n'
' expected_keys: %s\n'
' %s' % (keys_list, expected_keys, diag_str))
values_list = list(actual.values())
assert values_list == expected_values, ('Unexpected result from '
'actual.values()\n'
' actual.values(): %s\n'
' expected_values: %s\n'
' %s' % (values_list, expected_values, diag_str))
items_list = list(actual.items())
assert items_list == expected_items, ('Unexpected result from '
'actual.items()\n'
' actual.items(): %s\n'
' expected_items: %s\n'
' %s' % (items_list, expected_items, diag_str))
@staticmethod
def create_test_tree(num_nodes, seed = 0):
node_creation_order = [
'node008_',
'node004_l',
'node012_r',
'node002_ll',
'node006_lr',
'node010_rl',
'node014_rr',
'node001_lll',
'node003_llr',
'node005_lrl',
'node007_lrr',
'node009_rll',
'node011_rlr',
'node013_rrl',
'node015_rrr',
]
assert num_nodes <= len(node_creation_order), (
'num_nodes: %s max_supported: %s'
% (num_nodes, len(node_creation_order)))
tree = OrderedDict()
expected = []
random.seed(seed)
for i, node_name in enumerate(node_creation_order):
if i >= num_nodes:
break
val = TestSupport.AugIntMod100(random.randint(0, 1000000))
tree[node_name] = val
expected_entry = TestSupport.ExpectedEntry(node_name, val,
val.augmented_val())
expected.append(expected_entry)
return tree, expected
@staticmethod
def perform_npairs_add(num_elements, num_add):
if num_add > num_elements:
raise ValueError('Num add > Num elements,\n'
' num_elements: %s num_add: %s'
% (num_elements, num_add))
# Create list of elements that key and value for each of the
# adds will be taken from.
elements = []
for i in range(num_elements):
elements.append(('key%03i' % i,
TestSupport.AugIntMod100(1234 + i)))
for p in itertools.permutations(range(num_elements), num_add):
tree = OrderedDict()
tree.aug_enable()
expected = []
for j in p:
key = elements[j][0]
val = elements[j][1]
diag_str = ('permutation: %s\n'
'add key: %s val: %s aug_amt: %s'
% (p, key, val, val.augmented_val()))
tree[key] = val
expected.append(TestSupport.ExpectedEntry(key, val,
val.augmented_val()))
TestSupport.check_expected(tree, expected, diag_str)
@staticmethod
def perform_npairs_del(tree_size, num_del):
for p in itertools.permutations(range(tree_size), num_del):
tree, orig_expected = TestSupport.create_test_tree(
tree_size)
tree.aug_enable()
expected = orig_expected
for j in p:
if j < 0 or j > tree_size:
raise ValueError('Invalid index j: %s for\n'
' tree_size: %s'
% (j, tree_size))
# Delete from the tree and the expected list
# the jth element in the original tree.
diag_str = ('permutation: %s\n'
'delete: %s' % (p, j))
element_key = orig_expected[j].key
del tree[element_key]
expected = [x for x in expected if x.key != element_key]
TestSupport.check_expected(tree, expected, diag_str)
# ---- Unit Tests ----
class UnitTests(unittest.TestCase, TestSupport):
# Maintain a list that provides the preferred order of test
# execution. Although this list is not used by the UnitTests
# class and it should be possible to execute the unit tests in
# any order, this list is used by other test suites to execute
# these tests in the specified order. Although the order is
# arbitrary, it is mostly specified such that the simplier
# tests are specified earlier in the list.
preferred_execution_order = []
# Test empty augmented binary search tree.
preferred_execution_order.append('test_empty')
def test_empty(self):
tree = OrderedDict()
TestSupport.check_expected(tree, ( ))
# Validate KeyError for non-existent key
with self.assertRaisesRegex(KeyError, r'key of boat not found'):
val = tree['boat']
# Test augmented binary search tree with a single entry.
preferred_execution_order.append('test_single_entry')
def test_single_entry(self):
tree = OrderedDict()
tree[1653] = 35
TestSupport.check_expected(tree, (
self.ExpectedEntry(1653, 35, 1), ))
# Validate KeyError for non-existent key
with self.assertRaisesRegex(KeyError, r'key of 5 not found'):
val = tree[5]
def test_foo(self):
print()
tree = OrderedDict()
tree.aug_enable()
tree['a'] = self.AugIntMod100(103)
tree['b'] = self.AugIntMod100(105)
tree['c'] = self.AugIntMod100(102)
tree['d'] = self.AugIntMod100(104)
tree['e'] = self.AugIntMod100(103)
tree['f'] = self.AugIntMod100(106)
tree['g'] = self.AugIntMod100(102)
print('---- 0: ', tree.aug_find(0))
print('---- 1: ', tree.aug_find(1))
print('---- 2: ', tree.aug_find(2))
print('---- 3: ', tree.aug_find(3))
print('---- 4: ', tree.aug_find(4))
print('---- 5: ', tree.aug_find(5))
print('---- 6: ', tree.aug_find(6))
print('---- 7: ', tree.aug_find(7))
print('---- 8: ', tree.aug_find(8))
print('---- 9: ', tree.aug_find(9))
print('---- 10: ', tree.aug_find(10))
print('---- 11: ', tree.aug_find(11))
print('---- 12: ', tree.aug_find(12))
print('---- 13: ', tree.aug_find(13))
print('---- 14: ', tree.aug_find(14))
print('---- 15: ', tree.aug_find(15))
print('---- 16: ', tree.aug_find(16))
print('---- 17: ', tree.aug_find(17))
print('---- 18: ', tree.aug_find(18))
print('---- 19: ', tree.aug_find(19))
print('---- 20: ', tree.aug_find(20))
print('---- 21: ', tree.aug_find(21))
print('---- 22: ', tree.aug_find(22))
print('---- 23: ', tree.aug_find(23))
print('---- 24: ', tree.aug_find(24))
print('---- 25: ', tree.aug_find(25))
print('---- 26: ', tree.aug_find(26))
print('---- 27: ', tree.aug_find(27))
# Test string as key
preferred_execution_order.append('test_key_str')
def test_key_str(self):
tree = OrderedDict()
tree['boat'] = 6592
TestSupport.check_expected(tree, (
self.ExpectedEntry('boat', 6592, 1), ))
# Test string as key and value
preferred_execution_order.append('test_key_val_str')
def test_key_val_str(self):
tree = OrderedDict()
tree['car'] = 'blue'
TestSupport.check_expected(tree, (
self.ExpectedEntry('car', 'blue', 1), ))
# Test left child
preferred_execution_order.append('test_left_child')
def test_left_child(self):
tree = OrderedDict()
tree.aug_enable()
tree['house'] = self.AugIntDiv10(52) # Root node
tree['car'] = self.AugIntDiv10(78) # Left child of root node
TestSupport.check_expected(tree, (
self.ExpectedEntry('house', 52, 5),
self.ExpectedEntry('car', 78, 7),
))
# Test right child
preferred_execution_order.append('test_right_child')
def test_right_child(self):
tree = OrderedDict()
tree.aug_enable()
tree['berry'] = self.AugIntMod100(745) # Root node
tree['car'] = self.AugIntMod100(3548) # Right child of root node
TestSupport.check_expected(tree, (
self.ExpectedEntry('berry', 745, 45),
self.ExpectedEntry('car', 3548, 48),
))
# Test 3 nodes balanced - root with two children
preferred_execution_order.append('test_three_nodes_balanced')
def test_three_nodes_balanced(self):
tree = OrderedDict()
tree.aug_enable()
tree['green'] = self.AugIntMod100(8830) # Root node
tree['blue'] = self.AugIntMod100(4812) # Left child of root node
tree['red'] = self.AugIntMod100(382) # right child of root node
TestSupport.check_expected(tree, (
self.ExpectedEntry('blue', 4812, 12),
self.ExpectedEntry('green', 8830, 30),
self.ExpectedEntry('red', 382, 82),
))
# Test Left Rotate
#
# Addition of node with key of 'tree' initially creates the
# following binary search tree:
#
# 'flower'
# \
# 'rock'
# \
# 'tree'
#
# A left rotate is used to change this to:
#
# 'rock'
# / \
# 'flower' 'tree'
#
preferred_execution_order.append('test_three_nodes_left_rotate')
def test_three_nodes_left_rotate(self):
tree = OrderedDict()
tree.aug_enable()
tree['flower'] = self.AugIntDiv10(6234)
tree['rock'] = self.AugIntDiv10(5921)
tree['tree'] = self.AugIntDiv10(9123)
TestSupport.check_expected(tree, (
self.ExpectedEntry('flower', 6234, 623),
self.ExpectedEntry('rock', 5921, 592),
self.ExpectedEntry('tree', 9123, 912),
))
# Test Right Rotate
#
# Addition of node with key of 'tree' initially creates the
# following binary search tree:
#
# 'jack'
# /
# 'block'
# /
# 'ball'
#
# A right rotate is used to change this to:
#
# 'block'
# / \
# 'ball' 'jack'
#
preferred_execution_order.append('test_three_nodes_right_rotate')
def test_three_nodes_right_rotate(self):
tree = OrderedDict()
tree.aug_enable()
tree['jack'] = self.AugIntMod100(2984)
tree['block'] = self.AugIntMod100(1523)
tree['ball'] = self.AugIntMod100(5421)
TestSupport.check_expected(tree, (
self.ExpectedEntry('ball', 5421, 21),
self.ExpectedEntry('block', 1523, 23),
self.ExpectedEntry('jack', 2984, 84),
))
# Test add 4 nodes
# A set of 4 tests, each that starts out with
# the creation of the following tree:
#
# 'node004_root'
# / \
# 'node002_left' 'node006_right'
#
# Each of the 4 tests then add one of the following elements,
# which will each get added to one of the 4 empty left or
# right links of the initial 2 leaf nodes:
#
# 'node001_left_left' # Added to left of 'node002_left'
# 'node003_left_right' # Added to right of 'node002_left'
# 'node005_right_left' # Added to left of 'node006_right'
# 'node007_right_right' # Added to right of 'node006_right'
#
preferred_execution_order.append('test_add_four_nodes001')
def test_add_four_nodes001(self):
tree = OrderedDict()
tree.aug_enable()
tree['node004_root'] = self.AugIntMod100(6582)
tree['node002_left'] = self.AugIntMod100(8723)
tree['node006_right'] = self.AugIntMod100(8938)
tree['node001_left_left'] = self.AugIntMod100(783)
TestSupport.check_expected(tree, (
self.ExpectedEntry('node004_root', 6582, 82),
self.ExpectedEntry('node002_left', 8723, 23),
self.ExpectedEntry('node006_right', 8938, 38),
self.ExpectedEntry('node001_left_left', 783, 83),
))
preferred_execution_order.append('test_add_four_nodes002')
def test_add_four_nodes002(self):
tree = OrderedDict()
tree.aug_enable()
tree['node004_root'] = self.AugIntMod100(6582)
tree['node002_left'] = self.AugIntMod100(8723)
tree['node006_right'] = self.AugIntMod100(8938)
tree['node003_left_right'] = self.AugIntMod100(386)
TestSupport.check_expected(tree, (
self.ExpectedEntry('node004_root', 6582, 82),
self.ExpectedEntry('node002_left', 8723, 23),
self.ExpectedEntry('node006_right', 8938, 38),
self.ExpectedEntry('node003_left_right', 386, 86),
))
preferred_execution_order.append('test_add_four_nodes003')
def test_add_four_nodes003(self):
tree = OrderedDict()
tree.aug_enable()
tree['node004_root'] = self.AugIntMod100(6582)
tree['node002_left'] = self.AugIntMod100(8723)
tree['node006_right'] = self.AugIntMod100(8938)
tree['node005_right_left'] = self.AugIntMod100(821)
TestSupport.check_expected(tree, (
self.ExpectedEntry('node004_root', 6582, 82),
self.ExpectedEntry('node002_left', 8723, 23),
self.ExpectedEntry('node006_right', 8938, 38),
self.ExpectedEntry('node005_right_left', 821, 21),
))
preferred_execution_order.append('test_add_four_nodes004')
def test_add_four_nodes004(self):
tree = OrderedDict()
tree.aug_enable()
tree['node004_root'] = self.AugIntMod100(6582)
tree['node002_left'] = self.AugIntMod100(8723)
tree['node006_right'] = self.AugIntMod100(8938)
tree['node007_right_right'] = self.AugIntMod100(275)
TestSupport.check_expected(tree, (
self.ExpectedEntry('node004_root', 6582, 82),
self.ExpectedEntry('node002_left', 8723, 23),
self.ExpectedEntry('node006_right', 8938, 38),
self.ExpectedEntry('node007_right_right', 275, 75),
))
# Test add tree size == 5, npairs == 5
#
# For all permutations, of 5 elements from a tree size of
# 5 elements (5! == 120 permutations), creates a tree
# of 5 elements and then removes the 5 elements in the
# order specified by the permutation.
preferred_execution_order.append('test_add_tree5_npairs5')
def test_add_tree5_npairs5(self):
TestSupport.perform_npairs_add(5, 5)
# Test Mixed Types
#
# Validates a tree where not all of the element values
# are of the same type.
preferred_execution_order.append('test_val_mixed_types')
def test_val_mixed_types(self):
tree = OrderedDict()
tree.aug_enable()
tree['Marigold'] = 6582
tree['Lily'] = self.AugIntMod100(6458)
tree['Poinsettia'] = 65.23
tree['Daisy'] = self.AugIntDiv10(63487)
tree['Violet'] = 'Blue'
TestSupport.check_expected(tree, (
self.ExpectedEntry('Marigold', 6582, 1),
self.ExpectedEntry('Lily', 6458, 58),
self.ExpectedEntry('Poinsettia', 65.23, 1),
self.ExpectedEntry('Daisy', 63487, 6348),
self.ExpectedEntry('Violet', 'Blue', 1),
))
# Test Augmented Float
#
# Basic tree with entries that have augmented values of
# type float.
preferred_execution_order.append('test_augmented_float')
def test_augmented_float(self):
tree = OrderedDict()
tree.aug_enable()
tree['dog'] = self.AugFloatDiv100(753.926)
tree['cat'] = self.AugFloatDiv100(2.383)
tree['fish'] = self.AugFloatDiv100(1842.74)
tree['goat'] = self.AugFloatDiv100(83.832)
TestSupport.check_expected(tree, (
self.ExpectedEntry('dog', 753.926, 7.53926),
self.ExpectedEntry('cat', 2.383, 0.02383),
self.ExpectedEntry('fish', 1842.74, 18.4274),
self.ExpectedEntry('goat', 83.832, 0.83832),
))
# Test Augmented Float and Int
#
# Basic tree with entries that have augmented values of
# type float and int.
preferred_execution_order.append('test_augmented_float_and_int')
def test_augmented_float_and_int(self):
tree = OrderedDict()
tree.aug_enable()
tree['Blue'] = self.AugIntMod100(753)
tree['Green'] = self.AugFloatDiv100(3.482)
tree['Brown'] = self.AugIntDiv10(803)
tree['Yellow'] = self.AugFloatDiv100(2830.83)
TestSupport.check_expected(tree, (
self.ExpectedEntry('Blue', 753, 53),
self.ExpectedEntry('Green', 3.482, 0.03482),
self.ExpectedEntry('Brown', 803, 80),
self.ExpectedEntry('Yellow', 2830.83, 28.3083),
))
# Test Replace Existing
#
# Creates a 5 node tree and then 1-by-1 changes the value
# of each node and validates the contents of the entire
# tree.
preferred_execution_order.append('test_replace_existing')
def test_replace_existing(self):
tree = OrderedDict()
tree.aug_enable()
tree['Kazoo'] = self.AugIntMod100(3417)
tree['Balloon'] = self.AugIntMod100(2512)
tree['Tiddlywinks'] = self.AugIntMod100(8412)
tree['Boat'] = self.AugIntMod100(3712)
tree['Top'] = self.AugIntMod100(8732)
TestSupport.check_expected(tree, (
self.ExpectedEntry('Balloon', 2512, 12),
self.ExpectedEntry('Boat', 3712, 12),
self.ExpectedEntry('Kazoo', 3417, 17),
self.ExpectedEntry('Tiddlywinks', 8412, 12),
self.ExpectedEntry('Top', 8732, 32),
))
tree['Top'] = self.AugIntMod100(523)
TestSupport.check_expected(tree, (
self.ExpectedEntry('Balloon', 2512, 12),
self.ExpectedEntry('Boat', 3712, 12),
self.ExpectedEntry('Kazoo', 3417, 17),
self.ExpectedEntry('Tiddlywinks', 8412, 12),
self.ExpectedEntry('Top', 523, 23),
))
tree['Boat'] = self.AugIntDiv10(756)
TestSupport.check_expected(tree, (
self.ExpectedEntry('Balloon', 2512, 12),
self.ExpectedEntry('Boat', 756, 75),
self.ExpectedEntry('Kazoo', 3417, 17),
self.ExpectedEntry('Tiddlywinks', 8412, 12),
self.ExpectedEntry('Top', 523, 23),
))
tree['Tiddlywinks'] = 845
TestSupport.check_expected(tree, (
self.ExpectedEntry('Balloon', 2512, 12),
self.ExpectedEntry('Boat', 756, 75),
self.ExpectedEntry('Kazoo', 3417, 17),
self.ExpectedEntry('Tiddlywinks', 845, 1),
self.ExpectedEntry('Top', 523, 23),
))
tree['Balloon'] = 756.3
TestSupport.check_expected(tree, (
self.ExpectedEntry('Balloon', 756.3, 1),
self.ExpectedEntry('Boat', 756, 75),
self.ExpectedEntry('Kazoo', 3417, 17),
self.ExpectedEntry('Tiddlywinks', 845, 1),
self.ExpectedEntry('Top', 523, 23),
))
tree['Kazoo'] = self.AugIntMod100(5300)
TestSupport.check_expected(tree, (
self.ExpectedEntry('Balloon', 756.3, 1),
self.ExpectedEntry('Boat', 756, 75),
self.ExpectedEntry('Kazoo', 5300, 12),
self.ExpectedEntry('Tiddlywinks', 845, 1),
self.ExpectedEntry('Top', 523, 23),
))
# Test delete single node
#
# Creates a single node tree and then deletes that
# single node.
preferred_execution_order.append('test_del_single_entry')
def test_del_single_entry(self):
tree = OrderedDict()
tree['house'] = 'vacation'
TestSupport.check_expected(tree, (
self.ExpectedEntry('house', 'vacation', 1), ))
del tree['house']
TestSupport.check_expected(tree, ( ))
preferred_execution_order.append('test_del_right_leaf')
def test_del_right_leaf(self):
tree = OrderedDict()
tree.aug_enable()
tree['Delaware'] = self.AugIntMod100(73658)
tree['Ohio'] = self.AugIntMod100(87238)
del tree['Ohio']
expected = (
self.ExpectedEntry('Delaware', 73658, 58),
)
TestSupport.check_expected(tree, expected)
preferred_execution_order.append('test_del_left_leaf')
def test_del_left_leaf(self):
tree = OrderedDict()
tree.aug_enable()
tree['Michigan'] = self.AugIntMod100(8627)
tree['California'] = self.AugIntMod100(7835)
del tree['California']
expected = (
self.ExpectedEntry('Michigan', 8627, 27),
)
TestSupport.check_expected(tree, expected)
# Test delete tree size == 5, npairs == 5
# For all permutations, of 5 elements from a tree size of
# 5 elements (5! == 120 permutations), creates a tree
# of 5 elements and then removes the 5 elements in the
# order specified by the permutation.
preferred_execution_order.append('test_del_tree5_npairs5')
def test_del_tree5_npairs5(self):
TestSupport.perform_npairs_del(5, 5)
# Test delete tree size == 7, npairs == 4
# For all permutations, of 4 elements from a tree size of
# 7 elements (7!/(7 - 4)! == 840 permutations), creates a
# a tree of 7 elements and then removes the 4 elements in the
# order specified by the permutation.
preferred_execution_order.append('test_del_tree7_npairs4')
def test_del_tree7_npairs4(self):
TestSupport.perform_npairs_del(7, 4)
class FunctionalTests(unittest.TestCase):
def test_del_tree7_npairs7(self):
TestSupport.perform_npairs_del(7, 7)
def test_del_tree15_npairs3(self):
TestSupport.perform_npairs_del(15, 3)
def test_prandom_small(self):
PRandomTests.perform_prandom(pass_start=0, num_passes=10000,
banner_every=1000, elements_max=20)
def test_prandom_medium(self):
PRandomTests.perform_prandom(pass_start=0, num_passes=10000,
banner_every=1000, elements_max=100,
add_ratio=0.6, del_ratio=0.2)
def test_prandom_large(self):
PRandomTests.perform_prandom(pass_start=0, num_passes=1000,
banner_every=100, elements_max=1000, ops_max=10000,
add_ratio=0.8, del_ratio=0.1)
# ---- Pseudo-Random Tests ----
class PRandomTests(unittest.TestCase):
key_length = 8
def test_prandom_0_1000(self):
PRandomTests.perform_prandom(pass_start=0, num_passes=1000)
@staticmethod
def perform_prandom(pass_start=0, num_passes=1, verbose=False,
elements_max=100, ops_max=500, add_ratio=0.4, del_ratio=0.4,
aug_enable_ratio=0.7, banner_every=0):
if verbose:
print()
print('pass_start: ', pass_start)
print('num_passes: ', num_passes)
print('elements_max: ', elements_max)
print('ops_max: ', ops_max)
print('add_ratio: ', add_ratio)
print('del_ratio: ', del_ratio)
if (add_ratio + del_ratio) > 1.0:
raise ValueError('Add plus del ratio > 1.0\n'
' add_ratio: %s del_ratio: %s'
% (add_ratio, del_ratio))
for pass_num in range(pass_start, pass_start + num_passes):
if verbose or ((banner_every > 0)
and (pass_num != pass_start)
and (((pass_num - pass_start) % banner_every) == 0)):
if pass_num - pass_start == banner_every:
print()
print(' pass: %s of %s' % (pass_num, num_passes))
random.seed(pass_num)
elements_max = random.randint(5, elements_max)
operations_num = random.randint(3, ops_max)
if verbose:
print(' elements_max: ', elements_max)
print(' operations_num: ', operations_num)
tree = OrderedDict()
expected = []
# Initially popoulate half of maximum number of elements.
for i in range(elements_max // 2):
key = PRandomTests.random_key(
[_.key for _ in expected])
val = PRandomTests.random_val()
aug_amt = (val.augmented_val() if
hasattr(val, 'augmented_val') else 1)
if verbose:
print('initial add entry key: %s val: %s aug_amt: %s'
% (key, val, aug_amt))
tree[key] = val;
expected.append(TestSupport.ExpectedEntry(
key, val, aug_amt))
# Determine if and when support for augmented values should
# be enabled.
aug_enable_at_op = (random.randint(0, operations_num - 1)
if random.random() < aug_enable_ratio else None)
# Perform the operations.
for op_num in range(operations_num):
# Enable augmented value support if at operation
# number where it should be enabled.
if (op_num == aug_enable_at_op):
tree.aug_enable()
if verbose:
print(' aug_enable')
# When the tree is empty always perform an add operation.
force_add = True if not len(expected) else False
if verbose:
print(' op_num: %s ' % op_num, end="")
op_choice = random.random()
if force_add or op_choice < add_ratio:
# Add Entry
# Nop if max entries already exist.
if len(expected) < elements_max:
key = PRandomTests.random_key(
[_.key for _ in expected])
val = PRandomTests.random_val()
aug_amt = (val.augmented_val() if
hasattr(val, 'augmented_val') else 1)
if verbose:
print('add entry key: %s val: %s aug_amt: %s'
% (key, val, aug_amt))
tree[key] = val;
expected.append(TestSupport.ExpectedEntry(
key, val, aug_amt))
else:
if verbose:
print('NOP')
elif op_choice < add_ratio + del_ratio:
# Delete an entry
idx = random.randint(0, len(expected) - 1)
if verbose:
print('Delete entry at idx: %s key: %s'
% (idx, expected[idx].key))
del tree[expected[idx].key]
del expected[idx]
else:
# Update value in existing entry
idx = random.randint(0, len(expected) - 1)
key = expected[idx].key
val = PRandomTests.random_val()
aug_amt = (val.augmented_val() if
hasattr(val, 'augmented_val') else 1)
if verbose:
print('update entry at idx: ', idx)
tree[key] = val
expected[idx] = TestSupport.ExpectedEntry(
key, val, aug_amt)
TestSupport.check_expected(tree, expected,
'pass: %s' % pass_num)
@staticmethod
def random_key(exclude=[]):
key = ''.join(random.choice(string.digits
+ string.ascii_letters)
for _ in range(PRandomTests.key_length))
while key in exclude:
key = ''.join(random.choice(string.digits
+ string.letters) for _ in range(PRandomTests.key_length))
return key
@staticmethod
def random_val():
val = TestSupport.AugIntMod100(random.randint(-1000000, 1000000))
return val
class PRandomContinousTest(unittest.TestCase):
# elements_max=100, ops_max=500, add_ratio=0.4, del_ratio=0.4,
def test_continous(self):
traits = [
{'elements_max':20},
{'elements_max':100},
{'elements_max':1000, 'add_ratio':0.7, 'del_ratio':0.2},
]
pass_num = 0
passes_each_batch = 1000
batches_per_trait = 10
batches_this_trait = 0
trait_idx = 0
print()
while True:
if batches_this_trait == 0:
print('traits: ', traits[trait_idx])
print(' passes: %s' % pass_num, end="")
batch_start_time = time.clock()
PRandomTests.perform_prandom(pass_start=pass_num,
num_passes=passes_each_batch,
**traits[trait_idx])
batch_end_time = time.clock()
batch_time = batch_end_time - batch_start_time
if (batch_time > 0.0):
print(' - %s passes/sec: %.2f'
% (pass_num + passes_each_batch - 1,
passes_each_batch / batch_time))
else:
print()
batches_this_trait += 1
if batches_this_trait == batches_per_trait:
trait_idx = (trait_idx + 1) % len(traits)
batches_this_trait = 0
pass_num += passes_each_batch
# ---- Smoke Test Suite ----
SmokeTestSuite = unittest.TestSuite()
SmokeTestSuite.addTests(map(UnitTests, UnitTests.preferred_execution_order))
SmokeTestSuite.addTests(unittest.TestLoader().loadTestsFromTestCase(
PRandomTests))
# ---- PreSubmit Test Suite ----
PreSubmitTestSuite = unittest.TestSuite()
PreSubmitTestSuite.addTests(SmokeTestSuite)
PreSubmitTestSuite.addTests(unittest.TestLoader().loadTestsFromTestCase(
FunctionalTests))
# Run the command-line specified tests. If no tests specified
# on the comnmand-line, then by default the tests within the
# smoke test suite are executed.
unittest.main(defaultTest='SmokeTestSuite')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment