Skip to content

Instantly share code, notes, and snippets.

@proger
Created July 14, 2024 20:02
Show Gist options
  • Save proger/4c4d4fd0eebce88388f976087f27da76 to your computer and use it in GitHub Desktop.
Save proger/4c4d4fd0eebce88388f976087f27da76 to your computer and use it in GitHub Desktop.
"""
Randomized Binary Search Trees
https://www.cs.upc.edu/~conrado/research/papers/jacm-mr98.pdf
"""
import math
import random
from collections import Counter
class root:
__slots__ = 'x', 'left', 'right', 'size'
def __init__(self, x, left=None, right=None):
self.x = x
self.left = left
self.right = right
left_size = self.left.size if left else 0
right_size = self.right.size if right else 0
self.size = 1 + left_size + right_size
def __repr__(self):
return f'({self.x} {self.left or "_"} {self.right or "_"})'
def depth(self):
if not self:
return 0
return 1 + max(depth(self.left), depth(self.right))
def split(x, tree):
if not tree:
return None, None
if x <= tree.x:
left_left, left_right = split(x, tree.left)
return left_left, root(tree.x, left=left_right, right=tree.right)
else:
right_left, right_right = split(x, tree.right)
return root(tree.x, left=tree.left, right=right_left), right_right
def insert(x, tree):
if not tree:
return root(x)
n = tree.size + 1
if random.random() < 1/n:
left, right = split(x, tree)
return root(x, left=left, right=right)
else:
if tree.x <= x:
return root(tree.x, left=tree.left, right=insert(x, tree.right))
else: # tree.x > x
return root(tree.x, left=insert(x, tree.left), right=tree.right)
def join(left, right):
if not left:
return right
if not right:
return left
size = left.size + right.size
if random.random() < left.size/size:
return root(left.x, left=left.left, right=join(left.right, right))
else:
return root(right.x, left=join(left, right.left), right=right.right)
def delete(x, tree):
if not tree:
return None
if tree.x == x:
return join(tree.left, tree.right)
if x < tree.x:
return root(tree.x, left=delete(x, tree.left), right=tree.right)
else:
return root(tree.x, left=tree.left, right=delete(x, tree.right))
#
# let's simulate multiple trees with the same input and see how deep they get
#
def go(xs, tree=None):
tree = None
for x in xs:
tree = insert(x, tree)
return tree
random.seed(42)
N = 1000
S = 32
counter = Counter(
depth(go(range(1,S+1))) for _ in range(N)
)
print('depth distribution after', N, 'trials:')
for d, c in sorted(counter.items()):
print(d, '|'*int(c*100/N), c/N)
print('perfect depth:', math.log2(S))
xs = list(range(1,S+1))
t = go(xs)
print(f'after inserting {S} elements', t)
for x in xs[::2]: # delete odd
t = delete(x, t)
print('after deleting odds:', t)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment