Created
May 11, 2020 16:24
-
-
Save kennyyu/73775f1cbdeacdd0dab9194093024098 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
import base64 | |
import json | |
import random | |
import sys | |
from typing import Dict, List, NamedTuple, Optional, Tuple | |
def exp(n: int, e: int, mod: Optional[int] = None) -> int: | |
""" | |
Returns n^e (mod base if specified) | |
""" | |
result = 1 | |
current_pow = n | |
if mod: | |
current_pow = current_pow % mod | |
while e != 0: | |
bit = e & 1 | |
e = e >> 1 | |
if bit: | |
result *= current_pow | |
current_pow = current_pow * current_pow | |
if mod: | |
result = result % mod | |
current_pow = current_pow % mod | |
return result | |
def extended_euclid(a: int, b: int) -> Tuple[int, int, int]: | |
""" | |
Returns (d, x, y) where d = gcd(a, b) | |
and ax + by = d | |
""" | |
if b == 0: | |
return (a, 1, 0) | |
# a == b * k + r | |
r = a % b | |
k = (a - r) // b | |
(d, x, y) = extended_euclid(b, r) | |
return (d, y, x - k * y) | |
def is_probably_prime(n: int, num_iter: int = 50) -> bool: | |
""" | |
Rabin-Miller: | |
Returns True if n is probably prime, where | |
P(n is not prime) < 1 / (4^num_iter) | |
""" | |
if n == 2: | |
return True | |
def get_t_and_u(n: int) -> Tuple[int, int]: | |
""" | |
Returns (t, u) where n = 2^t * u + 1 | |
""" | |
n_1 = n - 1 | |
t = 0 | |
while n_1 % 2 == 0: | |
t += 1 | |
n_1 = n_1 >> 1 | |
u = n_1 | |
return (t, u) | |
t, u = get_t_and_u(n) | |
for _ in range(num_iter): | |
# Generate powers: a^u, a^(2 * u), a^(4 * u), ..., a^(2^t * u) | |
a = random.randint(2, n - 2) | |
powers = [exp(a, u, mod=n)] | |
for _ in range(t): | |
curr_pow = powers[-1] | |
powers.append((curr_pow * curr_pow) % n) | |
# iterate backwards to check for non trivial square roots of 1 | |
for i in range(len(powers) - 1, -1, -1): | |
curr_pow = powers[i] | |
if curr_pow == n - 1: | |
# inconclusive, try another a | |
break | |
elif curr_pow == 1: | |
# keep going up verifying we have 1 or -1 | |
continue | |
else: | |
# found a non-trivial square root of 1 | |
return False | |
return True | |
class PublicKey(NamedTuple): | |
""" | |
Represents an RSA public key. Messages will be encrypted with this. | |
""" | |
# n = p * q | |
n: int | |
# exponent to encrypt | |
e: int | |
@staticmethod | |
def write_to_file(public_key: "PublicKey", name: str) -> None: | |
with open(name, "w") as f: | |
data: Dict[str, int] = {"n": public_key.n, "e": public_key.e} | |
json.dump(data, f, indent=4) | |
@staticmethod | |
def read_from_file(name: str) -> "PublicKey": | |
with open(name, "r") as f: | |
data = json.load(f) | |
return PublicKey(n=data["n"], e=data["e"]) | |
class PrivateKey(NamedTuple): | |
""" | |
Represents an RSA private key. Messages will be decrypted with this. | |
""" | |
# two large primes | |
p: int | |
q: int | |
# exponent used to decrypt, d = e^(-1) mod (p-1)(q-1) | |
d: int | |
@staticmethod | |
def write_to_file(private_key: "PrivateKey", name: str) -> None: | |
with open(name, "w") as f: | |
data: Dict[str, int] = { | |
"p": private_key.p, | |
"q": private_key.q, | |
"d": private_key.d, | |
} | |
json.dump(data, f, indent=4) | |
@staticmethod | |
def read_from_file(name: str) -> "PrivateKey": | |
with open(name, "r") as f: | |
data = json.load(f) | |
return PrivateKey(p=data["p"], q=data["q"], d=data["d"]) | |
def rsa_make_keys( | |
prime_min: int = 1 << 512, prime_max: int = 1 << 550 | |
) -> Tuple[PublicKey, PrivateKey]: | |
""" | |
Returns ((n, e), (p, q, d)) representing (public key, private key) where: | |
- p and q are large primes | |
- n = p * q | |
- e is randomly chosen where gcd((p - 1)(q - 1), e) == 1 | |
- d = e^(-1) mod (p - 1)(q - 1) | |
""" | |
def generate_prime(prime_min: int, prime_max: int) -> int: | |
""" | |
Returns a probable prime | |
""" | |
a = random.randint(prime_min, prime_max) | |
while not is_probably_prime(a): | |
a = random.randint(prime_min, prime_max) | |
return a | |
def generate_e_d(p: int, q: int) -> Tuple[int, int]: | |
""" | |
Finds an e such that gcd((p - 1)(q - 1), e) == 1, | |
and d such that d = e^(-1) mod (p - 1)(q - 1) | |
Returns (e, d). | |
""" | |
e = 3 | |
while True: | |
# d * e + _ * (p - 1)(q - 1) = 1 | |
# d * e = 1 mod (p - 1)(q - 1) | |
# d = e^(-1) mod (p - 1)(q - 1) | |
(gcd, d, _) = extended_euclid(e, (p - 1) * (q - 1)) | |
if gcd == 1: | |
break | |
e += 1 | |
# d might be negative, return the mod of it | |
return (e, d % ((p - 1) * (q - 1))) | |
p = generate_prime(prime_min, prime_max) | |
q = generate_prime(prime_min, prime_max) | |
n = p * q | |
e, d = generate_e_d(p, q) | |
return (PublicKey(n=n, e=e), PrivateKey(p=p, q=q, d=d)) | |
def encode_with_public_key(public_key: PublicKey, message: int) -> int: | |
return exp(message, public_key.e, mod=public_key.n) | |
def decode_with_private_key(private_key: PrivateKey, message_encrypted: int) -> int: | |
return exp(message_encrypted, private_key.d, mod=private_key.p * private_key.q) | |
# Size of each individual message | |
# n must be bigger than this | |
MESSAGE_SIZE_BYTES: int = 128 | |
def encode_message(public_key: PublicKey, message_str: str) -> List[str]: | |
""" | |
Encodes a message with the public key. If the message | |
is large, this will divide up the message into chunks | |
""" | |
chunks: List[str] = [] | |
chunk_pos = 0 | |
while chunk_pos < len(message_str): | |
chunk_max = min(chunk_pos + MESSAGE_SIZE_BYTES, len(message_str)) | |
chunks.append(message_str[chunk_pos:chunk_max]) | |
chunk_pos = chunk_max | |
return [encode_message_chunk(public_key, chunk) for chunk in chunks] | |
def decode_message(private_key: PrivateKey, encrypted_chunks: List[str]) -> str: | |
""" | |
Decodes a set of encrypted chunks and returns the final message | |
""" | |
chunks = [decode_message_chunk(private_key, chunk) for chunk in encrypted_chunks] | |
return "".join(chunks) | |
def encode_message_chunk(public_key: PublicKey, message_str: str) -> str: | |
""" | |
Encodes a string with the public key. Adds padding if needed. | |
""" | |
# add padding | |
message_str = message_str + ((MESSAGE_SIZE_BYTES - len(message_str)) * "\0") | |
# convert from str -> bytes -> int | |
message_bytes = str.encode(message_str) | |
message_int = int.from_bytes(message_bytes, byteorder="big", signed=False) | |
# encryt using RSA | |
encrypted_int = encode_with_public_key(public_key, message_int) | |
# convert from int -> str -> bytes -> base64 str | |
encrypted_int_bytes = str.encode(str(encrypted_int)) | |
return base64.b64encode(encrypted_int_bytes).decode() | |
def decode_message_chunk(private_key: PrivateKey, encrypted_message_str: str) -> str: | |
""" | |
Decodes an encrypted string using the private key. Removes padding if | |
it was added. | |
""" | |
# convert from base64 str -> bytes -> str -> int | |
encrypted_int_bytes = base64.b64decode(str.encode(encrypted_message_str)) | |
encrypted_int = int(encrypted_int_bytes.decode()) | |
# decrypt using RSA | |
message_int = decode_with_private_key(private_key, encrypted_int) | |
# convert from int -> bytes -> str | |
message_bytes = message_int.to_bytes( | |
length=MESSAGE_SIZE_BYTES, byteorder="big", signed=False | |
) | |
message_str = message_bytes.decode() | |
# remove padding | |
cut_point = message_str.find("\0") | |
return message_str if cut_point == -1 else message_str[0:cut_point] | |
def command_keygen(args: argparse.Namespace) -> None: | |
""" | |
Generates a public and private key pair, and writes | |
them to the provided files. | |
""" | |
(public_key, private_key) = rsa_make_keys( | |
prime_min=(1 << args.bits_min), prime_max=(1 << args.bits_max) | |
) | |
PublicKey.write_to_file(public_key, args.public_key_file) | |
PrivateKey.write_to_file(private_key, args.private_key_file) | |
def command_encrypt(args: argparse.Namespace) -> None: | |
""" | |
Encryptes a message using the provided public key. | |
Reads the message to encrypt from from stdin. | |
""" | |
lines = sys.stdin.readlines() | |
message = "".join(lines) | |
public_key = PublicKey.read_from_file(args.public_key_file) | |
encrypted_chunks = encode_message(public_key, message) | |
print(json.dumps({"chunks": encrypted_chunks}, indent=4)) | |
def command_decrypt(args: argparse.Namespace) -> None: | |
""" | |
Decrypts a message using the provided private key. | |
Reads the message to decrypt from from stdin. | |
""" | |
encrypted_data = json.load(sys.stdin) | |
private_key = PrivateKey.read_from_file(args.private_key_file) | |
message = decode_message(private_key, encrypted_data["chunks"]) | |
# Don't add a trailing newline to get the exact original message | |
print(message, end="") | |
def main() -> None: | |
parser = argparse.ArgumentParser(description="RSA utility") | |
subparsers = parser.add_subparsers() | |
# Subcommand to generate keys | |
parser_keygen = subparsers.add_parser("keygen", description="Create RSA keys") | |
parser_keygen.add_argument( | |
"public_key_file", type=str, help="Output file for public key" | |
) | |
parser_keygen.add_argument( | |
"private_key_file", type=str, help="Output file for private key" | |
) | |
parser_keygen.add_argument( | |
"--bits_min", | |
type=int, | |
help="Minimum number of bits for primes used when generating keys", | |
default=512, | |
) | |
parser_keygen.add_argument( | |
"--bits_max", | |
type=int, | |
help="Maximum number of bits for primes used when generating keys", | |
default=550, | |
) | |
parser_keygen.set_defaults(func=command_keygen) | |
# Subcommand to encrypt a message | |
parser_encrypt = subparsers.add_parser( | |
"encrypt", description="Encrypt a message. Reads input from stdin." | |
) | |
parser_encrypt.add_argument( | |
"public_key_file", type=str, help="Public key to use to encrypt" | |
) | |
parser_encrypt.set_defaults(func=command_encrypt) | |
# Subcommand to decrypt a message | |
parser_decrypt = subparsers.add_parser( | |
"decrypt", description="Decrypt a message. Reads input from stdin" | |
) | |
parser_decrypt.add_argument( | |
"private_key_file", type=str, help="Private key to use to decrypt" | |
) | |
parser_decrypt.set_defaults(func=command_decrypt) | |
# Run the commands | |
args = parser.parse_args() | |
args.func(args) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment