Skip to content

Instantly share code, notes, and snippets.

Last active October 14, 2021 21:42
Show Gist options
  • Save versesrev/0f994f70c6de20344f6f44893adb80b0 to your computer and use it in GitHub Desktop.
Save versesrev/0f994f70c6de20344f6f44893adb80b0 to your computer and use it in GitHub Desktop.
Utilities for solving inequalities with lattice. Build on top of rkm's solver
from dataclasses import dataclass
from typing import Any, Callable, List, Mapping
class Constraint:
""" Constraint on a linear function
The corresponding formula is:
lower_bound <= sum(coefficients[var] * var, for all var) <= upper_bound
coefficients: Mapping[str, int]
lower_bound: int
upper_bound: int
def __str__(self):
formula = ' + '.join(f'{c}*{x}' for x, c in self.coefficients.items())
return f'{self.lower_bound} <= {formula} <= {self.upper_bound}'
def constraints_to_lattice(
constraints: List[Constraint],
debug: bool = False
) -> (List[List[int]], List[str]):
from itertools import chain
if debug:
print('constraints = [')
print(',\n'.join(f'\t{c}' for c in constraints))
variables = sorted(list(set(chain.from_iterable(
c.coefficients.keys() for c in constraints
lattice = [[0] * len(constraints) for _ in range(len(variables))]
for i, c in enumerate(constraints):
for var, coef in c.coefficients.items():
lattice[variables.index(var)][i] = coef
if debug:
print(f'variables = {variables}')
print(f'lattice_nrows = {len(variables)} variables')
print(f'lattice_ncols = {len(constraints)} constraints')
print('lattice =')
for row in lattice:
print(''.join('*' if v else '.' for v in row))
return lattice, variables
# ===== rkm solver =====
def load_rkm_solver(
filename: str = None
) -> Callable:
""" Load rkm's solver without overwriting solve() in globals() """
from copy import copy
if filename is None:
filename = '' # noqa
context = copy(globals())
sage.repl.load.load(filename, context)
return context['solve']
def rkm_wrapper(
constraints: List[Constraint],
debug: bool = False,
solver: Callable = load_rkm_solver(),
**kwargs: Any,
) -> Mapping[str, int]:
""" Wrapper for rkm's inequalities solver """
lattice, variables = constraints_to_lattice(constraints, debug)
# Call solver
if debug:
print('Start solving...')
weighted_close_vec, weights, sol_vec = \
[c.lower_bound for c in constraints],
[c.upper_bound for c in constraints],
# Get solution
if sol_vec is None:
weighted_lattice = matrix(lattice) * matrix.diagonal(weights)
H, U = weighted_lattice.hermite_form(transformation=True)
sol_vec = H.solve_left(weighted_close_vec).change_ring(ZZ) * U
solution = dict(zip(variables, sol_vec))
if debug:
print(f'solution = {solution}')
# Check solution
for c in constraints:
coefs, lb, ub = c.coefficients, c.lower_bound, c.upper_bound
val = sum(coef * solution[var] for var, coef in coefs.items())
if not lb <= val <= ub:
raise Exception('Constrained value out-of-bound, '
f'lb={lb}, ub={ub}, value={val}, coefs={coefs}, '
return solution
def example():
""" pbctf21 - Yet Another PRNG """
import random
# Constants
m1 = 2 ** 32 - 107
m2 = 2 ** 32 - 5
m3 = 2 ** 32 - 209
M = 2 ** 64 - 59
rnd = random.Random(b'rbtree')
a1 = [rnd.getrandbits(20) for _ in range(3)]
a2 = [rnd.getrandbits(20) for _ in range(3)]
a3 = [rnd.getrandbits(20) for _ in range(3)]
hints = bytes.fromhex(
# Variables
n_rounds = 5
xs = [f'x_{i}' for i in range(n_rounds)]
ys = [f'y_{i}' for i in range(n_rounds)]
zs = [f'z_{i}' for i in range(n_rounds)]
ks = [f'k_{i}' for i in range(n_rounds)]
h1s = [f'h1_{i}' for i in range(n_rounds - 3)]
h2s = [f'h2_{i}' for i in range(n_rounds - 3)]
h3s = [f'h3_{i}' for i in range(n_rounds - 3)]
# Constraints
constraints = []
# Size: 0 <= x[i], y[i], z[i] < 2**32
for var in xs + ys + zs:
constraints.append(Constraint({var: 1}, 0, 2 ** 32 - 1))
# Output: (2*m1)*x[i] + -m3*y[i] + -m2*z[i] = out[i] + k[i]*M
for i in range(n_rounds):
hint_val = int.from_bytes(hints[i*8: i*8+8], 'big')
coefs = {
xs[i]: 2 * m1,
ys[i]: -m3,
zs[i]: -m2,
ks[i]: -M,
constraints.append(Constraint(coefs, hint_val, hint_val))
# LFSR: a1[0]*x[i] + a1[1]*x[i+1] + a1[2]*x[i+2] = x[i+3] + h1[i]*m1
for i in range(n_rounds - 3):
tuples = [(xs, a1, h1s, m1), (ys, a2, h2s, m2), (zs, a3, h3s, m3)]
for x, a, h, m in tuples:
coefs = {
x[i]: a[0],
x[i+1]: a[1],
x[i+2]: a[2],
x[i+3]: -1,
h[i]: -m,
constraints.append(Constraint(coefs, 0, 0))
# Solve
solution = rkm_wrapper(constraints, debug=True)
print(f'self.x = {[solution[x] for x in xs[:3]]}')
print(f'self.y = {[solution[y] for y in ys[:3]]}')
print(f'self.z = {[solution[z] for z in zs[:3]]}')
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment