Skip to content

Instantly share code, notes, and snippets.

@cwgreene
Last active September 17, 2019 00:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cwgreene/bb7cac60b0eff60c9db34d3f4ce4ff36 to your computer and use it in GitHub Desktop.
Save cwgreene/bb7cac60b0eff60c9db34d3f4ce4ff36 to your computer and use it in GitHub Desktop.

Faultbox

We are greeted upon connecting to the server with the following:

====================================
            fault box
====================================
1. print encrypted flag
2. print encrypted fake flag
3. print encrypted fake flag (TEST)
4. encrypt
====================================

Looking at the code for each of these options we discover that the first option is simply the flag encrypted with a known exponent but unknown modulous. So are first task is to determine the modulus n.

Option 4 allows us to encrypt specified plaintexts. If we encrypt the number 2 on it's own, then we know the result is congruent to 2**exp mod n. This means that the difference (2**exp - encrypted(2)) of those two numbers is a multiple of n. Now since 2**65537 is much larger than n, it'll be quite a large multiple. To speed this up, we compute 3**exp mod n by passing in 3 for encryption, which means numbers 3**exp - encrypted(3) is also a multiple of n. Which means we can obtain n by computing the gcd of the two differences. Technically, this will potentially be a multiple of n. More on that below.

def get_modulus():
    e1 = encrypt("\x02")
    e2 = encrypt("\x03")
    n = math.gcd(2**exp-e1,3**exp-e2)
    return n

where encrypt is using pwntools

def encrypt(b):
    r.sendline("4")
    response = r.recvuntil("input the data:")
    r.sendline(b)
    encrypt1 = r.recvuntil("\n")
    response = r.recvuntil(FINAL)
    return int(encrypt1.strip(),16)

Now that we have n, we need to factor it. Options 2 and 3 operate on a "test flag". The important thing, at the moment, is that the test flag string is the same in both. Option 2 is encrypting using standard RSA. Option 3 is different

    # ===== FUNCTIONS FOR PERSONAL TESTS, DON'T USE THEM =====
    def TEST_CRT_encrypt(self, p, fun=0):           #  1
        ep = inverse(self.d, self.p-1)              #  2
        print("ep is", ep, "and e is", self.e)      #  3
        eq = inverse(self.d, self.q-1)              #  4
        qinv = inverse(self.q, self.p)              #  5
        c1 = pow(p, ep, self.p)                     #  6
        c2 = pow(p, eq, self.q) ^ fun               #  7
        h = (qinv * (c1 - c2)) % self.p             #  8
        c = c2 + h*self.q                           #  9
        print("c - c1", (c - c1) % self.p)          # 10
        return c                                    # 11

If we do a bit of algebra we discover that the return value of this is simpler than it appears. We note that qinv is the inverse of q mod p. We also note that this becomes simpler if the algebra is all done in mod p.

c = c2 + h*self.q
c = c2 + (qinv * (c1 -c2)) % p)*self.q
c = c2 + (c1-c2)*(qinv*self.q) 
c = c2 + c1 - c2
c = c1

c is not quite c1 since this was all done in mod p, it is equal to c1 up to some multiple of p.

c = c1 + n*p

Now c1 itself is

pow(p, ep, self.p) # super confusing, p is the plaintext

However, ep is just 65537, since it's the inverse of the decryption exponent mod phi(p). So this means that

(2) encrypted(fake_flag) = fake_flag**65537 mod n
(3) encrypted_test(fake_flag) = fake_flag**65537 mod p

But now we not that this means that (2) is equal to fake_flag**65537 + (k*q)*p which means that the number is also congruent to fake_flag**65537 in mod p. (This isn't true in general, consider mod 7 and mod 5).

This means that (2) - (3) = 0 mod p. Which means it's a multiple of p which means that if we compute the gcd of it with n we'll get the greatest common divisor of the two, which will be p. Note, if the weird fun parameter were 0, then the fake_test_flag would be exactly the same as the regular fake flag, which would give us no information (we already knew that 0 was congruent to 0 mod p ;) ). fun is simulating a fault, hence the problem name. in encryption.

Fantastic! Knowing p and q we can compute the inverse of 65537 in mod (p-1)(q-1) and decrypt the flag!

Well... not quite. There's a little problem. Looking back at the server code

    cnt = 2
    while cnt > 0:
        req.sendall(bytes(
            '====================================\n'
            '            fault box\n'
            '====================================\n'
            '1. print encrypted flag\n'
            '2. print encrypted fake flag\n'
            '3. print encrypted fake flag (TEST)\n'
            '4. encrypt\n'
            '====================================\n', 'utf-8'))

        choice = str(req.recv(2).strip(), 'utf-8')
        if choice not in menu:
            print("oh you bad little boy")
            exit(1)

        menu[choice]()

        if choice == '4':
            continue

        cnt -= 1

that cnt means that we only get to ask for either the encrypted flag, the encrypted fake flag, or the test encrypted fake flag (we can encrypt to our hearts content, though we only need to do it twice).

Now, if we knew the fake flag, we could encrypt it ourselves. What is the fake flag?

    r = RSA()
    p, x = gen_prime()
    q, y = gen_prime()

    r.generate(p, q)
    fake_flag = 'fake_flag{%s}' % (('%X' % y).rjust(32, '0'))

If we run this locally, we get something like

fake_flag{00000000000000000000000000000A29}

That's a lot of zeros! Running this a few times, we discover that there isn't much entropy here. Turns out that y is basically the distance between primes, so most of the time, it won't be that big. Since if we guess the correct flag, we get two numbers that can factor n, which is what we want, so we can just brute force the test flag. If we guess right, than n has been factored, and we're done.

So putting it all together we have

import math
from Crypto.Util.number import inverse, bytes_to_long, long_to_bytes

from pwn import *

FINAL = "encrypt\n====================================\n"

def s2n(s):
    return bytes_to_long(bytearray(s, 'latin-1'))

def encrypt(b):
    r.sendline("4")
    response = r.recvuntil("input the data:")
    r.sendline(b)
    encrypt1 = r.recvuntil("\n")
    response = r.recvuntil(FINAL)
    return int(encrypt1.strip(),16)

def grab(n):
    r.sendline(str(n))
    response=r.recvuntil("\n")
    r2= r.recvuntil(FINAL)
    return int(response.strip(),16)

def get_modulus():
    e1 = encrypt("\x02")
    e2 = encrypt("\x03")
    n = math.gcd(2**exp-e1,3**exp-e2)
    return n

def gen_fake_flag(y):
    fake_flag = 'fake_flag{%s}' % (('%X' % y).rjust(32, '0'))
    return s2n(fake_flag)

r = remote("crypto.chal.csaw.io", 1001)
exp = 65537
n = get_modulus()
flag = grab(1)
fake_test_flag = grab(3)

for i in range(0xfff):
    fake_flag = gen_fake_flag(i)
    encrypted_fake_flag = pow(fake_flag, exp, n)
    p = math.gcd(encrypted_fake_flag - fake_test_flag, n)
    if p > 1:
        print(long_to_bytes(fake_flag))
        q = n // p
        break
d = inverse(exp, (p-1)*(q-1))
print(long_to_bytes(pow(flag, d, n)))   

And this results in the flag. However, there are one issue. Our assumption when we computed the modulus was that 2**exp-encrypted(2) and 3**exp-encrypted(3) greatest common divsor would be the modulus. However, each one is some random multiple of n, so unless that multiple is coprime, this method breaks down a bit. One way to fix this is to check a few small prime powers from the computer n and remove them. Alternatively, the probability that two random numbers is coprime is about 61%, so you can just run this until a thing that looks like a flag shows up.

import socketserver
import random
import signal
import time
import gmpy2
from Crypto.Util.number import inverse, bytes_to_long, long_to_bytes
FLAG = open('flag', 'r').read().strip()
def s2n(s):
return bytes_to_long(bytearray(s, 'latin-1'))
def n2s(n):
return long_to_bytes(n).decode('latin-1')
def gen_prime():
base = random.getrandbits(1024)
off = 0
while True:
if gmpy2.is_prime(base + off):
break
off += 1
p = base + off
return p, off
class RSA(object):
def __init__(self):
pass
def generate(self, p, q, e=0x10001):
self.p = p
self.q = q
self.N = p * q
print("N=",self.N)
self.e = e
phi = (p-1) * (q-1)
self.d = inverse(e, phi)
def encrypt(self, p):
return pow(p, self.e, self.N)
def decrypt(self, c):
return pow(c, self.d, self.N)
# ===== FUNCTIONS FOR PERSONAL TESTS, DON'T USE THEM =====
def TEST_CRT_encrypt(self, p, fun=0):
ep = inverse(self.d, self.p-1)
print("ep is", ep, "and e is", self.e)
eq = inverse(self.d, self.q-1)
qinv = inverse(self.q, self.p)
c1 = pow(p, ep, self.p)
c2 = pow(p, eq, self.q) ^ fun
h = (qinv * (c1 - c2)) % self.p
c = c2 + h*self.q
print("c - c1", (c - c1) % self.p)
return c
def TEST_CRT_decrypt(self, c, fun=0):
dp = inverse(self.e, self.p-1)
dq = inverse(self.e, self.q-1)
qinv = inverse(self.q, self.p)
m1 = pow(c, dp, self.p)
m2 = pow(c, dq, self.q) ^ fun
h = (qinv * (m1 - m2)) % self.p
m = m2 + h*self.q
return m
def go(req):
print("hi")
r = RSA()
p, x = gen_prime()
q, y = gen_prime()
r.generate(p, q)
fake_flag = 'fake_flag{%s}' % (('%X' % y).rjust(32, '0'))
def enc_flag():
req.sendall(b'%X\n' % r.encrypt(s2n(FLAG)))
def enc_fake_flag():
req.sendall(b'%X\n' % r.encrypt(s2n(fake_flag)))
def enc_fake_flag_TEST():
req.sendall(b'%X\n' % r.TEST_CRT_encrypt(s2n(fake_flag), x))
def enc_msg():
req.sendall(b'input the data:')
p = str(req.recv(4096).strip(), 'utf-8')
req.sendall(b'%X\n' % r.encrypt(s2n(p)))
menu = {
'1': enc_flag,
'2': enc_fake_flag,
'3': enc_fake_flag_TEST,
'4': enc_msg,
}
cnt = 2
while cnt > 0:
req.sendall(bytes(
'====================================\n'
' fault box\n'
'====================================\n'
'1. print encrypted flag\n'
'2. print encrypted fake flag\n'
'3. print encrypted fake flag (TEST)\n'
'4. encrypt\n'
'====================================\n', 'utf-8'))
choice = str(req.recv(2).strip(), 'utf-8')
if choice not in menu:
exit(1)
menu[choice]()
if choice == '4':
continue
cnt -= 1
class incoming(socketserver.BaseRequestHandler):
def handle(self):
signal.alarm(300)
random.seed(time.time())
req = self.request
while True:
go(req)
class ReusableTCPServer(socketserver.ForkingMixIn, socketserver.TCPServer):
pass
socketserver.TCPServer.allow_reuse_address = True
server = ReusableTCPServer(("0.0.0.0", 23333), incoming)
server.serve_forever()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment