Skip to content

Instantly share code, notes, and snippets.

@malthe
Forked from gaubert/btree.py
Created March 2, 2012 22:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save malthe/1962054 to your computer and use it in GitHub Desktop.
Save malthe/1962054 to your computer and use it in GitHub Desktop.
A pure-python B+ tree implementation (interpreter versions 2.6 thru 3.2)
# Based on code by Travis Parker
import bisect
import operator
try:
xrange
except NameError:
xrange = range
try:
from itertools import izip as zip, imap as map
except ImportError:
pass
class Node(object):
__slots__ = "tree", "contents", "children"
def __init__(self, tree, contents=None, children=None):
self.tree = tree
self.contents = contents or []
self.children = children or []
if self.children:
assert len(self.contents) + 1 == len(self.children), \
"one more child than data item required"
def __repr__(self):
name = getattr(self, "children", 0) and "Branch" or "Leaf"
return "<%s %s>" % (name, ", ".join(map(str, self.contents)))
def lateral(self, parent, parent_index, dest, dest_index):
if parent_index > dest_index:
dest.contents.append(parent.contents[dest_index])
parent.contents[dest_index] = self.contents.pop(0)
if self.children:
dest.children.append(self.children.pop(0))
else:
dest.contents.insert(0, parent.contents[parent_index])
parent.contents[parent_index] = self.contents.pop()
if self.children:
dest.children.insert(0, self.children.pop())
def shrink(self, ancestors):
parent = None
if ancestors:
parent, parent_index = ancestors.pop()
# try to lend to the left neighboring sibling
if parent_index:
left_sib = parent.children[parent_index - 1]
if len(left_sib.contents) < self.tree.order:
self.lateral(
parent, parent_index, left_sib, parent_index - 1)
return
# try the right neighbor
if parent_index + 1 < len(parent.children):
right_sib = parent.children[parent_index + 1]
if len(right_sib.contents) < self.tree.order:
self.lateral(
parent, parent_index, right_sib, parent_index + 1)
return
# center = len(self.contents) // 2
sibling, push = self.split()
if not parent:
parent, parent_index = Node(
self.tree, children=[self]), 0
self.tree._root = parent
# pass the median up to the parent
parent.contents.insert(parent_index, push)
parent.children.insert(parent_index + 1, sibling)
if len(parent.contents) > parent.tree.order:
parent.shrink(ancestors)
def grow(self, ancestors):
parent, parent_index = ancestors.pop()
minimum = self.tree.order // 2
left_sib = right_sib = None
# try to borrow from the right sibling
if parent_index + 1 < len(parent.children):
right_sib = parent.children[parent_index + 1]
if len(right_sib.contents) > minimum:
right_sib.lateral(parent, parent_index + 1, self, parent_index)
return
# try to borrow from the left sibling
if parent_index:
left_sib = parent.children[parent_index - 1]
if len(left_sib.contents) > minimum:
left_sib.lateral(parent, parent_index - 1, self, parent_index)
return
# consolidate with a sibling - try left first
if left_sib:
left_sib.contents.append(parent.contents[parent_index - 1])
left_sib.contents.extend(self.contents)
if self.children:
left_sib.children.extend(self.children)
parent.contents.pop(parent_index - 1)
parent.children.pop(parent_index)
else:
self.contents.append(parent.contents[parent_index])
self.contents.extend(right_sib.contents)
if self.children:
self.children.extend(right_sib.children)
parent.contents.pop(parent_index)
parent.children.pop(parent_index + 1)
if len(parent.contents) < minimum:
if ancestors:
# parent is not the root
parent.grow(ancestors)
elif not parent.contents:
# parent is root, and its now empty
self.tree._root = left_sib or self
def split(self):
center = len(self.contents) // 2
median = self.contents[center]
sibling = type(self)(
self.tree,
self.contents[center + 1:],
self.children[center + 1:])
self.contents = self.contents[:center]
self.children = self.children[:center + 1]
return sibling, median
def insert(self, index, item, ancestors):
self.contents.insert(index, item)
if len(self.contents) > self.tree.order:
self.shrink(ancestors)
def remove(self, index, ancestors):
minimum = self.tree.order // 2
if self.children:
# try promoting from the right subtree first,
# but only if it won't have to resize
additional_ancestors = [(self, index + 1)]
descendent = self.children[index + 1]
while descendent.children:
additional_ancestors.append((descendent, 0))
descendent = descendent.children[0]
if len(descendent.contents) > minimum:
ancestors.extend(additional_ancestors)
self.contents[index] = descendent.contents[0]
descendent.remove(0, ancestors)
return
# fall back to the left child
additional_ancestors = [(self, index)]
descendent = self.children[index]
while descendent.children:
additional_ancestors.append(
(descendent, len(descendent.children) - 1))
descendent = descendent.children[-1]
ancestors.extend(additional_ancestors)
self.contents[index] = descendent.contents[-1]
descendent.remove(len(descendent.children) - 1, ancestors)
else:
self.contents.pop(index)
if len(self.contents) < minimum and ancestors:
self.grow(ancestors)
class Leaf(object):
__slots__ = "tree", "contents", "data", "next"
def __init__(self, tree, contents=None, data=None, next=None):
self.tree = tree
self.contents = contents or []
self.data = data or []
self.next = next
assert len(self.contents) == len(self.data), "one data per key"
def __repr__(self):
name = getattr(self, "children", 0) and "Branch" or "Leaf"
return "<%s %s>" % (name, ", ".join(map(str, self.contents)))
def insert(self, index, key, data, ancestors):
self.contents.insert(index, key)
self.data.insert(index, data)
if len(self.contents) > self.tree.order:
self.shrink(ancestors)
def lateral(self, parent, parent_index, dest, dest_index):
if parent_index > dest_index:
dest.contents.append(self.contents.pop(0))
dest.data.append(self.data.pop(0))
parent.contents[dest_index] = self.contents[0]
else:
dest.contents.insert(0, self.contents.pop())
dest.data.insert(0, self.data.pop())
parent.contents[parent_index] = dest.contents[0]
def split(self):
center = len(self.contents) // 2
# median = self.contents[center - 1]
sibling = type(self)(
self.tree,
self.contents[center:],
self.data[center:],
self.next)
self.contents = self.contents[:center]
self.data = self.data[:center]
self.next = sibling
return sibling, sibling.contents[0]
def remove(self, index, ancestors):
minimum = self.tree.order // 2
if index >= len(self.contents):
self, index = self.next, 0
key = self.contents[index]
# if any leaf that could accept the key can do so
# without any rebalancing necessary, then go that route
current = self
while current is not None and current.contents[0] == key:
if len(current.contents) > minimum:
if current.contents[0] == key:
index = 0
else:
index = bisect.bisect_left(current.contents, key)
current.contents.pop(index)
current.data.pop(index)
return
current = current.next
self.grow(ancestors)
def grow(self, ancestors):
minimum = self.tree.order // 2
parent, parent_index = ancestors.pop()
left_sib = right_sib = None
# try borrowing from a neighbor - try right first
if parent_index + 1 < len(parent.children):
right_sib = parent.children[parent_index + 1]
if len(right_sib.contents) > minimum:
right_sib.lateral(parent, parent_index + 1, self, parent_index)
return
# fallback to left
if parent_index:
left_sib = parent.children[parent_index - 1]
if len(left_sib.contents) > minimum:
left_sib.lateral(parent, parent_index - 1, self, parent_index)
return
# join with a neighbor - try left first
if left_sib:
left_sib.contents.extend(self.contents)
left_sib.data.extend(self.data)
parent.remove(parent_index - 1, ancestors)
return
# fallback to right
self.contents.extend(right_sib.contents)
self.data.extend(right_sib.data)
parent.remove(parent_index, ancestors)
def shrink(self, ancestors):
parent = None
if ancestors:
parent, parent_index = ancestors.pop()
# try to lend to the left neighboring sibling
if parent_index:
left_sib = parent.children[parent_index - 1]
if len(left_sib.contents) < self.tree.order:
self.lateral(
parent, parent_index, left_sib, parent_index - 1)
return
# try the right neighbor
if parent_index + 1 < len(parent.children):
right_sib = parent.children[parent_index + 1]
if len(right_sib.contents) < self.tree.order:
self.lateral(
parent, parent_index, right_sib, parent_index + 1)
return
# center = len(self.contents) // 2
sibling, push = self.split()
if not parent:
parent, parent_index = Node(
self.tree, children=[self]), 0
self.tree._root = parent
# pass the median up to the parent
parent.contents.insert(parent_index, push)
parent.children.insert(parent_index + 1, sibling)
if len(parent.contents) > parent.tree.order:
parent.shrink(ancestors)
class BTree(object):
def __init__(self, order):
self.order = order
self._root = Leaf(self)
def __contains__(self, key):
for item in self._get(key):
return True
return False
def __iter__(self):
def _recurse(node):
if node.children:
for child, item in zip(node.children, node.contents):
for child_item in _recurse(child):
yield child_item
yield item
for child_item in _recurse(node.children[-1]):
yield child_item
else:
for item in node.contents:
yield item
for item in _recurse(self._root):
yield item
def __repr__(self):
def recurse(node, accum, depth):
accum.append((" " * depth) + repr(node))
for node in getattr(node, "children", []):
recurse(node, accum, depth + 1)
accum = []
recurse(self._root, accum, 0)
return "\n".join(accum)
@classmethod
def bulkload(cls, items, order):
tree = object.__new__(cls)
tree.order = order
leaves = tree._build_bulkloaded_leaves(items)
tree._build_bulkloaded_branches(leaves)
return tree
def get(self, key, default=None):
try:
iterable = self._get(key)
return next(iterable)
except StopIteration:
return default
def getlist(self, key):
return list(self._get(key))
def insert(self, key, data):
path = self._path_to(key)
node, index = path.pop()
node.insert(index, key, data, path)
def remove(self, key):
path = self._path_to(key)
node, index = path.pop()
node.remove(index, path)
__getitem__ = get
__setitem__ = insert
__delitem__ = remove
def iteritems(self):
node = self._root
while hasattr(node, "children"):
node = node.children[0]
while node:
for pair in zip(node.contents, node.data):
yield pair
node = node.next
def iterkeys(self):
return map(operator.itemgetter(0), self.iteritems())
def itervalues(self):
return map(operator.itemgetter(1), self.iteritems())
__iter__ = iterkeys
def items(self):
return list(self.iteritems())
def keys(self):
return list(self.iterkeys())
def values(self):
return list(self.itervalues())
def _get(self, key):
node, index = self._path_to(key)[-1]
if index == len(node.contents):
if node.next:
node, index = node.next, 0
else:
return
while node.contents[index] == key:
yield node.data[index]
index += 1
if index == len(node.contents):
if node.next:
node, index = node.next, 0
else:
return
def _path_to(self, item):
current = self._root
ancestry = []
while getattr(current, "children", None):
index = bisect.bisect_left(current.contents, item)
ancestry.append((current, index))
if index < len(current.contents) \
and current.contents[index] == item:
break
current = current.children[index]
else:
index = bisect.bisect_left(current.contents, item)
ancestry.append((current, index))
# present = index < len(current.contents)
# present = present and current.contents[index] == item
path = ancestry
node, index = path[-1]
while hasattr(node, "children"):
node = node.children[index]
index = bisect.bisect_left(node.contents, item)
path.append((node, index))
return path
def _present(self, item, ancestors):
last, index = ancestors[-1]
return index < len(last.contents) and last.contents[index] == item
def _build_bulkloaded_leaves(self, items):
minimum = self.order // 2
leaves, seps = [[]], []
for item in items:
if len(leaves[-1]) >= self.order:
seps.append(item)
leaves.append([])
leaves[-1].append(item)
if len(leaves[-1]) < minimum and seps:
last_two = leaves[-2] + leaves[-1]
leaves[-2] = last_two[:minimum]
leaves[-1] = last_two[minimum:]
seps.append(last_two[minimum])
leaves = [Leaf(
self,
contents=[p[0] for p in pairs],
data=[p[1] for p in pairs])
for pairs in leaves]
for i in xrange(len(leaves) - 1):
leaves[i].next = leaves[i + 1]
return leaves, [s[0] for s in seps]
def _build_bulkloaded_branches(self, items):
leaves, seps = items
minimum = self.order // 2
levels = [leaves]
while len(seps) > self.order + 1:
items, nodes, seps = seps, [[]], []
for item in items:
if len(nodes[-1]) < self.order:
nodes[-1].append(item)
else:
seps.append(item)
nodes.append([])
if len(nodes[-1]) < minimum and seps:
last_two = nodes[-2] + [seps.pop()] + nodes[-1]
nodes[-2] = last_two[:minimum]
nodes[-1] = last_two[minimum + 1:]
seps.append(last_two[minimum])
offset = 0
for i, node in enumerate(nodes):
children = levels[-1][offset:offset + len(node) + 1]
nodes[i] = Node(self, contents=node, children=children)
offset += len(node) + 1
levels.append(nodes)
self._root = Node(self, contents=seps, children=levels[-1])
import random
import unittest
class TreeTestCase(unittest.TestCase):
@property
def factory(self):
from dobbin.btree import BTree
return BTree
def test_additions_sorted(self):
bt = self.factory(20)
l = tuple(range(2000))
for item in l:
bt.insert(item, str(item))
for item in l:
self.assertEqual(str(item), bt[item])
self.assertEqual(l, tuple(bt))
def test_additions_random(self):
bt = self.factory(20)
l = list(range(2000))
random.shuffle(l)
for item in l:
bt.insert(item, str(item))
for item in l:
self.assertEqual(str(item), bt[item])
self.assertEqual(tuple(range(2000)), tuple(bt))
def test_bulkload(self):
bt = self.factory.bulkload(zip(range(2000), map(str, range(2000))), 20)
self.assertEqual(tuple(bt), tuple(range(2000)))
self.assertEqual(
list(bt.iteritems()),
list(zip(range(2000), map(str, range(2000)))))
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment