Skip to content

Instantly share code, notes, and snippets.

@npodonnell
Last active July 25, 2023 12:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save npodonnell/f00b4773f9d2421c371a205830847570 to your computer and use it in GitHub Desktop.
Save npodonnell/f00b4773f9d2421c371a205830847570 to your computer and use it in GitHub Desktop.
Crypto Number Theory
#!/usr/bin/env python3
#
# Crypto & Number Theory
# N. P. O'Donnell, 2020
def odd(a):
"""
Is a odd?
"""
assert type(a) is int
return a % 2 == 1
def even(a):
"""
Is a even?
"""
assert type(a) is int
return a % 2 == 0
def posint(a):
"""
Is a a positive integer?
"""
assert type(a) is int
return a >= 1
def negint(a):
"""
Is a a negative integer?
"""
assert type(a) is int
return a <= -1
def divides(a, n):
"""
Does a divide n?
"""
assert type(a) is int
assert type(n) is int
return n % a == 0
def congruent(a, b, n):
"""
Returns true if a is congruent to b modulo n
"""
assert type(a) is int
return (a % n) == (b % n)
def factors(n):
"""
Yields positive factors of positive integer n
"""
assert posint(n)
for i in range(1, n + 1):
if n % i == 0:
yield i
def prime(a):
"""
Returns True if positive integer a is prime
"""
assert posint(a)
return len(list(factors(a))) == 2
def composite(a):
"""
Returns True if positive integer a is composite
"""
assert posint(a)
return len(list(factors(a))) >= 3
def odd_prime(a):
"""
Returns True if positive integer a is an odd prime
"""
assert posint(a)
return odd(a) and prime(a)
def primes(n):
"""
Returns prime numbers up and including n
"""
assert posint(n)
if n >= 2:
yield 2
for i in range(1, n + 1, 2):
if prime(i):
yield i
def prime_factorization(n):
"""
Returns prime factorization of n
"""
assert posint(n)
if prime(n):
yield n
else:
for p in primes(n):
if n % p == 0:
yield p
yield from prime_factorization(n // p)
break
def prime_power(n):
"""
Returns True if n is a prime power
"""
assert posint(n)
return len(set(prime_factorization(n))) == 1
def power_of_two(n):
"""
Returns True if n is a power of two
"""
assert posint(n)
return even(n) and prime_power(n)
def gcd(a, b):
"""
Euclidean Algorithm
"""
assert type(a) is int
assert type(b) is int
if b == 0:
return a
else:
return gcd(b, a % b)
def egcd(a, b):
"""
Extended Euclidean Algorithm
"""
assert type(a) is int
assert type(b) is int
if a == 0:
return (b, 0, 1)
else:
g, y, x = egcd(b % a, a)
return (g, x - (b // a) * y, y)
def lcm(a, b):
"""
Least Common Multiple
"""
assert type(a) is int
assert type(b) is int
return (a * b) // gcd(a, b)
def coprime(a, b):
"""
Returns True if a and b are coprime, otherwise False.
Two integers are coprime if 1 is the only positive
integer which evenly divides them both.
"""
assert type(a) is int
assert type(b) is int
return gcd(abs(a), abs(b)) == 1
def mai(a, n):
"""
Modular Additive Inverse (MAI) of a mod n
"""
assert type(a) is int
assert type(n) is int
return (n - a) % n
def mmi(a, n):
"""
Modular Mulplicative Inverse (MMI) of a mod n. The MMI exists iff
a and n are coprime.
"""
assert(n >= 1)
if a == 1 and n == 1:
return 0
else:
g, x, _ = egcd(a, n)
if g == 1:
return x
elif g == -1:
return mai(x, n)
else:
raise ValueError("MMI does not exist for {} modulo {}".format(a, n))
def mod_exp(a, x, n):
"""
Fast modular exponentiation of a to the power x mod n.
"""
if x == 0:
return 1
elif x == 1:
return a % n
elif x % 2 == 1:
return (a * mod_exp(a, x - 1, n)) % n
else:
return (mod_exp(a, x >> 1, n) ** 2) % n
def totatives(n):
"""
Returns the totatives of an integer n. A totative is
a positive integer less than or equal to n and coprime to n.
"""
assert posint(n)
for i in range(1, n + 1):
if coprime(i, n):
yield i
return totatives
def totient(n):
"""
Euler's totient function. The totient of a positive
integer n is is the number of totatives of n.
"""
assert posint(n)
return len(list(totatives(n)))
def eulers_criterion(a, p):
"""
Euler's criterion says that an integer a is quadratic
residue modulo an odd prime p iff a^((p-1)/2) = 1 mod p
"""
assert type(a) is int
return mod_exp(a, (p - 1) >> 1, p) == 1
def legendre_symbol(a, p):
"""
Legendre symbol is a number which is 1 if a is a
quadratic residue modulo a prime p, -1 if it's a
non-residue, or 0 if it is zero.
"""
assert type(a) is int
if a % p == 0:
return 0
else:
return 1 if eulers_criterion(a, p) else -1
def jacobi_symbol(a, n):
"""
Quadratic residue modulo an integer n - 1 if it's a residue, -1 if
it's a non-residue, or 0 if it is zero. The Jacobi symbol is the
product of the Legendre symbols for a and each prime factor p of n.
"""
assert type(a) is int
assert odd(n)
js = 1
for p in prime_factorization(n):
js *= legendre_symbol(a, p)
return js
def tonelli_shanks(a, p):
"""
Tonelli-Shanks algorithm solves x^2 = a mod p where p is prime. a must be a
quadratic residue modulo p. Returns both solutions in a tuple - the even one
followed by the odd one.
https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm
"""
def even_first(x1, x2):
return (x1, x2) if x1 % 2 == 0 else (x2, x1)
assert legendre_symbol(a, p) >= 0
if a % p == 0:
return 0, p
if congruent(p, 3, 4):
y1 = mod_exp(a, (p + 1) >> 2, p)
return even_first(y1, p - y1)
q = p - 1
s = 0
while q % 2 == 0:
s += 1
q >>= 1
# Find a QNR
z = 2
while legendre_symbol(z, p) >= 0:
z += 1
m = s
c = mod_exp(z, q, p)
t = mod_exp(a, q, p)
r = mod_exp(a, (q + 1) >> 1, p)
while (t % p) != 1:
for i in range(0, m):
if mod_exp(t, 2 ** i, p) == 1:
break
b = mod_exp(c, (2 ** (m - i - 1)), p)
m = i
c = mod_exp(b, 2, p)
t = t * mod_exp(b, 2, p)
r = (r * b) % p
return even_first(r, p - r)
def multiplicative_subgroups(n):
"""
Generates all multiplicative subgroups of the group Z_n*
"""
subgroups = []
for tot in totatives(n)[1:]: # skip 1
i = tot
subgroup = [i]
while i != 1:
i = (i * tot) % n
subgroup.append(i)
subgroups.append((tot, subgroup))
return subgroups
def main():
"""
Test odd
"""
assert odd(-3)
assert not odd(-2)
assert odd(-1)
assert not odd(0)
assert odd(+1)
assert not odd(+2)
assert odd(+3)
"""
Test even
"""
assert not even(-3)
assert even(-2)
assert not even(-1)
assert even(0)
assert not even(+1)
assert even(+2)
assert not even(+3)
"""
Test posint
"""
assert not posint(-3)
assert not posint(-2)
assert not posint(-1)
assert not posint(0)
assert posint(1)
assert posint(2)
assert posint(3)
"""
Test negint
"""
"""
Test divides
"""
assert divides(-3, -3)
assert not divides(-3, -2)
assert not divides(-3, -1)
assert divides(-3, 0)
assert not divides(-3, 1)
assert not divides(-3, 2)
assert divides(-3, 3)
assert not divides(-2, -3)
assert divides(-2, -2)
assert not divides(-2, -1)
assert divides(-2, -0)
assert not divides(-2, 1)
assert divides(-2, 2)
assert not divides(-2, 3)
assert divides(-1, -3)
assert divides(-1, -2)
assert divides(-1, -1)
assert divides(-1, -0)
assert divides(-1, 1)
assert divides(-1, 2)
assert divides(-1, 3)
assert divides(1, -3)
assert divides(1, -2)
assert divides(1, -1)
assert divides(1, -0)
assert divides(1, 1)
assert divides(1, 2)
assert divides(1, 3)
assert not divides(2, -3)
assert divides(2, -2)
assert not divides(2, -1)
assert divides(2, -0)
assert not divides(2, 1)
assert divides(2, 2)
assert not divides(2, 3)
assert divides(3, -3)
assert not divides(3, -2)
assert not divides(3, -1)
assert divides(3, 0)
assert not divides(3, 1)
assert not divides(3, 2)
assert divides(3, 3)
"""
Test congruent
See Theorem 3.1.3 - https://www.whitman.edu/mathematics/higher_math_online/section03.01.html
"""
# 1. a=a for any a
for n in [-3, -2, -1, 1, 2, 3]:
for a in range(-3, 4):
assert congruent(a, a, n)
# 2. a=b -> b=a
for n in [-3, -2, -1, 1, 2, 3]:
for a in range(-3, 4):
for b in range(-3, 4):
if congruent(a, b, n):
assert congruent(b, a, n)
# 3. a=b and b=c -> a=c
for n in [-3, -2, -1, 1, 2, 3]:
for a in range(-3, 4):
for b in range(-3, 4):
for c in range(-3, 4):
if congruent(a, b, n) and congruent(b, c, n):
assert congruent(a, c, n)
# 4. a=0 <-> n|a
for n in [-3, -2, -1, 1, 2, 3]:
for a in range(-3, 4):
assert congruent(a, 0, n) == divides(n, a)
"""
Test factors
"""
assert list(factors(1)) == [1]
assert list(factors(2)) == [1, 2]
assert list(factors(3)) == [1, 3]
assert list(factors(4)) == [1, 2, 4]
assert list(factors(5)) == [1, 5]
assert list(factors(6)) == [1, 2, 3, 6]
assert list(factors(7)) == [1, 7]
assert list(factors(8)) == [1, 2, 4, 8]
assert list(factors(9)) == [1, 3, 9]
assert list(factors(10)) == [1, 2, 5, 10]
assert list(factors(11)) == [1, 11]
assert list(factors(12)) == [1, 2, 3, 4, 6, 12]
"""
Test prime
"""
assert not prime(1)
assert prime(2)
assert prime(3)
assert not prime(4)
assert prime(5)
assert not prime(6)
assert prime(7)
assert not prime(8)
assert not prime(9)
assert not prime(10)
assert prime(11)
assert not prime(12)
"""
Test composite
"""
assert not composite(1)
assert not composite(2)
assert not composite(3)
assert composite(4)
assert not composite(5)
assert composite(6)
assert not composite(7)
assert composite(8)
assert composite(9)
assert composite(10)
assert not composite(11)
assert composite(12)
"""
Test odd_prime
"""
assert not odd_prime(1)
assert not odd_prime(2)
assert odd_prime(3)
assert not odd_prime(4)
assert odd_prime(5)
assert not odd_prime(6)
assert odd_prime(7)
assert not odd_prime(8)
assert not odd_prime(9)
assert not odd_prime(10)
assert odd_prime(11)
assert not odd_prime(12)
"""
Test primes
"""
assert list(primes(1)) == []
assert list(primes(2)) == [2]
assert list(primes(3)) == [2, 3]
assert list(primes(4)) == [2, 3]
assert list(primes(5)) == [2, 3, 5]
assert list(primes(6)) == [2, 3, 5]
assert list(primes(7)) == [2, 3, 5, 7]
assert list(primes(8)) == [2, 3, 5, 7]
assert list(primes(9)) == [2, 3, 5, 7]
assert list(primes(10)) == [2, 3, 5, 7]
assert list(primes(11)) == [2, 3, 5, 7, 11]
assert list(primes(12)) == [2, 3, 5, 7, 11]
"""
Test prime_factorization
"""
assert list(prime_factorization(1)) == []
assert list(prime_factorization(2)) == [2]
assert list(prime_factorization(3)) == [3]
assert list(prime_factorization(4)) == [2, 2]
assert list(prime_factorization(5)) == [5]
assert list(prime_factorization(6)) == [2, 3]
assert list(prime_factorization(7)) == [7]
assert list(prime_factorization(8)) == [2, 2, 2]
assert list(prime_factorization(9)) == [3, 3]
assert list(prime_factorization(10)) == [2, 5]
assert list(prime_factorization(11)) == [11]
assert list(prime_factorization(12)) == [2, 2, 3]
"""
Test prime_power
"""
assert not prime_power(1)
assert prime_power(2)
assert prime_power(3)
assert prime_power(4)
assert prime_power(5)
assert not prime_power(6)
assert prime_power(7)
assert prime_power(8)
assert prime_power(9)
assert not prime_power(10)
assert prime_power(11)
assert not prime_power(12)
"""
Test power_of_two
"""
assert not power_of_two(1)
assert power_of_two(2)
assert not power_of_two(3)
assert power_of_two(4)
assert not power_of_two(5)
assert not power_of_two(6)
assert not power_of_two(7)
assert power_of_two(8)
assert not power_of_two(9)
assert not power_of_two(10)
assert not power_of_two(11)
assert not power_of_two(12)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment