Created
May 14, 2019 15:20
-
-
Save rot256/f62392872c9199126fcd6a4064d1247f to your computer and use it in GitHub Desktop.
Simple CBC Padding Oracle Library
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
ASCII = set( | |
map(chr, range(0x20, 0x7f)) + | |
['\t', '\n'] | |
) | |
def xor(*args): | |
if len(args) > 2: | |
xs, ys = args[0], xor(*args[1:]) | |
else: | |
xs, ys = args[0], args[1] | |
out = [] | |
for x, y in zip(map(ord, xs), map(ord, ys)): | |
out.append(x ^ y) | |
return ''.join(map(chr, out)) | |
def repr_ascii(s): | |
o = '' | |
for c in s: | |
o += c if c in ASCII else '!' | |
return repr(o)[1:-1] | |
def padding_PKCS5(n): | |
return chr(n) * n | |
class PaddingOracle: | |
def __init__(self, query, block_size=16, nested=1, padding=padding_PKCS5, output = sys.stdout): | |
self.err = None | |
self.nested = nested | |
self.padding = padding | |
self.output = output | |
self.block_size = block_size | |
# handle different query function types | |
if query.func_code.co_argcount == 1: | |
self.query = lambda iv, ct: query(iv + ct) | |
elif query.func_code.co_argcount == 2: | |
self.query = query | |
else: | |
raise ValueError('Query function must take one/two arguments') | |
def encrypt_block(self, bl, mid, pt): | |
assert len(bl) == self.block_size | |
assert len(pt) == self.block_size | |
# ensure that dec(bl) -> pt | |
iv = 'A' * self.block_size | |
ptt = self.decrypt_block(iv = iv, ct = bl, mid = mid) | |
return xor(iv, pt, ptt) | |
def encrypt(self, pt): | |
pad = self.block_size - (len(pt) % self.block_size) | |
pt = pt + chr(pad) * pad | |
mid = 'B'*self.block_size * self.nested | |
ct = '' | |
assert len(pt) % self.block_size == 0 | |
bs = [ | |
pt[i:i+self.block_size] for i in range(0, len(pt), self.block_size) | |
] | |
assert len(bs) > 0 | |
for pblock in bs[::-1]: | |
assert len(mid) == self.block_size * self.nested | |
bl = mid[-self.block_size:] | |
mid = mid[:-self.block_size] | |
iv = self.encrypt_block(bl = bl, pt = pblock, mid = mid) | |
mid = iv + mid | |
ct = bl + ct | |
ct = mid + ct | |
assert len(ct) == len(pt) + self.block_size * self.nested | |
return ct[:self.block_size], ct[self.block_size:] | |
def decrypt(self, ct, iv = None): | |
if iv is not None: | |
ct = iv + ct | |
assert iv is None or len(iv) == self.block_size | |
assert len(ct) > self.block_size*self.nested | |
assert len(ct) % self.block_size == 0 | |
blocks = [ | |
ct[i:i+self.block_size] for i in range(0, len(ct), self.block_size) | |
] | |
pt = '' | |
for i in range(0, len(blocks) - self.nested): | |
bs = blocks[i:i+self.nested+1] | |
assert len(bs) == self.nested + 1 | |
pt += self.decrypt_block( | |
iv = bs[0], | |
ct = bs[-1], | |
mid = ''.join(bs[1:-1]) | |
) | |
return pt[:-ord(pt[-1])] | |
def decrypt_block(self, iv, ct, mid = ''): | |
assert len(iv) == self.block_size | |
assert len(ct) == self.block_size | |
assert len(mid) == (self.nested - 1) * self.block_size | |
def query(iv, b2): | |
return self.query(iv, mid + b2) | |
# Case A: there is exactly one byte of padding | |
# byte 15 is \x01 | |
# if byte 14 is \x02, then padding will be valid when we flip byte 15 with \x03 | |
pt = '' | |
for i in range(self.block_size-1, -1, -1): | |
for val in range(0x100): | |
pad = self.padding(self.block_size - i) | |
iv_flipped = \ | |
iv[:i]\ | |
+ xor(iv[i], chr(val))\ | |
+ xor( | |
pad[:-1], | |
pt, | |
iv[i+1:] | |
) | |
assert len(iv_flipped) == self.block_size | |
t = chr(val ^ ord(pad[-1])) + pt | |
# print progress | |
if self.output: | |
r = repr_ascii(t) | |
p = ' ' * (2 * self.block_size - len(r)) | |
self.output.write('byte %2d, pt %s : %s%s\r' % | |
( | |
i, | |
t.encode('hex').rjust(self.block_size * 2, '?'), | |
r, | |
p | |
) | |
) | |
self.output.flush() | |
# query the oracle | |
if query(iv_flipped, ct): | |
# check for edge-case false positive | |
if i == self.block_size - 1: | |
q = query( | |
xor( | |
iv_flipped, | |
'\x00' * (self.block_size-2) + '\x01\x00', | |
), | |
ct | |
) | |
if not q: | |
continue | |
pt = chr(val ^ ord(pad[-1])) + pt | |
break | |
else: | |
assert False, 'all 256 xored values failed' | |
if self.output: | |
self.output.write('\n') | |
self.output.flush() | |
return pt |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment