LLL/CVP utilities
from lll_cvp import *
from functools import partial
def example1():
# copied from
## Example 4 : HITCON CTF 2019 Quals not so hard RSA
## d is 465 bits
data = [
## each data has n, e for fixed d
## ed = k(n-p-q+1) + 1 -> ed + kn == k(-p-q+1) + 1
## construct a bound on k(-p-q+1) + 1
## 2 sqrt(n) <= p + q <= 3 sqrt(n/2)
## (e * 2^464 - 1) / (n - 2sqrt(n) + 1) <= (e * d - 1) / (n - 2sqrt(n) + 1) <= k
## k <= (e * d - 1) / (n - 3sqrt(n/2) + 1) <= (e * 2^465 - 1) / (n - 3sqrt(n/2) + 1)
## combine these to get a decent bound for k(-p-q+1) + 1
## 11 variables, d, and k for each 10 equations
## 11 equations, bound on d and each bound on ed + kn
# build matrix
M = matrix(ZZ, 11, 11)
lb = [0] * 11
ub = [0] * 11
# encode d
M[10, 10] = 1
lb[10] = 2**464
ub[10] = 2**465
# encode ed + kn
for i in range(0, 10):
M[10, i] = data[i][1] # e * d
M[i, i] = data[i][0] # k * n
low_sum = int(2 * (data[i][0] ** 0.5))
high_sum = int(3 * ((data[i][0] // 2) ** 0.5))
low_k = (data[i][1] * (2**464) - 1) // (data[i][0] - low_sum + 1)
high_k = (data[i][1] * (2**465) - 1) // (data[i][0] - high_sum + 1)
lb[i] = high_k * (-high_sum + 1) + 1
ub[i] = low_k * (-low_sum + 1) + 1
res = solve_inequality(M, lb, ub)
recovered_d = res[10]
n = data[i][0]
enc = data[i][2]
ptxt = pow(enc, recovered_d, n)
print((int)(ptxt).to_bytes(128, byteorder="big"))
def example2():
# modified from
p = 251
X = bytes.fromhex("02d4623be12c8f01cb2ebe5f837c1d")
Y = bytes.fromhex("bbdc06ceb34da7b16336b007dc5492")
X2 = bytes.fromhex("2fb9e753b237e68d35e266b0f01c9e")
Y2 = bytes.fromhex("20c0be9140f5a33d71b9e82f8f9409")
X3 = bytes.fromhex("f42e3ee10edeade0a3804a22e86a63")
Y3 = bytes.fromhex("c7224da73d9d96254f94136d9a65f1")
X4 = bytes.fromhex("37c9b07870283dd3f6198c46f027dd")
Y4 = bytes.fromhex("8101a88a365526e8faf417b79599a0")
X5 = bytes.fromhex("b0342cb7b3f5a022d927f9019a1bf3")
Y5 = bytes.fromhex("e2666d892955494775aa3c96c441f5")
X6 = bytes.fromhex("e56bf4f9e746252dbacb93a0a95087")
Y6 = bytes.fromhex("cbb43831857333b2c4663ba2c9189a")
X7 = bytes.fromhex("99ca36b1633cf3d903d8e6291f1bdc")
Y7 = bytes.fromhex("25180068651818171d10422dbdb395")
M = Matrix(GF(p), 105, 128)
vec = []
for i in range(105):
x, y = 0, 0
if i < 15:
x = int(X[i])
y = int(Y[i])
elif i < 30:
x = int(X2[i - 15])
y = int(Y2[i - 15])
elif i < 45:
x = int(X3[i - 30])
y = int(Y3[i - 30])
elif i < 60:
x = int(X4[i - 45])
y = int(Y4[i - 45])
elif i < 75:
x = int(X5[i - 60])
y = int(Y5[i - 60])
elif i < 90:
x = int(X6[i - 75])
y = int(Y6[i - 75])
elif i < 105:
x = int(X6[i - 90])
y = int(Y6[i - 90])
for j in range(16):
M[i, j] = (x**j) % p
if i < 15:
for j in range(16):
M[i, j + 16] = (x ** (j + 16)) % p
elif i < 30:
for j in range(16):
M[i, j + 32] = (x ** (j + 16)) % p
elif i < 45:
for j in range(16):
M[i, j + 48] = (x ** (j + 16)) % p
elif i < 60:
for j in range(16):
M[i, j + 64] = (x ** (j + 16)) % p
elif i < 75:
for j in range(16):
M[i, j + 80] = (x ** (j + 16)) % p
elif i < 90:
for j in range(16):
M[i, j + 96] = (x ** (j + 16)) % p
elif i < 105:
for j in range(16):
M[i, j + 112] = (x ** (j + 16)) % p
vec = vector(GF(p), vec)
bas = M.right_kernel().basis()
v = M.solve_right(vec)
# v + bas -> all in 97 ~ 122
M = Matrix(ZZ, 151, 151)
lb = [0] * 151
ub = [0] * 151
for i in range(23):
for j in range(128):
M[i, j] = int(bas[i][j])
M[i, 128 + i] = 1
for i in range(128):
M[23 + i, i] = p
for i in range(128):
if i >= 16:
lb[i] = int(97 - int(v[i]))
ub[i] = int(122 - int(v[i]))
lb[i] = int(32 - int(v[i]))
ub[i] = int(128 - int(v[i]))
for i in range(23):
lb[i + 128] = 0
ub[i + 128] = p
from functools import partial
res = solve_inequality(
cvp=partial(kannan_cvp, reduction=lambda M: M.BKZ(block_size=20), weight=251),
flag = ""
for i in range(16):
flag += chr((res[i] + int(v[i]) + 251 * 30) % 251)
print("ACSC{" + flag + "}")
# ACSC{wOAdvfst41xJzG6r}
def example3():
# modified from
from operator import xor
class JavaRNG:
# about the detail of java 17 rng
def __init__(self, seed):
self.seed = seed
def next(self):
self.seed = self.seed * 0x5DEECE66D + 0xB
return self.seed
Z = Zmod(2**48)
P = PolynomialRing(Z, "s")
s = P.gen()
aa = []
bb = []
zz = []
rng = JavaRNG(s)
for _ in range(16):
for _ in range(2047):
z =
# print(z)
b, a = z.change_ring(ZZ)
# print(((ZZ(z(xs)) >> 24) / (1 << 24)).n())
M = 2**48
B = block_matrix(
[[matrix([1]), matrix(aa)], [matrix(len(aa), 1), matrix.identity(len(aa)) * M]]
# manually changing parameters...
vlb = [M - 3100000000000 for _ in range(len(bb))]
vlb[0] = M - 2900000000000
vub = [M - 2100000000000 for _ in range(len(bb))]
vub[-1] = M - 2400000000000
lb = [0] + [v - b for v, b in zip(vlb, bb)]
ub = [2**48] + [v - b for v, b in zip(vub, bb)]
res = solve_inequality(matrix(B), list(lb), list(ub))
s = ZZ(res[0])
for z in zz:
r = ZZ(z(s))
o = ((r >> 24) / (1 << 24)).n()
print(o, o > 7.331 * 0.1337)
xor(s, 0x5DEECE66D)
) # Java RNG will xor your seed with 0x5DEECE66D when setting seed
# known good seed: 272404351039795
def example4():
n = 10
pub1 = random_vector(ZZ, n, x=1, y=2**256)
pub2 = random_vector(ZZ, n, x=1, y=2**256)
secret = random_vector(ZZ, n, x=1, y=2**64)
t1 = pub1 * secret
t2 = pub2 * secret
matrix([pub1, pub2]).T, vector([t1, t2]), [0] * n, [2**64] * n
def example5():
from fastecdsa.curve import secp256k1
from hashlib import sha256
from Crypto.Cipher import AES
G = secp256k1.G
q = secp256k1.q
# fmt: off
p = 9927040122486684509203958106419420141058188722199373989012953585197167125223276141324574147521754273735827724127795605194092299982453828901469369136978219
sigs = [(98078224267884884220741740422077019843954009281647502734600509731511013529371, 54523865988310606978987830048871561792183822750263202533230451076893555969316), (104372973739209868434840748268723094332969140159620819033951611727659419363988, 39660851627725578124743718742328950528148285144862142963822549722002689280409), (103919709879086855178251181244489637133481828592253195107866903154222896468253, 35031204282583023574328215246485186335362731664384171126097342931654133207246), (63175283280752608661708773461972110889312169792285211062806717970617630555061, 34712080692439206749112321272818736084925608248138548106200594874651099131535)]
ct = b'\xe6\x9c\xcaZ\x01\x90-\xa0\xbc8\xeb\xe4\xc6\xc7b\x16\xb9t++@\xc0\x0ce\t\x9e\xb5\x07p\xe49*\xb8\xce\xfe@\xea%\xc9\xd6\xefF\xf8\x7fQ\x9bg\xbd\x7f\xcf{h\\^\x11\xf9\xf5\xe8\x7f}\x94\xd3+\x06\x19.`\x84\x8d)\x1e\xdey\xe4 [\x9e'
nonce = b'Z\x1c\xba\xbc\x95\\\xe1u'
# fmt: on
msgs = [
ss = []
for m, (r, s) in zip(msgs, sigs):
z = int.from_bytes(sha256(m).digest(), "big") % q
ss.append((z, r, s))
a, b = G.x, G.y
syms = "d," + ",".join([f"k{i}" for i in range(len(ss))])
R = ZZ[syms]
d, *ks = R.gens()
# collect equations
eq_p = []
eq_q = []
for (z, r, s), k in zip(ss, ks):
eq_q.append(s * k - z - r * d)
eq_q = [f.resultant(g, d) for f, g in zip(eq_q, eq_q[1:])]
for k, kk in zip(ks, ks[1:]):
eq_p.append(a * k + b - kk)
# solve!
eqs = eq_p + eq_q
mods = [p] * len(eq_p) + [q] * len(eq_q)
lb = [0] * len(ks)
ub = [2**512] * len(ks)
ks = solve_multi_modulo_equations(eqs, mods, lb, ub)
z, r, s = ss[0]
k = ks[0]
d = (s * k - z) * pow(r, -1, q) % q
# get flag
key = sha256(str(d).encode()).digest()[:16]
cipher =, AES.MODE_CTR, nonce=nonce)
def example6():
n = ZZ(getrandbits(2048))
roots = [ZZ(getrandbits(128)) for _ in range(3)]
x, y, z = PolynomialRing(ZZ, ["x", "y", "z"]).gens()
f = randrange(1, n) * x * y + randrange(1, n) * y * z + randrange(1, n) * z * x
f -= f(*roots)
f %= n
g = randrange(1, n) * x**2 + randrange(1, n) * y**2 + randrange(1, n) * z**2
g -= g(*roots)
g %= n
eqs = [f, g]
bounds = {
x: 2**128,
y: 2**128,
z: 2**128,
for monos, sol in solve_underconstrained_equations_general(n, eqs, bounds):
print(monos, sol)
if sol[-1] < 0:
sol = -sol
if sol[-1] == 1:
polys = [f.change_ring(QQ) for f in sol - monos if f]
I = ideal(polys)
def example_flatter():
from fastecdsa.curve import secp256k1
from hashlib import sha256
from Crypto.Cipher import AES
G = secp256k1.G
q = secp256k1.q
# fmt: off
p = 9927040122486684509203958106419420141058188722199373989012953585197167125223276141324574147521754273735827724127795605194092299982453828901469369136978219
sigs = [(98078224267884884220741740422077019843954009281647502734600509731511013529371, 54523865988310606978987830048871561792183822750263202533230451076893555969316), (104372973739209868434840748268723094332969140159620819033951611727659419363988, 39660851627725578124743718742328950528148285144862142963822549722002689280409), (103919709879086855178251181244489637133481828592253195107866903154222896468253, 35031204282583023574328215246485186335362731664384171126097342931654133207246), (63175283280752608661708773461972110889312169792285211062806717970617630555061, 34712080692439206749112321272818736084925608248138548106200594874651099131535)]
ct = b'\xe6\x9c\xcaZ\x01\x90-\xa0\xbc8\xeb\xe4\xc6\xc7b\x16\xb9t++@\xc0\x0ce\t\x9e\xb5\x07p\xe49*\xb8\xce\xfe@\xea%\xc9\xd6\xefF\xf8\x7fQ\x9bg\xbd\x7f\xcf{h\\^\x11\xf9\xf5\xe8\x7f}\x94\xd3+\x06\x19.`\x84\x8d)\x1e\xdey\xe4 [\x9e'
nonce = b'Z\x1c\xba\xbc\x95\\\xe1u'
# fmt: on
msgs = [
ss = []
for m, (r, s) in zip(msgs, sigs):
z = int.from_bytes(sha256(m).digest(), "big") % q
ss.append((z, r, s))
a, b = G.x, G.y
syms = "d," + ",".join([f"k{i}" for i in range(len(ss))])
R = ZZ[syms]
d, *ks = R.gens()
# collect equations
eq_p = []
eq_q = []
for (z, r, s), k in zip(ss, ks):
eq_q.append(s * k - z - r * d)
eq_q = [f.resultant(g, d) for f, g in zip(eq_q, eq_q[1:])]
for k, kk in zip(ks, ks[1:]):
eq_p.append(a * k + b - kk)
# solve!
eqs = eq_p + eq_q
mods = [p] * len(eq_p) + [q] * len(eq_q)
lb = [0] * len(ks)
ub = [2**512] * len(ks)
ks = solve_multi_modulo_equations(
eqs, mods, lb, ub, cvp=partial(kannan_cvp, reduction=flatter)
z, r, s = ss[0]
k = ks[0]
d = (s * k - z) * pow(r, -1, q) % q
# get flag
key = sha256(str(d).encode()).digest()[:16]
cipher =, AES.MODE_CTR, nonce=nonce)
if __name__ == "__main__":
from sage.all import (
from subprocess import check_output
from re import findall
def build_lattice(mat, lb, ub):
n = mat.ncols() # num equations
m = mat.nrows() # num variables
if n != len(ub) or n != len(lb):
raise ValueError("Number of equations must match number of bounds")
if any([l > u for l, u in zip(lb, ub)]):
raise ValueError("All lower bounds must be less than upper bounds")
L = matrix(ZZ, mat)
target = vector([(l + u) // 2 for u, l in zip(ub, lb)])
bounds = [u - l for u, l in zip(ub, lb)]
K = max(bounds) or L.det()
Q = matrix.diagonal([K // x if x != 0 else K * n for x in bounds])
return L, target, Q
def flatter(M):
# compile and put it in $PATH
z = "[[" + "]\n[".join(" ".join(map(str, row)) for row in M) + "]]"
ret = check_output(["flatter"], input=z.encode())
return matrix(M.nrows(), M.ncols(), map(int, findall(b"-?\\d+", ret)))
def babai_cvp(mat, target, reduction=lambda M: M.LLL()):
M = reduction(matrix(ZZ, mat))
G = M.gram_schmidt()[0]
diff = target
for i in reversed(range(G.nrows())):
diff -= M[i] * ((diff * G[i]) / (G[i] * G[i])).round()
return target - diff
def kannan_cvp(mat, target, reduction=lambda M: M.LLL(), weight=None):
if weight is None:
weight = max(target)
L = block_matrix([[mat, 0], [-matrix(target), weight]])
for row in reduction(L):
if row[-1] < 0:
row = -row
if row[-1] == weight:
return row[:-1] + target
def kannan_cvp_ex(mat, target, reduction=lambda M: M.LLL(), weight=None):
# still kannan cvp, but return all possible solutions
# along with a reduced basis (useful for cvp enumeration)
if weight is None:
weight = max(target)
L = block_matrix([[mat, 0], [-matrix(target), weight]])
cvps = []
basis = []
for row in reduction(L):
if row[-1] < 0:
row = -row
if row[-1] == weight:
cvps.append(row[:-1] + target)
elif row[-1] == 0:
return matrix(ZZ, cvps), matrix(ZZ, basis)
def solve_inequality(M, lb, ub, cvp=kannan_cvp):
# find an vector x such that x*M is bounded by lb and ub
# not checked for correctness
# note that the returned value is x*M, not x
L, target, Q = build_lattice(M, lb, ub)
return Q.solve_left(cvp(L * Q, Q * target))
def solve_inequality_ex(M, lb, ub, cvp_ex=kannan_cvp_ex):
# find an vector x such that x*M is bounded by lb and ub
# not checked for correctness
# note that the returned value is x*M, not x
L, target, Q = build_lattice(M, lb, ub)
cvps, basis = cvp_ex(L * Q, Q * target)
Qi = matrix.diagonal([1 / x for x in Q.diagonal()])
cvps = (cvps * Qi).change_ring(ZZ)
basis = (basis * Qi).change_ring(ZZ)
return cvps, basis
def solve_underconstrained_equations(M, target, lb, ub, cvp=kannan_cvp):
# find an vector x such that x*M=target and x is bounded by lb and ub
# not checked for correctness
n = M.ncols() # number of equations
m = M.nrows() # number of variables
if n != len(target):
raise ValueError("number of equations and target mismatch")
if n >= m:
raise ValueError("use gauss elimination instead")
M = block_matrix([[matrix(ZZ, M), 1], [matrix(target), 0]])
lb = [0] * n + lb
ub = [0] * n + ub
sol = solve_inequality(M, lb, ub, cvp=cvp)
return sol[-m:]
def solve_multi_modulo_equations(
eqs, mods, lb, ub, reduction=lambda M: M.LLL(), cvp=kannan_cvp
# solve a linear system of equations modulo different modulus
# eqs: a list of equations over ZZ
# mods: a list of modulus
if len(eqs) != len(mods):
raise ValueError("number of equations and modulus mismatch")
if len(lb) != len(ub):
raise ValueError("number of lower bounds and upper bounds mismatch")
M, v = Sequence(eqs).coefficient_matrix()
assert v.list()[-1] == 1, "only support equations with constant term"
A, b = M[:, :-1], -M[:, -1]
M = A.dense_matrix().T
nr, nc = M.dimensions()
L = M.stack(diagonal_matrix(mods))
L = L.augment(matrix.identity(nr).stack(, nr)))
lbx = b.list() + lb
ubx = b.list() + ub
return solve_inequality(L, lbx, ubx, cvp=cvp)[-len(lb) :]
def polynomials_to_matrix(polys):
# coefficients_monomials is a replacement for coefficient_matrix in sage 10.3
# and coefficient_matrix is now deprecated
S = Sequence(polys)
if hasattr(S, "coefficients_monomials"):
return S.coefficients_monomials(sparse=False)
M, monos = S.coefficient_matrix(sparse=False)
return M, vector(monos)
def solve_underconstrained_equations_general(
n, eqs, bounds, reduction=lambda M: M.LLL()
# given a underconstrained list of polynomials over Z/nZ (or ZZ if n is None)
# where the unknown variables are bounded by some bounds
# bounds should be a dict mapping variable x to an positive integer W, such that |x| < W
# non-linear monomials will be linearized
M, monos = polynomials_to_matrix(eqs)
if n is None:
L = block_matrix(ZZ, [[M.T, 1]])
L = block_matrix(ZZ, [[n, 0], [M.T, 1]])
bounds = [1] * len(eqs) + [ZZ(m.subs(bounds)) for m in monos.list()]
K = max(bounds)
Q = diagonal_matrix([K // b for b in bounds])
L *= Q
L = reduction(L)
L /= Q
L = L.change_ring(ZZ)
for row in L:
if row[: len(eqs)] == 0:
sol = row[len(eqs) :]
yield vector(monos), sol
__all__ = [
