Last active
October 18, 2023 11:25
-
-
Save parodyBit/6556a1ac9535d8875f989d9b6808b454 to your computer and use it in GitHub Desktop.
Verify Witnet Signed Message in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from typing import List, AnyStr, Union, Tuple | |
import hashlib | |
import random | |
import unicodedata | |
''' | |
Transformation tools | |
''' | |
Bytes = List[int] | |
def sha256(x: bytes): return hashlib.sha256(x).digest() | |
def bytes_to_hex(b: bytes): | |
return b.hex() | |
def bytes_to_int(bts: bytes) -> int: | |
return int.from_bytes(bts, 'big') | |
def int_to_bytes(i: int) -> bytes: | |
length = max(1, (i.bit_length() + 7) // 8) | |
return i.to_bytes(length, 'big') | |
def int_to_hex(i: int) -> str: | |
return format(i, 'x') | |
def hex_to_bytes(h: str): | |
return bytes.fromhex(h) | |
def concat_string(values: List[str]) -> str: | |
return ''.join(values) | |
def concat_bytes(values: List[bytes]) -> bytes: | |
return b''.join(values) | |
def concat(values: Union[List[str], List[bytes]]) -> Union[str, bytes]: | |
if isinstance(values[0], str): | |
return concat_string(values) | |
elif isinstance(values[0], bytes): | |
return concat_bytes(values) | |
class BaseConversionError(Exception): | |
pass | |
def convert_bits(data: Bytes, from_bits: int, to_bits: int, pad=True) -> Bytes: | |
"""General power-of-2 base conversion.""" | |
acc = 0 | |
bits = 0 | |
ret = [] | |
max_v = (1 << to_bits) - 1 | |
max_acc = (1 << (from_bits + to_bits - 1)) - 1 | |
for value in data: | |
if value < 0 or (value >> from_bits): | |
raise BaseConversionError | |
acc = ((acc << from_bits) | value) & max_acc | |
bits += from_bits | |
while bits >= to_bits: | |
bits -= to_bits | |
ret.append((acc >> bits) & max_v) | |
if pad: | |
if bits: | |
ret.append((acc << (to_bits - bits)) & max_v) | |
elif bits >= from_bits or ((acc << (to_bits - bits)) & max_v): | |
raise BaseConversionError | |
return ret | |
def normalize_string(txt: AnyStr) -> str: | |
if isinstance(txt, bytes): | |
utxt = txt.decode('utf8') | |
elif isinstance(txt, str): | |
utxt = txt | |
else: | |
raise TypeError('String value expected') | |
return unicodedata.normalize('NFKD', utxt) | |
''' | |
Protobuf Utilities | |
''' | |
TAG_TYPE_BITS = 3 | |
TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 | |
VAR_INT = 0 | |
FIXED64 = 1 | |
LENGTH_DELIMITED = 2 | |
def var_int(value: int): | |
""" | |
Write unsigned `VarInt` to a file-like object. | |
""" | |
if isinstance(value, str): | |
value = int(value) | |
tmp = [] | |
while value > 0x7F: | |
tmp.append(bytes((value & 0x7F | 0x80,))) | |
value >>= 7 | |
tmp.append(bytes((value,))) | |
return concat(tmp) | |
def var_int_serializer(value: int): | |
return var_int(value) | |
def bytes_serializer(value: bytes): | |
return concat([var_int(len(value)), value]) | |
def make_tag(field_number: int, tag: int) -> int: return (field_number << TAG_TYPE_BITS) | tag | |
def make_tag_bytes(field_number: int, tag: int) -> bytes: return var_int_serializer(make_tag(field_number, tag)) | |
def pb_field(field_number: int, tag: int, value): | |
_data = [] | |
if tag == VAR_INT: | |
_data = concat([var_int_serializer(value=value)]) | |
elif tag == LENGTH_DELIMITED: | |
_data = bytes_serializer(value=value) | |
else: | |
... | |
return concat([make_tag_bytes(field_number=field_number, tag=tag), _data]) | |
''' | |
Bech32 | |
''' | |
class Bech32DecodeError(Exception): | |
pass | |
CHARSET = 'qpzry9x8gf2tvdw0s3jn54khce6mua7l' | |
def bech32_poly_mod(values) -> int: | |
"""Internal function that computes the Bech32 checksum.""" | |
generator = [0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3] | |
chk = 1 | |
for value in values: | |
top = chk >> 25 | |
chk = (chk & 0x1ffffff) << 5 ^ value | |
for i in range(5): | |
chk ^= generator[i] if ((top >> i) & 1) else 0 | |
return chk | |
def bech32_hrp_expand(hrp: str) -> List[int]: | |
"""Expand the HRP into values for checksum computation.""" | |
return [ord(x) >> 5 for x in hrp] + [0] + [ord(x) & 31 for x in hrp] | |
def bech32_verify_checksum(hrp: str, data) -> bool: | |
"""Verify a checksum given HRP and converted data characters.""" | |
return bech32_poly_mod(bech32_hrp_expand(hrp) + data) == 1 | |
def bech32_create_checksum(hrp: str, data): | |
"""Compute the checksum values given HRP and data.""" | |
values = bech32_hrp_expand(hrp) + data | |
polymod = bech32_poly_mod(values + [0, 0, 0, 0, 0, 0]) ^ 1 | |
return [(polymod >> 5 * (5 - i)) & 31 for i in range(6)] | |
def bech32_encode_address(hrp: str, data: str) -> str: | |
hex_bytes = [b for b in bytes.fromhex(data)] | |
data = convert_bits(data=hex_bytes, from_bits=8, to_bits=5, pad=False) | |
checksum = bech32_create_checksum(hrp=hrp, data=data) | |
combined = data + checksum | |
return normalize_string(hrp + '1' + ''.join([CHARSET[i] for i in combined])) | |
def bech32_decode_address(bech: str): | |
bech, separator = bech.lower(), bech.find('1') | |
hrp_ = bech[:separator] | |
data = [CHARSET.find(x) for x in bech[separator + 1:]] | |
decoded = data[:-6] | |
try: | |
if any(ord(x) < 33 or ord(x) > 126 for x in bech): | |
raise Bech32DecodeError('Character outside US-ASCII [33-126] range') | |
if (bech.lower() != bech) and (bech.upper() != bech): | |
raise Bech32DecodeError('Mixed upper and lower case') | |
if separator == 0: | |
raise Bech32DecodeError('Empty human readable part') | |
elif separator == -1: | |
raise Bech32DecodeError('No separator character') | |
elif separator + 7 > len(bech): | |
raise Bech32DecodeError('Checksum too short') | |
if not all(x in CHARSET for x in bech[separator + 1:]): | |
raise Bech32DecodeError('Character not in charset') | |
if not bech32_verify_checksum(hrp_, data): | |
raise Bech32DecodeError('Invalid checksum') | |
if decoded is None or len(decoded) < 2: | |
raise Bech32DecodeError('Witness program too short') | |
except Bech32DecodeError as error: | |
print(error) | |
b256 = convert_bits(decoded, from_bits=5, to_bits=8, pad=True) | |
return b256 | |
''' | |
Schema Items | |
''' | |
@dataclass | |
class PublicKey: | |
_bytes: bytes | |
compressed: int | |
@classmethod | |
def from_json(cls, data: dict): | |
return PublicKey(_bytes=bytes(data['bytes']), compressed=data['compressed']) | |
def to_json(self, as_hex: bool = True): | |
return { | |
'bytes': list(self._bytes), | |
'compressed': self.compressed | |
} if not as_hex else { | |
'bytes': bytes_to_hex(self._bytes), | |
'compressed': self.compressed | |
} | |
def pb_bytes(self) -> bytes: | |
return pb_field( | |
field_number=1, | |
tag=LENGTH_DELIMITED, | |
value=concat([int_to_bytes(self.compressed), self._bytes]) | |
) | |
def hash(self): | |
return sha256(self.pb_bytes()) | |
@dataclass | |
class PublicKeyHash: | |
hash: bytes | |
@classmethod | |
def from_address(cls, data: str): | |
return PublicKeyHash(hash=bytes(bech32_decode_address(data))) | |
def to_address(self): | |
return bech32_encode_address(hrp='wit', data=bytes_to_hex(self.hash)) | |
def pb_bytes(self) -> bytes: | |
return pb_field(field_number=1, tag=LENGTH_DELIMITED, value=self.hash) | |
def hash(self): | |
return sha256(self.pb_bytes()) | |
''' | |
Number Theory | |
''' | |
def miller_rabin(n, runs=40): | |
# Implementation uses the Miller-Rabin Primality Test The optimal number of rounds for this test is 40 See | |
# http://stackoverflow.com/questions/6325576/how-many-iterations-of-rabin-miller-should-i-use-for-cryptographic | |
# -safe-primes for justification | |
if n < 6: # assuming n >= 0 in all cases... shortcut small cases here | |
return [False, False, True, True, False, True][n] | |
# If number is even, it's a composite number | |
elif n & 1 == 0: # should be faster than n % 2 | |
return False | |
else: | |
r, s = 0, n - 1 | |
while s % 2 == 0: | |
r += 1 | |
s //= 2 | |
for _ in range(runs): | |
a = random.randrange(3, n - 1, 2) | |
x = pow(a, s, n) | |
if x == 1 or x == n - 1: | |
continue | |
for _ in range(r - 1): | |
x = pow(x, 2, n) | |
if x == n - 1: | |
break | |
else: | |
return False | |
return True | |
def legendre(a, p): | |
"""https://en.wikipedia.org/wiki/Legendre_symbol""" | |
assert miller_rabin(p), f"{p} is not a prime" | |
mod = pow(a, (p - 1) // 2, p) | |
return -1 if mod == p - 1 else mod | |
def xgcd(b: int, n: int) -> Tuple[int, int, int]: | |
"""Takes positive integers a, b as input, and return a triple (g, x, y), such that ax + by = g = gcd(a, b)""" | |
x0, x1, y0, y1 = 1, 0, 0, 1 | |
while n != 0: | |
q, b, n = b // n, n, b % n | |
x0, x1 = x1, x0 - q * x1 | |
y0, y1 = y1, y0 - q * y1 | |
return b, x0, y0 | |
def mulinv(b, n, algo=xgcd): | |
"""An application of extended GCD algorithm to finding modular inverses""" | |
g, x, _ = algo(b, n) | |
assert g == 1, 'Numbers must be coprimes' | |
return x % n | |
def modsqrt(a, p): | |
""" | |
https://eli.thegreenplace.net/2009/03/07/computing-modular-square-roots-in-python | |
Find a quadratic residue (mod p) of 'a'. p must be an odd prime. | |
Solve the congruence of the form: | |
x^2 = a (mod p) | |
And returns x. Note that p - x is also a root. | |
0 is returned if no square root exists for these a and p. | |
The Tonelli-Shanks algorithm is used (except for some simple cases in which the solution is known from an | |
identity). This algorithm runs in polynomial time (unless the generalized Riemann hypothesis is false). | |
""" | |
# Simple cases | |
# | |
if legendre(a, p) != 1: | |
return 0 | |
elif a == 0: | |
return 0 | |
elif p == 2: | |
return p | |
elif p % 4 == 3: | |
return pow(a, (p + 1) // 4, p) | |
# Partition p-1 to s * 2^e for an odd s (i.e. | |
# reduce all the powers of 2 from p-1) | |
# | |
s = p - 1 | |
e = 0 | |
while s % 2 == 0: | |
s //= 2 | |
e += 1 | |
# Find some 'n' with a legendre symbol n|p = -1. | |
# Shouldn't take long. | |
# | |
n = 2 | |
while legendre(n, p) != -1: | |
n += 1 | |
# Here be dragons! | |
# Read the paper "Square roots from 1; 24, 51, | |
# 10 to Dan Shanks" by Ezra Brown for more | |
# information | |
# | |
# x is a guess of the square root that gets better | |
# with each iteration. | |
# b is the "fudge factor" - by how much we're off | |
# with the guess. The invariant x^2 = ab (mod p) | |
# is maintained throughout the loop. | |
# g is used for successive powers of n to update | |
# both a and b | |
# r is the exponent - decreases with each update | |
# | |
x = pow(a, (s + 1) // 2, p) | |
b = pow(a, s, p) | |
g = pow(n, s, p) | |
r = e | |
while True: | |
t = b | |
m = 0 | |
for m in range(r): | |
if t == 1: | |
break | |
t = pow(t, 2, p) | |
if m == 0: | |
return x | |
gs = pow(g, 2 ** (r - m - 1), p) | |
g = (gs * gs) % p | |
x = (x * gs) % p | |
b = (b * g) % p | |
r = m | |
''' | |
Secp256k1 Stuff | |
''' | |
P = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F | |
# Generator | |
G = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, \ | |
0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8 | |
# Order | |
N = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 | |
# Elliptic curve parameters A and B of the curve : y² = x³ Ax + B | |
A: int = 0 | |
B: int = 7 | |
class Point: | |
def __init__(self, x, y, curve=None): | |
self.x = x | |
self.y = y | |
self.curve = curve | |
assert self in curve, f"Point {x}, {y} not in curve" | |
def __add__(self, other): | |
assert self.curve == other.curve, 'Cannot add points on different curves' | |
return self.curve.point_add(self, other) | |
def __sub__(self, other): | |
return self + (other * -1) | |
def __mul__(self, other: int): | |
assert isinstance(other, int), 'Multiplication is only defined between a point and an integer' | |
return self.curve.point_mul(self, other) | |
def __repr__(self): | |
return f"Point({self.x}, {self.y}, {self.curve.name})" | |
def __eq__(self, other): | |
return self.x % self.curve.prime == other.x % self.curve.prime \ | |
and self.y % self.curve.prime == other.y % self.curve.prime | |
@dataclass | |
class Curve: | |
prime: int # P | |
a: int | |
b: int | |
generator: Union[Tuple, Point] | |
order: int # N | |
name: str | |
def __post_init__(self): | |
if type(self.generator).__name__ == 'tuple': | |
self.generator = Point(*self.generator, curve=self) | |
def point_add(self, p, q): | |
"""https://en.wikipedia.org/wiki/Elliptic_curve_point_multiplication#Point_addition""" | |
_p = self.prime | |
if p == q: | |
lam = (3 * p.x * p.x) * pow(2 * p.y % _p, _p - 2, _p) | |
else: | |
lam = pow(q.x - p.x, _p - 2, _p) * (q.y - p.y) % _p | |
rx = lam ** 2 - p.x - q.x | |
ry = lam * (p.x - rx) - p.y | |
return Point(rx % _p, ry % _p, curve=self) | |
def point_mul(self, p, d): | |
d = d % self.order | |
n = p | |
q = None | |
for i in reversed(format(d, 'b')): | |
if i == '1': | |
if q is None: | |
q = n | |
else: | |
q = self.point_add(q, n) | |
n = self.point_add(n, n) | |
return q | |
def __contains__(self, point): | |
return point.y ** 2 % self.prime == (point.x ** 3 + self.a * point.x + self.b) % self.prime | |
def f(self, x): | |
"""Compute y**2 = x^3 + ax + b in field FP""" | |
return (x ** 3 + self.a * x + self.b) % self.prime | |
CURVE = Curve(P, 0, 7, G, N, name='secp256k1') | |
class WitPublicKey: | |
def __init__(self, point: 'Point'): | |
self.point = point | |
def __eq__(self, other: 'WitPublicKey') -> bool: | |
return self.point == other.point | |
def __repr__(self) -> str: | |
return f"PublicKey({self.encode().hex()})" | |
@classmethod | |
def decode(cls, key: bytes) -> 'WitPublicKey': | |
if key.startswith(b'\x04'): # uncompressed key | |
assert len(key) == 65, 'An uncompressed public key must be 65 bytes long' | |
x, y = bytes_to_int(key[1:33]), bytes_to_int(key[33:]) | |
else: # compressed key | |
assert len(key) == 33, 'A compressed public key must be 33 bytes long' | |
x = bytes_to_int(key[1:]) | |
root = modsqrt(CURVE.f(x), P) | |
if key.startswith(b'\x03'): # odd root | |
y = root if root % 2 == 1 else -root % P | |
elif key.startswith(b'\x02'): # even root | |
y = root if root % 2 == 0 else -root % P | |
else: | |
assert False, 'Wrong key format' | |
return cls(Point(x, y, curve=CURVE)) | |
@classmethod | |
def from_hex(cls, hexstring: str) -> 'WitPublicKey': | |
return cls.decode(hex_to_bytes(hexstring)) | |
@property | |
def x(self) -> int: | |
"""X coordinate of the (X, Y) point""" | |
return self.point.x | |
@property | |
def y(self) -> int: | |
"""Y coordinate of the (X, Y) point""" | |
return self.point.y | |
def encode(self, compressed=True) -> bytes: | |
if compressed: | |
if self.y & 1: # odd root | |
return b'\x03' + int_to_bytes(self.x).rjust(32, b'\x00') | |
else: # even root | |
return b'\x02' + int_to_bytes(self.x).rjust(32, b'\x00') | |
return b'\x04' + int_to_bytes(self.x).rjust(32, b'\x00') + int_to_bytes(self.y).rjust(32, b'\x00') | |
def to_json(self, compressed=True): | |
enc = self.encode(compressed=compressed) | |
return {'bytes': list(enc[1::]), 'compressed': enc[0]} | |
def pub_key(self): | |
if self.y & 1: # odd root | |
return PublicKey(_bytes=int_to_bytes(self.x).rjust(32, b'\x00'), compressed=3) | |
else: # even root | |
return PublicKey(_bytes=int_to_bytes(self.x).rjust(32, b'\x00'), compressed=2) | |
@classmethod | |
def from_json(cls, data: dict): | |
_bytes, _compressed = data.values() | |
return WitPublicKey.from_hex(bytes_to_hex((int_to_bytes(_compressed) + bytes(_bytes).rjust(32, b'\x00')))) | |
@classmethod | |
def from_schema(cls, public_key: PublicKey): | |
return WitPublicKey.from_json(public_key.to_json()) | |
def hex(self, compressed=True) -> str: | |
return bytes_to_hex(self.encode(compressed=compressed)) | |
def to_address(self) -> str: | |
h1 = sha256(bytearray.fromhex(self.hex(compressed=True)))[:20] | |
h2 = "".join([bin(nibble)[2:].zfill(8) for nibble in h1]) | |
h3 = [int(h2[i: i + 5], 2) for i in range(0, len(h2), 5)] | |
checksum = bech32_create_checksum('wit', h3) | |
h4 = h3 + checksum | |
address = 'wit' + "1" + "".join([CHARSET[i] for i in h4]) | |
return address | |
def to_pkh(self): | |
return sha256(bytearray.fromhex(self.hex(compressed=True)))[:20] | |
class Signature: | |
def __init__(self, r, s, force_low_s=True): | |
self.r = r | |
if force_low_s: | |
# https://github.com/bitcoin/bips/blob/master/bip-0062.mediawiki#low-s-values-in-signatures | |
self.s = s if s <= N // 2 else N - s | |
else: | |
self.s = s | |
@classmethod | |
def decode(cls, bts): | |
from collections import deque | |
data = deque(bts) | |
lead = data.popleft() == 0x30 | |
assert lead, f'Invalid leading byte: 0x{lead:x}' # ASN1 SEQUENCE | |
sequence_length = data.popleft() | |
assert sequence_length <= 70, f'Invalid Sequence length: {sequence_length}' | |
lead = data.popleft() | |
assert lead == 0x02, f'Invalid r leading byte: 0x{lead:x}' # 0x02 byte before r | |
len_r = data.popleft() | |
assert len_r <= 33, f'Invalid r length: {len_r}' | |
bts = bytes(data) | |
r, data = bytes_to_int(bts[:len_r]), deque(bts[len_r:]) | |
lead = data.popleft() | |
assert lead == 0x02, f'Invalid s leading byte: 0x{lead:x}' # 0x02 byte before s | |
len_s = data.popleft() | |
assert len_s <= 33, f'Invalid s length: {len_s}' | |
bts = bytes(data) | |
s, rest = bytes_to_int(bts[:len_s]), bts[len_s:] | |
assert len(rest) == 0, f'{len(rest)} leftover bytes' | |
return cls(r, s) | |
def encode(self, compact=False): | |
"""https://github.com/bitcoin/bips/blob/master/bip-0062.mediawiki#der-encoding""" | |
r = int_to_bytes(self.r) | |
if r[0] > 0x7f: | |
r = b'\x00' + r | |
s = int_to_bytes(self.s) | |
if s[0] > 0x7f: | |
s = b'\x00' + s | |
len_r = int_to_bytes(len(r)) | |
len_s = int_to_bytes(len(s)) | |
len_sig = int_to_bytes(len(r) + len(s) + 4) | |
if compact: | |
return r + s | |
return b'\x30' + len_sig + b'\x02' + len_r + r + b'\x02' + len_s + s | |
def verify_hash(self, _hash, public_key): | |
public_key: WitPublicKey = public_key | |
if not (1 <= self.r < N and 1 <= self.s < N): | |
return False | |
e = bytes_to_int(_hash) | |
w = mulinv(self.s, N) | |
u1 = (e * w) % N | |
u2 = (self.r * w) % N | |
point: Point = CURVE.generator * u1 + public_key.point * u2 | |
return self.r % N == point.x % N | |
@classmethod | |
def from_hex(cls, hex_string): | |
return cls.decode(hex_to_bytes(hex_string)) | |
def __repr__(self): | |
return f"{self.__class__.__name__}({int_to_hex(self.r)}, {int_to_hex(self.s)})" | |
def __eq__(self, other): | |
return self.r == other.r and self.s == other.s | |
def hex(self): | |
return bytes_to_hex(self.encode()) | |
def is_signature(hex_string): | |
try: | |
if isinstance(hex_string, bytes): | |
Signature.decode(hex_string) | |
else: | |
Signature.from_hex(hex_string) | |
except (AssertionError, IndexError): | |
return False | |
return True | |
def verify_message(data) -> bool: | |
try: | |
_address = data["address"] | |
_message = data["message"] | |
_public_key = WitPublicKey.from_hex(data["public_key"]) | |
_signature = Signature.from_hex(data["signature"]) | |
_message_hash = sha256(hex_to_bytes(bytes_to_hex(str.encode(data["message"], 'utf-8')))) | |
is_valid = _signature.verify_hash(_message_hash, _public_key) | |
assert _address == _public_key.to_address(), "The Public key does not match the given address." | |
assert is_signature(data["signature"]), "The Signature bytes are not a formatted correctly." | |
return is_valid | |
except AssertionError as e: | |
print(f" Error: {e}") | |
return False | |
if "__main__" == __name__: | |
sig = {"address": "wit174la8pevl74hczcpfepgmt036zkmjen4hu8zzs", "message": "Hello World", | |
"public_key": "03d5761050a170c53bd09dcd1d6a69e2053197ad55bdee169c65dc580eaec6bf4c", | |
"signature": "3045022100bb317d58ef1aced559b18ba4b0d80c8fa801f7faa707c1150f3c0bf43e03043402204db123654b58969ea64b3014df52c33db469a741ba2524c0f2c0cb1800f73370"} | |
assert verify_message(sig) is True | |
sig["message"] = "Hello" | |
assert verify_message(sig) is False | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment