Last active
May 8, 2024 23:10
-
-
Save TheBlupper/58f6bc59ccf96f40d422d6ea445c0b17 to your computer and use it in GitHub Desktop.
This is outdated, full repo here: https://github.com/TheBlupper/linineq
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import threading | |
from sage.all import * | |
from ortools.sat.python import cp_model as ort | |
from queue import Queue | |
def babai_coords(M, tgt): | |
''' | |
Returns both the (approximate) closest vector | |
to tgt and its coordinates in the lattice M | |
''' | |
G = M.gram_schmidt()[0] | |
diff = tgt | |
coord = [] | |
for i in reversed(range(G.nrows())): | |
c = ((diff * G[i]) / (G[i] * G[i])).round() | |
coord.append(c) | |
diff -= c*M[i] | |
return tgt - diff, vector(coord[::-1]) | |
def gen_solutions(model, variables): | |
''' | |
Return a generator that yields all solutions to a | |
model, this will be slower at finding a single solution | |
because ortools can't parallelize the search | |
''' | |
queue = Queue() | |
search_event = threading.Event() | |
stop_event = threading.Event() | |
slvr = ort.CpSolver() | |
slvr.parameters.enumerate_all_solutions = True | |
t = threading.Thread(target=solver_thread, | |
args=(model, variables, queue, search_event, stop_event)) | |
t.start() | |
try: | |
while True: | |
search_event.set() | |
sol = queue.get() | |
if sol is None: | |
break | |
yield sol | |
finally: | |
stop_event.set() | |
search_event.set() | |
t.join() | |
def solver_thread(model, variables, queue, search_event, stop_event): | |
slvr = ort.CpSolver() | |
slvr.parameters.enumerate_all_solutions = True | |
solution_collector = SolutionCollector(variables, queue, search_event, stop_event) | |
search_event.wait() | |
slvr.Solve(model, solution_collector) | |
queue.put(None) | |
class SolutionCollector(ort.CpSolverSolutionCallback): | |
def __init__(self, vars, queue, search_event, stop_event): | |
super().__init__() | |
self.vars = vars | |
self.queue = queue | |
self.search_event = search_event | |
self.stop_event = stop_event | |
def on_solution_callback(self): | |
self.queue.put(tuple(self.Value(v) for v in self.vars)) | |
if self.stop_event.is_set(): | |
self.StopSearch() | |
return | |
self.search_event.wait() | |
self.search_event.clear() | |
def find_solution(model, variables): | |
''' | |
Find a single solution to the model | |
''' | |
slvr = ort.CpSolver() | |
status = slvr.Solve(model) | |
if status not in [ort.OPTIMAL, ort.FEASIBLE]: | |
raise ValueError('No solution found') | |
return tuple(slvr.Value(v) for v in variables) | |
# https://library.wolfram.com/infocenter/Books/8502/AdvancedAlgebra.pdf page 80 | |
def _build_system(M, Mineq, b, bineq, lp_bound=99999): | |
''' | |
Returns a tuple (model, X, f) where model is an ortools model, | |
X is a list of variables we want the solution for, and f is a | |
function that will transform the solution back to the original space | |
''' | |
assert Mineq.ncols() == M.ncols() | |
assert Mineq.nrows() == len(bineq) | |
assert M.nrows() == len(b) | |
# find unbounded solution | |
D, U, V = M.smith_form() | |
s = V*D.solve_right(U*vector(ZZ, b)) | |
try: | |
s = s.change_ring(ZZ) | |
except TypeError: | |
raise ValueError('no solution (even without bounds)') | |
ker = M.right_kernel().basis_matrix().change_ring(ZZ) | |
Mker = Mineq*ker.T | |
# using BKZ instead might help in some cases | |
Mred = Mker.T.LLL().T | |
# matrix magic that will transform our | |
# solution back to the original space | |
R = ((Mker.T*Mker)**-1 * (Mker.T*Mred)).change_ring(ZZ) | |
bineq = bineq - Mineq*s | |
bineq_cv, v = babai_coords(Mred.T, bineq) | |
bineq_red = bineq - bineq_cv | |
model = ort.CpModel() | |
X = [model.NewIntVar(-lp_bound, lp_bound, f'x{i}') for i in range(Mred.ncols())] | |
# Mred*X >= bineq_red | |
Y = [sum([c*x for c, x in zip(row, X)]) for row in Mred] | |
for i, yi in enumerate(Y): | |
model.Add(yi >= bineq_red[i]) | |
# precompute the operation R*(x+v)*ker + s | |
# as T*x + c | |
T = ker.T*R | |
c = T*v + s | |
f = lambda sol: T*vector(ZZ, sol) + c | |
return model, X, f | |
def solve_ineq_gen(M, Mineq, b, bineq, **kwargs): | |
''' | |
Returns a generetor yielding all* solutions to: | |
M*x = b where Mineq*x >= bineq | |
*depending on the lp_bound parameter | |
''' | |
model, X, f = _build_system(M, Mineq, b, bineq, **kwargs) | |
yield from map(f, gen_solutions(model, X)) | |
def solve_ineq(M, Mineq, b, bineq, **kwargs): | |
''' | |
Finds a solution to: | |
M*x = b where Mineq*x >= bineq | |
''' | |
model, X, f = _build_system(M, Mineq, b, bineq, **kwargs) | |
return f(find_solution(model, X)) | |
def _build_mod_system(M, b, lb, ub, N, **kwargs): | |
''' | |
Returns a tuple (model, X, f) where model is an ortools model, | |
X is a list of variables we want the solution for, and f is a | |
function that will transform the solution back to the original space | |
''' | |
neqs = M.nrows() | |
nvars = M.ncols() | |
M = M.augment(identity_matrix(neqs)*N) | |
I = identity_matrix(nvars) | |
Mineq = I.stack(-I).augment(zero_matrix(nvars*2, neqs)) | |
bineq = vector([*lb] + [-x for x in ub]) | |
model, X, f = _build_system(M, Mineq, b, bineq, **kwargs) | |
return model, X, lambda sol: f(sol)[:nvars] | |
def solve_ineq_mod_gen(M, b, lb, ub, N, **kwargs): | |
''' | |
Returns a generator yielding all* solutions to: | |
M*x = b (mod N) where lb <= x <= ub | |
*depending on the lp_bound parameter | |
''' | |
model, X, f = _build_mod_system(M, b, lb, ub, N, **kwargs) | |
yield from map(f, gen_solutions(model, X)) | |
def solve_ineq_mod(M, b, lb, ub, N, **kwargs): | |
''' | |
Finds a solution to: | |
M*x = b (mod N) where lb <= x <= ub | |
''' | |
model, X, f = _build_mod_system(M, b, lb, ub, N, **kwargs) | |
return f(find_solution(model, X)) | |
# example usage, from https://connor-mccartney.github.io/cryptography/other/Trying-to-crack-COD-FNV-hashes | |
if __name__ == '__main__': | |
FNV_INIT = 0xCBF29CE484222325 | |
p = 0x100000001B3 | |
def fnv64(s): | |
hsh = FNV_INIT | |
for c in s.lower().replace(b"\\", b"/"): | |
hsh = ((hsh^c)*p)%2**64 | |
return hsh | |
def rev(sol): | |
hsh = FNV_INIT | |
ret = [] | |
new = hsh | |
prev = hsh | |
for s in sol: | |
new += s | |
ch = new^prev | |
if ch not in range(32, 128): | |
return None | |
ret.append(ch) | |
prev = new = new*p | |
return bytes(ret) | |
def solve(target, n): | |
hsh = FNV_INIT | |
rets = [] | |
M = matrix([[p**(n - i) for i in range(n)]]) | |
for sol in solve_ineq_mod_gen( | |
M, [target - hsh*p**n], [-128]*n, [128]*n, 2**64 | |
): | |
ret = rev(sol) | |
if ret is None: continue | |
print(ret) | |
rets.append(ret) | |
return rets | |
solve(fnv64(b'abcdefghi'), 9) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment