Skip to content

Instantly share code, notes, and snippets.

@rot256
Created May 14, 2019 15:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rot256/f62392872c9199126fcd6a4064d1247f to your computer and use it in GitHub Desktop.
Save rot256/f62392872c9199126fcd6a4064d1247f to your computer and use it in GitHub Desktop.
Simple CBC Padding Oracle Library
import sys
ASCII = set(
map(chr, range(0x20, 0x7f)) +
['\t', '\n']
)
def xor(*args):
if len(args) > 2:
xs, ys = args[0], xor(*args[1:])
else:
xs, ys = args[0], args[1]
out = []
for x, y in zip(map(ord, xs), map(ord, ys)):
out.append(x ^ y)
return ''.join(map(chr, out))
def repr_ascii(s):
o = ''
for c in s:
o += c if c in ASCII else '!'
return repr(o)[1:-1]
def padding_PKCS5(n):
return chr(n) * n
class PaddingOracle:
def __init__(self, query, block_size=16, nested=1, padding=padding_PKCS5, output = sys.stdout):
self.err = None
self.nested = nested
self.padding = padding
self.output = output
self.block_size = block_size
# handle different query function types
if query.func_code.co_argcount == 1:
self.query = lambda iv, ct: query(iv + ct)
elif query.func_code.co_argcount == 2:
self.query = query
else:
raise ValueError('Query function must take one/two arguments')
def encrypt_block(self, bl, mid, pt):
assert len(bl) == self.block_size
assert len(pt) == self.block_size
# ensure that dec(bl) -> pt
iv = 'A' * self.block_size
ptt = self.decrypt_block(iv = iv, ct = bl, mid = mid)
return xor(iv, pt, ptt)
def encrypt(self, pt):
pad = self.block_size - (len(pt) % self.block_size)
pt = pt + chr(pad) * pad
mid = 'B'*self.block_size * self.nested
ct = ''
assert len(pt) % self.block_size == 0
bs = [
pt[i:i+self.block_size] for i in range(0, len(pt), self.block_size)
]
assert len(bs) > 0
for pblock in bs[::-1]:
assert len(mid) == self.block_size * self.nested
bl = mid[-self.block_size:]
mid = mid[:-self.block_size]
iv = self.encrypt_block(bl = bl, pt = pblock, mid = mid)
mid = iv + mid
ct = bl + ct
ct = mid + ct
assert len(ct) == len(pt) + self.block_size * self.nested
return ct[:self.block_size], ct[self.block_size:]
def decrypt(self, ct, iv = None):
if iv is not None:
ct = iv + ct
assert iv is None or len(iv) == self.block_size
assert len(ct) > self.block_size*self.nested
assert len(ct) % self.block_size == 0
blocks = [
ct[i:i+self.block_size] for i in range(0, len(ct), self.block_size)
]
pt = ''
for i in range(0, len(blocks) - self.nested):
bs = blocks[i:i+self.nested+1]
assert len(bs) == self.nested + 1
pt += self.decrypt_block(
iv = bs[0],
ct = bs[-1],
mid = ''.join(bs[1:-1])
)
return pt[:-ord(pt[-1])]
def decrypt_block(self, iv, ct, mid = ''):
assert len(iv) == self.block_size
assert len(ct) == self.block_size
assert len(mid) == (self.nested - 1) * self.block_size
def query(iv, b2):
return self.query(iv, mid + b2)
# Case A: there is exactly one byte of padding
# byte 15 is \x01
# if byte 14 is \x02, then padding will be valid when we flip byte 15 with \x03
pt = ''
for i in range(self.block_size-1, -1, -1):
for val in range(0x100):
pad = self.padding(self.block_size - i)
iv_flipped = \
iv[:i]\
+ xor(iv[i], chr(val))\
+ xor(
pad[:-1],
pt,
iv[i+1:]
)
assert len(iv_flipped) == self.block_size
t = chr(val ^ ord(pad[-1])) + pt
# print progress
if self.output:
r = repr_ascii(t)
p = ' ' * (2 * self.block_size - len(r))
self.output.write('byte %2d, pt %s : %s%s\r' %
(
i,
t.encode('hex').rjust(self.block_size * 2, '?'),
r,
p
)
)
self.output.flush()
# query the oracle
if query(iv_flipped, ct):
# check for edge-case false positive
if i == self.block_size - 1:
q = query(
xor(
iv_flipped,
'\x00' * (self.block_size-2) + '\x01\x00',
),
ct
)
if not q:
continue
pt = chr(val ^ ord(pad[-1])) + pt
break
else:
assert False, 'all 256 xored values failed'
if self.output:
self.output.write('\n')
self.output.flush()
return pt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment