Skip to content
{{ message }}

Instantly share code, notes, and snippets.

# avalonalex/RSA.py

Last active Mar 6, 2021
A implementation of RSA public key encryption algorithms in python, this implementation is for educational purpose, and is not intended for real world use. Hope any one want to do computation like (a^b mode n) effectively find it useful.
 #!/usr/bin/env python import argparse import copy import math import pickle import random from itertools import combinations def euclid(a, b): """returns the Greatest Common Divisor of a and b""" a = abs(a) b = abs(b) if a < b: a, b = b, a while b != 0: a, b = b, a % b return a def coPrime(l): """returns 'True' if the values in the list L are all co-prime otherwise, it returns 'False'. """ for i, j in combinations(l, 2): if euclid(i, j) != 1: return False return True def extendedEuclid(a, b): """return a tuple of three values: x, y and z, such that x is the GCD of a and b, and x = y * a + z * b""" if a == 0: return b, 0, 1 else: g, y, x = extendedEuclid(b % a, a) return g, x - (b // a) * y, y def modInv(a, m): """returns the multiplicative inverse of a in modulo m as a positive value between zero and m-1""" # notice that a and m need to co-prime to each other. if coPrime([a, m]): linearCombination = extendedEuclid(a, m) return linearCombination % m else: return 0 def extractTwos(m): """m is a positive integer. A tuple (s, d) of integers is returned such that m = (2 ** s) * d.""" # the problem can be break down to count how many '0's are there in # the end of bin(m). This can be done this way: m & a stretch of '1's # which can be represent as (2 ** n) - 1. assert m >= 0 i = 0 while m & (2 ** i) == 0: i += 1 return i, m >> i def int2baseTwo(x): """x is a positive integer. Convert it to base two as a list of integers in reverse order as a list.""" # repeating x >>= 1 and x & 1 will do the trick assert x >= 0 bitInverse = [] while x != 0: bitInverse.append(x & 1) x >>= 1 return bitInverse def modExp(a, d, n): """returns a ** d (mod n)""" assert d >= 0 assert n >= 0 base2D = int2baseTwo(d) base2DLength = len(base2D) modArray = [] result = 1 for i in range(1, base2DLength + 1): if i == 1: modArray.append(a % n) else: modArray.append((modArray[i - 2] ** 2) % n) for i in range(0, base2DLength): if base2D[i] == 1: result *= base2D[i] * modArray[i] return result % n def millerRabin(n, k): """ Miller Rabin pseudo-prime test return True means likely a prime, (how sure about that, depending on k) return False means definitely a composite. Raise assertion error when n, k are not positive integers and n is not 1 """ assert n >= 1 # ensure n is bigger than 1 assert k > 0 # ensure k is a positive integer so everything down here makes sense if n == 2: return True # make sure to return True if n == 2 if n % 2 == 0: return False # immediately return False for all the even numbers bigger than 2 extract2 = extractTwos(n - 1) s = extract2 d = extract2 assert 2 ** s * d == n - 1 def tryComposite(a): """Inner function which will inspect whether a given witness will reveal the true identity of n. Will only be called within millerRabin""" x = modExp(a, d, n) if x == 1 or x == n - 1: return None else: for j in range(1, s): x = modExp(x, 2, n) if x == 1: return False elif x == n - 1: return None return False for i in range(0, k): a = random.randint(2, n - 2) if tryComposite(a) == False: return False return True # actually, we should return probably true. def primeSieve(k): """return a list with length k + 1, showing if list[i] == 1, i is a prime else if list[i] == 0, i is a composite, if list[i] == -1, not defined""" def isPrime(n): """return True is given number n is absolutely prime, return False is otherwise.""" for i in range(2, int(n ** 0.5) + 1): if n % i == 0: return False return True result = [-1] * (k + 1) for i in range(2, int(k + 1)): if isPrime(i): result[i] = 1 else: result[i] = 0 return result def findAPrime(a, b, k): """Return a pseudo prime number roughly between a and b, (could be larger than b). Raise ValueError if cannot find a pseudo prime after 10 * ln(x) + 3 tries. """ x = random.randint(a, b) for i in range(0, int(10 * math.log(x) + 3)): if millerRabin(x, k): return x else: x += 1 raise ValueError def newKey(a, b, k): """ Try to find two large pseudo primes roughly between a and b. Generate public and private keys for RSA encryption. Raises ValueError if it fails to find one""" try: p = findAPrime(a, b, k) while True: q = findAPrime(a, b, k) if q != p: break except: raise ValueError n = p * q m = (p - 1) * (q - 1) while True: e = random.randint(1, m) if coPrime([e, m]): break d = modInv(e, m) return (n, e, d) def string2numList(strn): """Converts a string to a list of integers based on ASCII values""" return [ ord(chars) for chars in pickle.dumps(strn) ] def numList2string(l): """Converts a list of integers to a string based on ASCII values""" return pickle.loads(''.join(map(chr, l))) def numList2blocks(l, n): """Take a list of integers(each between 0 and 127), and combines them into block size n using base 256. If len(L) % n != 0, use some random junk to fill L to make it.""" # Note that ASCII printable characters range is 0x20 - 0x7E returnList = [] toProcess = copy.copy(l) if len(toProcess) % n != 0: for i in range(0, n - len(toProcess) % n): toProcess.append(random.randint(32, 126)) for i in range(0, len(toProcess), n): block = 0 for j in range(0, n): block += toProcess[i + j] << (8 * (n - j - 1)) returnList.append(block) return returnList def blocks2numList(blocks, n): """inverse function of numList2blocks.""" toProcess = copy.copy(blocks) returnList = [] for numBlock in toProcess: inner = [] for i in range(0, n): inner.append(numBlock % 256) numBlock >>= 8 inner.reverse() returnList.extend(inner) return returnList def encrypt(message, modN, e, blockSize): """given a string message, public keys and blockSize, encrypt using RSA algorithms.""" numList = string2numList(message) numBlocks = numList2blocks(numList, blockSize) return [modExp(blocks, e, modN) for blocks in numBlocks] def decrypt(secret, modN, d, blockSize): """reverse function of encrypt""" numBlocks = [modExp(blocks, d, modN) for blocks in secret] numList = blocks2numList(numBlocks, blockSize) return numList2string(numList) def block_size(val): try: v = int(val) assert(v >= 10 and v <= 1000) except: raise argparse.ArgumentTypeError("{} is not a valid block size".format(val)) return val if __name__ == '__main__': parser = argparse.ArgumentParser() group = parser.add_mutually_exclusive_group(required=True) group.add_argument("-m", "--message", help="Text message to encrypt") group.add_argument("-f", "--file", type=file, help="Text file to encrypt") parser.add_argument("-b", "--block-size", type=block_size, default=15, help="Block size to break message info smaller trunks") args = parser.parse_args() print """ ------------------------------------------------------ This program is intended for the purpose pedagogy only ------------------------------------------------------ """ n, e, d = newKey(10 ** 100, 10 ** 101, 50) if args.message is not None: message = args.message else: print args.file try: message = args.file.read() finally: args.file.close() print "original message is {}".format(message) print "-"*80 cipher = encrypt(message, n, e, 15) print "cipher text is {}".format(cipher) print "-"*80 deciphered = decrypt(cipher, n, d, 15) print "decrypted message is {}".format(deciphered)

### Knexman427 commented Aug 1, 2014

 I couldn't get my program to calculate large power sums. This code really helped, thank you for posting.

### sadmicrowave commented Jan 13, 2015

 I think there may be an issue with your bit padding. For example, take our the word "earth" from your message variable and run again. You'll see a random set of decrypted values at the end of your plaintext message string. I'm thinking this is because we are not taking into consideration variable blocksizes which would typically only be an issue for the last block. Either it fits perfectly, or needs padding... what do you think?

### avalonalex commented May 31, 2015

 The reason for `numList2blocks(l, n)` to do padding is a naive way to prevent dictionary attack. A better way may be serialize message into a format with message length information before encrypt and descrialize to remove padding.

### avalonalex commented Jun 3, 2015

 (1) address @sadmicrowave 's concern. (2) better user interaction

### tomtyiu commented Jul 14, 2015

 Can I use it for encrypted virus scan purpose? It will not be for real world use.

### AlexGoodwin commented Sep 16, 2015

 thanks for this- very helpful from an academic standpoint!

### Magibela commented Jul 13, 2016

 i couldnt get my program running, the were lot of errrors and i was using python 3.3.... are there any libraries i need to add to this code for it to work?

### stuaxo commented Oct 11, 2016

 @Magibela start by adding brackets to the print statements + add a `from __future__ import print_function` to the top of the code. If there are still errors post them, it usually takes a little work to port scripts from python2 to python3
to join this conversation on GitHub. Already have an account? Sign in to comment