Skip to content

Instantly share code, notes, and snippets.

@elisohl-ncc
Created June 14, 2023 00:18
Show Gist options
  • Save elisohl-ncc/0ffced62c5fb4afcdcd05bb925f73d5d to your computer and use it in GitHub Desktop.
Save elisohl-ncc/0ffced62c5fb4afcdcd05bb925f73d5d to your computer and use it in GitHub Desktop.
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from random import random, seed, randrange
from math import log2
KEY = bytes(16)
IV = bytes(range(16))
PT = b'Plaintext!'
CT = AES.new(key=KEY, iv=IV, mode=AES.MODE_CBC).encrypt(pad(PT, block_size=16))
FP_RATE = FN_RATE = 0.4
target_confidence_threshold = 0.9999
#seed(b'bayes'*17)
num_queries = []
class ByteSearch:
def __init__(self, oracle, confidence_threshold=0.9, quiet=True):
self._counter = 0
self.oracle = oracle
self.queries = [[] for _ in range(256)]
self.confidences = [1/256]*256
self.confidence_threshold = confidence_threshold
self.quiet = quiet
def update_confidences(self, index, result):
"""Given an oracle result for a given byte, update the confidences for each byte."""
self.confidences = self.get_updated_confidences(self.confidences, index, result)
def pick_exhaustive(self):
return self._counter % 256
def pick_by_confidence(self):
"""Pick a byte to test based on the current confidences."""
return max(range(256), key=lambda i: self.confidences[i])
def pick_by_entropy(self):
"""Pick a byte to test based on expected reduction in entropy."""
# NOTE: VERY SLOW - for demo, try replacing 256 with 16 here and in randrange
entropies = []
for i in range(256):
e_if_t = self.get_entropy(self.get_updated_confidences(self.confidences, i, True))
e_if_f = self.get_entropy(self.get_updated_confidences(self.confidences, i, False))
p_t = self.confidences[i]
p_f = 1 - p_t
entropies.append(p_t * e_if_t + p_f * e_if_f)
return min(range(256), key=lambda i: entropies[i])
def query_byte(self, index):
"""Query the oracle for a given byte."""
self._counter += 1
result = self.oracle(index)
self.queries[index].append(result)
self.update_confidences(index, result)
if not self.quiet and self._counter & 0xFF == 0:
print(end=".", flush=True)
return result
def search(self, strategy):
"""Search for the plaintext byte by querying the oracle."""
threshold = self.confidence_threshold
while max(self.confidences) < threshold:
self.query_byte(strategy())
num_queries.append(sum(len(l) for l in self.queries))
return max(range(256), key=lambda i: self.confidences[i])
@staticmethod
def bayes(h, e_given_h, e_given_not_h):
"""Update the posterior probability of h given e.
e: evidence
h: hypothesis
e_given_h: probability of e given h
e_given_not_h: probability of e given not h
"""
return e_given_h * h / (e_given_h * h + e_given_not_h * (1 - h))
@staticmethod
def get_updated_confidences(confidences, index, result):
new_confidences = confidences[:] # shallow copy
for j in range(256):
p_h = confidences[j]
if index == j:
p_e_given_h = 1 - FN_RATE if result else FN_RATE
p_e_given_not_h = FP_RATE if result else 1 - FP_RATE
else:
p_e_given_h = FP_RATE if result else 1 - FP_RATE
p_hi_given_not_hj = confidences[index] / (1 - confidences[j])
p_not_hi_given_not_hj = 1 - p_hi_given_not_hj
if result:
p_e_given_not_h = p_hi_given_not_hj * (1 - FN_RATE) + p_not_hi_given_not_hj * FP_RATE
else:
p_e_given_not_h = p_hi_given_not_hj * FN_RATE + p_not_hi_given_not_hj * (1 - FP_RATE)
new_confidences[j] = ByteSearch.bayes(p_h, p_e_given_h, p_e_given_not_h)
return new_confidences
@staticmethod
def get_entropy(dist):
return -sum(p * log2(p) for p in dist if p)
#### BASIC SINGLE BYTE SEARCH TEST
def test_single_byte_search():
def oracle(index):
if index == TARGET_BYTE:
return random() > FN_RATE
return not (random() > FP_RATE)
def attack():
search = ByteSearch(oracle, confidence_threshold=target_confidence_threshold, quiet=False)
result = search.search(search.pick_by_entropy)
print()
print(*sorted([len(l) for l in search.queries], reverse=True))
return result
correct = total = 0
num_queries = []
while True:
TARGET_BYTE = randrange(256)
result = attack()
total += 1
if result == TARGET_BYTE:
correct += 1
accuracy = correct / total
avg_queries = sum(num_queries) / len(num_queries)
print(f"{accuracy=:.5f}\t{avg_queries=:.1f}\t{num_queries[-1]=}\n")
# Full Bayesian padding oracle attack
def test_padding_oracle_attack():
def oracle(iv, ct):
aes = AES.new(key=KEY, iv=iv, mode=AES.MODE_CBC)
try:
unpad(aes.decrypt(ct), 16)
except ValueError:
return False
return True
ORACLE_QUERIES = [0]
def worse_oracle(iv, ct):
ORACLE_QUERIES[0] += 1
result = oracle(iv, ct)
if result:
return random() > FN_RATE
return not (random() > FP_RATE)
def attack(block, quiet=False):
D_k = [0]*16
pad_len = 1
while pad_len <= 16:
prefix = [0] * (16-pad_len)
postfix = [pad_len ^ i for i in D_k[17-pad_len:]]
# scan through candidate bytes
def wrapped_oracle(ind, double_up=False): # TODO double_up is dumb, just replace this with a pad_len == 1 test and make sure that works
iv = bytes(prefix + [ind] + postfix)
result = worse_oracle(iv, block)
if not result:
return False
if double_up:
prefix[-1] ^= 1
iv_2 = bytes(prefix + [ind] + postfix)
result_2 = worse_oracle(iv_2, block)
if not result_2:
return False
return True
search = ByteSearch(wrapped_oracle, confidence_threshold=0.999)
result = search.search(search.pick_by_entropy)
D_k[-pad_len] = result ^ pad_len
if not quiet: print(end=f"{pad_len} ", flush=True)
pad_len += 1 # TODO add support for backtracking? or don't?
if not quiet: print()
return D_k
query_counts = []
while True:
ORACLE_QUERIES[0] = 0
print(bytes(a ^ b for a, b in zip(IV, attack(CT))))
print(f"{ORACLE_QUERIES[0]=}")
query_counts.append(ORACLE_QUERIES[0])
AVG_ORACLE_QUERIES = sum(query_counts) / len(query_counts)
print(f"{AVG_ORACLE_QUERIES=}")
print()
if __name__ == "__main__":
test_padding_oracle_attack()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment