Skip to content

Instantly share code, notes, and snippets.

@sunfkny
Created June 2, 2023 07:52
Show Gist options
  • Save sunfkny/2a3f76555a23c38e515a7470b6ac47ea to your computer and use it in GitHub Desktop.
Save sunfkny/2a3f76555a23c38e515a7470b6ac47ea to your computer and use it in GitHub Desktop.
from typing import Iterable, List, Literal
from typing_extensions import TypedDict
from eth_typing import HexStr
from eth_utils.crypto import keccak
from eth_utils.hexadecimal import encode_hex, remove_0x_prefix
def keccak_hex(hexstr: str):
return encode_hex(keccak(hexstr=hexstr))
def keccak_bytes(hexstr: str):
return keccak(hexstr=hexstr)
class MerkleProof:
@classmethod
def _hash_pair(cls, a: HexStr, b: HexStr):
if a > b:
a, b = b, a
return keccak_hex(remove_0x_prefix(a) + remove_0x_prefix(b))
@classmethod
def _process_proof(cls, proof: List[HexStr], leaf: HexStr):
computed_hash = leaf
for p in proof:
computed_hash = cls._hash_pair(computed_hash, p)
return computed_hash
@classmethod
def verify(cls, proof: List[HexStr], root: HexStr, leaf: HexStr):
return cls._process_proof(proof, leaf) == root
class ProofItem(TypedDict):
position: Literal["left", "right"]
data: bytes
class MerkleTree:
@staticmethod
def make_leaves(address_list: Iterable[HexStr]):
leaves = [keccak_bytes(address) for address in address_list]
return leaves
def __init__(self, leaves: List[bytes]):
self.leaves = leaves
self._process_leaves()
def _process_leaves(self):
self.leaves.sort()
self.layers = [self.leaves]
self._create_hashes(self.leaves)
def _create_hashes(self, nodes: List[bytes]):
while len(nodes) > 1:
n = len(nodes)
layer_index = len(self.layers)
self.layers.append([])
for i in range(0, n, 2):
if n == i + 1 and n % 2 == 1:
self.layers[layer_index].append(nodes[i])
continue
left = nodes[i]
right = left if i + 1 == n else nodes[i + 1]
combined = left + right if left < right else right + left
hashed_data = keccak_bytes(HexStr(combined.hex()))
self.layers[layer_index].append(hashed_data)
nodes = self.layers[layer_index]
def get_hex_layers(self) -> List[List[HexStr]]:
return [[encode_hex(leaf) for leaf in layer] for layer in self.layers]
def get_root(self):
return self.layers[-1][0]
def get_hex_root(self):
root = self.get_root()
return encode_hex(root)
def get_proof(self, leaf: bytes, index: int = 0) -> List[ProofItem]:
proof: List[ProofItem] = []
if not index:
try:
index = self.leaves.index(leaf)
except ValueError:
return []
for layer in self.layers:
is_right_node = index % 2
pair_index = index - 1 if is_right_node else index + 1
if pair_index < len(layer):
proof.append(
{
"position": "left" if is_right_node else "right",
"data": layer[pair_index],
}
)
index = index // 2
return proof
def get_hex_proof(self, leaf: bytes, index: int = 0):
return [encode_hex(item["data"]) for item in self.get_proof(leaf, index)]
def get_hex_proof_by_hexstr(self, hexstr: HexStr):
return self.get_hex_proof(leaf=keccak_bytes(hexstr))
def get_depth(self) -> int:
return len(self.layers) - 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment