Skip to content

Instantly share code, notes, and snippets.

@itzmeanjan
Last active August 5, 2022 18:04
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save itzmeanjan/d4853347dfdfa853993f5ea059824de6 to your computer and use it in GitHub Desktop.
Save itzmeanjan/d4853347dfdfa853993f5ea059824de6 to your computer and use it in GitHub Desktop.
Montgomery Modular Arithmetic for 256 -bit `secp256k1` Prime Field
#!/usr/bin/python3
from math import ceil
from typing import List, Tuple
from random import randint
def bit_count(num: int) -> int:
'''
Same as len(bin(num)[2:])
'''
cnt = 0
num_ = num
while(num > 0):
cnt += 1
num >>= 1
assert cnt == len(bin(num_)[2:])
return cnt
def calculate_mu() -> int:
'''
See algorithm 3 of https://eprint.iacr.org/2017/1057.pdf
'''
y = 1
for i in range(2, RADIX_BIT_LEN + 1):
if (PRIME * y) % (1 << i) != 1:
y = y + (1 << (i - 1))
return RADIX - y
# = p; See https://en.bitcoin.it/wiki/Secp256k1
PRIME: int = 0x_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFE_FFFFFC2F
PRIME_BIT_LEN: int = bit_count(PRIME)
# Can be rewritten using uint32_t data type of C
RADIX_BIT_LEN: int = 32
RADIX: int = 1 << RADIX_BIT_LEN
# = 8
LIMB_COUNT: int = ceil(PRIME_BIT_LEN / RADIX_BIT_LEN)
# = (2 ^ 32) ^ 8 = 2 ^ 256 % p
R: int = (RADIX ** LIMB_COUNT) % PRIME
# = (2 ^ 256) ^ 2 % p
R2: int = (R * R) % PRIME
MU: int = calculate_mu()
TEST_CNT: int = 1 << 10
def to_radix_r(num: int) -> List[int]:
'''
Converts large integer ( 256 -bit ) to radix-r interleaved representation | r = 2^32
'''
limbs = [0] * LIMB_COUNT
idx = 0
while num > 0:
limbs[idx] = num % RADIX
num //= RADIX
idx += 1
return limbs
def from_radix_r(limbs: List[int]) -> int:
'''
Converts radix-r interleaved representation, to large integer ( 256 -bit ) | r = 2^32
'''
cnt = len(limbs)
num = 0
for idx in range(cnt-1, -1, -1):
num = num * RADIX + limbs[idx]
return num
def adc(a: int, b: int, carry: int) -> Tuple[int, int]:
'''
See https://github.com/dusk-network/bls12_381/blob/ed4d87c6756c0020629edb5d8912a41e338ac85a/src/util.rs#L1-L6
'''
tmp = a + b + carry
return tmp & 0xffff_ffff, tmp >> 32
def mac(a: int, b: int, c: int, carry: int) -> Tuple[int, int]:
'''
See https://github.com/dusk-network/bls12_381/blob/ed4d87c6756c0020629edb5d8912a41e338ac85a/src/util.rs#L15-L20
'''
tmp = a + (b * c) + carry
return tmp & 0xffff_ffff, tmp >> 32
def bitwise_not(a: int) -> int:
'''
Same as `!a` in C
'''
return RADIX - 1 - a
def u256xu32(a: List[int], b: int, c: List[int]) -> List[int]:
'''
Inspired by https://github.com/dusk-network/bls12_381/blob/ed4d87c6756c0020629edb5d8912a41e338ac85a/src/fp.rs#L517-L522
'''
assert len(a) == 9 and len(c) == 8
a[0], carry = mac(a[0], b, c[0], 0)
a[1], carry = mac(a[1], b, c[1], carry)
a[2], carry = mac(a[2], b, c[2], carry)
a[3], carry = mac(a[3], b, c[3], carry)
a[4], carry = mac(a[4], b, c[4], carry)
a[5], carry = mac(a[5], b, c[5], carry)
a[6], carry = mac(a[6], b, c[6], carry)
a[7], a[8] = mac(a[7], b, c[7], carry)
return a
def mont_mult(a: List[int], b: List[int]) -> Tuple[List[int], int]:
'''
Inspired by https://github.com/dusk-network/bls12_381/blob/ed4d87c6756c0020629edb5d8912a41e338ac85a/src/fp.rs#L437-L560
and algorithm 2 of https://eprint.iacr.org/2017/1057.pdf
'''
assert len(a) == len(b)
prime = to_radix_r(PRIME)
cnt = len(a)
c = [0] * (cnt << 1)
c[0:9] = u256xu32(c[0:9], a[0], b)
q = (MU * c[0]) % RADIX
_, carry = mac(c[0], q, prime[0], 0)
c[1], carry = mac(c[1], q, prime[1], carry)
c[2], carry = mac(c[2], q, prime[2], carry)
c[3], carry = mac(c[3], q, prime[3], carry)
c[4], carry = mac(c[4], q, prime[4], carry)
c[5], carry = mac(c[5], q, prime[5], carry)
c[6], carry = mac(c[6], q, prime[6], carry)
c[7], carry = mac(c[7], q, prime[7], carry)
c[8], pc = adc(c[8], 0, carry)
c[1:10] = u256xu32(c[1:10], a[1], b)
q = (MU * c[1]) % RADIX
_, carry = mac(c[1], q, prime[0], 0)
c[2], carry = mac(c[2], q, prime[1], carry)
c[3], carry = mac(c[3], q, prime[2], carry)
c[4], carry = mac(c[4], q, prime[3], carry)
c[5], carry = mac(c[5], q, prime[4], carry)
c[6], carry = mac(c[6], q, prime[5], carry)
c[7], carry = mac(c[7], q, prime[6], carry)
c[8], carry = mac(c[8], q, prime[7], carry)
c[9], pc = adc(c[9], pc, carry)
c[2:11] = u256xu32(c[2:11], a[2], b)
q = (MU * c[2]) % RADIX
_, carry = mac(c[2], q, prime[0], 0)
c[3], carry = mac(c[3], q, prime[1], carry)
c[4], carry = mac(c[4], q, prime[2], carry)
c[5], carry = mac(c[5], q, prime[3], carry)
c[6], carry = mac(c[6], q, prime[4], carry)
c[7], carry = mac(c[7], q, prime[5], carry)
c[8], carry = mac(c[8], q, prime[6], carry)
c[9], carry = mac(c[9], q, prime[7], carry)
c[10], pc = adc(c[10], pc, carry)
c[3:12] = u256xu32(c[3:12], a[3], b)
q = (MU * c[3]) % RADIX
_, carry = mac(c[3], q, prime[0], 0)
c[4], carry = mac(c[4], q, prime[1], carry)
c[5], carry = mac(c[5], q, prime[2], carry)
c[6], carry = mac(c[6], q, prime[3], carry)
c[7], carry = mac(c[7], q, prime[4], carry)
c[8], carry = mac(c[8], q, prime[5], carry)
c[9], carry = mac(c[9], q, prime[6], carry)
c[10], carry = mac(c[10], q, prime[7], carry)
c[11], pc = adc(c[11], pc, carry)
c[4:13] = u256xu32(c[4:13], a[4], b)
q = (MU * c[4]) % RADIX
_, carry = mac(c[4], q, prime[0], 0)
c[5], carry = mac(c[5], q, prime[1], carry)
c[6], carry = mac(c[6], q, prime[2], carry)
c[7], carry = mac(c[7], q, prime[3], carry)
c[8], carry = mac(c[8], q, prime[4], carry)
c[9], carry = mac(c[9], q, prime[5], carry)
c[10], carry = mac(c[10], q, prime[6], carry)
c[11], carry = mac(c[11], q, prime[7], carry)
c[12], pc = adc(c[12], pc, carry)
c[5:14] = u256xu32(c[5:14], a[5], b)
q = (MU * c[5]) % RADIX
_, carry = mac(c[5], q, prime[0], 0)
c[6], carry = mac(c[6], q, prime[1], carry)
c[7], carry = mac(c[7], q, prime[2], carry)
c[8], carry = mac(c[8], q, prime[3], carry)
c[9], carry = mac(c[9], q, prime[4], carry)
c[10], carry = mac(c[10], q, prime[5], carry)
c[11], carry = mac(c[11], q, prime[6], carry)
c[12], carry = mac(c[12], q, prime[7], carry)
c[13], pc = adc(c[13], pc, carry)
c[6:15] = u256xu32(c[6:15], a[6], b)
q = (MU * c[6]) % RADIX
_, carry = mac(c[6], q, prime[0], 0)
c[7], carry = mac(c[7], q, prime[1], carry)
c[8], carry = mac(c[8], q, prime[2], carry)
c[9], carry = mac(c[9], q, prime[3], carry)
c[10], carry = mac(c[10], q, prime[4], carry)
c[11], carry = mac(c[11], q, prime[5], carry)
c[12], carry = mac(c[12], q, prime[6], carry)
c[13], carry = mac(c[13], q, prime[7], carry)
c[14], pc = adc(c[14], pc, carry)
c[7:16] = u256xu32(c[7:16], a[7], b)
q = (MU * c[7]) % RADIX
_, carry = mac(c[7], q, prime[0], 0)
c[8], carry = mac(c[8], q, prime[1], carry)
c[9], carry = mac(c[9], q, prime[2], carry)
c[10], carry = mac(c[10], q, prime[3], carry)
c[11], carry = mac(c[11], q, prime[4], carry)
c[12], carry = mac(c[12], q, prime[5], carry)
c[13], carry = mac(c[13], q, prime[6], carry)
c[14], carry = mac(c[14], q, prime[7], carry)
c[15], pc = adc(c[15], pc, carry)
c[8] += (pc * 977)
c[9] += pc
return c[8:16]
def mont_add(a: List[int], b: List[int]) -> List[int]:
'''
Collects some inspiration from https://github.com/dusk-network/bls12_381/blob/2c679a284c008475b543a67ee2300ee58ffe5d11/src/fp.rs#L394-L405
'''
assert len(a) == len(b)
c = [0] * len(a)
c[0], carry = adc(a[0], b[0], 0)
c[1], carry = adc(a[1], b[1], carry)
c[2], carry = adc(a[2], b[2], carry)
c[3], carry = adc(a[3], b[3], carry)
c[4], carry = adc(a[4], b[4], carry)
c[5], carry = adc(a[5], b[5], carry)
c[6], carry = adc(a[6], b[6], carry)
c[7], carry = adc(a[7], b[7], carry)
c[0] += (carry * 977)
c[1] += carry
return c
def mont_inv(a: List[int]) -> List[int]:
'''
Collects inspiration from https://github.com/dusk-network/bls12_381/blob/2c679a284c008475b543a67ee2300ee58ffe5d11/src/fp.rs#L355-L370
'''
def pow(a: List[int], b: List[int]) -> List[int]:
res = to_radix_r(R)
for i in reversed(b):
for j in reversed(range(RADIX_BIT_LEN)):
res = mont_mult(res, res)
if (i >> j) & 1:
res = mont_mult(res, a)
return res
return pow(a, to_radix_r(PRIME-2))
def to_mont(a: List[int]) -> List[int]:
'''
Just like https://github.com/dusk-network/bls12_381/blob/ed4d87c6756c0020629edb5d8912a41e338ac85a/src/fp.rs#L251-L253;
for better understanding read section 2.2 of https://eprint.iacr.org/2017/1057.pdf
'''
return mont_mult(a, to_radix_r(R2))
def from_mont(a: List[int]) -> List[int]:
'''
Read section 2.2 of https://eprint.iacr.org/2017/1057.pdf
'''
return mont_mult(a, to_radix_r(1))
# --- Testing ---
def test_to_and_from_mont_repr():
'''
Test with random secp256k1 field elements whether convertion in between radix-r and montgomery representation
is behaving as expected
'''
for _ in range(TEST_CNT):
a = randint(0, PRIME-1)
b = from_radix_r(from_mont(to_mont(to_radix_r(a))))
assert a == b, f'expeted {a}, found {b}'
def test_mont_mult():
'''
Test if modular multiplication of two randomly generated secp256k1 prime field elements, using Montgomery algorithm,
is behaving as expected
'''
for _ in range(TEST_CNT):
a = randint(0, PRIME-1)
b = randint(0, PRIME-1)
c = (a * b) % PRIME
d = from_radix_r(from_mont(mont_mult(
to_mont(to_radix_r(a)),
to_mont(to_radix_r(b)))))
assert c == d, f'expected {c}, found {d}'
def test_mont_add():
'''
Test if modular addition of two randomly generated secp256k1 prime field elements, in Montgomery representation,
is behaving as expected
'''
for _ in range(TEST_CNT):
a = randint(0, PRIME-1)
b = randint(0, PRIME-1)
c = (a + b) % PRIME
d = from_radix_r(from_mont(mont_add(
to_mont(to_radix_r(a)),
to_mont(to_radix_r(b)))))
assert c == d, f'expected {c}, found {d}'
def test_mont_inv():
'''
Test if modular multiplicative inverse of one randomly generated secp256k1 prime field element, in Montgomery representation,
is behaving as expected
'''
for _ in range(TEST_CNT):
a = randint(1, PRIME-1)
b = mont_inv(to_mont(to_radix_r(a)))
c = from_radix_r(from_mont(mont_inv(b)))
assert a == c, f'expected {a}, found {c}'
if __name__ == '__main__':
print('Use `pytest` to run test cases !')
@itzmeanjan
Copy link
Author

itzmeanjan commented Mar 26, 2022

This is a reference implementation of Montgomery Modular Arithmetic for prime field

F_p | p = 0x_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFE_FFFFFC2F

For running the tests

  • Install pytest
python3 -m pip install --user pytest
  • Run tests
pytest -v

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment