Skip to content

Instantly share code, notes, and snippets.

@kennyyu
Created May 8, 2020 08:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kennyyu/5fe34db4c039ff5d04b5659fd1bc7735 to your computer and use it in GitHub Desktop.
Save kennyyu/5fe34db4c039ff5d04b5659fd1bc7735 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import base64
import random
def exp(n, e, mod=None):
"""
Returns n^e (mod base if specified)
"""
result = 1
current_pow = n
if mod:
current_pow = current_pow % mod
while e != 0:
bit = e & 1
e = e >> 1
if bit:
result *= current_pow
if mod:
result = result % mod
current_pow = current_pow * current_pow
if mod:
current_pow = current_pow % mod
return result
# print(f"5, 0, {exp(5, 0)}")
# print(f"5, 1, {exp(5, 1)}")
# print(f"5, 2, {exp(5, 2)}")
# print(f"5, 3, {exp(5, 3)}")
# print(f"5, 4, {exp(5, 4)}")
# print(f"5, 5, {exp(5, 5)}")
def extended_euclid(a, b):
"""
Returns (d, x, y) where d = gcd(a, b)
and ax + by = d
"""
if b == 0:
return (a, 1, 0)
# a == b * k + r
r = a % b
k = (a - r) // b
(d, x, y) = extended_euclid(b, r)
return (d, y, x - k * y)
# print(f"10, 6, {extended_euclid(10, 6)}")
# print(f"20, 17, {extended_euclid(20, 17)}")
# print(f"24, 16, {extended_euclid(24, 16)}")
def is_probably_prime(n, num_iter=50):
"""
Rabin-Miller:
Returns True if n is probably prime, where
P(n is not prime) < 1 / (2^num_iter)
"""
if n == 2:
return True
def get_t_and_u(n):
"""
Returns (t, u) where n = 2^t * u + 1
"""
n_1 = n - 1
t = 0
while n_1 % 2 == 0:
t += 1
n_1 = n_1 >> 1
u = n_1
return (t, u)
t, u = get_t_and_u(n)
for _ in range(num_iter):
# Generate all the powers:
# a^u, a^(2 * u), a^(4 * u), ..., a^(2^t * u)
a = random.randint(2, n - 2)
powers = [exp(a, u, mod=n)]
for _ in range(t):
curr_pow = powers[-1]
powers.append((curr_pow * curr_pow) % n)
# iterate backwards to check for non trivial
# square roots of 1
for i in range(len(powers) - 1, -1, -1):
curr_pow = powers[i]
if curr_pow == n - 1:
# inconclusive, try another a
print(
f"Inconclusive. n: {n}, t:{t}, u: {u}, n == 2^t * u + 1, a: {a}, powers: {powers}"
)
break
elif curr_pow == 1:
# keep going up verifying we have 1 or -1
continue
else:
# found a non-trivial square root of 1
print(
f"Found composite! n: {n}, t:{t}, u: {u}, n == 2^t * u + 1, a: {a}, powers: {powers}"
)
return False
print(f"Found probable prime! n: {n}, t:{t}, u: {u}, n == 2^t * u + 1")
return True
# print(f"5, {is_probably_prime(5)}")
# print(f"109, {is_probably_prime(109)}")
# print(f"221, {is_probably_prime(221)}")
# print(f"222, {is_probably_prime(222)}")
# print(f"65, {is_probably_prime(65)}")
# print(f"1012313453, {is_probably_prime(1012313453)}")
def rsa_make_keys(
prime_min=100000000000000000000000000000000000000000,
prime_max=1000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000,
):
"""
Returns ((n, e), (p, q, d)) representing (public key, private key) where:
- p and q are large primes
- n = p * q
- e is randomly chosen where gcd((p - 1)(q - 1), e) == 1
- d = e^(-1) mod (p - 1)(q - 1)
"""
def generate_prime(prime_min, prime_max):
"""
Returns a probable prime
"""
a = random.randint(prime_min, prime_max)
while not is_probably_prime(a):
a = random.randint(prime_min, prime_max)
return a
def generate_e_d(p, q):
"""
Finds an e such that gcd((p - 1)(q - 1), e) == 1,
and d such that d = e^(-1) mod (p - 1)(q - 1)
Returns (e, d).
"""
e = 3
while True:
# d * e + _ * (p - 1)(q - 1) = 1
# d * e = 1 mod (p - 1)(q - 1)
# d = e^(-1) mod (p - 1)(q - 1)
(gcd, d, _) = extended_euclid(e, (p - 1) * (q - 1))
if gcd == 1:
break
e += 1
# d might be negative, return the mod of it
return (e, d % ((p - 1) * (q - 1)))
p = generate_prime(prime_min, prime_max)
q = generate_prime(prime_min, prime_max)
n = p * q
e, d = generate_e_d(p, q)
return ((n, e), (p, q, d))
def encode_with_public_key(n, e, num):
return exp(num, e, mod=n)
def decode_with_private_key(p, q, d, num_encrypted):
return exp(num_encrypted, d, mod=p * q)
# Size of each individual message
# n must be bigger than this
MESSAGE_SIZE_BYTES = 16
def encode_message(n, e, message_str):
"""
Encodes a message with the public key. If the message
is large, this will divide up the message into chunks
"""
chunks = []
chunk_pos = 0
while chunk_pos < len(message_str):
chunk_max = min(chunk_pos + MESSAGE_SIZE_BYTES, len(message_str))
chunks.append(message_str[chunk_pos:chunk_max])
chunk_pos = chunk_max
return [encode_message_chunk(n, e, chunk) for chunk in chunks]
def decode_message(p, q, d, encrypted_chunks):
"""
Decodes a set of encrypted chunks and returns the final message
"""
chunks = [decode_message_chunk(p, q, d, chunk) for chunk in encrypted_chunks]
return "".join(chunks)
def encode_message_chunk(n, e, message_str):
"""
Encodes a string with the public key
"""
# add padding
message_str = message_str + ((MESSAGE_SIZE_BYTES - len(message_str)) * "\0")
message_bytes = str.encode(message_str)
message_int = int.from_bytes(message_bytes, byteorder="big", signed=False)
encrypted_int = encode_with_public_key(n, e, message_int)
encrypted_int_bytes = str.encode(str(encrypted_int))
return base64.b64encode(encrypted_int_bytes)
def decode_message_chunk(p, q, d, encrypted_message_str):
"""
Decodes an encrypted string using the private key
"""
encrypted_int_bytes = base64.b64decode(encrypted_message_str)
encrypted_int = int(encrypted_int_bytes.decode())
message_int = decode_with_private_key(p, q, d, encrypted_int)
message_bytes = message_int.to_bytes(
length=MESSAGE_SIZE_BYTES, byteorder="big", signed=False
)
message_str = message_bytes.decode()
# remove padding
cut_point = message_str.find("\0")
return message_str if cut_point == -1 else message_str[0:cut_point]
((n, e), (p, q, d)) = rsa_make_keys()
print(f"n:{n}")
print(f"e:{e}")
print(f"p:{p}")
print(f"q:{q}")
print(f"d:{d}")
num = 171717
num_enc = encode_with_public_key(n, e, num)
num_dec = decode_with_private_key(p, q, d, num_enc)
print(f"num: {num}, num_enc: {num_enc}, num_dec: {num_dec}")
message = "Hello World Everyone! Hello World Everyone! Hello World Everyone! Hello World Everyone!"
encrypted_message = encode_message(n, e, message)
decrypted_message = decode_message(p, q, d, encrypted_message)
print(f" message: {message}")
print(f"encrypted: {encrypted_message}")
print(f"decrypted: {decrypted_message}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment