Last active
February 6, 2022 22:08
-
-
Save hellman/5c2deab6dfcb179b101237471e7450ca to your computer and use it in GitHub Desktop.
DiceCTF 2022 - psych (crypto 500)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sage.all import ZZ, GF, EllipticCurve, proof | |
from hashlib import scrypt | |
from sock import Sock | |
from psych import sidh, xor, G, H | |
proof.all(False) | |
def ser(*args): | |
ret = b"" | |
for x in args: | |
ret += x.re.x.to_bytes(length=sidh.p_bytes, byteorder='little') | |
ret += x.im.x.to_bytes(length=sidh.p_bytes, byteorder='little') | |
return ret | |
def unser(pk): | |
for i in range(0, len(pk), 2*sidh.p_bytes): | |
re = int.from_bytes(pk[i:i+sidh.p_bytes], byteorder='little') | |
im = int.from_bytes(pk[i+sidh.p_bytes:i+2*sidh.p_bytes], byteorder='little') | |
yield sidh.field([re, im]) | |
def decaps(ct): | |
f = Sock("mc.ax 31338", timeout=300) | |
f.send_line(ct.hex()) | |
f.read_until("took ") | |
tim = int(f.read_until(" ").strip()) | |
print("time", tim) | |
f.read_until("hex): ") | |
b = f.read_line().decode().strip() | |
return tim, bytes.fromhex(b) | |
def toSage(v): | |
return Fp2((v.re.x)) + Fp2((v.im.x)) * Fi | |
def toSibc(a): | |
if a == 0: | |
return C.field(0) | |
*lst, = map(int, a.polynomial()) | |
if len(lst) == 2: | |
return C.field(lst) | |
return C.field(*lst) | |
def xPQR_to_A(x1, x2, x3): | |
# homemade | |
return (x1**2*x2**2 - 2*x1**2*x2*x3 + x1**2*x3**2 - 2*x1*x2**2*x3 - 2*x1*x2*x3**2 + x2**2*x3**2 - 2*x1*x2 - 2*x1*x3 - 2*x2*x3 + 1) / (4*x1*x2*x3) | |
def encode_pub(sP, sQ): | |
sR = sP - sQ | |
return ser(*map(toSibc, [sP[0], sQ[0], sR[0]])) | |
pk = open('pk.bin', 'rb').read() | |
e2 = sidh.params.two | |
F = sidh.curve.field | |
C = sidh.curve | |
# sage setup | |
p = ZZ(C.p) | |
i = GF(p).polynomial_ring().gen() | |
Fp2 = GF(p**2, name='i', modulus=i**2+1) | |
Fi = Fp2.gen() | |
# create valid pub B from any seed | |
seed = b"\x00" * 16 | |
r = G(seed + pk) | |
c0_ = sidh.public_key_b(r) | |
P, Q, R = map(toSage, unser(c0_)) | |
A = xPQR_to_A(P, Q, R) | |
E = EllipticCurve(Fp2, [0, A, 0, 1, 0]) | |
sP = E.lift_x(P) | |
sQ = E.lift_x(Q) | |
if 0: | |
# somehow R is wrong for some seeds | |
# and somehow it does not matter ??? | |
sR = E.lift_x(R) | |
if sR != sP - sQ: | |
sQ = -sQ | |
assert sR == sP - sQ | |
# modify Q in the pub B | |
# while P in B still fully matches | |
# and leaks +110 in timing | |
# make Q low order | |
# so that P+[secret]Q has few possible values | |
# (does not work for 1 bit but is ok for 2+-bit orders) | |
# NOTE: this somehow works (what does sibc even compute on these inputs?) | |
# just because sibc does not validate inputs | |
# e.g. that Q has order 2^e2 | |
c0 = encode_pub(sP, sQ * 2**(e2-2)) | |
# recover first two bits | |
for K in range(4): | |
print("try 2-bit K", K) | |
j = sidh.dh_a(K.to_bytes(55, "little"), c0) | |
c1 = xor(seed, H(j)) | |
if decaps(c0 + c1)[0] > 80: | |
print("got bits:", K) | |
break | |
print() | |
else: | |
assert 0 | |
# recover the rest bit-by-bit | |
for i in range(2, e2): | |
c0 = encode_pub(sP, sQ * 2**(e2-1-i)) | |
j = sidh.dh_a(K.to_bytes(55, "little"), c0) | |
c1 = xor(seed, H(j)) | |
if decaps(c0 + c1)[0] <= 80: | |
K += 2**i | |
print("got bit:", i, "K", f"{K:b}".zfill(i+1), hex(K), "\n") | |
# K = 0x22516519031d3078a54a24e9f7acd3d9504e1108e6a39908987f28 | |
enc = open('flag.enc', 'rb').read() | |
sk = K.to_bytes(55, "little") | |
print("decrypting") | |
key = scrypt(sk, salt=b'defund', n=1048576, r=8, p=1, maxmem=1073744896, dklen=len(enc)) | |
print(xor(key, enc)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/local/bin/python | |
import secrets | |
import sys | |
from hashlib import scrypt, shake_256 | |
from sibc.sidh import SIDH, default_parameters | |
sidh = SIDH(**default_parameters) | |
xor = lambda x, y: bytes(map(int.__xor__, x, y)) | |
H = lambda x: shake_256(x).digest(16) | |
G = lambda x: shake_256(x).digest((3**sidh.strategy.three).bit_length() // 8) | |
def is_equal(x, y): | |
# let's simulate a timing attack! | |
c = secrets.randbelow(64) | |
equal = True | |
for a, b in zip(x, y): | |
c += 1 | |
if a != b: | |
equal = False | |
break | |
print(f'took {c} units of time') | |
return equal | |
class KEM: | |
def __init__(self, pk, sk=None): | |
self.pk = pk | |
self.sk = sk | |
@classmethod | |
def generate(cls): | |
sk, pk = sidh.keygen_a() | |
sk += secrets.token_bytes(16) | |
return cls(pk, sk) | |
def _encrypt(self, m, r): | |
c0 = sidh.public_key_b(r) | |
j = sidh.dh_b(r, self.pk) | |
h = H(j) | |
c1 = xor(h, m) | |
return c0, c1 | |
def _decrypt(self, c0, c1): | |
j = sidh.dh_a(self.sk[:-16], c0) | |
h = H(j) | |
m = xor(h, c1) | |
return m | |
def encapsulate(self): | |
m = secrets.token_bytes(16) | |
r = G(m + self.pk) | |
c0, c1 = self._encrypt(m, r) | |
ct = c0 + c1 | |
ss = H(m + ct) | |
return ct, ss | |
def decapsulate(self, ct): | |
if self.sk is None: | |
raise ValueError('no private key') | |
if len(ct) != 6*sidh.p_bytes + 16: | |
raise ValueError('malformed ciphertext') | |
m = self._decrypt(ct[:-16], ct[-16:]) | |
r = G(m + self.pk) | |
c0 = sidh.public_key_b(r) | |
if is_equal(c0, ct[:-16]): | |
ss = H(m + ct) | |
else: | |
ss = H(self.sk[-16:] + ct) | |
return ss | |
if __name__ == '__main__': | |
if len(sys.argv) > 1 and sys.argv[1] == 'init': | |
kem = KEM.generate() | |
with open('pk.bin', 'wb') as f: | |
f.write(kem.pk) | |
with open('sk.bin', 'wb') as f: | |
f.write(kem.sk) | |
with open('flag.txt', 'rb') as f: | |
flag = f.read().strip() | |
with open('flag.enc', 'wb') as f: | |
key = scrypt(kem.sk[:-16], salt=b'defund', n=1048576, r=8, p=1, maxmem=1073744896, dklen=len(flag)) | |
f.write(xor(key, flag)) | |
exit() | |
with open('pk.bin', 'rb') as f: | |
pk = f.read() | |
with open('sk.bin', 'rb') as f: | |
sk = f.read() | |
kem = KEM(pk, sk=sk) | |
ct = bytes.fromhex(input('ct (hex): ')) | |
print('decapsulating...') | |
ss = kem.decapsulate(ct) | |
print(f'ss (hex): {ss.hex()}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment