Created
June 2, 2023 07:52
-
-
Save sunfkny/2a3f76555a23c38e515a7470b6ac47ea to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Iterable, List, Literal | |
from typing_extensions import TypedDict | |
from eth_typing import HexStr | |
from eth_utils.crypto import keccak | |
from eth_utils.hexadecimal import encode_hex, remove_0x_prefix | |
def keccak_hex(hexstr: str): | |
return encode_hex(keccak(hexstr=hexstr)) | |
def keccak_bytes(hexstr: str): | |
return keccak(hexstr=hexstr) | |
class MerkleProof: | |
@classmethod | |
def _hash_pair(cls, a: HexStr, b: HexStr): | |
if a > b: | |
a, b = b, a | |
return keccak_hex(remove_0x_prefix(a) + remove_0x_prefix(b)) | |
@classmethod | |
def _process_proof(cls, proof: List[HexStr], leaf: HexStr): | |
computed_hash = leaf | |
for p in proof: | |
computed_hash = cls._hash_pair(computed_hash, p) | |
return computed_hash | |
@classmethod | |
def verify(cls, proof: List[HexStr], root: HexStr, leaf: HexStr): | |
return cls._process_proof(proof, leaf) == root | |
class ProofItem(TypedDict): | |
position: Literal["left", "right"] | |
data: bytes | |
class MerkleTree: | |
@staticmethod | |
def make_leaves(address_list: Iterable[HexStr]): | |
leaves = [keccak_bytes(address) for address in address_list] | |
return leaves | |
def __init__(self, leaves: List[bytes]): | |
self.leaves = leaves | |
self._process_leaves() | |
def _process_leaves(self): | |
self.leaves.sort() | |
self.layers = [self.leaves] | |
self._create_hashes(self.leaves) | |
def _create_hashes(self, nodes: List[bytes]): | |
while len(nodes) > 1: | |
n = len(nodes) | |
layer_index = len(self.layers) | |
self.layers.append([]) | |
for i in range(0, n, 2): | |
if n == i + 1 and n % 2 == 1: | |
self.layers[layer_index].append(nodes[i]) | |
continue | |
left = nodes[i] | |
right = left if i + 1 == n else nodes[i + 1] | |
combined = left + right if left < right else right + left | |
hashed_data = keccak_bytes(HexStr(combined.hex())) | |
self.layers[layer_index].append(hashed_data) | |
nodes = self.layers[layer_index] | |
def get_hex_layers(self) -> List[List[HexStr]]: | |
return [[encode_hex(leaf) for leaf in layer] for layer in self.layers] | |
def get_root(self): | |
return self.layers[-1][0] | |
def get_hex_root(self): | |
root = self.get_root() | |
return encode_hex(root) | |
def get_proof(self, leaf: bytes, index: int = 0) -> List[ProofItem]: | |
proof: List[ProofItem] = [] | |
if not index: | |
try: | |
index = self.leaves.index(leaf) | |
except ValueError: | |
return [] | |
for layer in self.layers: | |
is_right_node = index % 2 | |
pair_index = index - 1 if is_right_node else index + 1 | |
if pair_index < len(layer): | |
proof.append( | |
{ | |
"position": "left" if is_right_node else "right", | |
"data": layer[pair_index], | |
} | |
) | |
index = index // 2 | |
return proof | |
def get_hex_proof(self, leaf: bytes, index: int = 0): | |
return [encode_hex(item["data"]) for item in self.get_proof(leaf, index)] | |
def get_hex_proof_by_hexstr(self, hexstr: HexStr): | |
return self.get_hex_proof(leaf=keccak_bytes(hexstr)) | |
def get_depth(self) -> int: | |
return len(self.layers) - 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment