Skip to content

Instantly share code, notes, and snippets.

@urban-1
Forked from yiakwy/btree.py
Last active February 25, 2017 18:40
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save urban-1/aa8b9b66f20d6d53850985ac01b42243 to your computer and use it in GitHub Desktop.
Save urban-1/aa8b9b66f20d6d53850985ac01b42243 to your computer and use it in GitHub Desktop.
"""
Full credit to original authos of the gist:
https://gist.github.com/yiakwy/8380ee05a0bdbf6c291e
"""
import bisect
import itertools
import operator
import random
import unittest
import re
import time
import string
def rand_str(N=None):
if N is None:
N = random.randint(15,20)
return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(N))
class _BNode(object):
__slots__ = ["tree", "contents", "children"]
def __init__(self, tree, contents=None, children=None):
self.tree = tree
self.contents = contents or []
self.children = children or []
if self.children:
assert len(self.contents) + 1 == len(self.children), \
"one more child than data item required"
def __repr__(self):
name = getattr(self, "children", 0) and "Branch" or "Leaf"
return "<%s %s>" % (name, ", ".join(map(str, self.contents)))
def lateral(self, parent, parent_index, dest, dest_index):
if parent_index > dest_index:
dest.contents.append(parent.contents[dest_index])
parent.contents[dest_index] = self.contents.pop(0)
if self.children:
dest.children.append(self.children.pop(0))
else:
dest.contents.insert(0, parent.contents[parent_index])
parent.contents[parent_index] = self.contents.pop()
if self.children:
dest.children.insert(0, self.children.pop())
def shrink(self, ancestors):
parent = None
if ancestors:
parent, parent_index = ancestors.pop()
# try to lend to the left neighboring sibling
if parent_index:
left_sib = parent.children[parent_index - 1]
if len(left_sib.contents) < self.tree.order:
self.lateral(
parent, parent_index, left_sib, parent_index - 1)
return
# try the right neighbor
if parent_index + 1 < len(parent.children):
right_sib = parent.children[parent_index + 1]
if len(right_sib.contents) < self.tree.order:
self.lateral(
parent, parent_index, right_sib, parent_index + 1)
return
center = len(self.contents) // 2
sibling, push = self.split()
if not parent:
parent, parent_index = self.tree.BRANCH(
self.tree, children=[self]), 0
self.tree._root = parent
# pass the median up to the parent
parent.contents.insert(parent_index, push)
parent.children.insert(parent_index + 1, sibling)
if len(parent.contents) > parent.tree.order:
parent.shrink(ancestors)
def grow(self, ancestors):
parent, parent_index = ancestors.pop()
minimum = self.tree.order // 2
left_sib = right_sib = None
# try to borrow from the right sibling
if parent_index + 1 < len(parent.children):
right_sib = parent.children[parent_index + 1]
if len(right_sib.contents) > minimum:
right_sib.lateral(parent, parent_index + 1, self, parent_index)
return
# try to borrow from the left sibling
if parent_index:
left_sib = parent.children[parent_index - 1]
if len(left_sib.contents) > minimum:
left_sib.lateral(parent, parent_index - 1, self, parent_index)
return
# consolidate with a sibling - try left first
if left_sib:
left_sib.contents.append(parent.contents[parent_index - 1])
left_sib.contents.extend(self.contents)
if self.children:
left_sib.children.extend(self.children)
parent.contents.pop(parent_index - 1)
parent.children.pop(parent_index)
else:
self.contents.append(parent.contents[parent_index])
self.contents.extend(right_sib.contents)
if self.children:
self.children.extend(right_sib.children)
parent.contents.pop(parent_index)
parent.children.pop(parent_index + 1)
if len(parent.contents) < minimum:
if ancestors:
# parent is not the root
parent.grow(ancestors)
elif not parent.contents:
# parent is root, and its now empty
self.tree._root = left_sib or self
def split(self):
center = len(self.contents) // 2
median = self.contents[center]
sibling = type(self)(
self.tree,
self.contents[center + 1:],
self.children[center + 1:])
self.contents = self.contents[:center]
self.children = self.children[:center + 1]
return sibling, median
def insert(self, index, item, ancestors):
self.contents.insert(index, item)
if len(self.contents) > self.tree.order:
self.shrink(ancestors)
def remove(self, index, ancestors):
minimum = self.tree.order // 2
if self.children:
# try promoting from the right subtree first,
# but only if it won't have to resize
additional_ancestors = [(self, index + 1)]
descendent = self.children[index + 1]
while descendent.children:
additional_ancestors.append((descendent, 0))
descendent = descendent.children[0]
if len(descendent.contents) > minimum:
ancestors.extend(additional_ancestors)
self.contents[index] = descendent.contents[0]
descendent.remove(0, ancestors)
return
# fall back to the left child
additional_ancestors = [(self, index)]
descendent = self.children[index]
while descendent.children:
additional_ancestors.append(
(descendent, len(descendent.children) - 1))
descendent = descendent.children[-1]
ancestors.extend(additional_ancestors)
self.contents[index] = descendent.contents[-1]
descendent.remove(len(descendent.children) - 1, ancestors)
else:
self.contents.pop(index)
if len(self.contents) < minimum and ancestors:
self.grow(ancestors)
class _BPlusLeaf(_BNode):
__slots__ = ["tree", "contents", "data", "next"]
def __init__(self, tree, contents=None, data=None, next=None):
self.tree = tree
self.contents = contents or []
self.data = data or []
self.next = next
assert len(self.contents) == len(self.data), "one data per key"
def insert(self, index, key, data, ancestors):
self.contents.insert(index, key)
self.data.insert(index, data)
if len(self.contents) > self.tree.order:
self.shrink(ancestors)
def lateral(self, parent, parent_index, dest, dest_index):
if parent_index > dest_index:
dest.contents.append(self.contents.pop(0))
dest.data.append(self.data.pop(0))
parent.contents[dest_index] = self.contents[0]
else:
dest.contents.insert(0, self.contents.pop())
dest.data.insert(0, self.data.pop())
parent.contents[parent_index] = dest.contents[0]
def split(self):
center = len(self.contents) // 2
median = self.contents[center - 1]
sibling = type(self)(
self.tree,
self.contents[center:],
self.data[center:],
self.next)
self.contents = self.contents[:center]
self.data = self.data[:center]
self.next = sibling
return sibling, sibling.contents[0]
def remove(self, index, ancestors):
minimum = self.tree.order // 2
if index >= len(self.contents):
self, index = self.next, 0
key = self.contents[index]
# if any leaf that could accept the key can do so
# without any rebalancing necessary, then go that route
current = self
while current is not None and current.contents[0] == key:
if len(current.contents) > minimum:
if current.contents[0] == key:
index = 0
else:
index = bisect.bisect_left(current.contents, key)
current.contents.pop(index)
current.data.pop(index)
return
current = current.next
self.grow(ancestors)
def grow(self, ancestors):
minimum = self.tree.order // 2
parent, parent_index = ancestors.pop()
left_sib = right_sib = None
# try borrowing from a neighbor - try right first
if parent_index + 1 < len(parent.children):
right_sib = parent.children[parent_index + 1]
if len(right_sib.contents) > minimum:
right_sib.lateral(parent, parent_index + 1, self, parent_index)
return
# fallback to left
if parent_index:
left_sib = parent.children[parent_index - 1]
if len(left_sib.contents) > minimum:
left_sib.lateral(parent, parent_index - 1, self, parent_index)
return
# join with a neighbor - try left first
if left_sib:
left_sib.contents.extend(self.contents)
left_sib.data.extend(self.data)
parent.remove(parent_index - 1, ancestors)
return
# fallback to right
self.contents.extend(right_sib.contents)
self.data.extend(right_sib.data)
parent.remove(parent_index, ancestors)
class BTree(object):
BRANCH = LEAF = _BNode
OP_EQ = 1
"""Equals"""
OP_RANGE=2
"""Number in range"""
OP_SW = 6
"""startswith"""
def __init__(self, order):
self.order = order
self._root = self._bottom = self.LEAF(self)
def _path_to(self, item):
current = self._root
ancestry = []
while getattr(current, "children", None):
index = bisect.bisect_left(current.contents, item)
ancestry.append((current, index))
if index < len(current.contents) \
and current.contents[index] == item:
return ancestry
current = current.children[index]
index = bisect.bisect_left(current.contents, item)
ancestry.append((current, index))
present = index < len(current.contents)
present = present and current.contents[index] == item
return ancestry
def _path_to_range(self, item, item2=None, operator=OP_EQ):
current = self._root
ancestry = []
while getattr(current, "children", None):
index = bisect.bisect_left(current.contents, item)
ancestry.append((current, index))
if index < len(current.contents):
if (operator == BTree.OP_EQ and current.contents[index] == item) \
or (operator == BTree.OP_RANGE and item <= current.contents[index] <= item2) \
or (operator == BTree.OP_SW and current.contents[index].startswith(item)):
return ancestry
current = current.children[index]
index = bisect.bisect_left(current.contents, item)
ancestry.append((current, index))
present = index < len(current.contents)
present = present and current.contents[index] == item
return ancestry
def _present(self, item, ancestors):
last, index = ancestors[-1]
return index < len(last.contents) and last.contents[index] == item
def insert(self, item):
current = self._root
ancestors = self._path_to(item)
node, index = ancestors[-1]
while getattr(node, "children", None):
node = node.children[index]
index = bisect.bisect_left(node.contents, item)
ancestors.append((node, index))
node, index = ancestors.pop()
node.insert(index, item, ancestors)
def remove(self, item):
current = self._root
ancestors = self._path_to(item)
if self._present(item, ancestors):
node, index = ancestors.pop()
node.remove(index, ancestors)
else:
raise ValueError("%r not in %s" % (item, self.__class__.__name__))
def __contains__(self, item):
return self._present(item, self._path_to(item))
def __iter__(self):
def _recurse(node):
if node.children:
for child, item in zip(node.children, node.contents):
for child_item in _recurse(child):
yield child_item
yield item
for child_item in _recurse(node.children[-1]):
yield child_item
else:
for item in node.contents:
yield item
for item in _recurse(self._root):
yield item
def __repr__(self):
def recurse(node, accum, depth):
accum.append((" " * depth) + repr(node))
for node in getattr(node, "children", []):
recurse(node, accum, depth + 1)
accum = []
recurse(self._root, accum, 0)
return "\n".join(accum)
@classmethod
def bulkload(cls, items, order):
tree = object.__new__(cls)
tree.order = order
leaves = tree._build_bulkloaded_leaves(items)
tree._build_bulkloaded_branches(leaves)
return tree
def _build_bulkloaded_leaves(self, items):
minimum = self.order // 2
leaves, seps = [[]], []
for item in items:
if len(leaves[-1]) < self.order:
leaves[-1].append(item)
else:
seps.append(item)
leaves.append([])
if len(leaves[-1]) < minimum and seps:
last_two = leaves[-2] + [seps.pop()] + leaves[-1]
leaves[-2] = last_two[:minimum]
leaves[-1] = last_two[minimum + 1:]
seps.append(last_two[minimum])
return [self.LEAF(self, contents=node) for node in leaves], seps
def _build_bulkloaded_branches(self, (leaves, seps)):
minimum = self.order // 2
levels = [leaves]
while len(seps) > self.order + 1:
items, nodes, seps = seps, [[]], []
for item in items:
if len(nodes[-1]) < self.order:
nodes[-1].append(item)
else:
seps.append(item)
nodes.append([])
if len(nodes[-1]) < minimum and seps:
last_two = nodes[-2] + [seps.pop()] + nodes[-1]
nodes[-2] = last_two[:minimum]
nodes[-1] = last_two[minimum + 1:]
seps.append(last_two[minimum])
offset = 0
for i, node in enumerate(nodes):
children = levels[-1][offset:offset + len(node) + 1]
nodes[i] = self.BRANCH(self, contents=node, children=children)
offset += len(node) + 1
levels.append(nodes)
self._root = self.BRANCH(self, contents=seps, children=levels[-1])
class BPlusTree(BTree):
LEAF = _BPlusLeaf
def _get(self, key, key2=None, startswith=None):
# Default
operator = BTree.OP_EQ
if startswith is not None:
operator = BTree.OP_SW
elif key2 is not None:
operator = BTree.OP_RANGE
node, index = self._path_to(key, key2, operator)[-1]
if index == len(node.contents):
if node.next:
node, index = node.next, 0
else:
return
while (operator == BTree.OP_EQ and node.contents[index] == key) \
or (operator == BTree.OP_RANGE and key <= node.contents[index] <= key2) \
or (operator == BTree.OP_SW and node.contents[index].startswith(key)):
yield node.data[index]
index += 1
if index == len(node.contents):
if node.next:
node, index = node.next, 0
else:
return
def _path_to(self, item, item2=None, operator=BTree.OP_EQ):
# path = super(BPlusTree, self)._path_to(item)
path = super(BPlusTree, self)._path_to_range(item, item2, operator)
node, index = path[-1]
while hasattr(node, "children"):
node = node.children[index]
index = bisect.bisect_left(node.contents, item)
path.append((node, index))
return path
def get(self, key, key2=None, startswith=None, default=None):
try:
return self._get(key, key2, startswith).next()
except StopIteration:
return default
def getlist(self, key, key2=None, startswith=None, default=None):
return list(self._get(key, key2, startswith))
def insert(self, key, data):
path = self._path_to(key)
node, index = path.pop()
node.insert(index, key, data, path)
def remove(self, key):
path = self._path_to(key)
node, index = path.pop()
node.remove(index, path)
__getitem__ = get
__setitem__ = insert
__delitem__ = remove
def __contains__(self, key):
for item in self._get(key):
return True
return False
def iteritems(self):
node = self._root
while hasattr(node, "children"):
node = node.children[0]
while node:
for pair in itertools.izip(node.contents, node.data):
yield pair
node = node.next
def iterkeys(self):
return itertools.imap(operator.itemgetter(0), self.iteritems())
def itervalues(self):
return itertools.imap(operator.itemgetter(1), self.iteritems())
__iter__ = iterkeys
def items(self):
return list(self.iteritems())
def keys(self):
return list(self.iterkeys())
def values(self):
return list(self.itervalues())
def _build_bulkloaded_leaves(self, items):
minimum = self.order // 2
leaves, seps = [[]], []
for item in items:
if len(leaves[-1]) >= self.order:
seps.append(item)
leaves.append([])
leaves[-1].append(item)
if len(leaves[-1]) < minimum and seps:
last_two = leaves[-2] + leaves[-1]
leaves[-2] = last_two[:minimum]
leaves[-1] = last_two[minimum:]
seps.append(last_two[minimum])
leaves = [self.LEAF(
self,
contents=[p[0] for p in pairs],
data=[p[1] for p in pairs])
for pairs in leaves]
for i in xrange(len(leaves) - 1):
leaves[i].next = leaves[i + 1]
return leaves, [s[0] for s in seps]
def getregex(self, regex):
"""
Get all node values whose key matches a regex
"""
for k, v in self.iteritems():
if regex.search(k):
yield v
class BTreeTests(unittest.TestCase):
def test_additions(self):
bt = BTree(20)
l = range(2000)
for i, item in enumerate(l):
bt.insert(item)
self.assertEqual(list(bt), l[:i + 1])
def test_bulkloads(self):
bt = BTree.bulkload(range(2000), 20)
self.assertEqual(list(bt), range(2000))
def test_removals(self):
bt = BTree(20)
l = range(2000)
map(bt.insert, l)
rand = l[:]
random.shuffle(rand)
while l:
self.assertEqual(list(bt), l)
rem = rand.pop()
l.remove(rem)
bt.remove(rem)
self.assertEqual(list(bt), l)
def test_insert_regression(self):
bt = BTree.bulkload(range(2000), 50)
for i in xrange(100000):
bt.insert(random.randrange(2000))
class BPlusTreeTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.strbt = BPlusTree(20)
for item in range(200000):
cls.strbt.insert(rand_str(), random.randint(1000000, 2000000))
cls.intbt = BPlusTree(20)
for item in range(200000):
cls.intbt.insert(random.randint(0, 200000), random.randint(1000000, 2000000))
cls.floatbt = BPlusTree(20)
for item in range(200000):
cls.floatbt.insert(random.uniform(0.1, 100), random.randint(1000000, 2000000))
def test_additions_sorted(self):
bt = BPlusTree(20)
l = range(2000)
for item in l:
bt.insert(item, str(item))
for item in l:
self.assertEqual(str(item), bt[item])
self.assertEqual(l, list(bt))
def test_additions_random(self):
bt = BPlusTree(20)
l = range(2000)
random.shuffle(l)
for item in l:
bt.insert(item, str(item))
for item in l:
self.assertEqual(str(item), bt[item])
self.assertEqual(range(2000), list(bt))
def test_bulkload(self):
bt = BPlusTree.bulkload(zip(range(2000), map(str, range(2000))), 20)
self.assertEqual(list(bt), range(2000))
self.assertEqual(
list(bt.iteritems()),
zip(range(2000), map(str, range(2000))))
def test_get_string_startswith_old(self):
s = time.time()
# ... Slow way (23.5ms)
done=False
for k, v in BPlusTreeTests.strbt.iteritems():
if k.startswith("B1"):
done=True
elif done:
break
e = time.time()
print("Get String Old: %.2f" % ((e-s)*1000))
def test_get_string_startswith_new(self):
"""
Test string startswith retrieval
"""
s = time.time()
# ... Fast way (0.25ms)
BPlusTreeTests.strbt.getlist("B1", startswith=True)
e = time.time()
print("Get String New: %.2f" % ((e-s)*1000))
def test_get_string_regex(self):
s = time.time()
# Get all that start with B followed by number, contain A and
# end with D! (120 ms)
regex = re.compile("^B[0-9].*A.*D$")
list(BPlusTreeTests.strbt.getregex(regex))
e = time.time()
print("Get String Regex: %.2f" % ((e-s)*1000))
def test_get_int_in_range_old_close_to_root(self):
s = time.time()
# ... Slow way (4.3ms)
done=False
for k, v in BPlusTreeTests.intbt.iteritems():
if k > 100 and k < 200:
# print(v)
done=True
elif done:
break
e = time.time()
print("Get Int Old (Close to ROOT): %.2f" % ((e-s)*1000))
def test_get_int_in_range_new_close_to_root(self):
s = time.time()
BPlusTreeTests.intbt.getlist(100, key2=200)
e = time.time()
print("Get Int New (Close to ROOT): %.2f" % ((e-s)*1000))
def test_get_int_in_range_old(self):
s = time.time()
# ... Slow way (4.3ms)
done=False
for k, v in BPlusTreeTests.intbt.iteritems():
if k > 20100 and k < 20200:
# print(v)
done=True
elif done:
break
e = time.time()
print("Get Int Old: %.2f" % ((e-s)*1000))
def test_get_int_in_range_new(self):
s = time.time()
BPlusTreeTests.intbt.getlist(20100, key2=20200)
e = time.time()
print("Get Int New: %.2f" % ((e-s)*1000))
def test_get_float_in_range_new(self):
s = time.time()
BPlusTreeTests.floatbt.getlist(.5, key2=.6)
e = time.time()
print("Get Float: %.2f" % ((e-s)*1000))
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment