Skip to content

Instantly share code, notes, and snippets.

@TheBlupper
Last active May 8, 2024 23:10
Show Gist options
  • Save TheBlupper/58f6bc59ccf96f40d422d6ea445c0b17 to your computer and use it in GitHub Desktop.
Save TheBlupper/58f6bc59ccf96f40d422d6ea445c0b17 to your computer and use it in GitHub Desktop.
This is outdated, full repo here: https://github.com/TheBlupper/linineq
import threading
from sage.all import *
from ortools.sat.python import cp_model as ort
from queue import Queue
def babai_coords(M, tgt):
'''
Returns both the (approximate) closest vector
to tgt and its coordinates in the lattice M
'''
G = M.gram_schmidt()[0]
diff = tgt
coord = []
for i in reversed(range(G.nrows())):
c = ((diff * G[i]) / (G[i] * G[i])).round()
coord.append(c)
diff -= c*M[i]
return tgt - diff, vector(coord[::-1])
def gen_solutions(model, variables):
'''
Return a generator that yields all solutions to a
model, this will be slower at finding a single solution
because ortools can't parallelize the search
'''
queue = Queue()
search_event = threading.Event()
stop_event = threading.Event()
slvr = ort.CpSolver()
slvr.parameters.enumerate_all_solutions = True
t = threading.Thread(target=solver_thread,
args=(model, variables, queue, search_event, stop_event))
t.start()
try:
while True:
search_event.set()
sol = queue.get()
if sol is None:
break
yield sol
finally:
stop_event.set()
search_event.set()
t.join()
def solver_thread(model, variables, queue, search_event, stop_event):
slvr = ort.CpSolver()
slvr.parameters.enumerate_all_solutions = True
solution_collector = SolutionCollector(variables, queue, search_event, stop_event)
search_event.wait()
slvr.Solve(model, solution_collector)
queue.put(None)
class SolutionCollector(ort.CpSolverSolutionCallback):
def __init__(self, vars, queue, search_event, stop_event):
super().__init__()
self.vars = vars
self.queue = queue
self.search_event = search_event
self.stop_event = stop_event
def on_solution_callback(self):
self.queue.put(tuple(self.Value(v) for v in self.vars))
if self.stop_event.is_set():
self.StopSearch()
return
self.search_event.wait()
self.search_event.clear()
def find_solution(model, variables):
'''
Find a single solution to the model
'''
slvr = ort.CpSolver()
status = slvr.Solve(model)
if status not in [ort.OPTIMAL, ort.FEASIBLE]:
raise ValueError('No solution found')
return tuple(slvr.Value(v) for v in variables)
# https://library.wolfram.com/infocenter/Books/8502/AdvancedAlgebra.pdf page 80
def _build_system(M, Mineq, b, bineq, lp_bound=99999):
'''
Returns a tuple (model, X, f) where model is an ortools model,
X is a list of variables we want the solution for, and f is a
function that will transform the solution back to the original space
'''
assert Mineq.ncols() == M.ncols()
assert Mineq.nrows() == len(bineq)
assert M.nrows() == len(b)
# find unbounded solution
D, U, V = M.smith_form()
s = V*D.solve_right(U*vector(ZZ, b))
try:
s = s.change_ring(ZZ)
except TypeError:
raise ValueError('no solution (even without bounds)')
ker = M.right_kernel().basis_matrix().change_ring(ZZ)
Mker = Mineq*ker.T
# using BKZ instead might help in some cases
Mred = Mker.T.LLL().T
# matrix magic that will transform our
# solution back to the original space
R = ((Mker.T*Mker)**-1 * (Mker.T*Mred)).change_ring(ZZ)
bineq = bineq - Mineq*s
bineq_cv, v = babai_coords(Mred.T, bineq)
bineq_red = bineq - bineq_cv
model = ort.CpModel()
X = [model.NewIntVar(-lp_bound, lp_bound, f'x{i}') for i in range(Mred.ncols())]
# Mred*X >= bineq_red
Y = [sum([c*x for c, x in zip(row, X)]) for row in Mred]
for i, yi in enumerate(Y):
model.Add(yi >= bineq_red[i])
# precompute the operation R*(x+v)*ker + s
# as T*x + c
T = ker.T*R
c = T*v + s
f = lambda sol: T*vector(ZZ, sol) + c
return model, X, f
def solve_ineq_gen(M, Mineq, b, bineq, **kwargs):
'''
Returns a generetor yielding all* solutions to:
M*x = b where Mineq*x >= bineq
*depending on the lp_bound parameter
'''
model, X, f = _build_system(M, Mineq, b, bineq, **kwargs)
yield from map(f, gen_solutions(model, X))
def solve_ineq(M, Mineq, b, bineq, **kwargs):
'''
Finds a solution to:
M*x = b where Mineq*x >= bineq
'''
model, X, f = _build_system(M, Mineq, b, bineq, **kwargs)
return f(find_solution(model, X))
def _build_mod_system(M, b, lb, ub, N, **kwargs):
'''
Returns a tuple (model, X, f) where model is an ortools model,
X is a list of variables we want the solution for, and f is a
function that will transform the solution back to the original space
'''
neqs = M.nrows()
nvars = M.ncols()
M = M.augment(identity_matrix(neqs)*N)
I = identity_matrix(nvars)
Mineq = I.stack(-I).augment(zero_matrix(nvars*2, neqs))
bineq = vector([*lb] + [-x for x in ub])
model, X, f = _build_system(M, Mineq, b, bineq, **kwargs)
return model, X, lambda sol: f(sol)[:nvars]
def solve_ineq_mod_gen(M, b, lb, ub, N, **kwargs):
'''
Returns a generator yielding all* solutions to:
M*x = b (mod N) where lb <= x <= ub
*depending on the lp_bound parameter
'''
model, X, f = _build_mod_system(M, b, lb, ub, N, **kwargs)
yield from map(f, gen_solutions(model, X))
def solve_ineq_mod(M, b, lb, ub, N, **kwargs):
'''
Finds a solution to:
M*x = b (mod N) where lb <= x <= ub
'''
model, X, f = _build_mod_system(M, b, lb, ub, N, **kwargs)
return f(find_solution(model, X))
# example usage, from https://connor-mccartney.github.io/cryptography/other/Trying-to-crack-COD-FNV-hashes
if __name__ == '__main__':
FNV_INIT = 0xCBF29CE484222325
p = 0x100000001B3
def fnv64(s):
hsh = FNV_INIT
for c in s.lower().replace(b"\\", b"/"):
hsh = ((hsh^c)*p)%2**64
return hsh
def rev(sol):
hsh = FNV_INIT
ret = []
new = hsh
prev = hsh
for s in sol:
new += s
ch = new^prev
if ch not in range(32, 128):
return None
ret.append(ch)
prev = new = new*p
return bytes(ret)
def solve(target, n):
hsh = FNV_INIT
rets = []
M = matrix([[p**(n - i) for i in range(n)]])
for sol in solve_ineq_mod_gen(
M, [target - hsh*p**n], [-128]*n, [128]*n, 2**64
):
ret = rev(sol)
if ret is None: continue
print(ret)
rets.append(ret)
return rets
solve(fnv64(b'abcdefghi'), 9)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment