Last active
February 21, 2024 14:35
-
-
Save mjtb49/506ac15656bfad6f9307059257b07200 to your computer and use it in GitHub Desktop.
Earthcomputer requested that I write up a lcg discrete log solver for arbitrary lcgs. I have now done so, but lazily, there are several points where this could be improved and I am not entirely convinced by the approach. This solver assumes some sort of factorization is possible - in particular it needs to factor both phi(m) and m at some point …
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import sympy | |
import sympy as sp | |
import random | |
class LCG: | |
def __init__(self, a, b, m): | |
self.a = a % m | |
self.b = b % m | |
self.m = m | |
def __matmul__(self, other): | |
assert self.m == other.m | |
return LCG(self.a * other.a, self.a * other.b + self.b, self.m) | |
def __pow__(self, power, modulo=None): | |
result = LCG(1, 0, self.m) | |
current = self | |
while power != 0: | |
if power % 2 == 1: | |
result @= current | |
current @= current | |
power //= 2 | |
return result | |
def apply(self, seed): | |
return (self.a * seed + self.b) % self.m | |
# Find the p-adic valuation of a number a which we know up to an integer multiple of p^e | |
def vp(a, p, e): | |
a %= p ** e | |
if a == 0: | |
return e # "properly" this should be infinity. To avoid defining infinity is only reason why this is here | |
result = 0 | |
while a % p == 0: | |
result += 1 | |
a //= p | |
return result | |
def ord_mod_odd_prime(a, p): | |
assert sp.isprime(p) | |
a %= p | |
# I think order finding is roughly as hard as factoring p-1? But this might be removable. | |
N = p-1 | |
factors = list(sp.factorint(N).items()) | |
alpha = [pow(a, N//(q**e), p) for q, e in factors] | |
result = 1 | |
for i in range(len(factors)): | |
q_i = 1 | |
alpha_i = alpha[i] | |
while alpha_i != 1: | |
alpha_i = pow(alpha_i, factors[i][0], p) | |
q_i *= factors[i][0] | |
result *= q_i | |
assert result <= N | |
return result | |
def ord_mod_prime_power(a, p, e): | |
if e == 0: | |
return 1 | |
assert a % p != 0 | |
# For odd primes p | |
# the multiplicative group of Z_p is the product of the roots of unity \mu_p and U = 1 + pZ_p | |
# For 2 it is the product of U = 1 + 4Z_p and {+/-1} | |
# In the former case a projection operator onto U is given by exponentiating by p-1 | |
# and a projection onto \mu is given by reducing mod p. | |
# For 2 a projection onto U is given by negation if a = 3 mod 4, else the identity. | |
# and onto {+/- 1} is given by -1 if a = 3 mod 4, else 1. | |
# Handle case of p = 2 | |
if p == 2: | |
if e == 1: | |
return 1 | |
mu_order = 1 | |
if a % 4 == 3: | |
mu_order = 2 | |
a = p**e-a | |
U_order = p**(e-vp(a-1, p, e)) | |
return sympy.lcm(mu_order, U_order) | |
# Handle general case | |
mu_order = ord_mod_odd_prime(a, p) | |
U_order = p**(e-vp(pow(a, p-1, p**e)-1, p, e)) | |
return sympy.lcm(mu_order, U_order) | |
def bs_gs(a, target, p): | |
target %= p | |
a %= p | |
assert sp.isprime(p) and a % p != 0 | |
step_size = 1 + math.isqrt((p-1)-1) | |
giant_step = pow(a, -step_size, p) | |
baby_steps = {} | |
current = 1 | |
for i in range(step_size): | |
baby_steps[current] = i | |
current = a * current % p | |
for i in range(step_size): | |
if target in baby_steps: | |
return baby_steps[target] + step_size * i | |
target = target * giant_step % p | |
return None | |
def log_p(z, p, e): | |
xn = 1-z # 1 - xn = z | |
vpx = vp(xn, p, e) | |
assert vpx >= (2 if p == 2 else 1) | |
result = 0 | |
n = 1 | |
# Valuation of x^n = n * vpx, we want to pick n large enough so m * vp(x) - vp(m) >= e for all m >= n | |
# m * vp(x) - vp(m) >= m * vp(x) - log(m) / log(p) >= n * vp(x) - log(n) / log(p) | |
# Subtract 1 since I'm a coward | |
while e > n * vpx - math.log(n, p) - 1: | |
vpn = vp(n, p, e) | |
result -= xn // (p ** vpn) * pow(n // (p ** vpn), -1, p**e) | |
xn *= (1-z) | |
xn %= p ** (e + e) # This bound remains quite lazy. The correct bound is the max of ceil(e + math.log(n, p)) over n in the loop | |
n += 1 | |
result %= p**e | |
return result | |
def dist_mod_prime_power(a, b, p, e, x, y): | |
assert sp.gcd(a, p) == 1 | |
a %= p**e | |
b %= p**e | |
x %= p**e | |
y %= p**e | |
if a == 1: | |
a = p**e + 1 | |
d = (a-1) * x + b | |
n = (a-1) * y + b | |
e += vp(a - 1, p, e) | |
vpd = vp(d, p, e) | |
if vpd != vp(n, p, e): | |
return None | |
d //= p ** vpd | |
n //= p ** vpd | |
e -= vpd | |
target = (n * pow(d, -1, p ** e)) % (p ** e) | |
if e == 0: | |
return 0, 1 | |
order = ord_mod_prime_power(a, p, e) | |
# Want to solve a^k = target mod p**e | |
# solutions k determined up to value of order | |
# first solve mod p if p is odd, or mod 4 if p is 2 | |
# TODO, some costly parts of bs_gs can be computed from already known value of "order", or alternately order can be inferred from later work | |
if p % 2 == 1: | |
k0 = bs_gs(a, target, p) | |
if e == 1: | |
return k0 % order, order | |
# print(f"Here with p = {p} {a % 5} {target % 5}") | |
if k0 is None: | |
return None | |
# print(f"And here with p = {p}") | |
# We now know a^((p-1)k1 + k0) = target | |
target = target * pow(a, -k0, p ** e) % p ** e | |
a = pow(a, p-1, p**e) | |
loga = log_p(a, p, e + 2) % p ** e | |
logt = log_p(target, p, e + 2) % p ** e | |
if loga == logt == 0: | |
return k0 % order, order | |
while loga % p == 0: | |
if logt % p != 0: | |
return None | |
loga //= p | |
logt //= p | |
result = (pow(loga, -1, p**e) * logt % p**e) * (p-1) + k0 | |
return result % order, order | |
else: | |
# Now looking to solve a^k = target mod 2^e | |
# Unlike the odd prime case, the group here is not cyclic | |
# logarithm lets me solve a^k = t if a, t are both 1 mod 4. After that I must check if the sign is correct | |
# First compute the constraint coming from the sign | |
sign_residue = 0 | |
sign_modulus = 1 | |
if e <= 1: | |
return 0, 1 | |
if target % 4 == 3: | |
if a % 4 == 1: | |
return None | |
else: | |
sign_modulus = 2 | |
sign_residue = 1 | |
elif target % 4 == 1: | |
if a % 4 == 3: | |
sign_modulus = 2 | |
sign_residue = 0 | |
# Solve it in the U factor. | |
# Project a and target to U | |
a = a if a % 4 == 1 else p**e - a | |
target = target if target % 4 == 1 else p**e - target | |
loga = log_p(a, p, e) % p**e | |
logt = log_p(target, p, e) % p**e | |
if loga == logt == 0: # If this occurs either +/-a = +/-t = 1, | |
return sign_residue % order, order | |
while loga % p == 0: | |
if logt % p != 0: | |
return None | |
loga //= p | |
logt //= p | |
result = pow(loga, -1, p**e) * logt % p**e | |
if result % sign_modulus != sign_residue: | |
return None | |
return result % order, order | |
def solve_congruences(congruences): | |
if len(congruences) == 1: | |
return congruences[0] | |
a, b, *congruences = congruences | |
res_1, mod_1 = a | |
res_2, mod_2 = b | |
gcd = sp.gcd(mod_1, mod_2) | |
if (res_1 - res_2) % gcd != 0: | |
return None | |
u = pow(mod_1 // gcd, -1, mod_2 // gcd) if mod_2 // gcd != 1 else 0 | |
new_mod = sp.lcm(mod_1, mod_2) | |
new_res = (res_1 - mod_1 * u * (res_1 - res_2) // gcd) % new_mod | |
congruences.append((new_res % new_mod, new_mod)) | |
assert new_res % mod_1 == res_1 % mod_1 | |
assert new_res % mod_2 == res_2 % mod_2 | |
return solve_congruences(congruences) | |
def distance(lcg, x, y): | |
m = lcg.m | |
a = lcg.a | |
b = lcg.b | |
if m == 1: | |
return 0, 1 | |
# TODO handle the primes where this is not the case | |
assert sp.gcd(a, m) == 1 | |
# The multiplicative group mod m is the product of the multiplicative groups mod p^e | |
factors = list(sp.factorint(m).items()) | |
congruences = [] | |
for p, e in factors: | |
d = dist_mod_prime_power(a, b, p, e, x, y) | |
# print(p, e, d) | |
if d is None: | |
return None | |
congruences.append(d) | |
# print(congruences) | |
# print(m) | |
return solve_congruences(congruences) | |
def main(): | |
bound = 10**2 | |
for i in range(100000): | |
print(i) | |
m = random.randint(1, bound) | |
a = random.randint(1, m) | |
b = random.randint(0, m) | |
start = random.randint(1, m) | |
dist = random.randint(0, m) | |
while sp.gcd(a, m) != 1: | |
a = random.randint(1, bound) | |
# a, b, m, start, dist = 685, 940, 343, 869, 186 # (2, 2) | |
lcg = LCG(a, b, m) | |
target = (lcg**dist).apply(start) | |
d = distance(lcg, start, target) | |
# print(d) | |
if d is not None: | |
if (lcg ** d[1]).apply(start) % m != start % m: | |
print("Not good at all") | |
print() | |
print(a, b, m, start, dist, d, (lcg ** d[1]).apply(start)) | |
return | |
if dist % d[1] != d[0]: | |
print("No good") | |
print(a, b, m, start, dist, d) | |
return | |
else: | |
print("Uh oh") | |
print(a, b, m, start, dist, d) | |
return | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment