Skip to content

Instantly share code, notes, and snippets.

@tiqwab
Last active October 9, 2023 03:08
Show Gist options
  • Save tiqwab/3fc0cba339940c6687f2b05f060135fb to your computer and use it in GitHub Desktop.
Save tiqwab/3fc0cba339940c6687f2b05f060135fb to your computer and use it in GitHub Desktop.
RSA encryption and decryption
import base64
import hashlib
import os
import secrets
import sys
from typing import Tuple
from Crypto.Util.asn1 import DerBitString, DerOctetString, DerSequence
def powmod(base: int, exponent: int, modulus: int) -> int:
if modulus == 1:
return 0
acc = 1
base = base % modulus
while exponent > 0:
if exponent % 2 == 1:
acc = (acc * base) % modulus
base = (base * base) % modulus
exponent = exponent >> 1
return acc
def count_octet(n: int) -> int:
ans = 0
while n > 0:
n = n // 256
ans = ans + 1
return ans
# ref. https://en.wikipedia.org/wiki/Mask_generation_function
def mgf1(seed: bytes, length: int, hash_func) -> bytes:
"""Mask generation function."""
hLen = hash_func().digest_size
# https://www.ietf.org/rfc/rfc2437.txt
# 1. If l > 2^32(hLen), output "mask too long" and stop.
if length > (hLen << 32):
raise ValueError("mask too long")
# 2. Let T be the empty octet string.
T = b""
# 3. For counter from 0 to \lceil{l / hLen}\rceil-1, do the following:
# Note: \lceil{l / hLen}\rceil-1 is the number of iterations needed,
# but it's easier to check if we have reached the desired length.
counter = 0
while len(T) < length:
# a. Convert counter to an octet string C of length 4 with the primitive I2OSP: C = I2OSP (counter, 4)
C = int.to_bytes(counter, 4, "big")
# b. Concatenate the hash of the seed Z and C to the octet string T: T = T || Hash (Z || C)
T += hash_func(seed + C).digest()
counter += 1
# 4. Output the leading l octets of T as the octet string mask.
return T[:length]
# ref. https://datatracker.ietf.org/doc/html/rfc8017#section-7.1.1
def oaep_encode_sha256(message: bytes, k: int) -> bytes:
hash_func = hashlib.sha256
lHash = bytes.fromhex("e3b0c442 98fc1c14 9afbf4c8 996fb924 27ae41e4 649b934c a495991b 7852b855")
mLen = len(message)
hLen = hash_func().digest_size
# b. Generate a padding string PS consisting of k - mLen - 2hLen - 2 zero octets.
PS = bytes([0x00] * (k - mLen - 2 * hLen - 2))
# c. Concatenate lHash, PS, a single octet with hexadecimal value 0x01, and the message M
# to form a data block DB of length k - hLen - 1 octets as DB = lHash || PS || 0x01 || M.
DB = lHash + PS + bytes([0x01]) + message
# d. Generate a random octet string seed of length hLen.
seed = secrets.token_bytes(hLen)
# e. Let dbMask = MGF(seed, k - hLen - 1).
dbMask = mgf1(seed, k - hLen - 1, hash_func)
# f. Let maskedDB = DB \xor dbMask.
maskedDB = bytes(x ^ y for x, y in zip(DB, dbMask))
# g. Let seedMask = MGF(maskedDB, hLen).
seedMask = mgf1(maskedDB, hLen, hash_func)
# h. Let maskedSeed = seed \xor seedMask.
maskedSeed = bytes(x ^ y for x, y in zip(seed, seedMask))
# i. Concatenate a single octet with hexadecimal value 0x00,
# maskedSeed, and maskedDB to form an encoded message EM of
# length k octets as EM = 0x00 || maskedSeed || maskedDB.
EM = bytes([0x00]) + maskedSeed + maskedDB
return EM
# ref. https://datatracker.ietf.org/doc/html/rfc8017#section-7.1.2
def oaep_decode_sha256(EM: bytes, k: int) -> bytes:
hash_func = hashlib.sha256
lHash = bytes.fromhex("e3b0c442 98fc1c14 9afbf4c8 996fb924 27ae41e4 649b934c a495991b 7852b855")
hLen = hash_func().digest_size
# b. Separate the encoded message EM into a single octet Y, an
# octet string maskedSeed of length hLen, and an octet
# string maskedDB of length k - hLen - 1 as
# EM = Y || maskedSeed || maskedDB.
Y = EM[0]
maskedSeed = EM[1:(1+hLen)]
maskedDB = EM[(1+hLen):]
# c. Let seedMask = MGF(maskedDB, hLen)
seedMask = mgf1(maskedDB, hLen, hash_func)
# d. Let seed = maskedSeed \xor seedMask
seed = bytes(x ^ y for x, y in zip(maskedSeed, seedMask))
# e. Let dbMask = MGF(seed, k - hLen - 1)
dbMask = mgf1(seed, k - hLen - 1, hash_func)
# f. Let DB = maskedDB \xor dbMask
DB = bytes(x ^ y for x, y in zip(maskedDB, dbMask))
# g. Separate DB into an octet string lHash' of length hLen, a
# (possibly empty) padding string PS consisting of octets
# with hexadecimal value 0x00, and a message M as
# DB = lHash' || PS || 0x01 || M
lHash2 = DB[:len(lHash)]
idx_ps_end = len(lHash2)
while DB[idx_ps_end] == 0x00:
idx_ps_end += 1
PS = DB[len(lHash):idx_ps_end]
separator = DB[idx_ps_end:(idx_ps_end + 1)]
M = DB[(idx_ps_end + 1):]
if not lHash2 == lHash:
raise Exception("decryption error")
if any(PS):
raise Exception("decryption error")
if not separator == bytes([0x01]):
raise Exception("decryption error")
return M
# ref. https://datatracker.ietf.org/doc/html/rfc8017#section-4.2
def OS2IP(X: bytes) -> int:
acc = 0
for i, b in enumerate(X):
acc = acc * 256 + b
return acc
# ref. https://datatracker.ietf.org/doc/html/rfc8017#section-4.1
def I2OSP(x: int, xLen: int) -> bytes:
if x >= 256 ** xLen:
raise Exception("integer too large")
if x < 0:
raise Exception("x should be positive integer")
acc = []
while x > 0:
acc.append(x % 256)
x //= 256
while len(acc) < xLen:
acc.append(0x00)
return bytes(reversed(acc))
# ref. https://datatracker.ietf.org/doc/html/rfc8017#section-5.1.1
def RSAEP(n: int, e: int, m: int) -> int:
if m < 0 or m >= n:
raise Exception("message representative out of range")
return powmod(m, e, n)
# ref. https://datatracker.ietf.org/doc/html/rfc8017#section-5.1.2
def RSADP(n: int, d: int, c: int) -> int:
if c < 0 or c >= n:
raise Exception("ciphertext representative out of range")
return powmod(c, d, n)
def rsa_encrypt(EM: bytes, n: int, e: int, k: int) -> bytes:
m = OS2IP(EM)
c = RSAEP(n, e, m)
C = I2OSP(c, k)
return C
def rsa_decrypt(C: bytes, n: int, d: int, k: int) -> bytes:
c = OS2IP(C)
m = RSADP(n, d, c)
EM = I2OSP(m, k)
return EM
def read_rsa_pub_key(filename: str) -> Tuple[int, int]:
"""Read RSA public key file (SPKI format encoded in PEM) and return (n, e)"""
with open(filename, 'r') as f:
lines = f.readlines()
prefix = "-----BEGIN PUBLIC KEY-----"
suffix = "-----END PUBLIC KEY-----"
if len(lines) < 3:
raise Exception("illegal format of public key")
if not lines[0].strip() == prefix:
raise Exception("illegal format of public key")
if not lines[-1].strip() == suffix:
raise Exception("illegal format of public key")
b64_text = ""
for line in lines[1:len(lines) - 1]:
b64_text += line.strip()
"""
SubjectPublicKeyInfo ::= SEQUENCE {
algorithm AlgorithmIdentifier,
subjectPublicKey BIT STRING }
"""
spki_data = base64.b64decode(b64_text)
spki_der = DerSequence()
spki_der.decode(spki_data)
if len(spki_der) != 2:
raise Exception("illegal format of public key")
# algorithm must be rsaEncryption
algorithm_data = spki_der[0]
algorithm_der = DerSequence()
algorithm_der.decode(algorithm_data)
if len(algorithm_der) != 2:
raise Exception("ilegal format of public key")
if not algorithm_der[0] == bytes([0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01]):
raise Exception("illegal format of public key")
if not algorithm_der[1] == bytes([0x05, 0x00]):
raise Exception("illegal format of public key")
spk_data = spki_der[1]
spk_der = DerBitString()
spk_der.decode(spk_data)
"""
RSAPublicKey ::= SEQUENCE {
modulus INTEGER, -- n
publicExponent INTEGER -- e
}
"""
rsa_data = spk_der.value
rsa_der = DerSequence()
rsa_der.decode(rsa_data)
if len(rsa_der) != 2:
raise Exception("illegal format of public key")
return rsa_der[0], rsa_der[1]
def read_rsa_private_key(filename: str) -> Tuple[int, int, int]:
"""Read RSA private key file (PKCS#8 format encoded in PEM) and return (n, e, d)"""
with open(filename, 'r') as f:
lines = f.readlines()
prefix = "-----BEGIN PRIVATE KEY-----"
suffix = "-----END PRIVATE KEY-----"
if len(lines) < 3:
raise Exception("illegal format of private key")
if not lines[0].strip() == prefix:
raise Exception("illegal format of private key")
if not lines[-1].strip() == suffix:
raise Exception("illegal format of private key")
b64_text = ""
for line in lines[1:len(lines) - 1]:
b64_text += line.strip()
"""
PrivateKeyInfo ::= SEQUENCE {
version Version,
privateKeyAlgorithm PrivateKeyAlgorithmIdentifier,
privateKey PrivateKey,
attributes [0] IMPLICIT Attributes OPTIONAL }
"""
pki_data = base64.b64decode(b64_text)
pki_der = DerSequence()
pki_der.decode(pki_data)
if len(pki_der) < 3:
raise Exception("illegal format of public key")
# version must be 0
if pki_der[0] != 0:
raise Exception("illegal format of public key")
# algorithm must be rsaEncryption
algorithm_data = pki_der[1]
algorithm_der = DerSequence()
algorithm_der.decode(algorithm_data)
if len(algorithm_der) != 2:
raise Exception("illegal format of public key")
if not algorithm_der[0] == bytes([0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01]):
raise Exception("illegal format of public key")
if not algorithm_der[1] == bytes([0x05, 0x00]):
raise Exception("illegal format of public key")
# privateKey
private_key_data = pki_der[2]
private_key_der = DerOctetString()
private_key_der.decode(private_key_data)
"""
RSAPrivateKey ::= SEQUENCE {
version Version,
modulus INTEGER, -- n
publicExponent INTEGER, -- e
privateExponent INTEGER, -- d
prime1 INTEGER, -- p
prime2 INTEGER, -- q
exponent1 INTEGER, -- d mod (p-1)
exponent2 INTEGER, -- d mod (q-1)
coefficient INTEGER, -- (inverse of q) mod p
otherPrimeInfos OtherPrimeInfos OPTIONAL
}
"""
rsa_data = private_key_der.payload
rsa_der = DerSequence()
rsa_der.decode(rsa_data)
if len(rsa_der) < 9:
raise Exception("illegal format of public key")
# version must be 0 here
if rsa_der[0] != 0:
raise Exception("illegal format of public key")
return rsa_der[1], rsa_der[2], rsa_der[3]
def usage():
print("""
usage:
encrypt:
python3 rsa.py encrypt <public_key_file> <plain_text_file> <output_encrypted_file>
decrypt:
python3 rsa.py decrypt <private_key_file> <input_encrypted_file>
""".strip(), file=sys.stderr)
def main():
if len(sys.argv) < 2:
usage()
sys.exit(1)
command = sys.argv[1]
if command == "encrypt":
if len(sys.argv) < 5:
usage()
sys.exit(1)
public_key_file = sys.argv[2]
plain_text_file = sys.argv[3]
output_encrypted_file = sys.argv[4]
n, e = read_rsa_pub_key(public_key_file)
k = count_octet(n) # length in octets of the RSA modulus n
# print(n, e, k)
with open(plain_text_file, 'rb') as f:
plain_text = f.read()
EM = oaep_encode_sha256(plain_text, k)
# print([hex(x) for x in EM])
C = rsa_encrypt(EM, n, e, k)
# print([hex(x) for x in C])
with open(output_encrypted_file, 'wb') as f:
f.write(C)
elif command == "decrypt":
if len(sys.argv) < 4:
usage()
sys.exit(1)
private_key_file = sys.argv[2]
input_encrypted_file = sys.argv[3]
with open(input_encrypted_file, 'rb') as f:
encrypted = f.read()
n, e, d = read_rsa_private_key(private_key_file)
k = count_octet(n) # length in octets of the RSA modulus n
# print(n, e, d)
EM = rsa_decrypt(encrypted, n, d, k)
decrypted_text = oaep_decode_sha256(EM, k)
print(decrypted_text.decode('utf8'), end="")
else:
usage()
sys.exit(1)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment