Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Improved Bleichenbacher RSA Padding Oracle Attack
#! /usr/bin/env python3
# Copyright (C) 2019 Karim Kanso. All Rights Reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# The program in this file is a modified version of the Bleichenbacher
# RSA padding algorithm described in the paper, by Bardou et al.,
# "Efficient Padding Oracle Attacks on Cryptographic Hardware", which
# imporves the original attack from a median of 163000 oracle calls to
# 14500.
#
# In most cases, the implementation here uses standard python
# definitions and does not attempt to do anything fancy that would
# otherwise obfuscate the implementation or add dependencies.
#
# This script was written to find the flag during a crypto challenge.
#
# To use the program, provide an alternative version of the
# local_setup function that returns a 4-tuple of the needed parameters
# for the algorithm.
# 1. ciphertext encoded as a byte object
# 2. an oracle function that goes from a byte object to boolean value
# (if oracle throws an exception the program will crash, ensure
# all needed error handling is provided)
# 3. public exponent
# 4. modulus
#
# The default local_setup function will either generate a new key pair
# or use a pre-generated pair. This is then fed into the oracle to
# decrypt the content. The oracle directly uses a decryption library
# so it will check the full format of the message is PKCS1.5 compliant
# (i.e. the padding aswell) and not just the first two bytes.
#
import cryptography.hazmat.primitives.asymmetric.rsa as rsa
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import serialization
import binascii
import math
import textwrap
def local_setup(newkey=False):
'generates a key pair for local testing'
print('Using local loop back oracle for testing')
if newkey:
print('* Generating a new key pair')
priv_key = rsa.generate_private_key(
public_exponent=65537,
key_size=1024,
backend=default_backend()
)
else:
print('* Using hard coded key pair')
priv_key = '''\
-----BEGIN RSA PRIVATE KEY-----
MIIBOQIBAAJBALSh8S5eL8uEskw1L7QvoVOuLwK7rpKmdu0IQvu4e1gYf/HdsMp6
P7V7h+5D8siAbBigk2lphU60skVYK6OXCmUCAwEAAQJACJ5dep/l2ekX9Mjo4MkR
AoQiHBhGaRrmO8MUJxyTTg2wptfKN/4qWPMtssACA6V3db4x4fyKeZi0MMAeEzAt
QQIhAOP2J3Uggj8y6RxqWdS21z+qpdznwl5tzxK+pOEiW8+1AiEAytmJckmUQmdu
Q2fwmrMlcJuwkk5ef126/NFrlty9HfECIF1IY0kYrnOyH5YTJwNWdqqE6C6HYBBI
Gw5umQXPi4ZpAiBJXL9+2+mI0otoSXEVIfFKdqQ3Zax7d9Smlr7IgvDKoQIgO5YK
bzoHpEb1T/UZ1MZFWNhoMvPG2MQs9B6KGVz0gGM=
-----END RSA PRIVATE KEY-----'''
priv_key = serialization.load_pem_private_key(
textwrap.dedent(priv_key).encode('utf-8'),
password=None,
backend=default_backend()
)
pub_key = priv_key.public_key()
pn = pub_key.public_numbers()
print(' keysize: {}'.format(priv_key.key_size))
print(' e: {}'.format(pn.e))
print(' n: {}'.format(pn.n))
print(' p: {}'.format(priv_key.private_numbers().p))
print(' q: {}'.format(priv_key.private_numbers().q))
print(' d: {}'.format(priv_key.private_numbers().d))
if newkey:
ciphertext = pub_key.encrypt(
b'hello world!!',
padding.PKCS1v15()
)
else:
ciphertext = binascii.unhexlify(
'1e2fb249ddd03554d0a7c27cc276ded8' +
'cbf5e4daf1b84c28eccd37118adb7d46' +
'9c29eab603220057df68b84d9fdd40b8' +
'b5835c16e09ecbcf8ee7ec634e534f32'
)
print(' c: {}'.format(binascii.hexlify(ciphertext)))
print()
def oracle(ct):
c = int.from_bytes(ct, 'big')
try:
priv_key.decrypt(
ct,
padding.PKCS1v15()
)
return True
except ValueError:
return False
return ciphertext, oracle, pn.e, pn.n
# these two defs avoid rounding issues with floating point during
# division (especially with large numbers)
def ceildiv(a, b):
return -(-a // b)
def floordiv(a, b):
return (a // b)
def egcd(a, b):
u, u1 = 1, 0
v, v1 = 0, 1
while b:
q, r = divmod(a, b)
u, u1 = u1, u - q * u1
v, v1 = v1, v - q * v1
a, b = b, r
return u, v, a
def gcd(a, b):
return egcd(a, b)[2]
def inverse(a, b):
return egcd(a, b)[0] % b
def lcm(a, b):
return a*b // gcd(a,b)
def lcm_n(l):
ns = l.copy()
for _ in range(int(math.log(len(ns), 2)) + 1):
ns_next = []
for i in range(len(ns)//2):
ns_next.append(lcm(ns[2*i], ns[2*i+1]))
if len(ns) % 2 == 1:
ns_next.append(ns[-1])
ns = ns_next
return ns[0]
oracle_ctr = 0
def main():
print('Bleichenbacher RSA padding algorithm')
print(' for more info see 1998 paper.')
print()
# setup parameters, change local_setup with alternative
# implementation, such as an oracle that uses a real server
ct, oracle, e, n = local_setup(newkey=True)
# byte length of n
k = int(ceildiv(math.log(n,2), 8))
# convert ciphertext from bytes into integer
c = int.from_bytes(ct, 'big')
# lift oracle defition to take integers
def oracle_int(x):
global oracle_ctr
oracle_ctr = oracle_ctr + 1
if oracle_ctr % 100000 == 0:
print("[{}K tries] ".format(oracle_ctr // 1000), end='', flush=True)
return oracle(x.to_bytes(k, 'big'))
# define B as size of ciphertext space
# as first two bytes are 00 02, use 2^(keysize - 16)
B = pow(2, 8 * (k-2))
# precompute constants
_2B = 2 * B
_3B = 3 * B
multiply = lambda x, y: (x * pow(y, e, n)) % n
# should be identity as c is valid cipher text
c0 = multiply(c, 1)
assert c0 == c
i = 1
M = [(_2B, _3B - 1)]
s = 1
# ensure everything is working as expected
if oracle_int(c0):
print('Oracle ok, implicit step 1 passed')
else:
print('Oracle fail sanity check')
exit(1)
# trimming M0
print('start case 1.b: ', end='', flush=True)
trimmers = []
for t in range(2, 2**12+1):
if t <= 50:
for u in range(2*t//3, 3*t//2):
if GCD(u, t) == 1:
trimmers.append((u,t))
else: # t > 50
trimmers.append((t-1, t))
trimmers.append((t+1, t))
ut_pairs = []
for u, t in trimmers:
if oracle_int(c0 * pow(u*inverse(t,n), e, n) % n):
ut_pairs.append((u, t))
if ut_pairs:
t_prime = lcm_n([t for _, t in ut_pairs])
u_min = min([u*t_prime//t for u, t in ut_pairs])
u_max = max([u*t_prime//t for u, t in ut_pairs])
a = _2B * t_prime//u_min
b = (_3B-1) * t_prime//u_max
M[0] = (a, b)
print('done. trimming M0 iterations: {} [{},{}]'.format(oracle_ctr, M[0][0], M[0][1]))
while True:
if i == 1:
print('start case 2.a: ', end='', flush=True)
j = 1
ss = ceildiv(j*n+_2B, M[0][1])
while not oracle_int(multiply(c0, ss)):
ss = ss + 1
if ss >= ceildiv(j*n+_3B, M[0][0]):
j += 1
ss = ceildiv(j*n+_2B, M[0][1])
print('done. found s1 in {} iterations: {}'.format(
oracle_ctr-len(ut_pairs),ss))
else:
assert i > 1
if len(M) > 1:
print('start case 2.b: ', end='', flush=True)
ss = s + 1
while not oracle_int(multiply(c0, ss)):
ss = ss + 1
print('done. found s{} in {} iterations: {}'.format(
i, ss-s, ss))
else:
print('start case 2.c: ', end='', flush=True)
assert len(M) == 1
a, b = M[0]
r = ceildiv(2 * (b * s - _2B), n)
ctr = 0
while True:
# note: the floor function below needed +1 added
# to it, this is not clear from the paper (see
# equation 2 in paper where \lt is used instead of
# \lte).
for ss in range(
ceildiv(_2B + r * n, b),
floordiv(_3B + r * n, a) + 1):
ctr = ctr + 1
if oracle_int(multiply(c0, ss)):
break
else:
r = r + 1
continue
break
print('done. found s{} in {} iterations: {}'.format(i, ctr, ss))
# step 3, narrowing solutions
MM = []
for a,b in M:
for r in range(ceildiv(a * ss - _3B + 1, n),
floordiv(b * ss - _2B, n) + 1):
m = (
max(a, ceildiv(_2B + r * n, ss)),
min(b, floordiv(_3B - 1 + r * n, ss))
)
if m not in MM:
MM.append(m)
print('found interval [{},{}]'.format(m[0],m[1]))
# step 4, compute solutions
M = MM
s = ss
i = i + 1
if len(M) == 1 and M[0][0] == M[0][1]:
print()
print('Completed!')
print('used the oracle {} times'.format(oracle_ctr))
# note, no need to find multiplicative inverse of s0 in n
# as s0 = 1, so M[0][0] is directly the message.
message = M[0][0].to_bytes(k, 'big')
print('raw decryption: {}'.format(
binascii.hexlify(message).decode('utf-8')))
if message[0] != 0 or message[1] != 2:
return
message = message[message.index(b'\x00',1) + 1:]
print('unpadded message hex: {}'.format(
binascii.hexlify(message).decode('utf-8')))
try:
print('unpadded message ascii: {}'.format(
message.decode('utf-8')))
except UnicodeError:
pass
return
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment