Skip to content

Instantly share code, notes, and snippets.

@kennyyu
Created May 11, 2020 16:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kennyyu/73775f1cbdeacdd0dab9194093024098 to your computer and use it in GitHub Desktop.
Save kennyyu/73775f1cbdeacdd0dab9194093024098 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
import base64
import json
import random
import sys
from typing import Dict, List, NamedTuple, Optional, Tuple
def exp(n: int, e: int, mod: Optional[int] = None) -> int:
"""
Returns n^e (mod base if specified)
"""
result = 1
current_pow = n
if mod:
current_pow = current_pow % mod
while e != 0:
bit = e & 1
e = e >> 1
if bit:
result *= current_pow
current_pow = current_pow * current_pow
if mod:
result = result % mod
current_pow = current_pow % mod
return result
def extended_euclid(a: int, b: int) -> Tuple[int, int, int]:
"""
Returns (d, x, y) where d = gcd(a, b)
and ax + by = d
"""
if b == 0:
return (a, 1, 0)
# a == b * k + r
r = a % b
k = (a - r) // b
(d, x, y) = extended_euclid(b, r)
return (d, y, x - k * y)
def is_probably_prime(n: int, num_iter: int = 50) -> bool:
"""
Rabin-Miller:
Returns True if n is probably prime, where
P(n is not prime) < 1 / (4^num_iter)
"""
if n == 2:
return True
def get_t_and_u(n: int) -> Tuple[int, int]:
"""
Returns (t, u) where n = 2^t * u + 1
"""
n_1 = n - 1
t = 0
while n_1 % 2 == 0:
t += 1
n_1 = n_1 >> 1
u = n_1
return (t, u)
t, u = get_t_and_u(n)
for _ in range(num_iter):
# Generate powers: a^u, a^(2 * u), a^(4 * u), ..., a^(2^t * u)
a = random.randint(2, n - 2)
powers = [exp(a, u, mod=n)]
for _ in range(t):
curr_pow = powers[-1]
powers.append((curr_pow * curr_pow) % n)
# iterate backwards to check for non trivial square roots of 1
for i in range(len(powers) - 1, -1, -1):
curr_pow = powers[i]
if curr_pow == n - 1:
# inconclusive, try another a
break
elif curr_pow == 1:
# keep going up verifying we have 1 or -1
continue
else:
# found a non-trivial square root of 1
return False
return True
class PublicKey(NamedTuple):
"""
Represents an RSA public key. Messages will be encrypted with this.
"""
# n = p * q
n: int
# exponent to encrypt
e: int
@staticmethod
def write_to_file(public_key: "PublicKey", name: str) -> None:
with open(name, "w") as f:
data: Dict[str, int] = {"n": public_key.n, "e": public_key.e}
json.dump(data, f, indent=4)
@staticmethod
def read_from_file(name: str) -> "PublicKey":
with open(name, "r") as f:
data = json.load(f)
return PublicKey(n=data["n"], e=data["e"])
class PrivateKey(NamedTuple):
"""
Represents an RSA private key. Messages will be decrypted with this.
"""
# two large primes
p: int
q: int
# exponent used to decrypt, d = e^(-1) mod (p-1)(q-1)
d: int
@staticmethod
def write_to_file(private_key: "PrivateKey", name: str) -> None:
with open(name, "w") as f:
data: Dict[str, int] = {
"p": private_key.p,
"q": private_key.q,
"d": private_key.d,
}
json.dump(data, f, indent=4)
@staticmethod
def read_from_file(name: str) -> "PrivateKey":
with open(name, "r") as f:
data = json.load(f)
return PrivateKey(p=data["p"], q=data["q"], d=data["d"])
def rsa_make_keys(
prime_min: int = 1 << 512, prime_max: int = 1 << 550
) -> Tuple[PublicKey, PrivateKey]:
"""
Returns ((n, e), (p, q, d)) representing (public key, private key) where:
- p and q are large primes
- n = p * q
- e is randomly chosen where gcd((p - 1)(q - 1), e) == 1
- d = e^(-1) mod (p - 1)(q - 1)
"""
def generate_prime(prime_min: int, prime_max: int) -> int:
"""
Returns a probable prime
"""
a = random.randint(prime_min, prime_max)
while not is_probably_prime(a):
a = random.randint(prime_min, prime_max)
return a
def generate_e_d(p: int, q: int) -> Tuple[int, int]:
"""
Finds an e such that gcd((p - 1)(q - 1), e) == 1,
and d such that d = e^(-1) mod (p - 1)(q - 1)
Returns (e, d).
"""
e = 3
while True:
# d * e + _ * (p - 1)(q - 1) = 1
# d * e = 1 mod (p - 1)(q - 1)
# d = e^(-1) mod (p - 1)(q - 1)
(gcd, d, _) = extended_euclid(e, (p - 1) * (q - 1))
if gcd == 1:
break
e += 1
# d might be negative, return the mod of it
return (e, d % ((p - 1) * (q - 1)))
p = generate_prime(prime_min, prime_max)
q = generate_prime(prime_min, prime_max)
n = p * q
e, d = generate_e_d(p, q)
return (PublicKey(n=n, e=e), PrivateKey(p=p, q=q, d=d))
def encode_with_public_key(public_key: PublicKey, message: int) -> int:
return exp(message, public_key.e, mod=public_key.n)
def decode_with_private_key(private_key: PrivateKey, message_encrypted: int) -> int:
return exp(message_encrypted, private_key.d, mod=private_key.p * private_key.q)
# Size of each individual message
# n must be bigger than this
MESSAGE_SIZE_BYTES: int = 128
def encode_message(public_key: PublicKey, message_str: str) -> List[str]:
"""
Encodes a message with the public key. If the message
is large, this will divide up the message into chunks
"""
chunks: List[str] = []
chunk_pos = 0
while chunk_pos < len(message_str):
chunk_max = min(chunk_pos + MESSAGE_SIZE_BYTES, len(message_str))
chunks.append(message_str[chunk_pos:chunk_max])
chunk_pos = chunk_max
return [encode_message_chunk(public_key, chunk) for chunk in chunks]
def decode_message(private_key: PrivateKey, encrypted_chunks: List[str]) -> str:
"""
Decodes a set of encrypted chunks and returns the final message
"""
chunks = [decode_message_chunk(private_key, chunk) for chunk in encrypted_chunks]
return "".join(chunks)
def encode_message_chunk(public_key: PublicKey, message_str: str) -> str:
"""
Encodes a string with the public key. Adds padding if needed.
"""
# add padding
message_str = message_str + ((MESSAGE_SIZE_BYTES - len(message_str)) * "\0")
# convert from str -> bytes -> int
message_bytes = str.encode(message_str)
message_int = int.from_bytes(message_bytes, byteorder="big", signed=False)
# encryt using RSA
encrypted_int = encode_with_public_key(public_key, message_int)
# convert from int -> str -> bytes -> base64 str
encrypted_int_bytes = str.encode(str(encrypted_int))
return base64.b64encode(encrypted_int_bytes).decode()
def decode_message_chunk(private_key: PrivateKey, encrypted_message_str: str) -> str:
"""
Decodes an encrypted string using the private key. Removes padding if
it was added.
"""
# convert from base64 str -> bytes -> str -> int
encrypted_int_bytes = base64.b64decode(str.encode(encrypted_message_str))
encrypted_int = int(encrypted_int_bytes.decode())
# decrypt using RSA
message_int = decode_with_private_key(private_key, encrypted_int)
# convert from int -> bytes -> str
message_bytes = message_int.to_bytes(
length=MESSAGE_SIZE_BYTES, byteorder="big", signed=False
)
message_str = message_bytes.decode()
# remove padding
cut_point = message_str.find("\0")
return message_str if cut_point == -1 else message_str[0:cut_point]
def command_keygen(args: argparse.Namespace) -> None:
"""
Generates a public and private key pair, and writes
them to the provided files.
"""
(public_key, private_key) = rsa_make_keys(
prime_min=(1 << args.bits_min), prime_max=(1 << args.bits_max)
)
PublicKey.write_to_file(public_key, args.public_key_file)
PrivateKey.write_to_file(private_key, args.private_key_file)
def command_encrypt(args: argparse.Namespace) -> None:
"""
Encryptes a message using the provided public key.
Reads the message to encrypt from from stdin.
"""
lines = sys.stdin.readlines()
message = "".join(lines)
public_key = PublicKey.read_from_file(args.public_key_file)
encrypted_chunks = encode_message(public_key, message)
print(json.dumps({"chunks": encrypted_chunks}, indent=4))
def command_decrypt(args: argparse.Namespace) -> None:
"""
Decrypts a message using the provided private key.
Reads the message to decrypt from from stdin.
"""
encrypted_data = json.load(sys.stdin)
private_key = PrivateKey.read_from_file(args.private_key_file)
message = decode_message(private_key, encrypted_data["chunks"])
# Don't add a trailing newline to get the exact original message
print(message, end="")
def main() -> None:
parser = argparse.ArgumentParser(description="RSA utility")
subparsers = parser.add_subparsers()
# Subcommand to generate keys
parser_keygen = subparsers.add_parser("keygen", description="Create RSA keys")
parser_keygen.add_argument(
"public_key_file", type=str, help="Output file for public key"
)
parser_keygen.add_argument(
"private_key_file", type=str, help="Output file for private key"
)
parser_keygen.add_argument(
"--bits_min",
type=int,
help="Minimum number of bits for primes used when generating keys",
default=512,
)
parser_keygen.add_argument(
"--bits_max",
type=int,
help="Maximum number of bits for primes used when generating keys",
default=550,
)
parser_keygen.set_defaults(func=command_keygen)
# Subcommand to encrypt a message
parser_encrypt = subparsers.add_parser(
"encrypt", description="Encrypt a message. Reads input from stdin."
)
parser_encrypt.add_argument(
"public_key_file", type=str, help="Public key to use to encrypt"
)
parser_encrypt.set_defaults(func=command_encrypt)
# Subcommand to decrypt a message
parser_decrypt = subparsers.add_parser(
"decrypt", description="Decrypt a message. Reads input from stdin"
)
parser_decrypt.add_argument(
"private_key_file", type=str, help="Private key to use to decrypt"
)
parser_decrypt.set_defaults(func=command_decrypt)
# Run the commands
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment