Skip to content

Instantly share code, notes, and snippets.

@mbillingr
Last active January 2, 2024 19:35
Show Gist options
  • Save mbillingr/28d6260828dc734d44de54731831d952 to your computer and use it in GitHub Desktop.
Save mbillingr/28d6260828dc734d44de54731831d952 to your computer and use it in GitHub Desktop.
Prototype implementation of Ideal Hash Mapped Tries
from __future__ import annotations
import ctypes
import dataclasses
from typing import Any
LEAF_SIZE = 32
LEAF_MASK = LEAF_SIZE - 1
HASH_BITS = 64
class Trie:
pass
@dataclasses.dataclass
class Leaf(Trie):
key: Any
val: Any
def items(self):
yield self.key, self.val
@dataclasses.dataclass
class Node(Trie):
mapping: int
subtrie: list[Trie]
def items(self):
for st in self.subtrie:
yield from st.items()
def __repr__(self):
args = ", ".join(map(lambda kv: f"{kv[0]}: {kv[1]}", self.items()))
return f"hamt({{{args}}})"
_notset = object()
def hamt(seq: Any = _notset, **kwargs: Any):
if kwargs:
if seq is not _notset:
seq = {"seq": seq, **kwargs}
else:
seq = kwargs
if isinstance(seq, dict):
seq = seq.items()
trie = empty()
for k, v in seq:
trie = insert(k, v, trie)
return trie
def empty() -> Node:
return Node(0, [])
def contains(key: Any, trie: Node) -> bool:
try:
lookup(key, trie)
return True
except KeyError:
return False
def lookup(key: Any, trie: Node) -> Any:
k = uhash(key)
return lookup_(key, k, trie)
def lookup_(key: Any, k: int, trie: Node) -> Any:
idx = k & LEAF_MASK
mask_bit = 1 << idx
if not trie.mapping & mask_bit:
raise KeyError(key)
idx_ = ctpop(trie.mapping & (mask_bit - 1))
match trie.subtrie[idx_]:
case Leaf(key_, val_):
if key_ == key:
return val_
else:
raise KeyError(key)
case Node() as child:
return lookup_(key, k // LEAF_SIZE, child)
def insert(key: Any, value: Any, trie: Node) -> Node:
k = uhash(key)
return insert_(key, value, k, 1, trie)
def insert_(key: Any, value: Any, k: int, depth: int, trie: Node) -> Node:
idx = k & LEAF_MASK
mask_bit = 1 << idx
idx_ = ctpop(trie.mapping & (mask_bit - 1))
if trie.mapping & mask_bit:
match trie.subtrie[idx_]:
case Leaf(key_, _) if key_ == key:
new_child = Leaf(key, value)
case Leaf() as leaf:
new_child = split_(
Leaf(key, value), k // LEAF_SIZE, leaf, uhash(leaf.key) // (depth * LEAF_SIZE)
)
case Node() as child:
new_child = insert_(key, value, k // LEAF_SIZE, depth * LEAF_SIZE, child)
new_subtrie = trie.subtrie.copy()
new_subtrie[idx_] = new_child
return Node(trie.mapping, new_subtrie)
else:
leaf = Leaf(key, value)
new_mapping = trie.mapping | mask_bit
new_subtrie = trie.subtrie.copy()
new_subtrie.insert(idx_, leaf)
return Node(new_mapping, new_subtrie)
def split_(leaf1: Leaf, k1: int, leaf2: Leaf, k2: int) -> Node:
idx1 = k1 & LEAF_MASK
idx2 = k2 & LEAF_MASK
mb1 = 1 << idx1
mb2 = 1 << idx2
if idx1 == idx2:
return Node(mb1, [split_(leaf1, k1 // LEAF_SIZE, leaf2, k2 // LEAF_SIZE)])
if mb1 < mb2:
subtree = [leaf1, leaf2]
else:
subtree = [leaf2, leaf1]
return Node(mb1 | mb2, subtree)
def remove(key: Any, trie: Node) -> Node:
k = uhash(key)
return remove_(key, k, trie)
def remove_(key: Any, k: int, trie: Node) -> Trie | None:
idx = k & LEAF_MASK
mask_bit = 1 << idx
if not trie.mapping & mask_bit:
raise KeyError(key)
idx_ = ctpop(trie.mapping & (mask_bit - 1))
match trie.subtrie[idx_]:
case Leaf(key_, val_):
if key_ == key:
match len(trie.subtrie):
case 1:
return None
case 2:
return trie.subtrie[1 - idx_]
case _:
new_subtree = trie.subtrie.copy()
new_subtree.pop(idx_)
return Node(trie.mapping & ~mask_bit, new_subtree)
else:
raise KeyError(key)
case Node() as child:
match remove_(key, k // LEAF_SIZE, child):
case None:
match len(trie.subtrie):
case 1:
return None
case 2:
return trie.subtrie[1 - idx_]
case _:
new_subtree = trie.subtrie.copy()
new_subtree.pop(idx_)
return Node(trie.mapping & ~mask_bit, new_subtree)
case other:
new_subtree = trie.subtrie.copy()
new_subtree[idx_] = other
return Node(trie.mapping, new_subtree)
def uhash(x: Any) -> int:
"""Make sure we get a unsigned hash value"""
return ctypes.c_size_t(hash(x)).value
def ctpop(x: int) -> int:
count = 0
while x > 0:
count += x & 1
x //= 2
return count
assert ctpop(0) == 0
assert ctpop(1) == 1
assert ctpop(2) == 1
assert ctpop(3) == 2
assert ctpop(4) == 1
assert ctpop(5) == 2
assert ctpop(255) == 8
def show_trie(trie: Trie, indent=""):
match trie:
case Leaf(k, v):
print(indent, bin(uhash(k)), k, ":", v)
case Node(m, s):
print(indent, bin(m))
for t in s:
show_trie(t, indent + " ")
m = empty()
for k in range(100):
m = insert(str(k), k, m)
assert lookup("0", m) == 0
assert lookup("1", m) == 1
assert lookup("2", m) == 2
assert lookup("99", m) == 99
n = remove("2", m)
assert not contains("2", n)
if False:
import random
import time
from matplotlib import pyplot as plt
m = empty()
times = []
for _ in range(10000):
k = str(random.randint(0, 2**60))
a = time.time()
m = insert(str(k), k, m)
b = time.time() - a
times.append(b)
for _ in range(10000):
k = str(random.randint(0, 2**60))
a = time.time()
_ = insert(str(k), k, m)
b = time.time() - a
times.append(b)
plt.plot(times)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment