RSA encryption and decryption
import base64
import hashlib
import os
import secrets
import sys
from typing import Tuple
from Crypto.Util.asn1 import DerBitString, DerOctetString, DerSequence
def powmod(base: int, exponent: int, modulus: int) -> int:
if modulus == 1:
return 0
acc = 1
base = base % modulus
while exponent > 0:
if exponent % 2 == 1:
acc = (acc * base) % modulus
base = (base * base) % modulus
exponent = exponent >> 1
return acc
def count_octet(n: int) -> int:
ans = 0
while n > 0:
n = n // 256
ans = ans + 1
return ans
# ref.
def mgf1(seed: bytes, length: int, hash_func) -> bytes:
"""Mask generation function."""
hLen = hash_func().digest_size
# 1. If l > 2^32(hLen), output "mask too long" and stop.
if length > (hLen << 32):
raise ValueError("mask too long")
# 2. Let T be the empty octet string.
T = b""
# 3. For counter from 0 to \lceil{l / hLen}\rceil-1, do the following:
# Note: \lceil{l / hLen}\rceil-1 is the number of iterations needed,
# but it's easier to check if we have reached the desired length.
counter = 0
while len(T) < length:
# a. Convert counter to an octet string C of length 4 with the primitive I2OSP: C = I2OSP (counter, 4)
C = int.to_bytes(counter, 4, "big")
# b. Concatenate the hash of the seed Z and C to the octet string T: T = T || Hash (Z || C)
T += hash_func(seed + C).digest()
counter += 1
# 4. Output the leading l octets of T as the octet string mask.
return T[:length]
# ref.
def oaep_encode_sha256(message: bytes, k: int) -> bytes:
hash_func = hashlib.sha256
lHash = bytes.fromhex("e3b0c442 98fc1c14 9afbf4c8 996fb924 27ae41e4 649b934c a495991b 7852b855")
mLen = len(message)
hLen = hash_func().digest_size
# b. Generate a padding string PS consisting of k - mLen - 2hLen - 2 zero octets.
PS = bytes([0x00] * (k - mLen - 2 * hLen - 2))
# c. Concatenate lHash, PS, a single octet with hexadecimal value 0x01, and the message M
# to form a data block DB of length k - hLen - 1 octets as DB = lHash || PS || 0x01 || M.
DB = lHash + PS + bytes([0x01]) + message
# d. Generate a random octet string seed of length hLen.
seed = secrets.token_bytes(hLen)
# e. Let dbMask = MGF(seed, k - hLen - 1).
dbMask = mgf1(seed, k - hLen - 1, hash_func)
# f. Let maskedDB = DB \xor dbMask.
maskedDB = bytes(x ^ y for x, y in zip(DB, dbMask))
# g. Let seedMask = MGF(maskedDB, hLen).
seedMask = mgf1(maskedDB, hLen, hash_func)
# h. Let maskedSeed = seed \xor seedMask.
maskedSeed = bytes(x ^ y for x, y in zip(seed, seedMask))
# i. Concatenate a single octet with hexadecimal value 0x00,
# maskedSeed, and maskedDB to form an encoded message EM of
# length k octets as EM = 0x00 || maskedSeed || maskedDB.
EM = bytes([0x00]) + maskedSeed + maskedDB
return EM
# ref.
def oaep_decode_sha256(EM: bytes, k: int) -> bytes:
hash_func = hashlib.sha256
lHash = bytes.fromhex("e3b0c442 98fc1c14 9afbf4c8 996fb924 27ae41e4 649b934c a495991b 7852b855")
hLen = hash_func().digest_size
# b. Separate the encoded message EM into a single octet Y, an
# octet string maskedSeed of length hLen, and an octet
# string maskedDB of length k - hLen - 1 as
# EM = Y || maskedSeed || maskedDB.
Y = EM[0]
maskedSeed = EM[1:(1+hLen)]
maskedDB = EM[(1+hLen):]
# c. Let seedMask = MGF(maskedDB, hLen)
seedMask = mgf1(maskedDB, hLen, hash_func)
# d. Let seed = maskedSeed \xor seedMask
seed = bytes(x ^ y for x, y in zip(maskedSeed, seedMask))
# e. Let dbMask = MGF(seed, k - hLen - 1)
dbMask = mgf1(seed, k - hLen - 1, hash_func)
# f. Let DB = maskedDB \xor dbMask
DB = bytes(x ^ y for x, y in zip(maskedDB, dbMask))
# g. Separate DB into an octet string lHash' of length hLen, a
# (possibly empty) padding string PS consisting of octets
# with hexadecimal value 0x00, and a message M as
# DB = lHash' || PS || 0x01 || M
lHash2 = DB[:len(lHash)]
idx_ps_end = len(lHash2)
while DB[idx_ps_end] == 0x00:
idx_ps_end += 1
PS = DB[len(lHash):idx_ps_end]
separator = DB[idx_ps_end:(idx_ps_end + 1)]
M = DB[(idx_ps_end + 1):]
if not lHash2 == lHash:
raise Exception("decryption error")
if any(PS):
raise Exception("decryption error")
if not separator == bytes([0x01]):
raise Exception("decryption error")
return M
# ref.
def OS2IP(X: bytes) -> int:
acc = 0
for i, b in enumerate(X):
acc = acc * 256 + b
return acc
# ref.
def I2OSP(x: int, xLen: int) -> bytes:
if x >= 256 ** xLen:
raise Exception("integer too large")
if x < 0:
raise Exception("x should be positive integer")
acc = []
while x > 0:
acc.append(x % 256)
x //= 256
while len(acc) < xLen:
return bytes(reversed(acc))
# ref.
def RSAEP(n: int, e: int, m: int) -> int:
if m < 0 or m >= n:
raise Exception("message representative out of range")
return powmod(m, e, n)
# ref.
def RSADP(n: int, d: int, c: int) -> int:
if c < 0 or c >= n:
raise Exception("ciphertext representative out of range")
return powmod(c, d, n)
def rsa_encrypt(EM: bytes, n: int, e: int, k: int) -> bytes:
m = OS2IP(EM)
c = RSAEP(n, e, m)
C = I2OSP(c, k)
return C
def rsa_decrypt(C: bytes, n: int, d: int, k: int) -> bytes:
c = OS2IP(C)
m = RSADP(n, d, c)
EM = I2OSP(m, k)
return EM
def read_rsa_pub_key(filename: str) -> Tuple[int, int]:
"""Read RSA public key file (SPKI format encoded in PEM) and return (n, e)"""
with open(filename, 'r') as f:
lines = f.readlines()
prefix = "-----BEGIN PUBLIC KEY-----"
suffix = "-----END PUBLIC KEY-----"
if len(lines) < 3:
raise Exception("illegal format of public key")
if not lines[0].strip() == prefix:
raise Exception("illegal format of public key")
if not lines[-1].strip() == suffix:
raise Exception("illegal format of public key")
b64_text = ""
for line in lines[1:len(lines) - 1]:
b64_text += line.strip()
SubjectPublicKeyInfo ::= SEQUENCE {
algorithm AlgorithmIdentifier,
subjectPublicKey BIT STRING }
spki_data = base64.b64decode(b64_text)
spki_der = DerSequence()
if len(spki_der) != 2:
raise Exception("illegal format of public key")
# algorithm must be rsaEncryption
algorithm_data = spki_der[0]
algorithm_der = DerSequence()
if len(algorithm_der) != 2:
raise Exception("ilegal format of public key")
if not algorithm_der[0] == bytes([0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01]):
raise Exception("illegal format of public key")
if not algorithm_der[1] == bytes([0x05, 0x00]):
raise Exception("illegal format of public key")
spk_data = spki_der[1]
spk_der = DerBitString()
RSAPublicKey ::= SEQUENCE {
modulus INTEGER, -- n
publicExponent INTEGER -- e
rsa_data = spk_der.value
rsa_der = DerSequence()
if len(rsa_der) != 2:
raise Exception("illegal format of public key")
return rsa_der[0], rsa_der[1]
def read_rsa_private_key(filename: str) -> Tuple[int, int, int]:
"""Read RSA private key file (PKCS#8 format encoded in PEM) and return (n, e, d)"""
with open(filename, 'r') as f:
lines = f.readlines()
prefix = "-----BEGIN PRIVATE KEY-----"
suffix = "-----END PRIVATE KEY-----"
if len(lines) < 3:
raise Exception("illegal format of private key")
if not lines[0].strip() == prefix:
raise Exception("illegal format of private key")
if not lines[-1].strip() == suffix:
raise Exception("illegal format of private key")
b64_text = ""
for line in lines[1:len(lines) - 1]:
b64_text += line.strip()
PrivateKeyInfo ::= SEQUENCE {
version Version,
privateKeyAlgorithm PrivateKeyAlgorithmIdentifier,
privateKey PrivateKey,
attributes [0] IMPLICIT Attributes OPTIONAL }
pki_data = base64.b64decode(b64_text)
pki_der = DerSequence()
if len(pki_der) < 3:
raise Exception("illegal format of public key")
# version must be 0
if pki_der[0] != 0:
raise Exception("illegal format of public key")
# algorithm must be rsaEncryption
algorithm_data = pki_der[1]
algorithm_der = DerSequence()
if len(algorithm_der) != 2:
raise Exception("illegal format of public key")
if not algorithm_der[0] == bytes([0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01]):
raise Exception("illegal format of public key")
if not algorithm_der[1] == bytes([0x05, 0x00]):
raise Exception("illegal format of public key")
# privateKey
private_key_data = pki_der[2]
private_key_der = DerOctetString()
RSAPrivateKey ::= SEQUENCE {
version Version,
modulus INTEGER, -- n
publicExponent INTEGER, -- e
privateExponent INTEGER, -- d
prime1 INTEGER, -- p
prime2 INTEGER, -- q
exponent1 INTEGER, -- d mod (p-1)
exponent2 INTEGER, -- d mod (q-1)
coefficient INTEGER, -- (inverse of q) mod p
otherPrimeInfos OtherPrimeInfos OPTIONAL
rsa_data = private_key_der.payload
rsa_der = DerSequence()
if len(rsa_der) < 9:
raise Exception("illegal format of public key")
# version must be 0 here
if rsa_der[0] != 0:
raise Exception("illegal format of public key")
return rsa_der[1], rsa_der[2], rsa_der[3]
def usage():
python3 encrypt <public_key_file> <plain_text_file> <output_encrypted_file>
python3 decrypt <private_key_file> <input_encrypted_file>
""".strip(), file=sys.stderr)
def main():
if len(sys.argv) < 2:
command = sys.argv[1]
if command == "encrypt":
if len(sys.argv) < 5:
public_key_file = sys.argv[2]
plain_text_file = sys.argv[3]
output_encrypted_file = sys.argv[4]
n, e = read_rsa_pub_key(public_key_file)
k = count_octet(n) # length in octets of the RSA modulus n
# print(n, e, k)
with open(plain_text_file, 'rb') as f:
plain_text =
EM = oaep_encode_sha256(plain_text, k)
# print([hex(x) for x in EM])
C = rsa_encrypt(EM, n, e, k)
# print([hex(x) for x in C])
with open(output_encrypted_file, 'wb') as f:
elif command == "decrypt":
if len(sys.argv) < 4:
private_key_file = sys.argv[2]
input_encrypted_file = sys.argv[3]
with open(input_encrypted_file, 'rb') as f:
encrypted =
n, e, d = read_rsa_private_key(private_key_file)
k = count_octet(n) # length in octets of the RSA modulus n
# print(n, e, d)
EM = rsa_decrypt(encrypted, n, d, k)
decrypted_text = oaep_decode_sha256(EM, k)
print(decrypted_text.decode('utf8'), end="")
if __name__ == '__main__':
