|
from typing import Tuple, Optional, Any |
|
import hashlib |
|
import binascii |
|
|
|
# Set DEBUG to True to get a detailed debug output including |
|
# intermediate values during key generation, signing, and |
|
# verification. This is implemented via calls to the |
|
# debug_print_vars() function. |
|
# |
|
# If you want to print values on an individual basis, use |
|
# the pretty() function, e.g., print(pretty(foo)). |
|
DEBUG = False |
|
|
|
p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F |
|
n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 |
|
|
|
# Points are tuples of X and Y coordinates and the point at infinity is |
|
# represented by the None keyword. |
|
G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8) |
|
|
|
Point = Tuple[int, int] |
|
|
|
# This implementation can be sped up by storing the midstate after hashing |
|
# tag_hash instead of rehashing it all the time. |
|
def tagged_hash(tag: str, msg: bytes) -> bytes: |
|
tag_hash = hashlib.sha256(tag.encode()).digest() |
|
return hashlib.sha256(tag_hash + tag_hash + msg).digest() |
|
|
|
def is_infinite(P: Optional[Point]) -> bool: |
|
return P is None |
|
|
|
def x(P: Point) -> int: |
|
assert not is_infinite(P) |
|
return P[0] |
|
|
|
def y(P: Point) -> int: |
|
assert not is_infinite(P) |
|
return P[1] |
|
|
|
def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]: |
|
if P1 is None: |
|
return P2 |
|
if P2 is None: |
|
return P1 |
|
if (x(P1) == x(P2)) and (y(P1) != y(P2)): |
|
return None |
|
if P1 == P2: |
|
lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p |
|
else: |
|
lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p |
|
x3 = (lam * lam - x(P1) - x(P2)) % p |
|
return (x3, (lam * (x(P1) - x3) - y(P1)) % p) |
|
|
|
def point_mul(P: Optional[Point], n: int) -> Optional[Point]: |
|
R = None |
|
for i in range(256): |
|
if (n >> i) & 1: |
|
R = point_add(R, P) |
|
P = point_add(P, P) |
|
return R |
|
|
|
def bytes_from_int(x: int) -> bytes: |
|
return x.to_bytes(32, byteorder="big") |
|
|
|
def bytes_from_point(P: Point) -> bytes: |
|
return bytes_from_int(x(P)) |
|
|
|
def xor_bytes(b0: bytes, b1: bytes) -> bytes: |
|
return bytes(x ^ y for (x, y) in zip(b0, b1)) |
|
|
|
def lift_x(b: bytes) -> Optional[Point]: |
|
x = int_from_bytes(b) |
|
if x >= p: |
|
return None |
|
y_sq = (pow(x, 3, p) + 7) % p |
|
y = pow(y_sq, (p + 1) // 4, p) |
|
if pow(y, 2, p) != y_sq: |
|
return None |
|
return (x, y if y & 1 == 0 else p-y) |
|
|
|
def int_from_bytes(b: bytes) -> int: |
|
return int.from_bytes(b, byteorder="big") |
|
|
|
def hash_sha256(b: bytes) -> bytes: |
|
return hashlib.sha256(b).digest() |
|
|
|
def has_even_y(P: Point) -> bool: |
|
assert not is_infinite(P) |
|
return y(P) % 2 == 0 |
|
|
|
def pubkey_gen(seckey: bytes) -> bytes: |
|
d0 = int_from_bytes(seckey) |
|
if not (1 <= d0 <= n - 1): |
|
raise ValueError('The secret key must be an integer in the range 1..n-1.') |
|
P = point_mul(G, d0) |
|
assert P is not None |
|
return bytes_from_point(P) |
|
|
|
def schnorr_sign(msg: bytes, seckey: bytes, aux_rand: bytes) -> bytes: |
|
if len(msg) != 32: |
|
raise ValueError('The message must be a 32-byte array.') |
|
d0 = int_from_bytes(seckey) |
|
if not (1 <= d0 <= n - 1): |
|
raise ValueError('The secret key must be an integer in the range 1..n-1.') |
|
if len(aux_rand) != 32: |
|
raise ValueError('aux_rand must be 32 bytes instead of %i.' % len(aux_rand)) |
|
P = point_mul(G, d0) |
|
assert P is not None |
|
d = d0 if has_even_y(P) else n - d0 |
|
t = xor_bytes(bytes_from_int(d), tagged_hash("BIP0340/aux", aux_rand)) |
|
k0 = int_from_bytes(tagged_hash("BIP0340/nonce", t + bytes_from_point(P) + msg)) % n |
|
if k0 == 0: |
|
raise RuntimeError('Failure. This happens only with negligible probability.') |
|
R = point_mul(G, k0) |
|
assert R is not None |
|
k = n - k0 if not has_even_y(R) else k0 |
|
e = int_from_bytes(tagged_hash("BIP0340/challenge", bytes_from_point(R) + bytes_from_point(P) + msg)) % n |
|
sig = bytes_from_point(R) + bytes_from_int((k + e * d) % n) |
|
debug_print_vars() |
|
if not schnorr_verify(msg, bytes_from_point(P), sig): |
|
raise RuntimeError('The created signature does not pass verification.') |
|
return sig |
|
|
|
def schnorr_verify(msg: bytes, pubkey: bytes, sig: bytes) -> bool: |
|
if len(msg) != 32: |
|
raise ValueError('The message must be a 32-byte array.') |
|
if len(pubkey) != 32: |
|
raise ValueError('The public key must be a 32-byte array.') |
|
if len(sig) != 64: |
|
raise ValueError('The signature must be a 64-byte array.') |
|
P = lift_x(pubkey) |
|
r = int_from_bytes(sig[0:32]) |
|
s = int_from_bytes(sig[32:64]) |
|
if (P is None) or (r >= p) or (s >= n): |
|
debug_print_vars() |
|
return False |
|
e = int_from_bytes(tagged_hash("BIP0340/challenge", sig[0:32] + pubkey + msg)) % n |
|
R = point_add(point_mul(G, s), point_mul(P, n - e)) |
|
if (R is None) or (not has_even_y(R)) or (x(R) != r): |
|
debug_print_vars() |
|
return False |
|
debug_print_vars() |
|
return True |
|
|
|
# |
|
# The following code is only used to verify the test vectors. |
|
# |
|
import csv |
|
import os |
|
import sys |
|
|
|
def test_vectors() -> bool: |
|
all_passed = True |
|
with open(os.path.join(sys.path[0], 'test-vectors.csv'), newline='') as csvfile: |
|
reader = csv.reader(csvfile) |
|
reader.__next__() |
|
for row in reader: |
|
(index, seckey_hex, pubkey_hex, aux_rand_hex, msg_hex, sig_hex, result_str, comment) = row |
|
pubkey = bytes.fromhex(pubkey_hex) |
|
msg = bytes.fromhex(msg_hex) |
|
sig = bytes.fromhex(sig_hex) |
|
result = result_str == 'TRUE' |
|
print('\nTest vector', ('#' + index).rjust(3, ' ') + ':') |
|
if seckey_hex != '': |
|
seckey = bytes.fromhex(seckey_hex) |
|
pubkey_actual = pubkey_gen(seckey) |
|
if pubkey != pubkey_actual: |
|
print(' * Failed key generation.') |
|
print(' Expected key:', pubkey.hex().upper()) |
|
print(' Actual key:', pubkey_actual.hex().upper()) |
|
aux_rand = bytes.fromhex(aux_rand_hex) |
|
try: |
|
sig_actual = schnorr_sign(msg, seckey, aux_rand) |
|
if sig == sig_actual: |
|
print(' * Passed signing test.') |
|
else: |
|
print(' * Failed signing test.') |
|
print(' Expected signature:', sig.hex().upper()) |
|
print(' Actual signature:', sig_actual.hex().upper()) |
|
all_passed = False |
|
except RuntimeError as e: |
|
print(' * Signing test raised exception:', e) |
|
all_passed = False |
|
result_actual = schnorr_verify(msg, pubkey, sig) |
|
if result == result_actual: |
|
print(' * Passed verification test.') |
|
else: |
|
print(' * Failed verification test.') |
|
print(' Expected verification result:', result) |
|
print(' Actual verification result:', result_actual) |
|
if comment: |
|
print(' Comment:', comment) |
|
all_passed = False |
|
print() |
|
if all_passed: |
|
print('All test vectors passed.') |
|
else: |
|
print('Some test vectors failed.') |
|
return all_passed |
|
|
|
# |
|
# The following code is only used for debugging |
|
# |
|
import inspect |
|
|
|
def pretty(v: Any) -> Any: |
|
if isinstance(v, bytes): |
|
return '0x' + v.hex() |
|
if isinstance(v, int): |
|
return pretty(bytes_from_int(v)) |
|
if isinstance(v, tuple): |
|
return tuple(map(pretty, v)) |
|
return v |
|
|
|
def debug_print_vars() -> None: |
|
if DEBUG: |
|
current_frame = inspect.currentframe() |
|
assert current_frame is not None |
|
frame = current_frame.f_back |
|
assert frame is not None |
|
print(' Variables in function ', frame.f_code.co_name, ' at line ', frame.f_lineno, ':', sep='') |
|
for var_name, var_val in frame.f_locals.items(): |
|
print(' ' + var_name.rjust(11, ' '), '==', pretty(var_val)) |
|
|
|
if __name__ == '__main__': |
|
test_vectors() |
Is a rogue key attack possible here?
Let's say:
In the
test_forger()
, aggregate pubkey will be calculated asagg_pubkey = bytes_from_point(xonly_point_agg(pubkeys))
. Thexonly_point_agg()
always considers each of the given point'sy
to be even (since, it useslift_x()
).Therefore, passing
X3
(negate ofX1
) will be considered asX1
during point aggregation.I could make
pubkeys = n*[X1] + [X2]
to negateX1
, but python shows an overflow error sincen
is huge. Also, can't use afor loop
since the computation will take forever.Am I missing something?