Skip to content

Instantly share code, notes, and snippets.

@parodyBit
Last active October 18, 2023 11:25
Show Gist options
  • Save parodyBit/6556a1ac9535d8875f989d9b6808b454 to your computer and use it in GitHub Desktop.
Save parodyBit/6556a1ac9535d8875f989d9b6808b454 to your computer and use it in GitHub Desktop.
Verify Witnet Signed Message in Python
from dataclasses import dataclass
from typing import List, AnyStr, Union, Tuple
import hashlib
import random
import unicodedata
'''
Transformation tools
'''
Bytes = List[int]
def sha256(x: bytes): return hashlib.sha256(x).digest()
def bytes_to_hex(b: bytes):
return b.hex()
def bytes_to_int(bts: bytes) -> int:
return int.from_bytes(bts, 'big')
def int_to_bytes(i: int) -> bytes:
length = max(1, (i.bit_length() + 7) // 8)
return i.to_bytes(length, 'big')
def int_to_hex(i: int) -> str:
return format(i, 'x')
def hex_to_bytes(h: str):
return bytes.fromhex(h)
def concat_string(values: List[str]) -> str:
return ''.join(values)
def concat_bytes(values: List[bytes]) -> bytes:
return b''.join(values)
def concat(values: Union[List[str], List[bytes]]) -> Union[str, bytes]:
if isinstance(values[0], str):
return concat_string(values)
elif isinstance(values[0], bytes):
return concat_bytes(values)
class BaseConversionError(Exception):
pass
def convert_bits(data: Bytes, from_bits: int, to_bits: int, pad=True) -> Bytes:
"""General power-of-2 base conversion."""
acc = 0
bits = 0
ret = []
max_v = (1 << to_bits) - 1
max_acc = (1 << (from_bits + to_bits - 1)) - 1
for value in data:
if value < 0 or (value >> from_bits):
raise BaseConversionError
acc = ((acc << from_bits) | value) & max_acc
bits += from_bits
while bits >= to_bits:
bits -= to_bits
ret.append((acc >> bits) & max_v)
if pad:
if bits:
ret.append((acc << (to_bits - bits)) & max_v)
elif bits >= from_bits or ((acc << (to_bits - bits)) & max_v):
raise BaseConversionError
return ret
def normalize_string(txt: AnyStr) -> str:
if isinstance(txt, bytes):
utxt = txt.decode('utf8')
elif isinstance(txt, str):
utxt = txt
else:
raise TypeError('String value expected')
return unicodedata.normalize('NFKD', utxt)
'''
Protobuf Utilities
'''
TAG_TYPE_BITS = 3
TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1
VAR_INT = 0
FIXED64 = 1
LENGTH_DELIMITED = 2
def var_int(value: int):
"""
Write unsigned `VarInt` to a file-like object.
"""
if isinstance(value, str):
value = int(value)
tmp = []
while value > 0x7F:
tmp.append(bytes((value & 0x7F | 0x80,)))
value >>= 7
tmp.append(bytes((value,)))
return concat(tmp)
def var_int_serializer(value: int):
return var_int(value)
def bytes_serializer(value: bytes):
return concat([var_int(len(value)), value])
def make_tag(field_number: int, tag: int) -> int: return (field_number << TAG_TYPE_BITS) | tag
def make_tag_bytes(field_number: int, tag: int) -> bytes: return var_int_serializer(make_tag(field_number, tag))
def pb_field(field_number: int, tag: int, value):
_data = []
if tag == VAR_INT:
_data = concat([var_int_serializer(value=value)])
elif tag == LENGTH_DELIMITED:
_data = bytes_serializer(value=value)
else:
...
return concat([make_tag_bytes(field_number=field_number, tag=tag), _data])
'''
Bech32
'''
class Bech32DecodeError(Exception):
pass
CHARSET = 'qpzry9x8gf2tvdw0s3jn54khce6mua7l'
def bech32_poly_mod(values) -> int:
"""Internal function that computes the Bech32 checksum."""
generator = [0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3]
chk = 1
for value in values:
top = chk >> 25
chk = (chk & 0x1ffffff) << 5 ^ value
for i in range(5):
chk ^= generator[i] if ((top >> i) & 1) else 0
return chk
def bech32_hrp_expand(hrp: str) -> List[int]:
"""Expand the HRP into values for checksum computation."""
return [ord(x) >> 5 for x in hrp] + [0] + [ord(x) & 31 for x in hrp]
def bech32_verify_checksum(hrp: str, data) -> bool:
"""Verify a checksum given HRP and converted data characters."""
return bech32_poly_mod(bech32_hrp_expand(hrp) + data) == 1
def bech32_create_checksum(hrp: str, data):
"""Compute the checksum values given HRP and data."""
values = bech32_hrp_expand(hrp) + data
polymod = bech32_poly_mod(values + [0, 0, 0, 0, 0, 0]) ^ 1
return [(polymod >> 5 * (5 - i)) & 31 for i in range(6)]
def bech32_encode_address(hrp: str, data: str) -> str:
hex_bytes = [b for b in bytes.fromhex(data)]
data = convert_bits(data=hex_bytes, from_bits=8, to_bits=5, pad=False)
checksum = bech32_create_checksum(hrp=hrp, data=data)
combined = data + checksum
return normalize_string(hrp + '1' + ''.join([CHARSET[i] for i in combined]))
def bech32_decode_address(bech: str):
bech, separator = bech.lower(), bech.find('1')
hrp_ = bech[:separator]
data = [CHARSET.find(x) for x in bech[separator + 1:]]
decoded = data[:-6]
try:
if any(ord(x) < 33 or ord(x) > 126 for x in bech):
raise Bech32DecodeError('Character outside US-ASCII [33-126] range')
if (bech.lower() != bech) and (bech.upper() != bech):
raise Bech32DecodeError('Mixed upper and lower case')
if separator == 0:
raise Bech32DecodeError('Empty human readable part')
elif separator == -1:
raise Bech32DecodeError('No separator character')
elif separator + 7 > len(bech):
raise Bech32DecodeError('Checksum too short')
if not all(x in CHARSET for x in bech[separator + 1:]):
raise Bech32DecodeError('Character not in charset')
if not bech32_verify_checksum(hrp_, data):
raise Bech32DecodeError('Invalid checksum')
if decoded is None or len(decoded) < 2:
raise Bech32DecodeError('Witness program too short')
except Bech32DecodeError as error:
print(error)
b256 = convert_bits(decoded, from_bits=5, to_bits=8, pad=True)
return b256
'''
Schema Items
'''
@dataclass
class PublicKey:
_bytes: bytes
compressed: int
@classmethod
def from_json(cls, data: dict):
return PublicKey(_bytes=bytes(data['bytes']), compressed=data['compressed'])
def to_json(self, as_hex: bool = True):
return {
'bytes': list(self._bytes),
'compressed': self.compressed
} if not as_hex else {
'bytes': bytes_to_hex(self._bytes),
'compressed': self.compressed
}
def pb_bytes(self) -> bytes:
return pb_field(
field_number=1,
tag=LENGTH_DELIMITED,
value=concat([int_to_bytes(self.compressed), self._bytes])
)
def hash(self):
return sha256(self.pb_bytes())
@dataclass
class PublicKeyHash:
hash: bytes
@classmethod
def from_address(cls, data: str):
return PublicKeyHash(hash=bytes(bech32_decode_address(data)))
def to_address(self):
return bech32_encode_address(hrp='wit', data=bytes_to_hex(self.hash))
def pb_bytes(self) -> bytes:
return pb_field(field_number=1, tag=LENGTH_DELIMITED, value=self.hash)
def hash(self):
return sha256(self.pb_bytes())
'''
Number Theory
'''
def miller_rabin(n, runs=40):
# Implementation uses the Miller-Rabin Primality Test The optimal number of rounds for this test is 40 See
# http://stackoverflow.com/questions/6325576/how-many-iterations-of-rabin-miller-should-i-use-for-cryptographic
# -safe-primes for justification
if n < 6: # assuming n >= 0 in all cases... shortcut small cases here
return [False, False, True, True, False, True][n]
# If number is even, it's a composite number
elif n & 1 == 0: # should be faster than n % 2
return False
else:
r, s = 0, n - 1
while s % 2 == 0:
r += 1
s //= 2
for _ in range(runs):
a = random.randrange(3, n - 1, 2)
x = pow(a, s, n)
if x == 1 or x == n - 1:
continue
for _ in range(r - 1):
x = pow(x, 2, n)
if x == n - 1:
break
else:
return False
return True
def legendre(a, p):
"""https://en.wikipedia.org/wiki/Legendre_symbol"""
assert miller_rabin(p), f"{p} is not a prime"
mod = pow(a, (p - 1) // 2, p)
return -1 if mod == p - 1 else mod
def xgcd(b: int, n: int) -> Tuple[int, int, int]:
"""Takes positive integers a, b as input, and return a triple (g, x, y), such that ax + by = g = gcd(a, b)"""
x0, x1, y0, y1 = 1, 0, 0, 1
while n != 0:
q, b, n = b // n, n, b % n
x0, x1 = x1, x0 - q * x1
y0, y1 = y1, y0 - q * y1
return b, x0, y0
def mulinv(b, n, algo=xgcd):
"""An application of extended GCD algorithm to finding modular inverses"""
g, x, _ = algo(b, n)
assert g == 1, 'Numbers must be coprimes'
return x % n
def modsqrt(a, p):
"""
https://eli.thegreenplace.net/2009/03/07/computing-modular-square-roots-in-python
Find a quadratic residue (mod p) of 'a'. p must be an odd prime.
Solve the congruence of the form:
x^2 = a (mod p)
And returns x. Note that p - x is also a root.
0 is returned if no square root exists for these a and p.
The Tonelli-Shanks algorithm is used (except for some simple cases in which the solution is known from an
identity). This algorithm runs in polynomial time (unless the generalized Riemann hypothesis is false).
"""
# Simple cases
#
if legendre(a, p) != 1:
return 0
elif a == 0:
return 0
elif p == 2:
return p
elif p % 4 == 3:
return pow(a, (p + 1) // 4, p)
# Partition p-1 to s * 2^e for an odd s (i.e.
# reduce all the powers of 2 from p-1)
#
s = p - 1
e = 0
while s % 2 == 0:
s //= 2
e += 1
# Find some 'n' with a legendre symbol n|p = -1.
# Shouldn't take long.
#
n = 2
while legendre(n, p) != -1:
n += 1
# Here be dragons!
# Read the paper "Square roots from 1; 24, 51,
# 10 to Dan Shanks" by Ezra Brown for more
# information
#
# x is a guess of the square root that gets better
# with each iteration.
# b is the "fudge factor" - by how much we're off
# with the guess. The invariant x^2 = ab (mod p)
# is maintained throughout the loop.
# g is used for successive powers of n to update
# both a and b
# r is the exponent - decreases with each update
#
x = pow(a, (s + 1) // 2, p)
b = pow(a, s, p)
g = pow(n, s, p)
r = e
while True:
t = b
m = 0
for m in range(r):
if t == 1:
break
t = pow(t, 2, p)
if m == 0:
return x
gs = pow(g, 2 ** (r - m - 1), p)
g = (gs * gs) % p
x = (x * gs) % p
b = (b * g) % p
r = m
'''
Secp256k1 Stuff
'''
P = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
# Generator
G = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, \
0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8
# Order
N = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
# Elliptic curve parameters A and B of the curve : y² = x³ Ax + B
A: int = 0
B: int = 7
class Point:
def __init__(self, x, y, curve=None):
self.x = x
self.y = y
self.curve = curve
assert self in curve, f"Point {x}, {y} not in curve"
def __add__(self, other):
assert self.curve == other.curve, 'Cannot add points on different curves'
return self.curve.point_add(self, other)
def __sub__(self, other):
return self + (other * -1)
def __mul__(self, other: int):
assert isinstance(other, int), 'Multiplication is only defined between a point and an integer'
return self.curve.point_mul(self, other)
def __repr__(self):
return f"Point({self.x}, {self.y}, {self.curve.name})"
def __eq__(self, other):
return self.x % self.curve.prime == other.x % self.curve.prime \
and self.y % self.curve.prime == other.y % self.curve.prime
@dataclass
class Curve:
prime: int # P
a: int
b: int
generator: Union[Tuple, Point]
order: int # N
name: str
def __post_init__(self):
if type(self.generator).__name__ == 'tuple':
self.generator = Point(*self.generator, curve=self)
def point_add(self, p, q):
"""https://en.wikipedia.org/wiki/Elliptic_curve_point_multiplication#Point_addition"""
_p = self.prime
if p == q:
lam = (3 * p.x * p.x) * pow(2 * p.y % _p, _p - 2, _p)
else:
lam = pow(q.x - p.x, _p - 2, _p) * (q.y - p.y) % _p
rx = lam ** 2 - p.x - q.x
ry = lam * (p.x - rx) - p.y
return Point(rx % _p, ry % _p, curve=self)
def point_mul(self, p, d):
d = d % self.order
n = p
q = None
for i in reversed(format(d, 'b')):
if i == '1':
if q is None:
q = n
else:
q = self.point_add(q, n)
n = self.point_add(n, n)
return q
def __contains__(self, point):
return point.y ** 2 % self.prime == (point.x ** 3 + self.a * point.x + self.b) % self.prime
def f(self, x):
"""Compute y**2 = x^3 + ax + b in field FP"""
return (x ** 3 + self.a * x + self.b) % self.prime
CURVE = Curve(P, 0, 7, G, N, name='secp256k1')
class WitPublicKey:
def __init__(self, point: 'Point'):
self.point = point
def __eq__(self, other: 'WitPublicKey') -> bool:
return self.point == other.point
def __repr__(self) -> str:
return f"PublicKey({self.encode().hex()})"
@classmethod
def decode(cls, key: bytes) -> 'WitPublicKey':
if key.startswith(b'\x04'): # uncompressed key
assert len(key) == 65, 'An uncompressed public key must be 65 bytes long'
x, y = bytes_to_int(key[1:33]), bytes_to_int(key[33:])
else: # compressed key
assert len(key) == 33, 'A compressed public key must be 33 bytes long'
x = bytes_to_int(key[1:])
root = modsqrt(CURVE.f(x), P)
if key.startswith(b'\x03'): # odd root
y = root if root % 2 == 1 else -root % P
elif key.startswith(b'\x02'): # even root
y = root if root % 2 == 0 else -root % P
else:
assert False, 'Wrong key format'
return cls(Point(x, y, curve=CURVE))
@classmethod
def from_hex(cls, hexstring: str) -> 'WitPublicKey':
return cls.decode(hex_to_bytes(hexstring))
@property
def x(self) -> int:
"""X coordinate of the (X, Y) point"""
return self.point.x
@property
def y(self) -> int:
"""Y coordinate of the (X, Y) point"""
return self.point.y
def encode(self, compressed=True) -> bytes:
if compressed:
if self.y & 1: # odd root
return b'\x03' + int_to_bytes(self.x).rjust(32, b'\x00')
else: # even root
return b'\x02' + int_to_bytes(self.x).rjust(32, b'\x00')
return b'\x04' + int_to_bytes(self.x).rjust(32, b'\x00') + int_to_bytes(self.y).rjust(32, b'\x00')
def to_json(self, compressed=True):
enc = self.encode(compressed=compressed)
return {'bytes': list(enc[1::]), 'compressed': enc[0]}
def pub_key(self):
if self.y & 1: # odd root
return PublicKey(_bytes=int_to_bytes(self.x).rjust(32, b'\x00'), compressed=3)
else: # even root
return PublicKey(_bytes=int_to_bytes(self.x).rjust(32, b'\x00'), compressed=2)
@classmethod
def from_json(cls, data: dict):
_bytes, _compressed = data.values()
return WitPublicKey.from_hex(bytes_to_hex((int_to_bytes(_compressed) + bytes(_bytes).rjust(32, b'\x00'))))
@classmethod
def from_schema(cls, public_key: PublicKey):
return WitPublicKey.from_json(public_key.to_json())
def hex(self, compressed=True) -> str:
return bytes_to_hex(self.encode(compressed=compressed))
def to_address(self) -> str:
h1 = sha256(bytearray.fromhex(self.hex(compressed=True)))[:20]
h2 = "".join([bin(nibble)[2:].zfill(8) for nibble in h1])
h3 = [int(h2[i: i + 5], 2) for i in range(0, len(h2), 5)]
checksum = bech32_create_checksum('wit', h3)
h4 = h3 + checksum
address = 'wit' + "1" + "".join([CHARSET[i] for i in h4])
return address
def to_pkh(self):
return sha256(bytearray.fromhex(self.hex(compressed=True)))[:20]
class Signature:
def __init__(self, r, s, force_low_s=True):
self.r = r
if force_low_s:
# https://github.com/bitcoin/bips/blob/master/bip-0062.mediawiki#low-s-values-in-signatures
self.s = s if s <= N // 2 else N - s
else:
self.s = s
@classmethod
def decode(cls, bts):
from collections import deque
data = deque(bts)
lead = data.popleft() == 0x30
assert lead, f'Invalid leading byte: 0x{lead:x}' # ASN1 SEQUENCE
sequence_length = data.popleft()
assert sequence_length <= 70, f'Invalid Sequence length: {sequence_length}'
lead = data.popleft()
assert lead == 0x02, f'Invalid r leading byte: 0x{lead:x}' # 0x02 byte before r
len_r = data.popleft()
assert len_r <= 33, f'Invalid r length: {len_r}'
bts = bytes(data)
r, data = bytes_to_int(bts[:len_r]), deque(bts[len_r:])
lead = data.popleft()
assert lead == 0x02, f'Invalid s leading byte: 0x{lead:x}' # 0x02 byte before s
len_s = data.popleft()
assert len_s <= 33, f'Invalid s length: {len_s}'
bts = bytes(data)
s, rest = bytes_to_int(bts[:len_s]), bts[len_s:]
assert len(rest) == 0, f'{len(rest)} leftover bytes'
return cls(r, s)
def encode(self, compact=False):
"""https://github.com/bitcoin/bips/blob/master/bip-0062.mediawiki#der-encoding"""
r = int_to_bytes(self.r)
if r[0] > 0x7f:
r = b'\x00' + r
s = int_to_bytes(self.s)
if s[0] > 0x7f:
s = b'\x00' + s
len_r = int_to_bytes(len(r))
len_s = int_to_bytes(len(s))
len_sig = int_to_bytes(len(r) + len(s) + 4)
if compact:
return r + s
return b'\x30' + len_sig + b'\x02' + len_r + r + b'\x02' + len_s + s
def verify_hash(self, _hash, public_key):
public_key: WitPublicKey = public_key
if not (1 <= self.r < N and 1 <= self.s < N):
return False
e = bytes_to_int(_hash)
w = mulinv(self.s, N)
u1 = (e * w) % N
u2 = (self.r * w) % N
point: Point = CURVE.generator * u1 + public_key.point * u2
return self.r % N == point.x % N
@classmethod
def from_hex(cls, hex_string):
return cls.decode(hex_to_bytes(hex_string))
def __repr__(self):
return f"{self.__class__.__name__}({int_to_hex(self.r)}, {int_to_hex(self.s)})"
def __eq__(self, other):
return self.r == other.r and self.s == other.s
def hex(self):
return bytes_to_hex(self.encode())
def is_signature(hex_string):
try:
if isinstance(hex_string, bytes):
Signature.decode(hex_string)
else:
Signature.from_hex(hex_string)
except (AssertionError, IndexError):
return False
return True
def verify_message(data) -> bool:
try:
_address = data["address"]
_message = data["message"]
_public_key = WitPublicKey.from_hex(data["public_key"])
_signature = Signature.from_hex(data["signature"])
_message_hash = sha256(hex_to_bytes(bytes_to_hex(str.encode(data["message"], 'utf-8'))))
is_valid = _signature.verify_hash(_message_hash, _public_key)
assert _address == _public_key.to_address(), "The Public key does not match the given address."
assert is_signature(data["signature"]), "The Signature bytes are not a formatted correctly."
return is_valid
except AssertionError as e:
print(f" Error: {e}")
return False
if "__main__" == __name__:
sig = {"address": "wit174la8pevl74hczcpfepgmt036zkmjen4hu8zzs", "message": "Hello World",
"public_key": "03d5761050a170c53bd09dcd1d6a69e2053197ad55bdee169c65dc580eaec6bf4c",
"signature": "3045022100bb317d58ef1aced559b18ba4b0d80c8fa801f7faa707c1150f3c0bf43e03043402204db123654b58969ea64b3014df52c33db469a741ba2524c0f2c0cb1800f73370"}
assert verify_message(sig) is True
sig["message"] = "Hello"
assert verify_message(sig) is False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment