-
-
Save soreatu/38bd088cb6f735777f3a5eb368b55188 to your computer and use it in GitHub Desktop.
Improved Bleichenbacher RSA Padding Oracle Attack
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python3 | |
# Copyright (C) 2019 Karim Kanso. All Rights Reserved. | |
# | |
# This program is free software: you can redistribute it and/or modify | |
# it under the terms of the GNU General Public License as published by | |
# the Free Software Foundation, either version 3 of the License, or | |
# (at your option) any later version. | |
# | |
# This program is distributed in the hope that it will be useful, | |
# but WITHOUT ANY WARRANTY; without even the implied warranty of | |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
# GNU General Public License for more details. | |
# | |
# You should have received a copy of the GNU General Public License | |
# along with this program. If not, see <http://www.gnu.org/licenses/>. | |
# The program in this file is a modified version of the Bleichenbacher | |
# RSA padding algorithm described in the paper, by Bardou et al., | |
# "Efficient Padding Oracle Attacks on Cryptographic Hardware", which | |
# imporves the original attack from a median of 163000 oracle calls to | |
# 14500. | |
# | |
# In most cases, the implementation here uses standard python | |
# definitions and does not attempt to do anything fancy that would | |
# otherwise obfuscate the implementation or add dependencies. | |
# | |
# This script was written to find the flag during a crypto challenge. | |
# | |
# To use the program, provide an alternative version of the | |
# local_setup function that returns a 4-tuple of the needed parameters | |
# for the algorithm. | |
# 1. ciphertext encoded as a byte object | |
# 2. an oracle function that goes from a byte object to boolean value | |
# (if oracle throws an exception the program will crash, ensure | |
# all needed error handling is provided) | |
# 3. public exponent | |
# 4. modulus | |
# | |
# The default local_setup function will either generate a new key pair | |
# or use a pre-generated pair. This is then fed into the oracle to | |
# decrypt the content. The oracle directly uses a decryption library | |
# so it will check the full format of the message is PKCS1.5 compliant | |
# (i.e. the padding aswell) and not just the first two bytes. | |
# | |
import cryptography.hazmat.primitives.asymmetric.rsa as rsa | |
from cryptography.hazmat.backends import default_backend | |
from cryptography.hazmat.primitives.asymmetric import padding | |
from cryptography.hazmat.primitives import serialization | |
import binascii | |
import math | |
import textwrap | |
def local_setup(newkey=False): | |
'generates a key pair for local testing' | |
print('Using local loop back oracle for testing') | |
if newkey: | |
print('* Generating a new key pair') | |
priv_key = rsa.generate_private_key( | |
public_exponent=65537, | |
key_size=1024, | |
backend=default_backend() | |
) | |
else: | |
print('* Using hard coded key pair') | |
priv_key = '''\ | |
-----BEGIN RSA PRIVATE KEY----- | |
MIIBOQIBAAJBALSh8S5eL8uEskw1L7QvoVOuLwK7rpKmdu0IQvu4e1gYf/HdsMp6 | |
P7V7h+5D8siAbBigk2lphU60skVYK6OXCmUCAwEAAQJACJ5dep/l2ekX9Mjo4MkR | |
AoQiHBhGaRrmO8MUJxyTTg2wptfKN/4qWPMtssACA6V3db4x4fyKeZi0MMAeEzAt | |
QQIhAOP2J3Uggj8y6RxqWdS21z+qpdznwl5tzxK+pOEiW8+1AiEAytmJckmUQmdu | |
Q2fwmrMlcJuwkk5ef126/NFrlty9HfECIF1IY0kYrnOyH5YTJwNWdqqE6C6HYBBI | |
Gw5umQXPi4ZpAiBJXL9+2+mI0otoSXEVIfFKdqQ3Zax7d9Smlr7IgvDKoQIgO5YK | |
bzoHpEb1T/UZ1MZFWNhoMvPG2MQs9B6KGVz0gGM= | |
-----END RSA PRIVATE KEY-----''' | |
priv_key = serialization.load_pem_private_key( | |
textwrap.dedent(priv_key).encode('utf-8'), | |
password=None, | |
backend=default_backend() | |
) | |
pub_key = priv_key.public_key() | |
pn = pub_key.public_numbers() | |
print(' keysize: {}'.format(priv_key.key_size)) | |
print(' e: {}'.format(pn.e)) | |
print(' n: {}'.format(pn.n)) | |
print(' p: {}'.format(priv_key.private_numbers().p)) | |
print(' q: {}'.format(priv_key.private_numbers().q)) | |
print(' d: {}'.format(priv_key.private_numbers().d)) | |
if newkey: | |
ciphertext = pub_key.encrypt( | |
b'hello world!!', | |
padding.PKCS1v15() | |
) | |
else: | |
ciphertext = binascii.unhexlify( | |
'1e2fb249ddd03554d0a7c27cc276ded8' + | |
'cbf5e4daf1b84c28eccd37118adb7d46' + | |
'9c29eab603220057df68b84d9fdd40b8' + | |
'b5835c16e09ecbcf8ee7ec634e534f32' | |
) | |
print(' c: {}'.format(binascii.hexlify(ciphertext))) | |
print() | |
def oracle(ct): | |
c = int.from_bytes(ct, 'big') | |
try: | |
priv_key.decrypt( | |
ct, | |
padding.PKCS1v15() | |
) | |
return True | |
except ValueError: | |
return False | |
return ciphertext, oracle, pn.e, pn.n | |
# these two defs avoid rounding issues with floating point during | |
# division (especially with large numbers) | |
def ceildiv(a, b): | |
return -(-a // b) | |
def floordiv(a, b): | |
return (a // b) | |
def egcd(a, b): | |
u, u1 = 1, 0 | |
v, v1 = 0, 1 | |
while b: | |
q, r = divmod(a, b) | |
u, u1 = u1, u - q * u1 | |
v, v1 = v1, v - q * v1 | |
a, b = b, r | |
return u, v, a | |
def gcd(a, b): | |
return egcd(a, b)[2] | |
def inverse(a, b): | |
return egcd(a, b)[0] % b | |
def lcm(a, b): | |
return a*b // gcd(a,b) | |
def lcm_n(l): | |
ns = l.copy() | |
for _ in range(int(math.log(len(ns), 2)) + 1): | |
ns_next = [] | |
for i in range(len(ns)//2): | |
ns_next.append(lcm(ns[2*i], ns[2*i+1])) | |
if len(ns) % 2 == 1: | |
ns_next.append(ns[-1]) | |
ns = ns_next | |
return ns[0] | |
oracle_ctr = 0 | |
def main(): | |
print('Bleichenbacher RSA padding algorithm') | |
print(' for more info see 1998 paper.') | |
print() | |
# setup parameters, change local_setup with alternative | |
# implementation, such as an oracle that uses a real server | |
ct, oracle, e, n = local_setup(newkey=True) | |
# byte length of n | |
k = int(ceildiv(math.log(n,2), 8)) | |
# convert ciphertext from bytes into integer | |
c = int.from_bytes(ct, 'big') | |
# lift oracle defition to take integers | |
def oracle_int(x): | |
global oracle_ctr | |
oracle_ctr = oracle_ctr + 1 | |
if oracle_ctr % 100000 == 0: | |
print("[{}K tries] ".format(oracle_ctr // 1000), end='', flush=True) | |
return oracle(x.to_bytes(k, 'big')) | |
# define B as size of ciphertext space | |
# as first two bytes are 00 02, use 2^(keysize - 16) | |
B = pow(2, 8 * (k-2)) | |
# precompute constants | |
_2B = 2 * B | |
_3B = 3 * B | |
multiply = lambda x, y: (x * pow(y, e, n)) % n | |
# should be identity as c is valid cipher text | |
c0 = multiply(c, 1) | |
assert c0 == c | |
i = 1 | |
M = [(_2B, _3B - 1)] | |
s = 1 | |
# ensure everything is working as expected | |
if oracle_int(c0): | |
print('Oracle ok, implicit step 1 passed') | |
else: | |
print('Oracle fail sanity check') | |
exit(1) | |
# trimming M0 | |
print('start case 1.b: ', end='', flush=True) | |
trimmers = [] | |
for t in range(2, 2**12+1): | |
if t <= 50: | |
for u in range(2*t//3, 3*t//2): | |
if GCD(u, t) == 1: | |
trimmers.append((u,t)) | |
else: # t > 50 | |
trimmers.append((t-1, t)) | |
trimmers.append((t+1, t)) | |
ut_pairs = [] | |
for u, t in trimmers: | |
if oracle_int(c0 * pow(u*inverse(t,n), e, n) % n): | |
ut_pairs.append((u, t)) | |
if ut_pairs: | |
t_prime = lcm_n([t for _, t in ut_pairs]) | |
u_min = min([u*t_prime//t for u, t in ut_pairs]) | |
u_max = max([u*t_prime//t for u, t in ut_pairs]) | |
a = _2B * t_prime//u_min | |
b = (_3B-1) * t_prime//u_max | |
M[0] = (a, b) | |
print('done. trimming M0 iterations: {} [{},{}]'.format(oracle_ctr, M[0][0], M[0][1])) | |
while True: | |
if i == 1: | |
print('start case 2.a: ', end='', flush=True) | |
j = 1 | |
ss = ceildiv(j*n+_2B, M[0][1]) | |
while not oracle_int(multiply(c0, ss)): | |
ss = ss + 1 | |
if ss >= ceildiv(j*n+_3B, M[0][0]): | |
j += 1 | |
ss = ceildiv(j*n+_2B, M[0][1]) | |
print('done. found s1 in {} iterations: {}'.format( | |
oracle_ctr-len(ut_pairs),ss)) | |
else: | |
assert i > 1 | |
if len(M) > 1: | |
print('start case 2.b: ', end='', flush=True) | |
ss = s + 1 | |
while not oracle_int(multiply(c0, ss)): | |
ss = ss + 1 | |
print('done. found s{} in {} iterations: {}'.format( | |
i, ss-s, ss)) | |
else: | |
print('start case 2.c: ', end='', flush=True) | |
assert len(M) == 1 | |
a, b = M[0] | |
r = ceildiv(2 * (b * s - _2B), n) | |
ctr = 0 | |
while True: | |
# note: the floor function below needed +1 added | |
# to it, this is not clear from the paper (see | |
# equation 2 in paper where \lt is used instead of | |
# \lte). | |
for ss in range( | |
ceildiv(_2B + r * n, b), | |
floordiv(_3B + r * n, a) + 1): | |
ctr = ctr + 1 | |
if oracle_int(multiply(c0, ss)): | |
break | |
else: | |
r = r + 1 | |
continue | |
break | |
print('done. found s{} in {} iterations: {}'.format(i, ctr, ss)) | |
# step 3, narrowing solutions | |
MM = [] | |
for a,b in M: | |
for r in range(ceildiv(a * ss - _3B + 1, n), | |
floordiv(b * ss - _2B, n) + 1): | |
m = ( | |
max(a, ceildiv(_2B + r * n, ss)), | |
min(b, floordiv(_3B - 1 + r * n, ss)) | |
) | |
if m not in MM: | |
MM.append(m) | |
print('found interval [{},{}]'.format(m[0],m[1])) | |
# step 4, compute solutions | |
M = MM | |
s = ss | |
i = i + 1 | |
if len(M) == 1 and M[0][0] == M[0][1]: | |
print() | |
print('Completed!') | |
print('used the oracle {} times'.format(oracle_ctr)) | |
# note, no need to find multiplicative inverse of s0 in n | |
# as s0 = 1, so M[0][0] is directly the message. | |
message = M[0][0].to_bytes(k, 'big') | |
print('raw decryption: {}'.format( | |
binascii.hexlify(message).decode('utf-8'))) | |
if message[0] != 0 or message[1] != 2: | |
return | |
message = message[message.index(b'\x00',1) + 1:] | |
print('unpadded message hex: {}'.format( | |
binascii.hexlify(message).decode('utf-8'))) | |
try: | |
print('unpadded message ascii: {}'.format( | |
message.decode('utf-8'))) | |
except UnicodeError: | |
pass | |
return | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment