# hellman/0_solve.py

Last active September 4, 2017 08:29
TWCTF 2017 - BabyPinhole
 #-*- coding:utf-8 -*- """ In this challenge we have a Paillier cryptosystem. We are given a decryption oracle, which leaks only one bit in the middle of the plaintext. Due to homomorphic properties of the Paillier cryptosystem, we can recover the full decryption using such an oracle. 1. First, we recover the lower half of the message bit-by-bit. This can be done by manipulating and observing the carry bit going through the pinhole, by exploiting the homomorphic addition: leaked guess | | | | v [ known ] v [ unknown ] @ 0 0 1 0 1 X ? ? ? ? ? ? | + |  1 1 0 1 0 1 0 0 0 0 0 0 v = (+0)<1 1 1 1 1 1 (if X = 0) (+1)<0 0 0 0 0 0 (if X = 1) 2. Once we know the lower half, we can easily learn the higher half. First, we zeroize the lower half by adding its complement. Then, we can simply shift the message right and leak each bit, by homomorphically multiplying the message by the inverse of 2. The flag: TWCTF{ccb71c01f350cf0bc844e87d161f84b9b479b439} """ from sock import Sock from libnum import invmod, gcd from random import randint n = 0xadd142708d464b5c50b936f2dc3a0419842a06741761e160d31d6c0330f2c515b91479f37502a0e9ddf30f7a18c71ef1eba993bdc368f7e90b58fb9fdbfc0d9ee0776dc629c8893a118e0ad8fc05633f0b9b4ab20f7c6363c4625dbaedf5a8d8799abc8353cb54bbfeab829792eb57837030900d73a06c4e87172599338fd5b1 n2 = n **2 g = 0x676ae3e2f70e9a5e35b007a70f4e7e113a77f0dbe462d867b19a67839f41b6e66940c02936bb73839d98966fc01f81b2b79c834347e71de6d754b038cb83f27bac6b33bf7ebd25de75a625ea6dd78fb973ed8637d32d2eaf5ae412b5222c8efea99b183ac823ab04219f1b700b207614df11f1f3759dea6d722635f45e453f6eae4d597dcb741d996ec72fe3e54075f6211056769056c5ad949c8becec7e179da3514c1f110ce65dc39300dfdce1170893c44f334a1b7260c51fb71b2d4dc6032e907bbaeebff763665e38cdfe418039dc782ae46f80e835bfd1ef94aeaba3ab086e61dab2ff99f600eb8d1cd3cf3fc952b56b561adc2de2097e7d04cb7c556 CT = 0x2ab54e5c3bde8614bd0e98bf838eb998d071ead770577333cf472fb54bdc72833c3daa76a07a4fee8738e75eb3403c6bcbd24293dc2b661ab1462d6d6ac19188879f3b1c51e5094eb66e61763df22c0654174032f15613a53c0bed24920fd8601d0ac42465267b7eba01a6df3ab14dd039a32432003fd8c3db0501ae2046a76a8b1e56f456a2d40e2dd6e2e1ab77a8d96318778e8a61fe32d03407fc6a7429ec1fb66fc68c92e33310b3a574bde7818eb7089d392a30d07c85032a3d34fd589889ff6053fb19592dbb647a38063c5b403d64ee94859d9cf9b746041e5494ab7413f508d814c4b3bba29bca41d4464e1feb2bce27b3b081c85b455e035a138747 assert gcd(g-1, n2) == n mbits = 1024 b = mbits//2 f = None nref = 999999999 def refresh(): global f, nref nref += 1 if nref >= 100: f = Sock("ppc2.chal.ctf.westerns.tokyo 38264", timeout=1000) def oracle(add, aftermul=1): refresh() c = CT * pow(g, add, n2) % n2 c = pow(c, aftermul, n2) f.send_line("%x" % c) res = f.read_line().strip() assert res in "01" return int(res) base = oracle(0) known = 0 for pos in reversed(range(b)): k = b - 1 - pos # number of already known bits add = (2**k - 1 - known) * 2 + 1 newbit = oracle(add << pos) ^ base known = (known << 1) + newbit print "new low", pos, bin(known)[2:].rstrip("L").rjust(b-pos+1, "0") low = known assert low < 2**b add = 2**b - low known = low for i in xrange(1, b): new = oracle(add=add, aftermul=invmod(2**i, n)) known |= new << (b + i) print "new high", b - i, bin(known)[2:].rstrip("L").rjust(b+i+1, "0") # correct the zeroizing addition known -= 1<>= b expected &= 1 real = oracle(a, m) print expected, real assert expected == real
 # Python 3 from Crypto.Util.number import * from hashlib import sha1 bits = 1024 def LCM(x, y): return x * y // GCD(x, y) def L(x, n): return (x - 1) // n p = getStrongPrime(bits/2) q = getStrongPrime(bits/2) n = p*q n2 = n*n k = getRandomRange(0, n) g = (1 + k*n) % n2 sk1 = LCM(p - 1, q - 1) sk2 = inverse(L(pow(g, sk1, n2), n), n) message = getRandomInteger(bits - 1) with open("message", "w") as f: f.write(hex(message)) with open("flag", "w") as f: f.write("TWCTF{" + sha1(str(message).encode("ascii")).hexdigest() + "}\n") with open("secretkey", "w") as f: f.write(hex(sk1) + "\n") f.write(hex(sk2) + "\n") with open("publickey", "w") as f: f.write(hex(n) + "\n") f.write(hex(n2) + "\n") f.write(hex(g) + "\n") def encrypt(m): r = getRandomRange(1, n2) c = pow(g, m, n2) * pow(r, n, n2) % n2 return c ciphertext = encrypt(message) with open("ciphertext", "w") as f: f.write(hex(ciphertext) + "\n")
 # Python 3 from signal import alarm from Crypto.Util.number import * import Crypto.Random as Random with open("secretkey", "r") as f: sk1 = int(f.readline(), 16) sk2 = int(f.readline(), 16) with open("publickey", "r") as f: n = int(f.readline(), 16) n2 = int(f.readline(), 16) g = int(f.readline(), 16) cbits = size(n2) mbits = size(n) b = mbits//2 def L(x, n): return (x - 1) // n def decrypt(c, sk1, sk2, n, n2): return L(pow(c, sk1, n2), n) * sk2 % n def run(fin, fout): alarm(1200) try: while True: line = fin.readline()[:4+cbits//4] ciphertext = int(line, 16) # Note: input is HEX m = decrypt(ciphertext, sk1, sk2, n, n2) fout.write(str((m >> b) & 1) + "\n") fout.flush() except: pass if __name__ == "__main__": run(sys.stdin, sys.stdout)
