Elliptic Curves
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Elliptic Curve functions | |
# For educational purposes only. | |
# | |
# N. P. O'Donnell, 2020-2021 | |
from math import inf | |
import random | |
POINT_AT_INFINITY = (inf, inf) | |
""" | |
Note that the order (n) can be higher or lower than p, depending on the curve parameters. | |
For example in CURVE_TOY it's higher, but in CURVE_SECP256K1 it's lower. | |
""" | |
""" | |
Toy curve. | |
""" | |
CURVE_TOY = ( | |
17, # p | |
2, # a | |
2, # b | |
( | |
6, # G.x | |
14 # G.y | |
), | |
19, # n (order) | |
1 # h (cofactor) | |
) | |
""" | |
secp256k1 - A more serious curve... the one used in bitcoin. | |
""" | |
CURVE_SECP256K1 = ( | |
115792089237316195423570985008687907853269984665640564039457584007908834671663, # p | |
0, # a | |
7, # b | |
( | |
55066263022277343669578718895168534326250603453777594175500187360389116729240, # G.x | |
32670510020758816978083085130507043184471273380659243275938904335757337482424 # G.y | |
), | |
115792089237316195423570985008687907852837564279074904382605163141518161494337, # n (order) | |
1 # h (cofactor) | |
) | |
def egcd(a: int, b: int): | |
""" | |
Extended Euclidean Algorithm | |
""" | |
if a == 0: | |
return (b, 0, 1) | |
else: | |
g, y, x = egcd(b % a, a) | |
return (g, x - (b // a) * y, y) | |
def mai(a: int, n: int): | |
""" | |
Modular Additive Inverse (MAI) of a mod n | |
""" | |
return (n - a) % n | |
def mmi(a: int, n: int): | |
""" | |
Modular Mulplicative Inverse (MMI) of a mod n. The MMI exists iff | |
a and n are coprime. | |
""" | |
assert(n >= 1) | |
if a == 1 and n == 1: | |
return 0 | |
else: | |
g, x, _ = egcd(a, n) | |
if g == 1: | |
return x % n | |
elif g == -1: | |
return mai(x, n) | |
else: | |
raise ValueError(f"MMI does not exist for {a} modulo {n}") | |
def random_scalar(curve: tuple): | |
""" | |
Makes a random scalar, which when multiplied by the generator of | |
curve G will return an element in the group G generates. | |
All elements are equiprobable. | |
""" | |
return random.randint(1, curve[4] - 1) | |
def point_is_on_curve(curve: tuple, P: tuple): | |
"""Returns True if point P is on curve""" | |
if P == POINT_AT_INFINITY: | |
return True | |
else: | |
x, y = P | |
p, a, b, _, _, _ = curve | |
# L.H.S. of EC equation: y^2 | |
lhs = (y ** 2) % p | |
# R.H.S. of EC equation: x^3+ax+b | |
rhs = ((x ** 3) + (a * x) + b) % p | |
return lhs == rhs | |
def add_points(curve, P: tuple, Q: tuple): | |
"""Add two points""" | |
assert point_is_on_curve(curve, P) | |
assert point_is_on_curve(curve, Q) | |
p, a, b, G, n, h = curve | |
Px, Py = P | |
Qx, Qy = Q | |
if P == POINT_AT_INFINITY: | |
Rx, Ry = Q | |
elif Q == POINT_AT_INFINITY: | |
Rx, Ry = P | |
elif Px == Qx: | |
if Py == Qy: | |
# Point doubling | |
x, y = P | |
s = (((3 * x * x) + a) * mmi(2 * y, p)) % p | |
Rx = (s**2 - (2 * x)) % p | |
Ry = ((s * (x - Rx)) - y) % p | |
else: | |
# Opposite points | |
Rx, Ry = POINT_AT_INFINITY | |
else: | |
# Point addition | |
x1, y1 = P | |
x2, y2 = Q | |
s = ((y2 - y1) * mmi(x2 - x1, p)) % p | |
Rx = ((s**2) - x1 - x2) % p | |
Ry = ((s * (x1 - Rx)) - y1) % p | |
R = (Rx, Ry) | |
assert point_is_on_curve(curve, R) | |
return R | |
def multiply_point(curve: tuple, P: tuple, x: int): | |
"""Add point P to itself x - 1 times - a.k.a. multiply""" | |
assert point_is_on_curve(curve, P) | |
if x == 0: | |
return POINT_AT_INFINITY | |
elif x == 1: | |
return P | |
elif x == 2: | |
return add_points(curve, P, P) | |
elif x % 2 == 1: | |
return add_points(curve, multiply_point(curve, P, x - 1), P) | |
else: | |
return multiply_point(curve, multiply_point(curve, P, x >> 1), 2) | |
def make_privkey(curve: tuple): | |
"""Makes a private key suitable for curve""" | |
return random_scalar(curve) | |
def derive_pubkey(curve: tuple, privkey: int): | |
_, _, _, G, _, _ = curve | |
return multiply_point(curve, G, privkey) | |
def ecdsa_sign(curve: tuple, privkey: int, msg_hash: int): | |
"""Create ECDSA signature""" | |
p, a, b, G, n, h = curve | |
k = random_scalar(curve) | |
R = multiply_point(curve, G, k) | |
r = R[0] | |
s = (((msg_hash + ((privkey * r) % n)) % n) * mmi(k, n)) % n | |
return r, s | |
def ecdsa_verify(curve: tuple, signature: tuple, pubkey: tuple, msg_hash: int): | |
"""Verify ECDSA Signature""" | |
p, a, b, G, n, h = curve | |
r, s = signature | |
w = mmi(s, n) | |
u1 = (w * msg_hash) % n | |
u2 = (w * r) % n | |
P = add_points(curve, multiply_point(curve, G, u1), multiply_point(curve, pubkey, u2)) | |
return P[0] == r | |
def main(): | |
curve = CURVE_TOY | |
assert add_points(curve, POINT_AT_INFINITY, POINT_AT_INFINITY) == POINT_AT_INFINITY | |
assert add_points(curve, POINT_AT_INFINITY, (6,14)) == (6, 14) | |
assert add_points(curve, (6, 14), POINT_AT_INFINITY) == (6, 14) | |
assert add_points(curve, (6, 14), (6, 3)) == POINT_AT_INFINITY | |
assert add_points(curve, (6, 3), (6, 14)) == POINT_AT_INFINITY | |
assert add_points(curve, (3, 1), (3, 1)) == (13, 7) | |
assert add_points(curve, (3, 1), (10, 11)) == (5, 1) | |
assert add_points(curve, (10, 11), (3, 1)) == (5, 1) | |
assert multiply_point(curve, (7, 6), 0) == POINT_AT_INFINITY | |
assert multiply_point(curve, (7, 6), 1) == (7, 6) | |
assert multiply_point(curve, (7, 6), 2) == (5, 16) | |
assert multiply_point(curve, (7, 6), 3) == (13, 7) | |
assert multiply_point(curve, (7, 6), 4) == (6, 14) | |
assert multiply_point(curve, (7, 6), 5) == (0, 6) | |
assert multiply_point(curve, (7, 6), 6) == (10, 11) | |
assert multiply_point(curve, (7, 6), 7) == (16, 13) | |
assert multiply_point(curve, (7, 6), 8) == (3, 16) | |
assert multiply_point(curve, (7, 6), 9) == (9, 16) | |
assert multiply_point(curve, (7, 6), 10) == (9, 1) | |
assert multiply_point(curve, (7, 6), 11) == (3, 1) | |
assert multiply_point(curve, (7, 6), 12) == (16, 4) | |
assert multiply_point(curve, (7, 6), 13) == (10, 6) | |
assert multiply_point(curve, (7, 6), 14) == (0, 11) | |
assert multiply_point(curve, (7, 6), 15) == (6, 3) | |
assert multiply_point(curve, (7, 6), 16) == (13, 10) | |
assert multiply_point(curve, (7, 6), 17) == (5, 1) | |
assert multiply_point(curve, (7, 6), 18) == (7, 11) | |
assert multiply_point(curve, (7, 6), 19) == POINT_AT_INFINITY | |
assert multiply_point(curve, (7, 6), 20) == (7, 6) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment