Skip to content

Instantly share code, notes, and snippets.

@lontivero
Created July 31, 2018 17:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lontivero/1d75b11e9243c0f2a75466db05776623 to your computer and use it in GitHub Desktop.
Save lontivero/1d75b11e9243c0f2a75466db05776623 to your computer and use it in GitHub Desktop.
Verify schnorr signatures in batch
import hashlib
import binascii
p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8)
def point_add(p1, p2):
if (p1 is None):
return p2
if (p2 is None):
return p1
if (p1[0] == p2[0] and p1[1] != p2[1]):
return None
if (p1 == p2):
lam = (3 * p1[0] * p1[0] * pow(2 * p1[1], p - 2, p)) % p
else:
lam = ((p2[1] - p1[1]) * pow(p2[0] - p1[0], p - 2, p)) % p
x3 = (lam * lam - p1[0] - p2[0]) % p
return (x3, (lam * (p1[0] - x3) - p1[1]) % p)
def point_mul(p, n):
r = None
for i in range(256):
if ((n >> i) & 1):
r = point_add(r, p)
p = point_add(p, p)
return r
def bytes_point(p):
return (b'\x03' if p[1] & 1 else b'\x02') + p[0].to_bytes(32, byteorder="big")
def sha256(b):
return int.from_bytes(hashlib.sha256(b).digest(), byteorder="big")
def on_curve(point):
return (pow(point[1], 2, p) - pow(point[0], 3, p)) % p == 7
def jacobi(x):
return pow(x, (p - 1) // 2, p)
def schnorr_sign(msg, seckey):
k = sha256(seckey.to_bytes(32, byteorder="big") + msg)
R = point_mul(G, k)
if jacobi(R[1]) != 1:
k = n - k
e = sha256(R[0].to_bytes(32, byteorder="big") + bytes_point(point_mul(G, seckey)) + msg)
return R[0].to_bytes(32, byteorder="big") + ((k + e * seckey) % n).to_bytes(32, byteorder="big")
def schnorr_verify(msg, pubkey, sig):
if (not on_curve(pubkey)):
return False
r = int.from_bytes(sig[0:32], byteorder="big")
s = int.from_bytes(sig[32:64], byteorder="big")
if r >= p or s >= n:
return False
e = sha256(sig[0:32] + bytes_point(pubkey) + msg)
R = point_add(point_mul(G, s), point_mul(pubkey, n - e))
if R is None or jacobi(R[1]) != 1 or R[0] != r:
return False
return True
def schnorr_batch_verify(pubkeys, ms, sigs, rands):
S =0
Rs=None
Ps=None
rands.insert(0, 1)
for i in range(len(pubkeys)):
pubkey = pubkeys[i]
sig = sigs[i]
msg = ms[i]
if (not on_curve(pubkey)):
return False
r = int.from_bytes(sig[0:32], byteorder="big")
s = int.from_bytes(sig[32:64], byteorder="big")
if r >= p or s >= n:
return False
e = sha256(sig[0:32] + bytes_point(pubkey) + msg)
c = pow(r, 3, p) + 7
y = pow(c, (p+1)//4, p)
if pow(y, 2, p) != c:
return False
a = rands[i]
S = S + s * a
Rs= point_add(Rs, point_mul((r, y), a))
Ps= point_add(Ps, point_mul(pubkey, a * e))
if point_mul(G, S) != point_add(Rs, Ps):
return False
return True
k = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798
pubkey = point_mul(G, k)
message = 0x0000000000000000000000000000000000000000000000000000000000000000.to_bytes(32, byteorder="big")
signature = schnorr_sign(message, k)
batch_verify([pubkey], [message], [signature], [])
schnorr_verify(message, pubkey, signature)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment