Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@maple3142
Last active October 11, 2021 08:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maple3142/38d06519f6e7d4dd12b526792682d70b to your computer and use it in GitHub Desktop.
Save maple3142/38d06519f6e7d4dd12b526792682d70b to your computer and use it in GitHub Desktop.
pbctf 2021 Yet Another PRNG
import random
with open("output.txt") as f:
stream = bytes.fromhex(f.readline().strip())
states = [
int.from_bytes(stream[i : i + 8], "big") for i in range(0, len(stream), 8)
]
assert len(states) == 12
ct = bytes.fromhex(f.readline().strip())
m1 = 2 ** 32 - 107
m2 = 2 ** 32 - 5
m3 = 2 ** 32 - 209
M = 2 ** 64 - 59
rnd = random.Random(b"rbtree")
a1 = [rnd.getrandbits(20) for _ in range(3)]
a2 = [rnd.getrandbits(20) for _ in range(3)]
a3 = [rnd.getrandbits(20) for _ in range(3)]
numk = 9 * 3 + 12
varstr = "x0,x1,x2,y0,y1,y2,z0,z1,z2," + ",".join(f"k{i}" for i in range(numk))
P = PolynomialRing(ZZ, varstr)
x0, x1, x2, y0, y1, y2, z0, z1, z2, *ks = P.gens()
polys = []
iks = iter(ks)
xs = [x0, x1, x2]
while len(xs) != 12:
xs += [sum(x * y for x, y in zip(a1, xs[-3:])) - next(iks) * m1]
ys = [y0, y1, y2]
while len(ys) != 12:
ys += [sum(x * y for x, y in zip(a2, ys[-3:])) - next(iks) * m2]
zs = [z0, z1, z2]
while len(zs) != 12:
zs += [sum(x * y for x, y in zip(a3, zs[-3:])) - next(iks) * m3]
sts = [2 * m1 * x - m3 * y - m2 * z - next(iks) * M for x, y, z in zip(xs, ys, zs)]
M, _ = Sequence(sts).coefficient_matrix()
print(vector(_))
M = M.dense_matrix().T
A = matrix.identity(24)
A = A.augment(matrix(24, 12))
B = matrix(12, 36)
C = matrix(12, 24)
C = C.augment(matrix.identity(12))
A = A.stack(B).stack(C)
M = M.augment(A)
print(M.dimensions())
lb = states + [0] * 9 + [0] * (36 - 9 - 12) + [-3] * 12
ub = states + [2 ** 32] * 9 + [2 ** 20] * (36 - 9 - 12) + [0] * 12
load("solver.sage") # https://github.com/rkm0959/Inequality_Solving_with_CVP
result, applied_weights, fin = solve(M, list(lb), list(ub)) # `solve` will mutate M, lb and ub
init_states = list(M.solve_left(result)[:9])
print(init_states)
class PRNG:
def __init__(self, init_states):
self.m1 = 2 ** 32 - 107
self.m2 = 2 ** 32 - 5
self.m3 = 2 ** 32 - 209
self.M = 2 ** 64 - 59
rnd = random.Random(b"rbtree")
self.a1 = [rnd.getrandbits(20) for _ in range(3)]
self.a2 = [rnd.getrandbits(20) for _ in range(3)]
self.a3 = [rnd.getrandbits(20) for _ in range(3)]
self.x = init_states[:3]
self.y = init_states[3:6]
self.z = init_states[6:]
def out(self):
o = (
2 * self.m1 * self.x[0] - self.m3 * self.y[0] - self.m2 * self.z[0]
) % self.M
self.x = self.x[1:] + [sum(x * y for x, y in zip(self.x, self.a1)) % self.m1]
self.y = self.y[1:] + [sum(x * y for x, y in zip(self.y, self.a2)) % self.m2]
self.z = self.z[1:] + [sum(x * y for x, y in zip(self.z, self.a3)) % self.m3]
return int(o).to_bytes(8, byteorder="big")
prng = PRNG(init_states)
for o in states:
assert int(o).to_bytes(8, "big") == prng.out()
flag = b""
for blk in [ct[i : i + 8] for i in range(0, len(ct), 8)]:
flag += bytes(a ^^ b for a, b in zip(blk, prng.out()))
print(flag)
# pbctf{Wow_how_did_you_solve_this?_I_thought_this_is_super_secure._Thank_you_for_solving_this!!!}
import random
import os
def urand(b):
return int.from_bytes(os.urandom(b), byteorder="big")
class PRNG:
def __init__(self):
self.m1 = 2 ** 32 - 107
self.m2 = 2 ** 32 - 5
self.m3 = 2 ** 32 - 209
self.M = 2 ** 64 - 59
rnd = random.Random(b"rbtree")
self.a1 = [rnd.getrandbits(20) for _ in range(3)]
self.a2 = [rnd.getrandbits(20) for _ in range(3)]
self.a3 = [rnd.getrandbits(20) for _ in range(3)]
self.x = [urand(4) for _ in range(3)]
self.y = [urand(4) for _ in range(3)]
self.z = [urand(4) for _ in range(3)]
def out(self):
global xhist, yhist, zhist
xhist.append(self.x[0])
yhist.append(self.y[0])
zhist.append(self.z[0])
o = (
2 * self.m1 * self.x[0] - self.m3 * self.y[0] - self.m2 * self.z[0]
) % self.M
self.x = self.x[1:] + [sum(x * y for x, y in zip(self.x, self.a1)) % self.m1]
self.y = self.y[1:] + [sum(x * y for x, y in zip(self.y, self.a2)) % self.m2]
self.z = self.z[1:] + [sum(x * y for x, y in zip(self.z, self.a3)) % self.m3]
return o
xhist = []
yhist = []
zhist = []
prng = PRNG()
xx = list(prng.x)
yy = list(prng.y)
zz = list(prng.z)
states = [prng.out() for _ in range(12)]
print(states)
m1 = 2 ** 32 - 107
m2 = 2 ** 32 - 5
m3 = 2 ** 32 - 209
M = 2 ** 64 - 59
a1 = prng.a1
a2 = prng.a2
a3 = prng.a3
numk = 9 * 3 + 12
varstr = "x0,x1,x2,y0,y1,y2,z0,z1,z2," + ",".join(f"k{i}" for i in range(numk))
P = PolynomialRing(ZZ, varstr)
x0, x1, x2, y0, y1, y2, z0, z1, z2, *ks = P.gens()
polys = []
iks = iter(ks)
xs = [x0, x1, x2]
while len(xs) != 12:
xs += [sum(x * y for x, y in zip(a1, xs[-3:])) - next(iks) * m1]
ys = [y0, y1, y2]
while len(ys) != 12:
ys += [sum(x * y for x, y in zip(a2, ys[-3:])) - next(iks) * m2]
zs = [z0, z1, z2]
while len(zs) != 12:
zs += [sum(x * y for x, y in zip(a3, zs[-3:])) - next(iks) * m3]
from itertools import tee
iks, iks2 = tee(iks)
sts = [2 * m1 * x - m3 * y - m2 * z - next(iks) * M for x, y, z in zip(xs, ys, zs)]
ssts = [
2 * m1 * x - m3 * y - m2 * z - next(iks2) * M
for x, y, z in zip(xhist, yhist, zhist)
]
for s, o in zip(ssts, states):
print(s(xx + yy + zz + [0] * numk) % M, o)
# -----
testks = []
for a, b, c, d in zip(xhist, xhist[1:], xhist[2:], xhist[3:]):
# lhs = rhs - k * m1
lhs = d
rhs = sum(x * y for x, y in zip(a1, [a, b, c]))
assert (rhs - lhs) % m1 == 0
testks.append((rhs - lhs) // m1)
for a, b, c, d in zip(yhist, yhist[1:], yhist[2:], yhist[3:]):
# lhs = rhs - k * m1
lhs = d
rhs = sum(x * y for x, y in zip(a2, [a, b, c]))
assert (rhs - lhs) % m2 == 0
testks.append((rhs - lhs) // m2)
for a, b, c, d in zip(zhist, zhist[1:], zhist[2:], zhist[3:]):
# lhs = rhs - k * m1
lhs = d
rhs = sum(x * y for x, y in zip(a3, [a, b, c]))
assert (rhs - lhs) % m3 == 0
testks.append((rhs - lhs) // m3)
for o, x, y, z in zip(states, xhist, yhist, zhist):
lhs = o
rhs = 2 * m1 * x - m3 * y - m2 * z
assert (rhs - lhs) % M == 0
testks.append((rhs - lhs) // M)
print("testks", testks)
assert len(testks) == numk
for s, o in zip(sts, states):
print(s(xx + yy + zz + testks), o)
# -----
# M, _ = Sequence(sts).coefficient_matrix()
# print(vector(_))
# M = M.dense_matrix().T
# A = matrix.identity(9)
# A = A.augment(A)
# A = A.augment(A)
# A = A.stack(matrix(39, 36))
# M = M.augment(A)
M, _ = Sequence(sts).coefficient_matrix()
print(vector(_))
M = M.dense_matrix().T
A = matrix.identity(24)
A = A.augment(matrix(24, 12))
B = matrix(12, 36)
C = matrix(12, 24)
C = C.augment(matrix.identity(12))
A = A.stack(B).stack(C)
M = M.augment(A)
target = xx + yy + zz
# ---
assert vector(target + testks) * M == vector(
states + target + testks[: 24 - 9] + testks[-12:]
)
# ---
print(testks[-12:])
lb = states + [0] * 9 + [0] * (36 - 9 - 12) + [-3] * 12
ub = states + [2 ** 32] * 9 + [2 ** 20] * (36 - 9 - 12) + [0] * 12
load("solver.sage")
result, applied_weights, fin = solve(M, list(lb), list(ub))
assert M.solve_left(result)[:9] == vector(target), "failed"
# although the ks isn't correct, the recovered initial states are correct
print(M.solve_left(result))
print(vector(target + testks))
# it is because the left kernel are mostly in the dimensions of ks
print(M.rank())
print(M.left_kernel().basis())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment