Skip to content

Instantly share code, notes, and snippets.

@vkobel
Last active June 21, 2024 11:17
Show Gist options
  • Save vkobel/f609468a1d3b0fc8ed9c5e1177d3673a to your computer and use it in GitHub Desktop.
Save vkobel/f609468a1d3b0fc8ed9c5e1177d3673a to your computer and use it in GitHub Desktop.
Short incremental Merkle Tree python implementation
import hashlib
import random
class IncrementalMerkleTree:
def __init__(self, depth: int) -> None:
if depth < 1:
raise ValueError("Depth must be at least 1")
self.depth: int = depth
self.leaf_count: int = 2 ** depth
self.empty_leaf: bytes = self._hash(b'')
# The total number of nodes in the tree (both internal and leaf nodes) is 2^(d+1) - 1.
self.tree: list[bytes] = [self.empty_leaf * 32] * (2 * self.leaf_count - 1)
self.next_leaf_index = 0
@staticmethod
def _hash(left: bytes, right: bytes = b'') -> bytes:
return hashlib.sha256(left + right).digest()
def update(self, leaf: str) -> None:
if self.next_leaf_index >= self.leaf_count:
raise ValueError("Tree is full")
index: int = self.leaf_count - 1 + self.next_leaf_index
self.tree[index] = self._hash(leaf.encode())
while index > 0:
parent: int = (index - 1) // 2
left_child: bytes = self.tree[2 * parent + 1]
right_child: bytes = self.tree[2 * parent + 2]
self.tree[parent] = self._hash(left_child, right_child)
index = parent
self.next_leaf_index += 1
def get_root(self) -> str:
return self.tree[0].hex()
def get_proof(self, leaf_index: int) -> list[dict[str, str]]:
if leaf_index < 0 or leaf_index >= self.next_leaf_index:
raise ValueError("Leaf index out of range")
proof = []
index: int = self.leaf_count - 1 + leaf_index
while index > 0:
sibling_index: int = index - 1 if index % 2 == 0 else index + 1
is_left: bool = sibling_index < index
proof.append({
'sibling': self.tree[sibling_index].hex(),
'is_left': is_left
})
index = (index - 1) // 2
return proof
@staticmethod
def verify_proof(leaf: str, proof: list[dict[str, str]], root: str) -> bool:
current: bytes = IncrementalMerkleTree._hash(leaf.encode())
for node in proof:
sibling: bytes = bytes.fromhex(node['sibling'])
if node['is_left']:
current = IncrementalMerkleTree._hash(sibling, current)
else:
current = IncrementalMerkleTree._hash(current, sibling)
return current.hex() == root
def print_tree(self) -> None:
def format_hash(hash_bytes: bytes) -> str:
return hash_bytes.hex()[:8] # Show first 8 characters for brevity
levels = []
for i in range(self.depth + 1):
start: int = 2**i - 1
end: int = 2**(i + 1) - 1
levels.append([format_hash(h) for h in self.tree[start:end]])
max_level = len(levels[-1])
max_width = max_level * 10 # Assuming each hash is 8 characters long
def center_text(text, width):
if len(text) >= width:
return text
space = (width - len(text)) // 2
return ' ' * space + text + ' ' * space
for i, level in enumerate(levels):
level_width = len(level) * 10
spacing = (max_width - level_width) // len(level)
padded_level = [center_text(h, 10 + spacing) for h in level]
print(''.join(padded_level).center(max_width))
def test_merkle_tree(depth=3, nb_leaves=0) -> None:
tree = IncrementalMerkleTree(depth)
# generate a number of leaves <= tree leaf count (random)
if nb_leaves <= 0:
random_nb = random.randint(1, tree.leaf_count)
nb_leaves = random_nb
print(f"\nGenerating {nb_leaves} leaves, tree depth: {
depth}, max leaves: {tree.leaf_count}")
leaves = [f"leaf_{i}" for i in range(nb_leaves)]
for leaf in leaves:
tree.update(leaf)
if depth <= 5:
print("Tree structure:")
tree.print_tree()
root: str = tree.get_root()
print(f"Root: {root}")
for i, leaf in enumerate(leaves):
proof: list[dict[str, str]] = tree.get_proof(i)
is_valid: bool = IncrementalMerkleTree.verify_proof(leaf, proof, root)
if not is_valid:
print(f" Proof: {proof}")
assert is_valid
# Test invalid leaf
invalid_leaf = "invalid_leaf"
invalid_proof: list[dict[str, str]] = tree.get_proof(
0) # Use proof of leaf1
is_valid = IncrementalMerkleTree.verify_proof(
invalid_leaf, invalid_proof, root)
assert not is_valid
if __name__ == "__main__":
test_merkle_tree(depth=14, nb_leaves=16_000)
test_merkle_tree(depth=18, nb_leaves=1)
test_merkle_tree(depth=1, nb_leaves=1)
test_merkle_tree(depth=4)
print("All tests passed!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment