Last active
October 9, 2023 03:08
-
-
Save tiqwab/3fc0cba339940c6687f2b05f060135fb to your computer and use it in GitHub Desktop.
RSA encryption and decryption
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import hashlib | |
import os | |
import secrets | |
import sys | |
from typing import Tuple | |
from Crypto.Util.asn1 import DerBitString, DerOctetString, DerSequence | |
def powmod(base: int, exponent: int, modulus: int) -> int: | |
if modulus == 1: | |
return 0 | |
acc = 1 | |
base = base % modulus | |
while exponent > 0: | |
if exponent % 2 == 1: | |
acc = (acc * base) % modulus | |
base = (base * base) % modulus | |
exponent = exponent >> 1 | |
return acc | |
def count_octet(n: int) -> int: | |
ans = 0 | |
while n > 0: | |
n = n // 256 | |
ans = ans + 1 | |
return ans | |
# ref. https://en.wikipedia.org/wiki/Mask_generation_function | |
def mgf1(seed: bytes, length: int, hash_func) -> bytes: | |
"""Mask generation function.""" | |
hLen = hash_func().digest_size | |
# https://www.ietf.org/rfc/rfc2437.txt | |
# 1. If l > 2^32(hLen), output "mask too long" and stop. | |
if length > (hLen << 32): | |
raise ValueError("mask too long") | |
# 2. Let T be the empty octet string. | |
T = b"" | |
# 3. For counter from 0 to \lceil{l / hLen}\rceil-1, do the following: | |
# Note: \lceil{l / hLen}\rceil-1 is the number of iterations needed, | |
# but it's easier to check if we have reached the desired length. | |
counter = 0 | |
while len(T) < length: | |
# a. Convert counter to an octet string C of length 4 with the primitive I2OSP: C = I2OSP (counter, 4) | |
C = int.to_bytes(counter, 4, "big") | |
# b. Concatenate the hash of the seed Z and C to the octet string T: T = T || Hash (Z || C) | |
T += hash_func(seed + C).digest() | |
counter += 1 | |
# 4. Output the leading l octets of T as the octet string mask. | |
return T[:length] | |
# ref. https://datatracker.ietf.org/doc/html/rfc8017#section-7.1.1 | |
def oaep_encode_sha256(message: bytes, k: int) -> bytes: | |
hash_func = hashlib.sha256 | |
lHash = bytes.fromhex("e3b0c442 98fc1c14 9afbf4c8 996fb924 27ae41e4 649b934c a495991b 7852b855") | |
mLen = len(message) | |
hLen = hash_func().digest_size | |
# b. Generate a padding string PS consisting of k - mLen - 2hLen - 2 zero octets. | |
PS = bytes([0x00] * (k - mLen - 2 * hLen - 2)) | |
# c. Concatenate lHash, PS, a single octet with hexadecimal value 0x01, and the message M | |
# to form a data block DB of length k - hLen - 1 octets as DB = lHash || PS || 0x01 || M. | |
DB = lHash + PS + bytes([0x01]) + message | |
# d. Generate a random octet string seed of length hLen. | |
seed = secrets.token_bytes(hLen) | |
# e. Let dbMask = MGF(seed, k - hLen - 1). | |
dbMask = mgf1(seed, k - hLen - 1, hash_func) | |
# f. Let maskedDB = DB \xor dbMask. | |
maskedDB = bytes(x ^ y for x, y in zip(DB, dbMask)) | |
# g. Let seedMask = MGF(maskedDB, hLen). | |
seedMask = mgf1(maskedDB, hLen, hash_func) | |
# h. Let maskedSeed = seed \xor seedMask. | |
maskedSeed = bytes(x ^ y for x, y in zip(seed, seedMask)) | |
# i. Concatenate a single octet with hexadecimal value 0x00, | |
# maskedSeed, and maskedDB to form an encoded message EM of | |
# length k octets as EM = 0x00 || maskedSeed || maskedDB. | |
EM = bytes([0x00]) + maskedSeed + maskedDB | |
return EM | |
# ref. https://datatracker.ietf.org/doc/html/rfc8017#section-7.1.2 | |
def oaep_decode_sha256(EM: bytes, k: int) -> bytes: | |
hash_func = hashlib.sha256 | |
lHash = bytes.fromhex("e3b0c442 98fc1c14 9afbf4c8 996fb924 27ae41e4 649b934c a495991b 7852b855") | |
hLen = hash_func().digest_size | |
# b. Separate the encoded message EM into a single octet Y, an | |
# octet string maskedSeed of length hLen, and an octet | |
# string maskedDB of length k - hLen - 1 as | |
# EM = Y || maskedSeed || maskedDB. | |
Y = EM[0] | |
maskedSeed = EM[1:(1+hLen)] | |
maskedDB = EM[(1+hLen):] | |
# c. Let seedMask = MGF(maskedDB, hLen) | |
seedMask = mgf1(maskedDB, hLen, hash_func) | |
# d. Let seed = maskedSeed \xor seedMask | |
seed = bytes(x ^ y for x, y in zip(maskedSeed, seedMask)) | |
# e. Let dbMask = MGF(seed, k - hLen - 1) | |
dbMask = mgf1(seed, k - hLen - 1, hash_func) | |
# f. Let DB = maskedDB \xor dbMask | |
DB = bytes(x ^ y for x, y in zip(maskedDB, dbMask)) | |
# g. Separate DB into an octet string lHash' of length hLen, a | |
# (possibly empty) padding string PS consisting of octets | |
# with hexadecimal value 0x00, and a message M as | |
# DB = lHash' || PS || 0x01 || M | |
lHash2 = DB[:len(lHash)] | |
idx_ps_end = len(lHash2) | |
while DB[idx_ps_end] == 0x00: | |
idx_ps_end += 1 | |
PS = DB[len(lHash):idx_ps_end] | |
separator = DB[idx_ps_end:(idx_ps_end + 1)] | |
M = DB[(idx_ps_end + 1):] | |
if not lHash2 == lHash: | |
raise Exception("decryption error") | |
if any(PS): | |
raise Exception("decryption error") | |
if not separator == bytes([0x01]): | |
raise Exception("decryption error") | |
return M | |
# ref. https://datatracker.ietf.org/doc/html/rfc8017#section-4.2 | |
def OS2IP(X: bytes) -> int: | |
acc = 0 | |
for i, b in enumerate(X): | |
acc = acc * 256 + b | |
return acc | |
# ref. https://datatracker.ietf.org/doc/html/rfc8017#section-4.1 | |
def I2OSP(x: int, xLen: int) -> bytes: | |
if x >= 256 ** xLen: | |
raise Exception("integer too large") | |
if x < 0: | |
raise Exception("x should be positive integer") | |
acc = [] | |
while x > 0: | |
acc.append(x % 256) | |
x //= 256 | |
while len(acc) < xLen: | |
acc.append(0x00) | |
return bytes(reversed(acc)) | |
# ref. https://datatracker.ietf.org/doc/html/rfc8017#section-5.1.1 | |
def RSAEP(n: int, e: int, m: int) -> int: | |
if m < 0 or m >= n: | |
raise Exception("message representative out of range") | |
return powmod(m, e, n) | |
# ref. https://datatracker.ietf.org/doc/html/rfc8017#section-5.1.2 | |
def RSADP(n: int, d: int, c: int) -> int: | |
if c < 0 or c >= n: | |
raise Exception("ciphertext representative out of range") | |
return powmod(c, d, n) | |
def rsa_encrypt(EM: bytes, n: int, e: int, k: int) -> bytes: | |
m = OS2IP(EM) | |
c = RSAEP(n, e, m) | |
C = I2OSP(c, k) | |
return C | |
def rsa_decrypt(C: bytes, n: int, d: int, k: int) -> bytes: | |
c = OS2IP(C) | |
m = RSADP(n, d, c) | |
EM = I2OSP(m, k) | |
return EM | |
def read_rsa_pub_key(filename: str) -> Tuple[int, int]: | |
"""Read RSA public key file (SPKI format encoded in PEM) and return (n, e)""" | |
with open(filename, 'r') as f: | |
lines = f.readlines() | |
prefix = "-----BEGIN PUBLIC KEY-----" | |
suffix = "-----END PUBLIC KEY-----" | |
if len(lines) < 3: | |
raise Exception("illegal format of public key") | |
if not lines[0].strip() == prefix: | |
raise Exception("illegal format of public key") | |
if not lines[-1].strip() == suffix: | |
raise Exception("illegal format of public key") | |
b64_text = "" | |
for line in lines[1:len(lines) - 1]: | |
b64_text += line.strip() | |
""" | |
SubjectPublicKeyInfo ::= SEQUENCE { | |
algorithm AlgorithmIdentifier, | |
subjectPublicKey BIT STRING } | |
""" | |
spki_data = base64.b64decode(b64_text) | |
spki_der = DerSequence() | |
spki_der.decode(spki_data) | |
if len(spki_der) != 2: | |
raise Exception("illegal format of public key") | |
# algorithm must be rsaEncryption | |
algorithm_data = spki_der[0] | |
algorithm_der = DerSequence() | |
algorithm_der.decode(algorithm_data) | |
if len(algorithm_der) != 2: | |
raise Exception("ilegal format of public key") | |
if not algorithm_der[0] == bytes([0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01]): | |
raise Exception("illegal format of public key") | |
if not algorithm_der[1] == bytes([0x05, 0x00]): | |
raise Exception("illegal format of public key") | |
spk_data = spki_der[1] | |
spk_der = DerBitString() | |
spk_der.decode(spk_data) | |
""" | |
RSAPublicKey ::= SEQUENCE { | |
modulus INTEGER, -- n | |
publicExponent INTEGER -- e | |
} | |
""" | |
rsa_data = spk_der.value | |
rsa_der = DerSequence() | |
rsa_der.decode(rsa_data) | |
if len(rsa_der) != 2: | |
raise Exception("illegal format of public key") | |
return rsa_der[0], rsa_der[1] | |
def read_rsa_private_key(filename: str) -> Tuple[int, int, int]: | |
"""Read RSA private key file (PKCS#8 format encoded in PEM) and return (n, e, d)""" | |
with open(filename, 'r') as f: | |
lines = f.readlines() | |
prefix = "-----BEGIN PRIVATE KEY-----" | |
suffix = "-----END PRIVATE KEY-----" | |
if len(lines) < 3: | |
raise Exception("illegal format of private key") | |
if not lines[0].strip() == prefix: | |
raise Exception("illegal format of private key") | |
if not lines[-1].strip() == suffix: | |
raise Exception("illegal format of private key") | |
b64_text = "" | |
for line in lines[1:len(lines) - 1]: | |
b64_text += line.strip() | |
""" | |
PrivateKeyInfo ::= SEQUENCE { | |
version Version, | |
privateKeyAlgorithm PrivateKeyAlgorithmIdentifier, | |
privateKey PrivateKey, | |
attributes [0] IMPLICIT Attributes OPTIONAL } | |
""" | |
pki_data = base64.b64decode(b64_text) | |
pki_der = DerSequence() | |
pki_der.decode(pki_data) | |
if len(pki_der) < 3: | |
raise Exception("illegal format of public key") | |
# version must be 0 | |
if pki_der[0] != 0: | |
raise Exception("illegal format of public key") | |
# algorithm must be rsaEncryption | |
algorithm_data = pki_der[1] | |
algorithm_der = DerSequence() | |
algorithm_der.decode(algorithm_data) | |
if len(algorithm_der) != 2: | |
raise Exception("illegal format of public key") | |
if not algorithm_der[0] == bytes([0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01]): | |
raise Exception("illegal format of public key") | |
if not algorithm_der[1] == bytes([0x05, 0x00]): | |
raise Exception("illegal format of public key") | |
# privateKey | |
private_key_data = pki_der[2] | |
private_key_der = DerOctetString() | |
private_key_der.decode(private_key_data) | |
""" | |
RSAPrivateKey ::= SEQUENCE { | |
version Version, | |
modulus INTEGER, -- n | |
publicExponent INTEGER, -- e | |
privateExponent INTEGER, -- d | |
prime1 INTEGER, -- p | |
prime2 INTEGER, -- q | |
exponent1 INTEGER, -- d mod (p-1) | |
exponent2 INTEGER, -- d mod (q-1) | |
coefficient INTEGER, -- (inverse of q) mod p | |
otherPrimeInfos OtherPrimeInfos OPTIONAL | |
} | |
""" | |
rsa_data = private_key_der.payload | |
rsa_der = DerSequence() | |
rsa_der.decode(rsa_data) | |
if len(rsa_der) < 9: | |
raise Exception("illegal format of public key") | |
# version must be 0 here | |
if rsa_der[0] != 0: | |
raise Exception("illegal format of public key") | |
return rsa_der[1], rsa_der[2], rsa_der[3] | |
def usage(): | |
print(""" | |
usage: | |
encrypt: | |
python3 rsa.py encrypt <public_key_file> <plain_text_file> <output_encrypted_file> | |
decrypt: | |
python3 rsa.py decrypt <private_key_file> <input_encrypted_file> | |
""".strip(), file=sys.stderr) | |
def main(): | |
if len(sys.argv) < 2: | |
usage() | |
sys.exit(1) | |
command = sys.argv[1] | |
if command == "encrypt": | |
if len(sys.argv) < 5: | |
usage() | |
sys.exit(1) | |
public_key_file = sys.argv[2] | |
plain_text_file = sys.argv[3] | |
output_encrypted_file = sys.argv[4] | |
n, e = read_rsa_pub_key(public_key_file) | |
k = count_octet(n) # length in octets of the RSA modulus n | |
# print(n, e, k) | |
with open(plain_text_file, 'rb') as f: | |
plain_text = f.read() | |
EM = oaep_encode_sha256(plain_text, k) | |
# print([hex(x) for x in EM]) | |
C = rsa_encrypt(EM, n, e, k) | |
# print([hex(x) for x in C]) | |
with open(output_encrypted_file, 'wb') as f: | |
f.write(C) | |
elif command == "decrypt": | |
if len(sys.argv) < 4: | |
usage() | |
sys.exit(1) | |
private_key_file = sys.argv[2] | |
input_encrypted_file = sys.argv[3] | |
with open(input_encrypted_file, 'rb') as f: | |
encrypted = f.read() | |
n, e, d = read_rsa_private_key(private_key_file) | |
k = count_octet(n) # length in octets of the RSA modulus n | |
# print(n, e, d) | |
EM = rsa_decrypt(encrypted, n, d, k) | |
decrypted_text = oaep_decode_sha256(EM, k) | |
print(decrypted_text.decode('utf8'), end="") | |
else: | |
usage() | |
sys.exit(1) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment