Skip to content

Instantly share code, notes, and snippets.

@robot-dreams
Last active February 8, 2022 21:32
Show Gist options
  • Save robot-dreams/6300fde4017eefcf02c241f203a75162 to your computer and use it in GitHub Desktop.
Save robot-dreams/6300fde4017eefcf02c241f203a75162 to your computer and use it in GitHub Desktop.
Challenge 003

Previous: Challenge 002

The signer, sick of getting funds stolen, wants to try Schnorr multisig. Unfortunately, they fell for one of the classic blunders (the other being, of course, "never get involved in a land war in Asia").

The signer is using the following insecure scheme: whenever multiple signers with public keys X1, ..., Xn want to collaboratively sign a message, they proceed as follows:

  • Generate private nonces r1, ..., rn
  • Exchange all the corresponding public nonces R1, ..., Rn
  • When generating the SHA256 challenge value:
    • Use X = X1 + ... + Xn as the aggregate public key
    • Use R = R1 + ... + Rn as the aggregate public nonce

Each signer generates partial signature si = ri + H(X, R, m) * xi (where xi is the i-th signer's private key). The the sum of all the partial signatures, together with the aggregate nonce R, is a valid Schnorr signature for the message m against the public key X.

Can you exploit the weakness of this scheme? This is a good challenge for rogue cryptographers 😎


This challenge is interactive, and involves making a test pass. You will find three files:

  • naive_multisig.py: An implementation of the insecure multisignature scheme; you should in particular consult the test_normal_multisig() function for more details on how the scheme works
  • reference.py: A naive and slow Python 3.7 implementation of BIP-340 (the version below was copied directly from the version in the bips repo). WARNING: Do not use this in production!
  • util.py: A small collection of helper functions; you shouldn't need to edit these

Your task is to fill in the implementation of forge_signature() in naive_multisig.py to make test_forgery() pass. Good luck!

import secrets
from reference import *
from util import *
def forge_signature(honest_signer, msg):
"""
TODO: Your implementation here!
Your goal is to return a tuple with two elements:
- A list of public keys, at least one of which is the
honest signer's public key
- A valid BIP-340 signature for the input msg, when verified
against the aggregate public key which is the sum of the
individual public keys in the above list
In trying to generate a forgery, you may interact with the
honest signer by calling public methods as much as you want;
however, to make this attack realistic, you're NOT allowed to:
- Access private fields of the honest signer
- Ask the honest signer to generate a partial signature
on the same (pubkeys, msg) pair as the forgery you
output
Hopefully these restrictions are obvious and sensible; otherwise
the challenge would be trivial.
Good luck!
"""
return None, None
class NaiveMultisigSigner:
def __init__(self, seckey=None):
if seckey is None:
seckey = secrets.token_bytes(32)
self.seckey = seckey
self.pubkey = pubkey_gen(self.seckey)
self.seen_queries = set()
def get_pubkey(self):
return self.pubkey
def gen_partial_pubnonce(self):
self.secnonce = secrets.token_bytes(32)
return pubkey_gen(self.secnonce)
def gen_partial_sig(self, pubkeys, aggnonce, msg):
assert pubkey_gen(self.seckey) in pubkeys
assert len(aggnonce) == 33
assert len(msg) == 32
X = xonly_point_agg(pubkeys)
R = point_from_cbytes(aggnonce)
r1 = xonly_int(self.secnonce, R)
x1 = xonly_int(self.seckey, X)
agg_pubkey = bytes_from_point(X)
e = int_from_bytes(tagged_hash("BIP0340/challenge", bytes_from_point(R) + agg_pubkey + msg)) % n
self.seen_queries.add((agg_pubkey, msg))
return bytes_from_int((r1 + e * x1) % n)
def test_normal_multisig():
signer1 = NaiveMultisigSigner()
signer2 = NaiveMultisigSigner()
X1 = signer1.get_pubkey()
X2 = signer2.get_pubkey()
pubkeys = [X1, X2]
agg_pubkey = bytes_from_point(xonly_point_agg(pubkeys))
R1 = signer1.gen_partial_pubnonce()
R2 = signer2.gen_partial_pubnonce()
R = xonly_point_agg([R1, R2])
msg = b'msg signed by both Alice and Bob'
aggnonce = cbytes_from_point(R)
s1 = signer1.gen_partial_sig(pubkeys, aggnonce, msg)
s2 = signer2.gen_partial_sig(pubkeys, aggnonce, msg)
sig = bytes_from_point(R) + partial_sig_agg([s1, s2])
assert schnorr_verify(msg, agg_pubkey, sig)
def test_forgery():
honest_signer = NaiveMultisigSigner()
msg = b'send all of Bob\'s coins to Alice'
pubkeys, sig = forge_signature(honest_signer, msg)
agg_pubkey = bytes_from_point(xonly_point_agg(pubkeys))
assert honest_signer.get_pubkey() in pubkeys
assert (agg_pubkey, msg) not in honest_signer.seen_queries
assert schnorr_verify(msg, agg_pubkey, sig)
if __name__ == '__main__':
test_normal_multisig()
test_forgery()
from typing import Tuple, Optional, Any
import hashlib
import binascii
# Set DEBUG to True to get a detailed debug output including
# intermediate values during key generation, signing, and
# verification. This is implemented via calls to the
# debug_print_vars() function.
#
# If you want to print values on an individual basis, use
# the pretty() function, e.g., print(pretty(foo)).
DEBUG = False
p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
# Points are tuples of X and Y coordinates and the point at infinity is
# represented by the None keyword.
G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8)
Point = Tuple[int, int]
# This implementation can be sped up by storing the midstate after hashing
# tag_hash instead of rehashing it all the time.
def tagged_hash(tag: str, msg: bytes) -> bytes:
tag_hash = hashlib.sha256(tag.encode()).digest()
return hashlib.sha256(tag_hash + tag_hash + msg).digest()
def is_infinite(P: Optional[Point]) -> bool:
return P is None
def x(P: Point) -> int:
assert not is_infinite(P)
return P[0]
def y(P: Point) -> int:
assert not is_infinite(P)
return P[1]
def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]:
if P1 is None:
return P2
if P2 is None:
return P1
if (x(P1) == x(P2)) and (y(P1) != y(P2)):
return None
if P1 == P2:
lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p
else:
lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p
x3 = (lam * lam - x(P1) - x(P2)) % p
return (x3, (lam * (x(P1) - x3) - y(P1)) % p)
def point_mul(P: Optional[Point], n: int) -> Optional[Point]:
R = None
for i in range(256):
if (n >> i) & 1:
R = point_add(R, P)
P = point_add(P, P)
return R
def bytes_from_int(x: int) -> bytes:
return x.to_bytes(32, byteorder="big")
def bytes_from_point(P: Point) -> bytes:
return bytes_from_int(x(P))
def xor_bytes(b0: bytes, b1: bytes) -> bytes:
return bytes(x ^ y for (x, y) in zip(b0, b1))
def lift_x(b: bytes) -> Optional[Point]:
x = int_from_bytes(b)
if x >= p:
return None
y_sq = (pow(x, 3, p) + 7) % p
y = pow(y_sq, (p + 1) // 4, p)
if pow(y, 2, p) != y_sq:
return None
return (x, y if y & 1 == 0 else p-y)
def int_from_bytes(b: bytes) -> int:
return int.from_bytes(b, byteorder="big")
def hash_sha256(b: bytes) -> bytes:
return hashlib.sha256(b).digest()
def has_even_y(P: Point) -> bool:
assert not is_infinite(P)
return y(P) % 2 == 0
def pubkey_gen(seckey: bytes) -> bytes:
d0 = int_from_bytes(seckey)
if not (1 <= d0 <= n - 1):
raise ValueError('The secret key must be an integer in the range 1..n-1.')
P = point_mul(G, d0)
assert P is not None
return bytes_from_point(P)
def schnorr_sign(msg: bytes, seckey: bytes, aux_rand: bytes) -> bytes:
if len(msg) != 32:
raise ValueError('The message must be a 32-byte array.')
d0 = int_from_bytes(seckey)
if not (1 <= d0 <= n - 1):
raise ValueError('The secret key must be an integer in the range 1..n-1.')
if len(aux_rand) != 32:
raise ValueError('aux_rand must be 32 bytes instead of %i.' % len(aux_rand))
P = point_mul(G, d0)
assert P is not None
d = d0 if has_even_y(P) else n - d0
t = xor_bytes(bytes_from_int(d), tagged_hash("BIP0340/aux", aux_rand))
k0 = int_from_bytes(tagged_hash("BIP0340/nonce", t + bytes_from_point(P) + msg)) % n
if k0 == 0:
raise RuntimeError('Failure. This happens only with negligible probability.')
R = point_mul(G, k0)
assert R is not None
k = n - k0 if not has_even_y(R) else k0
e = int_from_bytes(tagged_hash("BIP0340/challenge", bytes_from_point(R) + bytes_from_point(P) + msg)) % n
sig = bytes_from_point(R) + bytes_from_int((k + e * d) % n)
debug_print_vars()
if not schnorr_verify(msg, bytes_from_point(P), sig):
raise RuntimeError('The created signature does not pass verification.')
return sig
def schnorr_verify(msg: bytes, pubkey: bytes, sig: bytes) -> bool:
if len(msg) != 32:
raise ValueError('The message must be a 32-byte array.')
if len(pubkey) != 32:
raise ValueError('The public key must be a 32-byte array.')
if len(sig) != 64:
raise ValueError('The signature must be a 64-byte array.')
P = lift_x(pubkey)
r = int_from_bytes(sig[0:32])
s = int_from_bytes(sig[32:64])
if (P is None) or (r >= p) or (s >= n):
debug_print_vars()
return False
e = int_from_bytes(tagged_hash("BIP0340/challenge", sig[0:32] + pubkey + msg)) % n
R = point_add(point_mul(G, s), point_mul(P, n - e))
if (R is None) or (not has_even_y(R)) or (x(R) != r):
debug_print_vars()
return False
debug_print_vars()
return True
#
# The following code is only used to verify the test vectors.
#
import csv
import os
import sys
def test_vectors() -> bool:
all_passed = True
with open(os.path.join(sys.path[0], 'test-vectors.csv'), newline='') as csvfile:
reader = csv.reader(csvfile)
reader.__next__()
for row in reader:
(index, seckey_hex, pubkey_hex, aux_rand_hex, msg_hex, sig_hex, result_str, comment) = row
pubkey = bytes.fromhex(pubkey_hex)
msg = bytes.fromhex(msg_hex)
sig = bytes.fromhex(sig_hex)
result = result_str == 'TRUE'
print('\nTest vector', ('#' + index).rjust(3, ' ') + ':')
if seckey_hex != '':
seckey = bytes.fromhex(seckey_hex)
pubkey_actual = pubkey_gen(seckey)
if pubkey != pubkey_actual:
print(' * Failed key generation.')
print(' Expected key:', pubkey.hex().upper())
print(' Actual key:', pubkey_actual.hex().upper())
aux_rand = bytes.fromhex(aux_rand_hex)
try:
sig_actual = schnorr_sign(msg, seckey, aux_rand)
if sig == sig_actual:
print(' * Passed signing test.')
else:
print(' * Failed signing test.')
print(' Expected signature:', sig.hex().upper())
print(' Actual signature:', sig_actual.hex().upper())
all_passed = False
except RuntimeError as e:
print(' * Signing test raised exception:', e)
all_passed = False
result_actual = schnorr_verify(msg, pubkey, sig)
if result == result_actual:
print(' * Passed verification test.')
else:
print(' * Failed verification test.')
print(' Expected verification result:', result)
print(' Actual verification result:', result_actual)
if comment:
print(' Comment:', comment)
all_passed = False
print()
if all_passed:
print('All test vectors passed.')
else:
print('Some test vectors failed.')
return all_passed
#
# The following code is only used for debugging
#
import inspect
def pretty(v: Any) -> Any:
if isinstance(v, bytes):
return '0x' + v.hex()
if isinstance(v, int):
return pretty(bytes_from_int(v))
if isinstance(v, tuple):
return tuple(map(pretty, v))
return v
def debug_print_vars() -> None:
if DEBUG:
current_frame = inspect.currentframe()
assert current_frame is not None
frame = current_frame.f_back
assert frame is not None
print(' Variables in function ', frame.f_code.co_name, ' at line ', frame.f_lineno, ':', sep='')
for var_name, var_val in frame.f_locals.items():
print(' ' + var_name.rjust(11, ' '), '==', pretty(var_val))
if __name__ == '__main__':
test_vectors()
from typing import List
from reference import *
def cbytes_from_point(P: Point) -> bytes:
prefix = b'\x02' if has_even_y(P) else b'\x03'
return prefix + bytes_from_point(P)
def point_from_cbytes(b: bytes) -> Point:
prefix = b[:1]
x, y = lift_x(b[1:])
return (x, y) if prefix == b'\x02' else (x, p-y)
def xonly_point_agg(xonly_points: List[bytes]) -> Point:
P = None # point at infinity
for xonly_point in xonly_points:
P = point_add(P, lift_x(xonly_point))
return P
def xonly_int(b: bytes, P_agg: Point) -> int:
k = int_from_bytes(b)
if has_even_y(point_mul(G, k)) != has_even_y(P_agg):
k = n - k
return k
def partial_sig_agg(partial_sigs: List[bytes]) -> bytes:
s = 0
for partial_sig in partial_sigs:
s = (s + int_from_bytes(partial_sig)) % n
return bytes_from_int(s)
@siv2r
Copy link

siv2r commented Feb 7, 2022

Is a rogue key attack possible here?
Let's say:

pubkeys = [X1, X2, X3]
X1 = honest signer
X2 = forger
X3 = negate honest signer

In the test_forger(), aggregate pubkey will be calculated as agg_pubkey = bytes_from_point(xonly_point_agg(pubkeys)). The xonly_point_agg() always considers each of the given point's y to be even (since, it uses lift_x()).
Therefore, passing X3(negate of X1) will be considered as X1 during point aggregation.

I could make pubkeys = n*[X1] + [X2] to negate X1, but python shows an overflow error since n is huge. Also, can't use a for loop since the computation will take forever.

Am I missing something?

@robot-dreams
Copy link
Author

@siv2r You're definitely on the right track, but you might be trying something that's more complicated than you need.

As a starting point, given a public key X, how do you find a public key Y such that X + Y = 0? What does 0 even mean in that equation?

@siv2r
Copy link

siv2r commented Feb 8, 2022

given a public key X, how do you find a public key Y such that X + Y = 0? What does 0 even mean in that equation?

if X = (x, y) then, Y = (x, p-y) and 0 is point at infinity (represented as None in python).

I have thought about it for a long time. I still think a rogue key attack might not be possible here due to the use of xonly_point_agg() function to aggregate the pubkeys.

For example,

X1 = (x1, y1) --> honest signer pubkey
X2 = (x2, y2) --> forger pubkey
X3 =  ~X1 = (x1, p-y1) --> negated honest signer pubkey

Note: y1 = even (BIP-340 implicitly assumes this) hence, p-y1 = odd

pubkeys = [x(X1), x(X2), x(X3)] --> x(X1) = x co-ordinate of point X1 in bytes (i.e, x1 in 32 bytes).
        = [x1, x2, x1]

Now, xonly_point_agg(pubkeys) calculates:

Xagg = (x1, lift_x(x1)) + (x2, lift_x(x2)) + (x1, lift_x(x1))
     = X1 + X2 + X1 (failed to make Xagg = X2)

Instead, if a different aggregate function was used: (let's say it takes the whole point or compressed point)

Xagg = X1 + X2 + X3
     = X2 (successful rogue key attack)

Happy to take this conversation elsewhere since it might be spoiling the challenge for others.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment