Skip to content

Instantly share code, notes, and snippets.

@versesrev
Last active October 14, 2021 21:42
Show Gist options
  • Save versesrev/0f994f70c6de20344f6f44893adb80b0 to your computer and use it in GitHub Desktop.
Save versesrev/0f994f70c6de20344f6f44893adb80b0 to your computer and use it in GitHub Desktop.
Utilities for solving inequalities with lattice. Build on top of rkm's solver https://github.com/rkm0959/Inequality_Solving_with_CVP
from dataclasses import dataclass
from typing import Any, Callable, List, Mapping
@dataclass
class Constraint:
""" Constraint on a linear function
The corresponding formula is:
lower_bound <= sum(coefficients[var] * var, for all var) <= upper_bound
"""
coefficients: Mapping[str, int]
lower_bound: int
upper_bound: int
def __str__(self):
formula = ' + '.join(f'{c}*{x}' for x, c in self.coefficients.items())
return f'{self.lower_bound} <= {formula} <= {self.upper_bound}'
def constraints_to_lattice(
constraints: List[Constraint],
debug: bool = False
) -> (List[List[int]], List[str]):
from itertools import chain
if debug:
print('constraints = [')
print(',\n'.join(f'\t{c}' for c in constraints))
print(']')
variables = sorted(list(set(chain.from_iterable(
c.coefficients.keys() for c in constraints
))))
lattice = [[0] * len(constraints) for _ in range(len(variables))]
for i, c in enumerate(constraints):
for var, coef in c.coefficients.items():
lattice[variables.index(var)][i] = coef
if debug:
print(f'variables = {variables}')
print(f'lattice_nrows = {len(variables)} variables')
print(f'lattice_ncols = {len(constraints)} constraints')
print('lattice =')
for row in lattice:
print(''.join('*' if v else '.' for v in row))
return lattice, variables
# ===== rkm solver =====
def load_rkm_solver(
filename: str = None
) -> Callable:
""" Load rkm's solver without overwriting solve() in globals() """
from copy import copy
if filename is None:
filename = 'https://raw.githubusercontent.com/rkm0959/Inequality_Solving_with_CVP/main/solver.sage' # noqa
context = copy(globals())
sage.repl.load.load(filename, context)
return context['solve']
def rkm_wrapper(
constraints: List[Constraint],
debug: bool = False,
solver: Callable = load_rkm_solver(),
**kwargs: Any,
) -> Mapping[str, int]:
""" Wrapper for rkm's inequalities solver """
lattice, variables = constraints_to_lattice(constraints, debug)
# Call solver
if debug:
print('Start solving...')
weighted_close_vec, weights, sol_vec = \
solver(matrix(lattice),
[c.lower_bound for c in constraints],
[c.upper_bound for c in constraints],
**kwargs)
# Get solution
if sol_vec is None:
weighted_lattice = matrix(lattice) * matrix.diagonal(weights)
H, U = weighted_lattice.hermite_form(transformation=True)
sol_vec = H.solve_left(weighted_close_vec).change_ring(ZZ) * U
solution = dict(zip(variables, sol_vec))
if debug:
print(f'solution = {solution}')
# Check solution
for c in constraints:
coefs, lb, ub = c.coefficients, c.lower_bound, c.upper_bound
val = sum(coef * solution[var] for var, coef in coefs.items())
if not lb <= val <= ub:
raise Exception('Constrained value out-of-bound, '
f'lb={lb}, ub={ub}, value={val}, coefs={coefs}, '
f'solution={solution}')
return solution
def example():
""" pbctf21 - Yet Another PRNG """
import random
# Constants
m1 = 2 ** 32 - 107
m2 = 2 ** 32 - 5
m3 = 2 ** 32 - 209
M = 2 ** 64 - 59
rnd = random.Random(b'rbtree')
a1 = [rnd.getrandbits(20) for _ in range(3)]
a2 = [rnd.getrandbits(20) for _ in range(3)]
a3 = [rnd.getrandbits(20) for _ in range(3)]
hints = bytes.fromhex(
'67f19d3da8af1480f39ac04f7e9134b2dc4ad094475b696224389c9ef29b8a2a'
'ff8933bd3fefa6e0d03827ab2816ba0fd9c0e2d73e01aa6f184acd9c58122616'
'f9621fb8313a62efb27fb3d3aa385b89435630d0704f0dceec00fef703d54fca'
)
# Variables
n_rounds = 5
xs = [f'x_{i}' for i in range(n_rounds)]
ys = [f'y_{i}' for i in range(n_rounds)]
zs = [f'z_{i}' for i in range(n_rounds)]
ks = [f'k_{i}' for i in range(n_rounds)]
h1s = [f'h1_{i}' for i in range(n_rounds - 3)]
h2s = [f'h2_{i}' for i in range(n_rounds - 3)]
h3s = [f'h3_{i}' for i in range(n_rounds - 3)]
# Constraints
constraints = []
# Size: 0 <= x[i], y[i], z[i] < 2**32
for var in xs + ys + zs:
constraints.append(Constraint({var: 1}, 0, 2 ** 32 - 1))
# Output: (2*m1)*x[i] + -m3*y[i] + -m2*z[i] = out[i] + k[i]*M
for i in range(n_rounds):
hint_val = int.from_bytes(hints[i*8: i*8+8], 'big')
coefs = {
xs[i]: 2 * m1,
ys[i]: -m3,
zs[i]: -m2,
ks[i]: -M,
}
constraints.append(Constraint(coefs, hint_val, hint_val))
# LFSR: a1[0]*x[i] + a1[1]*x[i+1] + a1[2]*x[i+2] = x[i+3] + h1[i]*m1
for i in range(n_rounds - 3):
tuples = [(xs, a1, h1s, m1), (ys, a2, h2s, m2), (zs, a3, h3s, m3)]
for x, a, h, m in tuples:
coefs = {
x[i]: a[0],
x[i+1]: a[1],
x[i+2]: a[2],
x[i+3]: -1,
h[i]: -m,
}
constraints.append(Constraint(coefs, 0, 0))
# Solve
solution = rkm_wrapper(constraints, debug=True)
print(f'self.x = {[solution[x] for x in xs[:3]]}')
print(f'self.y = {[solution[y] for y in ys[:3]]}')
print(f'self.z = {[solution[z] for z in zs[:3]]}')
if __name__ == '__main__':
example()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment