Skip to content

Instantly share code, notes, and snippets.

@chinchila
Last active December 3, 2021 09:15
Show Gist options
  • Save chinchila/883644d538bb134b9a3a428eee402075 to your computer and use it in GitHub Desktop.
Save chinchila/883644d538bb134b9a3a428eee402075 to your computer and use it in GitHub Desktop.
Pythia wirteup from Chinchila

Writeup Pythia

Description:

Solves: 65

Yet another oracle, but the queries are costly and limited so be frugal with them.

About the challenge

We are given a service with the source 1server.py, looking through the source we can see that there is a list called passwords and the contents are 3 keys, each one with 3 characters. There are also 4 options to interact with the service:

  1. Change the current encryption key, we can send an integer with the position of the key we want to use on option 3.
  2. Read flag, that read the input and compares to see if they are the 3 passwords concatenated. If so they respond us with the flag.
  3. Decrypt a cyphertext with nonce, this means that we can try decrypting a ciphertext with the password we set on option 1. The encryption algorithm is AESGCM, with a KDF(Scrypt) to derive a key from the password chosen. It returns to us if the ciphertext was decrypted successfully or not.
  4. Exit

Another important thing to mention is that we can make 150 queries with 10 second delay each, so a naive bruteforce on all passwords won't work.

At first I was like "There is no way we can solve this", then I started thinking about some AESGCM vulnerabilities but none was leading me to a solution. With some failures, I started searching for some kind of vulnerability on the GCM, I mean mathematical problems on the design but none was helping.

Another teammate from ELT just sent a reddit post on the teams discord server. This post pretty much explain to us a real case about the multicolision keys vulnerability. Also this post link us to a paper called "Partitioning Oracle Attacks" that looks really promising. When I've read the multicolision function on AESGCM I immediately had a search idea and typed on discord: "What if we just compute a ciphertext that is accepted on half of the key space and check on the oracle if it decrypts? If it decrypts then we know that the key is inside this half space, if not we know that it is on the other half space, this give us a log(26^3) approach to solve the problem!"

Luckily the author already have an implementation to the attack, available at Julia Len's repository. With this script it was almost easy to code the search algorithm to find each password.

By almost easy I mean that we can't just do a binary search on the key space, because 26^3 would take us a long time to compute everything.

Then a pretty straightforward solution was to make windows of keys, I mean make a ciphertext with the first 512 keys, then another with the next 512 and so on. With this approach we can precompute all the ciphertexts and then just send them to the server and see if the key is inside this window.

This first thing we have to do is to generate all 26^3 passwords. This can be done by just bruteforcing all 3 characters and generating all possible passwords.

for i in string.ascii_lowercase:
    for j in string.ascii_lowercase:
        for k in string.ascii_lowercase:
            kdf = Scrypt(salt=b'', length=16, n=2**4, r=8, p=1, backend=default_backend())
            passw = (i+j+k).encode()
            key = kdf.derive(passw)
            dic[key] = passw
            pos.append(key)
print("Generated passwords")

Then I started coding the ciphertext windows generation, with the following:

def generate_cts():
    fl = open("cts", "w")
    for i in range(len(pos)//block_size):
        l = i*block_size
        r = (i+1)*block_size
        ct = multi_collide_gcm(pos[l:r], nonce, tag)
        fl.write(str(l) + " " + str(r) + " " + ct.hex() + "\n")
    fl.close()
    print("Generated ciphertexts")

So we save all ciphertexts to a text file (just to not compute it every time we run the script). Then we search the window that contains the password:

def ok(ct):
    payload = base64.b64encode(nonce)+b","+base64.b64encode(ct)
    rr.recvuntil(b">>> ")
    rr.sendline(b"3")
    rr.recvuntil(b">>> ")
    rr.sendline(payload)
    rr.recvline()
    if b"ERROR" in rr.recvline():
        return False
    else:
        return True
    
for k in range(len(cts)):
    if ok(cts[k][2]):
        l = cts[k][0]
        r = cts[k][1]
        print("Found password at [", l,",", r,"]")
        break

And after finding the window, we can run our binary search with the following code:

while r-l > 1:
    mid = (l+r)//2
    keyset1 = pos[l:mid]
    ct1 = multi_collide_gcm(keyset1, nonce, tag, first_block=first_block)
    if ok(ct1):
        r = mid
    else:
        l = mid
    key = dic[pos[l]]

With all this parts we can make a 2solve.sage script that interacts with the server and recover all 3 passwords. Another curious thing that I didn't mentioned is that the KDF function uses always the same salt, so it will generate always the same key on encryption, then our attack is done and we can run the solution. Here is the solve script output:

Generated passwords
Loaded ciphertexts from file
Found password at [ 13312 , 13824 ]
First b'tte'
Found password at [ 7680 , 8192 ]
Second b'mcw'
Found password at [ 5120 , 5632 ]
Third b'hth'
b'ACCESS GRANTED: CTF{gCm_1s_n0t_v3ry_r0bust_4nd_1_sh0uld_us3_s0m3th1ng_els3_h3r3}\n'

real	14m56,182s
user	0m17,513s
sys	0m0,897s

Finally we got it!

#!/usr/bin/python -u
import random
import string
import time
from base64 import b64encode, b64decode
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
max_queries = 150
query_delay = 10
passwords = [bytes(''.join(random.choice(string.ascii_lowercase) for _ in range(3)), 'UTF-8') for _ in range(3)]
flag = open("flag.txt", "rb").read()
def menu():
print("What you wanna do?")
print("1- Set key")
print("2- Read flag")
print("3- Decrypt text")
print("4- Exit")
try:
return int(input(">>> "))
except:
return -1
print("Welcome!\n")
key_used = 0
for query in range(max_queries):
option = menu()
if option == 1:
print("Which key you want to use [0-2]?")
try:
i = int(input(">>> "))
except:
i = -1
if i >= 0 and i <= 2:
key_used = i
else:
print("Please select a valid key.")
elif option == 2:
print("Password?")
passwd = bytes(input(">>> "), 'UTF-8')
print("Checking...")
# Prevent bruteforce attacks...
time.sleep(query_delay)
if passwd == (passwords[0] + passwords[1] + passwords[2]):
print("ACCESS GRANTED: " + flag.decode('UTF-8'))
else:
print("ACCESS DENIED!")
elif option == 3:
print("Send your ciphertext ")
ct = input(">>> ")
print("Decrypting...")
# Prevent bruteforce attacks...
time.sleep(query_delay)
try:
nonce, ciphertext = ct.split(",")
nonce = b64decode(nonce)
ciphertext = b64decode(ciphertext)
except:
print("ERROR: Ciphertext has invalid format. Must be of the form \"nonce,ciphertext\", where nonce and ciphertext are base64 strings.")
continue
kdf = Scrypt(salt=b'', length=16, n=2**4, r=8, p=1, backend=default_backend())
key = kdf.derive(passwords[key_used])
try:
cipher = AESGCM(key)
plaintext = cipher.decrypt(nonce, ciphertext, associated_data=None)
except:
print("ERROR: Decryption failed. Key was not correct.")
continue
print("Decryption successful")
elif option == 4:
print("Bye!")
break
else:
print("Invalid option!")
print("You have " + str(max_queries - query) + " trials left...\n")
from cryptography.hazmat.primitives.ciphers import (
Cipher, algorithms, modes
)
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.exceptions import InvalidTag
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.number import long_to_bytes, bytes_to_long
from bitstring import BitArray, Bits
from pwn import *
import binascii
import base64
import string
import sys
ALL_ZEROS = b'\x00'*16
GCM_BITS_PER_BLOCK = 128
def pad(a):
if len(a) < GCM_BITS_PER_BLOCK:
diff = GCM_BITS_PER_BLOCK - len(a)
zeros = ['0'] * diff
a = a + zeros
return a
def bytes_to_element(val, field, a):
bits = BitArray(val)
result = field.fetch_int(0)
for i in range(len(bits)):
if bits[i]:
result += a^i
return result
def multi_collide_gcm(keyset, nonce, tag, first_block=None, use_magma=False):
# initialize matrix and vector spaces
P.<x> = PolynomialRing(GF(2))
p = x^128 + x^7 + x^2 + x + 1
GFghash.<a> = GF(2^128,'x',modulus=p)
if use_magma:
t = "p:=IrreducibleLowTermGF2Polynomial(128); GFghash<a> := ext<GF(2) | p>;"
magma.eval(t)
else:
R = PolynomialRing(GFghash, 'x')
# encode length as lens
if first_block is not None:
ctbitlen = (len(keyset) + 1) * GCM_BITS_PER_BLOCK
else:
ctbitlen = len(keyset) * GCM_BITS_PER_BLOCK
adbitlen = 0
lens = (adbitlen << 64) | ctbitlen
lens_byte = int(lens).to_bytes(16,byteorder='big')
lens_bf = bytes_to_element(lens_byte, GFghash, a)
# increment nonce
nonce_plus = int((int.from_bytes(nonce,'big') << 32) | 1).to_bytes(16,'big')
# encode fixed ciphertext block and tag
if first_block is not None:
block_bf = bytes_to_element(first_block, GFghash, a)
tag_bf = bytes_to_element(tag, GFghash, a)
keyset_len = len(keyset)
if use_magma:
I = []
V = []
else:
pairs = []
for k in keyset:
# compute H
aes = AES.new(k, AES.MODE_ECB)
H = aes.encrypt(ALL_ZEROS)
h_bf = bytes_to_element(H, GFghash, a)
# compute P
P = aes.encrypt(nonce_plus)
p_bf = bytes_to_element(P, GFghash, a)
if first_block is not None:
# assign (lens * H) + P + T + (C1 * H^{k+2}) to b
b = (lens_bf * h_bf) + p_bf + tag_bf + (block_bf * h_bf^(keyset_len+2))
else:
# assign (lens * H) + P + T to b
b = (lens_bf * h_bf) + p_bf + tag_bf
# get pair (H, b*(H^-2))
y = b * h_bf^-2
if use_magma:
I.append(h_bf)
V.append(y)
else:
pairs.append((h_bf, y))
# compute Lagrange interpolation
if use_magma:
f = magma("Interpolation(%s,%s)" % (I,V)).sage()
else:
f = R.lagrange_polynomial(pairs)
coeffs = f.list()
coeffs.reverse()
# get ciphertext
if first_block is not None:
ct = list(map(str, block_bf.polynomial().list()))
ct_pad = pad(ct)
ct = Bits(bin=''.join(ct_pad))
else:
ct = ''
for i in range(len(coeffs)):
ct_i = list(map(str, coeffs[i].polynomial().list()))
ct_pad = pad(ct_i)
ct_i = Bits(bin=''.join(ct_pad))
ct += ct_i
ct = ct.bytes
return ct+tag
first_block = b'\x01'
nonce = b'\x00'*12
tag = b'\x01'*16
# rr = process(['python', 'server.py'])
rr = remote('pythia.2021.ctfcompetition.com', 1337)
def ok(ct):
payload = base64.b64encode(nonce)+b","+base64.b64encode(ct)
rr.recvuntil(b">>> ")
rr.sendline(b"3")
rr.recvuntil(b">>> ")
rr.sendline(payload)
rr.recvline()
if b"ERROR" in rr.recvline():
return False
else:
return True
dic = {}
cts = []
block_size = 512
pos = []
def search():
for k in range(len(cts)):
if ok(cts[k][2]):
l = cts[k][0]
r = cts[k][1]
print("Found password at [", l,",", r,"]")
break
while r-l > 1:
mid = (l+r)//2
keyset1 = pos[l:mid]
ct1 = multi_collide_gcm(keyset1, nonce, tag, first_block=first_block)
if ok(ct1):
r = mid
else:
l = mid
return dic[pos[l]]
def change_pass(to):
rr.recvuntil(b">>> ")
rr.sendline(b"1")
rr.recvuntil(b">>> ")
rr.sendline(str(to).encode())
def load_cts():
fl = open("cts", "r")
lines = fl.readlines()
fl.close()
for line in lines:
lr = line.split(" ")
l = int(lr[0])
r = int(lr[1])
ct = bytes.fromhex(lr[2][:-1])
cts.append((l, r, ct))
print("Loaded ciphertexts from file")
def generate_cts():
fl = open("cts", "w")
for i in range(len(pos)//block_size):
l = i*block_size
r = (i+1)*block_size
ct = multi_collide_gcm(pos[l:r], nonce, tag)
fl.write(str(l) + " " + str(r) + " " + ct.hex() + "\n")
fl.close()
print("Generated ciphertexts")
def generate_passwords():
for i in string.ascii_lowercase:
for j in string.ascii_lowercase:
for k in string.ascii_lowercase:
kdf = Scrypt(salt=b'', length=16, n=2**4, r=8, p=1, backend=default_backend())
passw = (i+j+k).encode()
key = kdf.derive(passw)
dic[key] = passw
pos.append(key)
print("Generated passwords")
if __name__ == '__main__':
generate_passwords()
# generate_cts()
load_cts()
passw1 = search()
print("First", passw1)
change_pass(1)
passw2 = search()
print("Second", passw2)
change_pass(2)
passw3 = search()
print("Third", passw3)
# Retrieve flag
rr.recvuntil(b">>> ")
rr.sendline(b"2")
rr.recvuntil(b">>> ")
rr.sendline(passw1+passw2+passw3)
rr.recvline()
print(rr.recvline())
rr.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment