Last active
October 11, 2021 08:38
-
-
Save maple3142/38d06519f6e7d4dd12b526792682d70b to your computer and use it in GitHub Desktop.
pbctf 2021 Yet Another PRNG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
with open("output.txt") as f: | |
stream = bytes.fromhex(f.readline().strip()) | |
states = [ | |
int.from_bytes(stream[i : i + 8], "big") for i in range(0, len(stream), 8) | |
] | |
assert len(states) == 12 | |
ct = bytes.fromhex(f.readline().strip()) | |
m1 = 2 ** 32 - 107 | |
m2 = 2 ** 32 - 5 | |
m3 = 2 ** 32 - 209 | |
M = 2 ** 64 - 59 | |
rnd = random.Random(b"rbtree") | |
a1 = [rnd.getrandbits(20) for _ in range(3)] | |
a2 = [rnd.getrandbits(20) for _ in range(3)] | |
a3 = [rnd.getrandbits(20) for _ in range(3)] | |
numk = 9 * 3 + 12 | |
varstr = "x0,x1,x2,y0,y1,y2,z0,z1,z2," + ",".join(f"k{i}" for i in range(numk)) | |
P = PolynomialRing(ZZ, varstr) | |
x0, x1, x2, y0, y1, y2, z0, z1, z2, *ks = P.gens() | |
polys = [] | |
iks = iter(ks) | |
xs = [x0, x1, x2] | |
while len(xs) != 12: | |
xs += [sum(x * y for x, y in zip(a1, xs[-3:])) - next(iks) * m1] | |
ys = [y0, y1, y2] | |
while len(ys) != 12: | |
ys += [sum(x * y for x, y in zip(a2, ys[-3:])) - next(iks) * m2] | |
zs = [z0, z1, z2] | |
while len(zs) != 12: | |
zs += [sum(x * y for x, y in zip(a3, zs[-3:])) - next(iks) * m3] | |
sts = [2 * m1 * x - m3 * y - m2 * z - next(iks) * M for x, y, z in zip(xs, ys, zs)] | |
M, _ = Sequence(sts).coefficient_matrix() | |
print(vector(_)) | |
M = M.dense_matrix().T | |
A = matrix.identity(24) | |
A = A.augment(matrix(24, 12)) | |
B = matrix(12, 36) | |
C = matrix(12, 24) | |
C = C.augment(matrix.identity(12)) | |
A = A.stack(B).stack(C) | |
M = M.augment(A) | |
print(M.dimensions()) | |
lb = states + [0] * 9 + [0] * (36 - 9 - 12) + [-3] * 12 | |
ub = states + [2 ** 32] * 9 + [2 ** 20] * (36 - 9 - 12) + [0] * 12 | |
load("solver.sage") # https://github.com/rkm0959/Inequality_Solving_with_CVP | |
result, applied_weights, fin = solve(M, list(lb), list(ub)) # `solve` will mutate M, lb and ub | |
init_states = list(M.solve_left(result)[:9]) | |
print(init_states) | |
class PRNG: | |
def __init__(self, init_states): | |
self.m1 = 2 ** 32 - 107 | |
self.m2 = 2 ** 32 - 5 | |
self.m3 = 2 ** 32 - 209 | |
self.M = 2 ** 64 - 59 | |
rnd = random.Random(b"rbtree") | |
self.a1 = [rnd.getrandbits(20) for _ in range(3)] | |
self.a2 = [rnd.getrandbits(20) for _ in range(3)] | |
self.a3 = [rnd.getrandbits(20) for _ in range(3)] | |
self.x = init_states[:3] | |
self.y = init_states[3:6] | |
self.z = init_states[6:] | |
def out(self): | |
o = ( | |
2 * self.m1 * self.x[0] - self.m3 * self.y[0] - self.m2 * self.z[0] | |
) % self.M | |
self.x = self.x[1:] + [sum(x * y for x, y in zip(self.x, self.a1)) % self.m1] | |
self.y = self.y[1:] + [sum(x * y for x, y in zip(self.y, self.a2)) % self.m2] | |
self.z = self.z[1:] + [sum(x * y for x, y in zip(self.z, self.a3)) % self.m3] | |
return int(o).to_bytes(8, byteorder="big") | |
prng = PRNG(init_states) | |
for o in states: | |
assert int(o).to_bytes(8, "big") == prng.out() | |
flag = b"" | |
for blk in [ct[i : i + 8] for i in range(0, len(ct), 8)]: | |
flag += bytes(a ^^ b for a, b in zip(blk, prng.out())) | |
print(flag) | |
# pbctf{Wow_how_did_you_solve_this?_I_thought_this_is_super_secure._Thank_you_for_solving_this!!!} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import os | |
def urand(b): | |
return int.from_bytes(os.urandom(b), byteorder="big") | |
class PRNG: | |
def __init__(self): | |
self.m1 = 2 ** 32 - 107 | |
self.m2 = 2 ** 32 - 5 | |
self.m3 = 2 ** 32 - 209 | |
self.M = 2 ** 64 - 59 | |
rnd = random.Random(b"rbtree") | |
self.a1 = [rnd.getrandbits(20) for _ in range(3)] | |
self.a2 = [rnd.getrandbits(20) for _ in range(3)] | |
self.a3 = [rnd.getrandbits(20) for _ in range(3)] | |
self.x = [urand(4) for _ in range(3)] | |
self.y = [urand(4) for _ in range(3)] | |
self.z = [urand(4) for _ in range(3)] | |
def out(self): | |
global xhist, yhist, zhist | |
xhist.append(self.x[0]) | |
yhist.append(self.y[0]) | |
zhist.append(self.z[0]) | |
o = ( | |
2 * self.m1 * self.x[0] - self.m3 * self.y[0] - self.m2 * self.z[0] | |
) % self.M | |
self.x = self.x[1:] + [sum(x * y for x, y in zip(self.x, self.a1)) % self.m1] | |
self.y = self.y[1:] + [sum(x * y for x, y in zip(self.y, self.a2)) % self.m2] | |
self.z = self.z[1:] + [sum(x * y for x, y in zip(self.z, self.a3)) % self.m3] | |
return o | |
xhist = [] | |
yhist = [] | |
zhist = [] | |
prng = PRNG() | |
xx = list(prng.x) | |
yy = list(prng.y) | |
zz = list(prng.z) | |
states = [prng.out() for _ in range(12)] | |
print(states) | |
m1 = 2 ** 32 - 107 | |
m2 = 2 ** 32 - 5 | |
m3 = 2 ** 32 - 209 | |
M = 2 ** 64 - 59 | |
a1 = prng.a1 | |
a2 = prng.a2 | |
a3 = prng.a3 | |
numk = 9 * 3 + 12 | |
varstr = "x0,x1,x2,y0,y1,y2,z0,z1,z2," + ",".join(f"k{i}" for i in range(numk)) | |
P = PolynomialRing(ZZ, varstr) | |
x0, x1, x2, y0, y1, y2, z0, z1, z2, *ks = P.gens() | |
polys = [] | |
iks = iter(ks) | |
xs = [x0, x1, x2] | |
while len(xs) != 12: | |
xs += [sum(x * y for x, y in zip(a1, xs[-3:])) - next(iks) * m1] | |
ys = [y0, y1, y2] | |
while len(ys) != 12: | |
ys += [sum(x * y for x, y in zip(a2, ys[-3:])) - next(iks) * m2] | |
zs = [z0, z1, z2] | |
while len(zs) != 12: | |
zs += [sum(x * y for x, y in zip(a3, zs[-3:])) - next(iks) * m3] | |
from itertools import tee | |
iks, iks2 = tee(iks) | |
sts = [2 * m1 * x - m3 * y - m2 * z - next(iks) * M for x, y, z in zip(xs, ys, zs)] | |
ssts = [ | |
2 * m1 * x - m3 * y - m2 * z - next(iks2) * M | |
for x, y, z in zip(xhist, yhist, zhist) | |
] | |
for s, o in zip(ssts, states): | |
print(s(xx + yy + zz + [0] * numk) % M, o) | |
# ----- | |
testks = [] | |
for a, b, c, d in zip(xhist, xhist[1:], xhist[2:], xhist[3:]): | |
# lhs = rhs - k * m1 | |
lhs = d | |
rhs = sum(x * y for x, y in zip(a1, [a, b, c])) | |
assert (rhs - lhs) % m1 == 0 | |
testks.append((rhs - lhs) // m1) | |
for a, b, c, d in zip(yhist, yhist[1:], yhist[2:], yhist[3:]): | |
# lhs = rhs - k * m1 | |
lhs = d | |
rhs = sum(x * y for x, y in zip(a2, [a, b, c])) | |
assert (rhs - lhs) % m2 == 0 | |
testks.append((rhs - lhs) // m2) | |
for a, b, c, d in zip(zhist, zhist[1:], zhist[2:], zhist[3:]): | |
# lhs = rhs - k * m1 | |
lhs = d | |
rhs = sum(x * y for x, y in zip(a3, [a, b, c])) | |
assert (rhs - lhs) % m3 == 0 | |
testks.append((rhs - lhs) // m3) | |
for o, x, y, z in zip(states, xhist, yhist, zhist): | |
lhs = o | |
rhs = 2 * m1 * x - m3 * y - m2 * z | |
assert (rhs - lhs) % M == 0 | |
testks.append((rhs - lhs) // M) | |
print("testks", testks) | |
assert len(testks) == numk | |
for s, o in zip(sts, states): | |
print(s(xx + yy + zz + testks), o) | |
# ----- | |
# M, _ = Sequence(sts).coefficient_matrix() | |
# print(vector(_)) | |
# M = M.dense_matrix().T | |
# A = matrix.identity(9) | |
# A = A.augment(A) | |
# A = A.augment(A) | |
# A = A.stack(matrix(39, 36)) | |
# M = M.augment(A) | |
M, _ = Sequence(sts).coefficient_matrix() | |
print(vector(_)) | |
M = M.dense_matrix().T | |
A = matrix.identity(24) | |
A = A.augment(matrix(24, 12)) | |
B = matrix(12, 36) | |
C = matrix(12, 24) | |
C = C.augment(matrix.identity(12)) | |
A = A.stack(B).stack(C) | |
M = M.augment(A) | |
target = xx + yy + zz | |
# --- | |
assert vector(target + testks) * M == vector( | |
states + target + testks[: 24 - 9] + testks[-12:] | |
) | |
# --- | |
print(testks[-12:]) | |
lb = states + [0] * 9 + [0] * (36 - 9 - 12) + [-3] * 12 | |
ub = states + [2 ** 32] * 9 + [2 ** 20] * (36 - 9 - 12) + [0] * 12 | |
load("solver.sage") | |
result, applied_weights, fin = solve(M, list(lb), list(ub)) | |
assert M.solve_left(result)[:9] == vector(target), "failed" | |
# although the ks isn't correct, the recovered initial states are correct | |
print(M.solve_left(result)) | |
print(vector(target + testks)) | |
# it is because the left kernel are mostly in the dimensions of ks | |
print(M.rank()) | |
print(M.left_kernel().basis()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment