Skip to content

Instantly share code, notes, and snippets.

@eatonphil
Created August 27, 2023 16:47
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eatonphil/51b91bb30bc7dbebd9bfb3c33248e563 to your computer and use it in GitHub Desktop.
Save eatonphil/51b91bb30bc7dbebd9bfb3c33248e563 to your computer and use it in GitHub Desktop.
Python In-memory B-Tree
import math
import uuid
class BTree:
def __init__(self, order=3):
self.root = BTreeNode(order)
def insert(self, toinsert):
all_elements = self.list()
new_self = BTree(self.root.order)
new_node, middle, left, right = self.root.insert(toinsert)
if middle is None:
assert new_node is not None
new_self.root = new_node
else:
# Otherwise, split at root.
new_self.root.elements[0] = middle
new_self.root.children[0] = left
new_self.root.children[1] = right
all_elements.append(toinsert)
for el in all_elements:
assert new_self.contains(el["key"]), f"Expected {el} in {new_self.list()}\n{new_self.print()}"
return new_self
def contains(self, key):
return self.root.contains(key)
def list(self):
return self.root.list()
def print(self):
def _print(node, level, self_id, parent_id):
print([self_id, parent_id], " " * (level + 1) + str(node))
print("[")
self.root.walk(_print)
print("]")
class BTreeNode:
def __init__(self, order=3):
self.id = str(uuid.uuid4())[:8]
self.order = order
if self.order < 3:
raise Exception("BTree must have at least 2 children.")
self.elements = [None for i in range(order - 1)]
self.children = [None for i in range(order)]
def walk(self, fn, level=0, parent_id=None):
assert len(self.children) == self.order
assert len(self.elements) == self.order - 1
assert self.order >= 3
for i, child in enumerate(self.children):
if child is not None:
child.walk(fn, level + 1, self.id)
if i < len(self.elements):
el = self.elements[i]
if el is None:
continue
fn(el, level, self.id, parent_id)
def contains(self, key):
for i, child in enumerate(self.children):
if i < len(self.elements):
if self.elements[i] is not None:
if key == self.elements[i]["key"]:
return True
elif key < self.elements[i]["key"] and child is not None:
return child.contains(key)
elif child is not None:
return child.contains(key)
elif child is not None:
return child.contains(key)
return False
def list(self):
l = []
def _tolist(node, *args):
l.append(node)
self.walk(_tolist)
return l
def split(self, copy, children_copy):
left_elements = copy[:self.order // 2]
right_elements = copy[self.order // 2 + 1:]
left = BTreeNode(self.order)
left.elements[:self.order // 2] = left_elements
left_children = children_copy[:self.order // 2 + 1]
left.children[:len(left_children)] = left_children
assert len(left.elements) == self.order - 1
assert len(left.children) == self.order
middle = copy[self.order // 2]
right = BTreeNode(self.order)
right.elements[:len(right_elements)] = right_elements
right_children = children_copy[self.order // 2 + 1:]
right.children[:len(right_children)] = right_children
assert len(right.elements) == self.order - 1
assert len(right.children) == self.order
return None, middle, left, right
def insert_leaf(self, toinsert):
copy = self.elements.copy()
location_to_insert = 0
for e in copy:
if e is None or toinsert["key"] < e["key"]:
break
location_to_insert += 1
copy.insert(location_to_insert, toinsert)
children_copy = self.children.copy()
children_copy.insert(location_to_insert, None)
has_space = self.elements.count(None) > 0
if has_space:
new_self = BTreeNode(self.order)
assert copy[-1] is None
copy.pop()
new_self.elements = copy
assert len(new_self.elements) == new_self.order - 1
assert children_copy[-1] is None
children_copy.pop()
new_self.children = children_copy
assert len(new_self.children) == new_self.order
return new_self, None, None, None
# Otherwise, no space, let's split.
return self.split(copy, children_copy)
def insert_child(self, toinsert, i, child):
new_self = BTreeNode(self.order)
new_self.elements = self.elements.copy()
new_self.children = self.children.copy()
ret, middle, left, right = child.insert(toinsert)
if middle is None:
new_self.children[i] = ret
return new_self, None, None, None
# No space, we must split.
location_to_insert = 0
for e in new_self.elements:
if e is None or toinsert["key"] < e["key"]:
break
location_to_insert += 1
new_self.elements.insert(location_to_insert, middle)
assert sorted([x["key"] if x is not None else math.inf for x in new_self.elements]) == \
[x["key"] if x is not None else math.inf for x in new_self.elements]
new_self.children.insert(location_to_insert, None)
new_self.children[location_to_insert] = left
new_self.children[location_to_insert+1] = right
has_space = self.elements.count(None) > 0
if has_space:
assert new_self.elements[-1] is None
new_self.elements.pop()
assert len(new_self.elements) == new_self.order - 1
assert new_self.children[-1] is None
new_self.children.pop()
assert len(new_self.children) == new_self.order
return new_self, None, None, None
# No space, let's split.
return self.split(new_self.elements, new_self.children)
def insert(self, toinsert):
is_leaf = self.children.count(None) == len(self.children)
if is_leaf:
return self.insert_leaf(toinsert)
for i, child in enumerate(self.children):
if i < len(self.elements):
if self.elements[i] is None or \
toinsert["key"] < self.elements[i]["key"]:
return self.insert_child(toinsert, i, child)
elif child is not None:
return self.insert_child(toinsert, i, child)
assert False
import os
import struct
from btree import BTree
def test(order, generated, debug=False):
if debug: print("Input", generated)
t = BTree(order)
for i, v in enumerate(generated):
if debug: print("nth", i+1)
t = t.insert({
"key": v,
"value": v,
})
if debug: print(sorted(generated[:i+1]), [x["key"] for x in t.list()])
assert sorted(generated[:i+1]) == [x["key"] for x in t.list()]
if debug: t.print()
if debug: print("Output")
if debug: t.print()
l = [x["key"] for x in t.list()]
if debug: print(l)
s = sorted(generated)
for i in range(len(s)):
assert s[i] == l[i], f"wanted {s[i]}, got {l[i]}"
assert len(l) == len(s), f"wanted: {len(s)}, got len(l)"
assert sorted(generated) == [x["key"] for x in t.list()]
# First insert backwards
generated = []
for i in list(reversed(range(10))):
r = i
generated.append(r)
test(3, generated)
# Then insert going forward.
generated = []
for i in list(range(10)):
r = i
generated.append(r)
test(3, generated)
# Then insert randomly.
generated = []
for _ in list(range(10)):
r = struct.unpack('H', os.urandom(2))[0]
generated.append(r)
test(3, generated)
generated = []
for _ in list(range(20)):
r = struct.unpack('H', os.urandom(2))[0]
generated.append(r)
test(8, generated)
# Now let's try some big ones.
for tree_size in [3, 100, 4096]:
generated = []
for _ in list(range(10000)):
r = struct.unpack('H', os.urandom(2))[0]
generated.append(r)
test(tree_size, generated)
print(tree_size, len(generated))