Skip to content

Instantly share code, notes, and snippets.

@hellman
Last active February 6, 2022 22:08
Embed
What would you like to do?
DiceCTF 2022 - psych (crypto 500)
from sage.all import ZZ, GF, EllipticCurve, proof
from hashlib import scrypt
from sock import Sock
from psych import sidh, xor, G, H
proof.all(False)
def ser(*args):
ret = b""
for x in args:
ret += x.re.x.to_bytes(length=sidh.p_bytes, byteorder='little')
ret += x.im.x.to_bytes(length=sidh.p_bytes, byteorder='little')
return ret
def unser(pk):
for i in range(0, len(pk), 2*sidh.p_bytes):
re = int.from_bytes(pk[i:i+sidh.p_bytes], byteorder='little')
im = int.from_bytes(pk[i+sidh.p_bytes:i+2*sidh.p_bytes], byteorder='little')
yield sidh.field([re, im])
def decaps(ct):
f = Sock("mc.ax 31338", timeout=300)
f.send_line(ct.hex())
f.read_until("took ")
tim = int(f.read_until(" ").strip())
print("time", tim)
f.read_until("hex): ")
b = f.read_line().decode().strip()
return tim, bytes.fromhex(b)
def toSage(v):
return Fp2((v.re.x)) + Fp2((v.im.x)) * Fi
def toSibc(a):
if a == 0:
return C.field(0)
*lst, = map(int, a.polynomial())
if len(lst) == 2:
return C.field(lst)
return C.field(*lst)
def xPQR_to_A(x1, x2, x3):
# homemade
return (x1**2*x2**2 - 2*x1**2*x2*x3 + x1**2*x3**2 - 2*x1*x2**2*x3 - 2*x1*x2*x3**2 + x2**2*x3**2 - 2*x1*x2 - 2*x1*x3 - 2*x2*x3 + 1) / (4*x1*x2*x3)
def encode_pub(sP, sQ):
sR = sP - sQ
return ser(*map(toSibc, [sP[0], sQ[0], sR[0]]))
pk = open('pk.bin', 'rb').read()
e2 = sidh.params.two
F = sidh.curve.field
C = sidh.curve
# sage setup
p = ZZ(C.p)
i = GF(p).polynomial_ring().gen()
Fp2 = GF(p**2, name='i', modulus=i**2+1)
Fi = Fp2.gen()
# create valid pub B from any seed
seed = b"\x00" * 16
r = G(seed + pk)
c0_ = sidh.public_key_b(r)
P, Q, R = map(toSage, unser(c0_))
A = xPQR_to_A(P, Q, R)
E = EllipticCurve(Fp2, [0, A, 0, 1, 0])
sP = E.lift_x(P)
sQ = E.lift_x(Q)
if 0:
# somehow R is wrong for some seeds
# and somehow it does not matter ???
sR = E.lift_x(R)
if sR != sP - sQ:
sQ = -sQ
assert sR == sP - sQ
# modify Q in the pub B
# while P in B still fully matches
# and leaks +110 in timing
# make Q low order
# so that P+[secret]Q has few possible values
# (does not work for 1 bit but is ok for 2+-bit orders)
# NOTE: this somehow works (what does sibc even compute on these inputs?)
# just because sibc does not validate inputs
# e.g. that Q has order 2^e2
c0 = encode_pub(sP, sQ * 2**(e2-2))
# recover first two bits
for K in range(4):
print("try 2-bit K", K)
j = sidh.dh_a(K.to_bytes(55, "little"), c0)
c1 = xor(seed, H(j))
if decaps(c0 + c1)[0] > 80:
print("got bits:", K)
break
print()
else:
assert 0
# recover the rest bit-by-bit
for i in range(2, e2):
c0 = encode_pub(sP, sQ * 2**(e2-1-i))
j = sidh.dh_a(K.to_bytes(55, "little"), c0)
c1 = xor(seed, H(j))
if decaps(c0 + c1)[0] <= 80:
K += 2**i
print("got bit:", i, "K", f"{K:b}".zfill(i+1), hex(K), "\n")
# K = 0x22516519031d3078a54a24e9f7acd3d9504e1108e6a39908987f28
enc = open('flag.enc', 'rb').read()
sk = K.to_bytes(55, "little")
print("decrypting")
key = scrypt(sk, salt=b'defund', n=1048576, r=8, p=1, maxmem=1073744896, dklen=len(enc))
print(xor(key, enc))
#!/usr/local/bin/python
import secrets
import sys
from hashlib import scrypt, shake_256
from sibc.sidh import SIDH, default_parameters
sidh = SIDH(**default_parameters)
xor = lambda x, y: bytes(map(int.__xor__, x, y))
H = lambda x: shake_256(x).digest(16)
G = lambda x: shake_256(x).digest((3**sidh.strategy.three).bit_length() // 8)
def is_equal(x, y):
# let's simulate a timing attack!
c = secrets.randbelow(64)
equal = True
for a, b in zip(x, y):
c += 1
if a != b:
equal = False
break
print(f'took {c} units of time')
return equal
class KEM:
def __init__(self, pk, sk=None):
self.pk = pk
self.sk = sk
@classmethod
def generate(cls):
sk, pk = sidh.keygen_a()
sk += secrets.token_bytes(16)
return cls(pk, sk)
def _encrypt(self, m, r):
c0 = sidh.public_key_b(r)
j = sidh.dh_b(r, self.pk)
h = H(j)
c1 = xor(h, m)
return c0, c1
def _decrypt(self, c0, c1):
j = sidh.dh_a(self.sk[:-16], c0)
h = H(j)
m = xor(h, c1)
return m
def encapsulate(self):
m = secrets.token_bytes(16)
r = G(m + self.pk)
c0, c1 = self._encrypt(m, r)
ct = c0 + c1
ss = H(m + ct)
return ct, ss
def decapsulate(self, ct):
if self.sk is None:
raise ValueError('no private key')
if len(ct) != 6*sidh.p_bytes + 16:
raise ValueError('malformed ciphertext')
m = self._decrypt(ct[:-16], ct[-16:])
r = G(m + self.pk)
c0 = sidh.public_key_b(r)
if is_equal(c0, ct[:-16]):
ss = H(m + ct)
else:
ss = H(self.sk[-16:] + ct)
return ss
if __name__ == '__main__':
if len(sys.argv) > 1 and sys.argv[1] == 'init':
kem = KEM.generate()
with open('pk.bin', 'wb') as f:
f.write(kem.pk)
with open('sk.bin', 'wb') as f:
f.write(kem.sk)
with open('flag.txt', 'rb') as f:
flag = f.read().strip()
with open('flag.enc', 'wb') as f:
key = scrypt(kem.sk[:-16], salt=b'defund', n=1048576, r=8, p=1, maxmem=1073744896, dklen=len(flag))
f.write(xor(key, flag))
exit()
with open('pk.bin', 'rb') as f:
pk = f.read()
with open('sk.bin', 'rb') as f:
sk = f.read()
kem = KEM(pk, sk=sk)
ct = bytes.fromhex(input('ct (hex): '))
print('decapsulating...')
ss = kem.decapsulate(ct)
print(f'ss (hex): {ss.hex()}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment