Skip to content

Instantly share code, notes, and snippets.

@pervognsen
Last active October 23, 2022 10:57
Show Gist options
  • Star 11 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pervognsen/59a0d0cc63ec04b0325a7322133276bc to your computer and use it in GitHub Desktop.
Save pervognsen/59a0d0cc63ec04b0325a7322133276bc to your computer and use it in GitHub Desktop.
# Here's a probably-not-new data structure I discovered after implementing weight-balanced trees with
# amortized subtree rebuilds (e.g. http://jeffe.cs.illinois.edu/teaching/algorithms/notes/10-scapegoat-splay.pdf)
# and realizing it was silly to turn a subtree into a sorted array with an in-order traversal only as an
# aid in constructing a balanced subtree in linear time. Why not replace the subtree by the sorted array and
# use binary search when hitting that leaf array? Then you'd defer any splitting of the array until inserts and
# deletes hit that leaf array. Only in hindsight did I realize this is similar to the rope data structure for strings.
# Unlike ropes, it's a key-value search tree for arbitrary ordered keys.
#
# The main attraction of this data structure is its simplicity (on par with a standard weight-balanced tree) and that it
# coalesces subtrees into contiguous arrays, which reduces memory overhead and boosts the performance of in-order traversals
# and range queries.
#
# This chain tree has a worst-case search time of O(log n) and insertion and deletion take O(log n) amortized time. The
# worst-case time for insertions/deletions is O(n), similar to table doubling for dynamic arrays, but the spikes
# for rebuilds are usually much smaller than that since they are confined to subtrees.
#
# In these bounds the number n is the number of live items + the number of deletions. Past deletions can affect the
# time taken by future operations because deleted items initially remain in interior nodes as tombstones/markers.
# There's a standard amortization trick to get rid of this O(log(deletions)) tax: You just rebuild the whole tree if
# the fraction of tombstones gets too high; you've probably seen this in open-addressed hash tables. The amortized cost
# of this global rebuild is only O(1) per deletion. But since tombstones are skipped when we rebuild a subtree, this
# worst-case O(log(deletions)) tax is usually much smaller. A lot of databases handle B-tree deletions in a similar way
# by letting nodes underflow below the B/2 limit, and they can't afford the O(n) rebuild time since n could be trillions.
# You can show that this incurs a similar O(log(deletions)) tax: http://sidsen.azurewebsites.net/papers/relaxed-b-tree-tods.pdf
#
# If you implement this in a language without garbage collection, you'll need to do your own reference counting for leaf arrays
# since they are shared among different chains. This also illustrates a potential issue with space bloat. Like with ropes, the
# array sharing means a single item can keep alive an arbitrarily large array.
#
# This is probably the best engineering argument for amortized global rebuilds: Instead of worrying about tombstones directly,
# just track the amount of wasted space and trigger a global rebuild when it reaches too large a fraction of the tree's total space.
# That gives you a worst-case space bound of O(n) with O(1) additional amortized cost per insertion/deletion. Waste is introduced
# both by deletions (waste += 1) and by insertions when a subtree is rebuilt (waste += subtree_size if the shared array is still live).
# We charge $1 of the global rebuild cost to each deletion. For an insertion that triggers a subtree rebuild, we charge $1 to the
# subtree_size item copies during the subtree rebuild. Due to weight balance, the sibling subtree is also O(subtree_size) in size.
# Hence we've charged enough to pay for the eventual global rebuild of both the rebuilt and non-rebuilt subtrees.
#
# NOTE: This version doesn't store values in nodes, so they're just key splitters. Splits are around medians from sorted chains,
# so this is very close to the original inspiration of weight-balanced trees with lazy splitting. Some of the changes also makes
# it look closer to B-trees (specifically B+trees since values only live in leaves).
from random import randrange, seed
from itertools import islice
from bisect import bisect_left
local_rebuild_cost = 0
global_rebuild_cost = 0
DELETED = object()
SMALL = 32
HALF_SMALL = 16
class Chain:
def __init__(self, keys, values, start, end):
assert len(keys) == len(values)
assert 0 <= start <= end <= len(keys)
self.keys = keys
self.values = values
self.start = start
self.end = end
def weight(self):
return self.end - self.start
def __iter__(self):
for i in range(self.start, self.end):
yield self.keys[i], self.values[i]
def emit(self, keys, values):
keys.extend(self.keys[self.start:self.end])
values.extend(self.values[self.start:self.end])
def slice(self, start, end):
return Chain(self.keys, self.values, start, end)
def key_index(self, key):
i = bisect_left(self.keys, key, self.start, self.end)
return i, i < self.end and self.keys[i] == key
def split(self):
i = (self.start + self.end) // 2
return i, self.keys[i], self.slice(self.start, i + 1), self.slice(i + 1, self.end)
def search(self, key, none):
i, found = self.key_index(key)
return self.values[i] if found else none
def is_small_array(self):
return self.end - self.start == len(self.keys) <= SMALL
def insert(self, key, value):
i, found = self.key_index(key)
if found:
self.values[i] = value
return self
# This is a fast path for small arrays or appends. No point in splitting when it's cheaper to just do it in place.
# It's especially helpful since appends are common and would otherwise provoke the maximal number of median splits.
# It also helps with insert locality: if you do an insert at a key and split around it and then the left subtree
# is eventually rebuilt into its own dynamic array then you can do fast in-place appends at the split location,
# so every split which is adjacent to a complete chain effectively becomes a fast insert cursor.
if self.is_small_array() or i == len(self.keys):
return self.insert_in_place(key, value, i)
return self.insert_at(key, value, i)
def insert_in_place(self, key, value, i):
self.keys.insert(i, key)
self.values.insert(i, value)
self.end += 1
return self
def insert_at(self, key, value, i):
# We use HALF_SMALL so that new arrays can double in size before needing splits.
if self.weight() <= HALF_SMALL:
return rebuild(self).insert_in_place(key, value, i - self.start)
split, split_key, left, right = self.split()
if i <= split:
left = left.insert_at(key, value, i)
else:
right = right.insert_at(key, value, i)
return Node(split_key, left, right)
def delete(self, key):
i, found = self.key_index(key)
if not found:
return self
# This is a fast path for deleting at the start/end of a chain; you might have seen the same idea for ropes.
if i == self.start:
self.start += 1
return self
elif i == self.end - 1:
self.end -= 1
return self
return self.delete_at(key, i)
def delete_at(self, key, i):
if self.weight() <= SMALL:
self.keys[i:self.end-1] = self.keys[i+1:self.end]
self.values[i:self.end-1] = self.values[i+1:self.end]
self.end -= 1
return self
split, split_key, left, right = self.split()
if i <= split:
left = left.delete_at(key, i)
else:
right = right.delete_at(key, i)
return Node(split_key, left, right)
def rebuild(node):
keys = []
values = []
node.emit(keys, values)
return Chain(keys, values, 0, len(keys))
class Node:
def __init__(self, key, left, right):
self.key = key
self.left = left
self.right = right
self.cached_weight = 1 + left.weight() + right.weight()
def weight(self):
return self.cached_weight
def __iter__(self):
yield from self.left
yield from self.right
def emit(self, keys, values):
self.left.emit(keys, values)
self.right.emit(keys, values)
def search(self, key, none):
if key <= self.key:
return self.left.search(key, none)
else:
return self.right.search(key, none)
def balance(self):
weights = self.left.weight(), self.right.weight()
self.cached_weight = 1 + sum(weights)
# Allow up to 1:2 weight balance before a local rebuild.
if max(weights) <= 2 * min(weights):
return self
else:
global local_rebuild_cost
local_rebuild_cost += self.cached_weight
return rebuild(self)
def insert(self, key, value):
if key <= self.key:
self.left = self.left.insert(key, value)
else:
self.right = self.right.insert(key, value)
return self.balance()
def delete(self, key):
if key <= self.key:
self.left = self.left.delete(key)
else:
self.right = self.right.delete(key)
return self.balance()
class ChainTree:
def __init__(self):
self.root = Chain([], [], 0, 0)
self.account = 0
def flatten(self):
self.root = rebuild(self.root)
self.account = 0
def charge(self):
self.account += 1
cost = self.root.weight()
# This is a simple amortization scheme for global rebuilds that's guaranteed not to overcharge.
if self.account >= cost:
global global_rebuild_cost
global_rebuild_cost += cost
self.flatten()
def __iter__(self):
return iter(self.root)
def search(self, key, none=None):
return self.root.search(key, none)
def insert(self, key, value):
self.root = self.root.insert(key, value)
self.charge()
def delete(self, key):
self.root = self.root.delete(key)
self.charge()
# Example
seed(0)
tree = ChainTree()
n = 100000
d = {}
print("Inserting")
for i in range(n // 4):
key = randrange(2 * n)
# print("%s => %s" % (key, i))
d[key] = i
tree.insert(key, i)
for i in range(n // 4):
key = randrange(n - 10, n + 10)
# print("%s => %s" % (key, i))
d[key] = i
tree.insert(key, i)
for i in range(n // 4):
key = randrange(2 * n)
# print("%s => %s" % (key, i))
d[key] = i
tree.insert(key, i)
for i in range(n // 4):
key = randrange(n - 10, n + 10)
# print("%s => %s" % (key, i))
d[key] = i
tree.insert(key, i)
print("Local rebuild cost:", local_rebuild_cost)
print("Global rebuild cost:", global_rebuild_cost)
#tree.flatten()
print("Searching")
for key, value in d.items():
assert tree.search(key) == value
print("Iterating")
for key, value in tree:
assert d[key] == value
print("Deleting")
for i in range(n // 2):
key = randrange(2 * n)
if key in d:
# print("Deleting", key)
del d[key]
tree.delete(key)
assert sorted(d.items()) == list(tree)
print("Local rebuild cost:", local_rebuild_cost)
print("Global rebuild cost:", global_rebuild_cost)
#for key, value in tree:
# print(key, value)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment