pbctf 2021 Yet Another PRNG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
with open("output.txt") as f: | |
stream = bytes.fromhex(f.readline().strip()) | |
states = [ | |
int.from_bytes(stream[i : i + 8], "big") for i in range(0, len(stream), 8) | |
] | |
assert len(states) == 12 | |
ct = bytes.fromhex(f.readline().strip()) | |
m1 = 2 ** 32 - 107 | |
m2 = 2 ** 32 - 5 | |
m3 = 2 ** 32 - 209 | |
M = 2 ** 64 - 59 | |
rnd = random.Random(b"rbtree") | |
a1 = [rnd.getrandbits(20) for _ in range(3)] | |
a2 = [rnd.getrandbits(20) for _ in range(3)] | |
a3 = [rnd.getrandbits(20) for _ in range(3)] | |
numk = 9 * 3 + 12 | |
varstr = "x0,x1,x2,y0,y1,y2,z0,z1,z2," + ",".join(f"k{i}" for i in range(numk)) | |
P = PolynomialRing(ZZ, varstr) | |
x0, x1, x2, y0, y1, y2, z0, z1, z2, *ks = P.gens() | |
polys = [] | |
iks = iter(ks) | |
xs = [x0, x1, x2] | |
while len(xs) != 12: | |
xs += [sum(x * y for x, y in zip(a1, xs[-3:])) - next(iks) * m1] | |
ys = [y0, y1, y2] | |
while len(ys) != 12: | |
ys += [sum(x * y for x, y in zip(a2, ys[-3:])) - next(iks) * m2] | |
zs = [z0, z1, z2] | |
while len(zs) != 12: | |
zs += [sum(x * y for x, y in zip(a3, zs[-3:])) - next(iks) * m3] | |
sts = [2 * m1 * x - m3 * y - m2 * z - next(iks) * M for x, y, z in zip(xs, ys, zs)] | |
M, _ = Sequence(sts).coefficient_matrix() | |
print(vector(_)) | |
M = M.dense_matrix().T | |
A = matrix.identity(24) | |
A = A.augment(matrix(24, 12)) | |
B = matrix(12, 36) | |
C = matrix(12, 24) | |
C = C.augment(matrix.identity(12)) | |
A = A.stack(B).stack(C) | |
M = M.augment(A) | |
print(M.dimensions()) | |
lb = states + [0] * 9 + [0] * (36 - 9 - 12) + [-3] * 12 | |
ub = states + [2 ** 32] * 9 + [2 ** 20] * (36 - 9 - 12) + [0] * 12 | |
load("solver.sage") # https://github.com/rkm0959/Inequality_Solving_with_CVP | |
result, applied_weights, fin = solve(M, list(lb), list(ub)) # `solve` will mutate M, lb and ub | |
init_states = list(M.solve_left(result)[:9]) | |
print(init_states) | |
class PRNG: | |
def __init__(self, init_states): | |
self.m1 = 2 ** 32 - 107 | |
self.m2 = 2 ** 32 - 5 | |
self.m3 = 2 ** 32 - 209 | |
self.M = 2 ** 64 - 59 | |
rnd = random.Random(b"rbtree") | |
self.a1 = [rnd.getrandbits(20) for _ in range(3)] | |
self.a2 = [rnd.getrandbits(20) for _ in range(3)] | |
self.a3 = [rnd.getrandbits(20) for _ in range(3)] | |
self.x = init_states[:3] | |
self.y = init_states[3:6] | |
self.z = init_states[6:] | |
def out(self): | |
o = ( | |
2 * self.m1 * self.x[0] - self.m3 * self.y[0] - self.m2 * self.z[0] | |
) % self.M | |
self.x = self.x[1:] + [sum(x * y for x, y in zip(self.x, self.a1)) % self.m1] | |
self.y = self.y[1:] + [sum(x * y for x, y in zip(self.y, self.a2)) % self.m2] | |
self.z = self.z[1:] + [sum(x * y for x, y in zip(self.z, self.a3)) % self.m3] | |
return int(o).to_bytes(8, byteorder="big") | |
prng = PRNG(init_states) | |
for o in states: | |
assert int(o).to_bytes(8, "big") == prng.out() | |
flag = b"" | |
for blk in [ct[i : i + 8] for i in range(0, len(ct), 8)]: | |
flag += bytes(a ^^ b for a, b in zip(blk, prng.out())) | |
print(flag) | |
# pbctf{Wow_how_did_you_solve_this?_I_thought_this_is_super_secure._Thank_you_for_solving_this!!!} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import os | |
def urand(b): | |
return int.from_bytes(os.urandom(b), byteorder="big") | |
class PRNG: | |
def __init__(self): | |
self.m1 = 2 ** 32 - 107 | |
self.m2 = 2 ** 32 - 5 | |
self.m3 = 2 ** 32 - 209 | |
self.M = 2 ** 64 - 59 | |
rnd = random.Random(b"rbtree") | |
self.a1 = [rnd.getrandbits(20) for _ in range(3)] | |
self.a2 = [rnd.getrandbits(20) for _ in range(3)] | |
self.a3 = [rnd.getrandbits(20) for _ in range(3)] | |
self.x = [urand(4) for _ in range(3)] | |
self.y = [urand(4) for _ in range(3)] | |
self.z = [urand(4) for _ in range(3)] | |
def out(self): | |
global xhist, yhist, zhist | |
xhist.append(self.x[0]) | |
yhist.append(self.y[0]) | |
zhist.append(self.z[0]) | |
o = ( | |
2 * self.m1 * self.x[0] - self.m3 * self.y[0] - self.m2 * self.z[0] | |
) % self.M | |
self.x = self.x[1:] + [sum(x * y for x, y in zip(self.x, self.a1)) % self.m1] | |
self.y = self.y[1:] + [sum(x * y for x, y in zip(self.y, self.a2)) % self.m2] | |
self.z = self.z[1:] + [sum(x * y for x, y in zip(self.z, self.a3)) % self.m3] | |
return o | |
xhist = [] | |
yhist = [] | |
zhist = [] | |
prng = PRNG() | |
xx = list(prng.x) | |
yy = list(prng.y) | |
zz = list(prng.z) | |
states = [prng.out() for _ in range(12)] | |
print(states) | |
m1 = 2 ** 32 - 107 | |
m2 = 2 ** 32 - 5 | |
m3 = 2 ** 32 - 209 | |
M = 2 ** 64 - 59 | |
a1 = prng.a1 | |
a2 = prng.a2 | |
a3 = prng.a3 | |
numk = 9 * 3 + 12 | |
varstr = "x0,x1,x2,y0,y1,y2,z0,z1,z2," + ",".join(f"k{i}" for i in range(numk)) | |
P = PolynomialRing(ZZ, varstr) | |
x0, x1, x2, y0, y1, y2, z0, z1, z2, *ks = P.gens() | |
polys = [] | |
iks = iter(ks) | |
xs = [x0, x1, x2] | |
while len(xs) != 12: | |
xs += [sum(x * y for x, y in zip(a1, xs[-3:])) - next(iks) * m1] | |
ys = [y0, y1, y2] | |
while len(ys) != 12: | |
ys += [sum(x * y for x, y in zip(a2, ys[-3:])) - next(iks) * m2] | |
zs = [z0, z1, z2] | |
while len(zs) != 12: | |
zs += [sum(x * y for x, y in zip(a3, zs[-3:])) - next(iks) * m3] | |
from itertools import tee | |
iks, iks2 = tee(iks) | |
sts = [2 * m1 * x - m3 * y - m2 * z - next(iks) * M for x, y, z in zip(xs, ys, zs)] | |
ssts = [ | |
2 * m1 * x - m3 * y - m2 * z - next(iks2) * M | |
for x, y, z in zip(xhist, yhist, zhist) | |
] | |
for s, o in zip(ssts, states): | |
print(s(xx + yy + zz + [0] * numk) % M, o) | |
# ----- | |
testks = [] | |
for a, b, c, d in zip(xhist, xhist[1:], xhist[2:], xhist[3:]): | |
# lhs = rhs - k * m1 | |
lhs = d | |
rhs = sum(x * y for x, y in zip(a1, [a, b, c])) | |
assert (rhs - lhs) % m1 == 0 | |
testks.append((rhs - lhs) // m1) | |
for a, b, c, d in zip(yhist, yhist[1:], yhist[2:], yhist[3:]): | |
# lhs = rhs - k * m1 | |
lhs = d | |
rhs = sum(x * y for x, y in zip(a2, [a, b, c])) | |
assert (rhs - lhs) % m2 == 0 | |
testks.append((rhs - lhs) // m2) | |
for a, b, c, d in zip(zhist, zhist[1:], zhist[2:], zhist[3:]): | |
# lhs = rhs - k * m1 | |
lhs = d | |
rhs = sum(x * y for x, y in zip(a3, [a, b, c])) | |
assert (rhs - lhs) % m3 == 0 | |
testks.append((rhs - lhs) // m3) | |
for o, x, y, z in zip(states, xhist, yhist, zhist): | |
lhs = o | |
rhs = 2 * m1 * x - m3 * y - m2 * z | |
assert (rhs - lhs) % M == 0 | |
testks.append((rhs - lhs) // M) | |
print("testks", testks) | |
assert len(testks) == numk | |
for s, o in zip(sts, states): | |
print(s(xx + yy + zz + testks), o) | |
# ----- | |
# M, _ = Sequence(sts).coefficient_matrix() | |
# print(vector(_)) | |
# M = M.dense_matrix().T | |
# A = matrix.identity(9) | |
# A = A.augment(A) | |
# A = A.augment(A) | |
# A = A.stack(matrix(39, 36)) | |
# M = M.augment(A) | |
M, _ = Sequence(sts).coefficient_matrix() | |
print(vector(_)) | |
M = M.dense_matrix().T | |
A = matrix.identity(24) | |
A = A.augment(matrix(24, 12)) | |
B = matrix(12, 36) | |
C = matrix(12, 24) | |
C = C.augment(matrix.identity(12)) | |
A = A.stack(B).stack(C) | |
M = M.augment(A) | |
target = xx + yy + zz | |
# --- | |
assert vector(target + testks) * M == vector( | |
states + target + testks[: 24 - 9] + testks[-12:] | |
) | |
# --- | |
print(testks[-12:]) | |
lb = states + [0] * 9 + [0] * (36 - 9 - 12) + [-3] * 12 | |
ub = states + [2 ** 32] * 9 + [2 ** 20] * (36 - 9 - 12) + [0] * 12 | |
load("solver.sage") | |
result, applied_weights, fin = solve(M, list(lb), list(ub)) | |
assert M.solve_left(result)[:9] == vector(target), "failed" | |
# although the ks isn't correct, the recovered initial states are correct | |
print(M.solve_left(result)) | |
print(vector(target + testks)) | |
# it is because the left kernel are mostly in the dimensions of ks | |
print(M.rank()) | |
print(M.left_kernel().basis()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment