Skip to content

Instantly share code, notes, and snippets.

@mjtb49
Last active February 21, 2024 14:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mjtb49/506ac15656bfad6f9307059257b07200 to your computer and use it in GitHub Desktop.
Save mjtb49/506ac15656bfad6f9307059257b07200 to your computer and use it in GitHub Desktop.
Earthcomputer requested that I write up a lcg discrete log solver for arbitrary lcgs. I have now done so, but lazily, there are several points where this could be improved and I am not entirely convinced by the approach. This solver assumes some sort of factorization is possible - in particular it needs to factor both phi(m) and m at some point …
import math
import sympy
import sympy as sp
import random
class LCG:
def __init__(self, a, b, m):
self.a = a % m
self.b = b % m
self.m = m
def __matmul__(self, other):
assert self.m == other.m
return LCG(self.a * other.a, self.a * other.b + self.b, self.m)
def __pow__(self, power, modulo=None):
result = LCG(1, 0, self.m)
current = self
while power != 0:
if power % 2 == 1:
result @= current
current @= current
power //= 2
return result
def apply(self, seed):
return (self.a * seed + self.b) % self.m
# Find the p-adic valuation of a number a which we know up to an integer multiple of p^e
def vp(a, p, e):
a %= p ** e
if a == 0:
return e # "properly" this should be infinity. To avoid defining infinity is only reason why this is here
result = 0
while a % p == 0:
result += 1
a //= p
return result
def ord_mod_odd_prime(a, p):
assert sp.isprime(p)
a %= p
# I think order finding is roughly as hard as factoring p-1? But this might be removable.
N = p-1
factors = list(sp.factorint(N).items())
alpha = [pow(a, N//(q**e), p) for q, e in factors]
result = 1
for i in range(len(factors)):
q_i = 1
alpha_i = alpha[i]
while alpha_i != 1:
alpha_i = pow(alpha_i, factors[i][0], p)
q_i *= factors[i][0]
result *= q_i
assert result <= N
return result
def ord_mod_prime_power(a, p, e):
if e == 0:
return 1
assert a % p != 0
# For odd primes p
# the multiplicative group of Z_p is the product of the roots of unity \mu_p and U = 1 + pZ_p
# For 2 it is the product of U = 1 + 4Z_p and {+/-1}
# In the former case a projection operator onto U is given by exponentiating by p-1
# and a projection onto \mu is given by reducing mod p.
# For 2 a projection onto U is given by negation if a = 3 mod 4, else the identity.
# and onto {+/- 1} is given by -1 if a = 3 mod 4, else 1.
# Handle case of p = 2
if p == 2:
if e == 1:
return 1
mu_order = 1
if a % 4 == 3:
mu_order = 2
a = p**e-a
U_order = p**(e-vp(a-1, p, e))
return sympy.lcm(mu_order, U_order)
# Handle general case
mu_order = ord_mod_odd_prime(a, p)
U_order = p**(e-vp(pow(a, p-1, p**e)-1, p, e))
return sympy.lcm(mu_order, U_order)
def bs_gs(a, target, p):
target %= p
a %= p
assert sp.isprime(p) and a % p != 0
step_size = 1 + math.isqrt((p-1)-1)
giant_step = pow(a, -step_size, p)
baby_steps = {}
current = 1
for i in range(step_size):
baby_steps[current] = i
current = a * current % p
for i in range(step_size):
if target in baby_steps:
return baby_steps[target] + step_size * i
target = target * giant_step % p
return None
def log_p(z, p, e):
xn = 1-z # 1 - xn = z
vpx = vp(xn, p, e)
assert vpx >= (2 if p == 2 else 1)
result = 0
n = 1
# Valuation of x^n = n * vpx, we want to pick n large enough so m * vp(x) - vp(m) >= e for all m >= n
# m * vp(x) - vp(m) >= m * vp(x) - log(m) / log(p) >= n * vp(x) - log(n) / log(p)
# Subtract 1 since I'm a coward
while e > n * vpx - math.log(n, p) - 1:
vpn = vp(n, p, e)
result -= xn // (p ** vpn) * pow(n // (p ** vpn), -1, p**e)
xn *= (1-z)
xn %= p ** (e + e) # This bound remains quite lazy. The correct bound is the max of ceil(e + math.log(n, p)) over n in the loop
n += 1
result %= p**e
return result
def dist_mod_prime_power(a, b, p, e, x, y):
assert sp.gcd(a, p) == 1
a %= p**e
b %= p**e
x %= p**e
y %= p**e
if a == 1:
a = p**e + 1
d = (a-1) * x + b
n = (a-1) * y + b
e += vp(a - 1, p, e)
vpd = vp(d, p, e)
if vpd != vp(n, p, e):
return None
d //= p ** vpd
n //= p ** vpd
e -= vpd
target = (n * pow(d, -1, p ** e)) % (p ** e)
if e == 0:
return 0, 1
order = ord_mod_prime_power(a, p, e)
# Want to solve a^k = target mod p**e
# solutions k determined up to value of order
# first solve mod p if p is odd, or mod 4 if p is 2
# TODO, some costly parts of bs_gs can be computed from already known value of "order", or alternately order can be inferred from later work
if p % 2 == 1:
k0 = bs_gs(a, target, p)
if e == 1:
return k0 % order, order
# print(f"Here with p = {p} {a % 5} {target % 5}")
if k0 is None:
return None
# print(f"And here with p = {p}")
# We now know a^((p-1)k1 + k0) = target
target = target * pow(a, -k0, p ** e) % p ** e
a = pow(a, p-1, p**e)
loga = log_p(a, p, e + 2) % p ** e
logt = log_p(target, p, e + 2) % p ** e
if loga == logt == 0:
return k0 % order, order
while loga % p == 0:
if logt % p != 0:
return None
loga //= p
logt //= p
result = (pow(loga, -1, p**e) * logt % p**e) * (p-1) + k0
return result % order, order
else:
# Now looking to solve a^k = target mod 2^e
# Unlike the odd prime case, the group here is not cyclic
# logarithm lets me solve a^k = t if a, t are both 1 mod 4. After that I must check if the sign is correct
# First compute the constraint coming from the sign
sign_residue = 0
sign_modulus = 1
if e <= 1:
return 0, 1
if target % 4 == 3:
if a % 4 == 1:
return None
else:
sign_modulus = 2
sign_residue = 1
elif target % 4 == 1:
if a % 4 == 3:
sign_modulus = 2
sign_residue = 0
# Solve it in the U factor.
# Project a and target to U
a = a if a % 4 == 1 else p**e - a
target = target if target % 4 == 1 else p**e - target
loga = log_p(a, p, e) % p**e
logt = log_p(target, p, e) % p**e
if loga == logt == 0: # If this occurs either +/-a = +/-t = 1,
return sign_residue % order, order
while loga % p == 0:
if logt % p != 0:
return None
loga //= p
logt //= p
result = pow(loga, -1, p**e) * logt % p**e
if result % sign_modulus != sign_residue:
return None
return result % order, order
def solve_congruences(congruences):
if len(congruences) == 1:
return congruences[0]
a, b, *congruences = congruences
res_1, mod_1 = a
res_2, mod_2 = b
gcd = sp.gcd(mod_1, mod_2)
if (res_1 - res_2) % gcd != 0:
return None
u = pow(mod_1 // gcd, -1, mod_2 // gcd) if mod_2 // gcd != 1 else 0
new_mod = sp.lcm(mod_1, mod_2)
new_res = (res_1 - mod_1 * u * (res_1 - res_2) // gcd) % new_mod
congruences.append((new_res % new_mod, new_mod))
assert new_res % mod_1 == res_1 % mod_1
assert new_res % mod_2 == res_2 % mod_2
return solve_congruences(congruences)
def distance(lcg, x, y):
m = lcg.m
a = lcg.a
b = lcg.b
if m == 1:
return 0, 1
# TODO handle the primes where this is not the case
assert sp.gcd(a, m) == 1
# The multiplicative group mod m is the product of the multiplicative groups mod p^e
factors = list(sp.factorint(m).items())
congruences = []
for p, e in factors:
d = dist_mod_prime_power(a, b, p, e, x, y)
# print(p, e, d)
if d is None:
return None
congruences.append(d)
# print(congruences)
# print(m)
return solve_congruences(congruences)
def main():
bound = 10**2
for i in range(100000):
print(i)
m = random.randint(1, bound)
a = random.randint(1, m)
b = random.randint(0, m)
start = random.randint(1, m)
dist = random.randint(0, m)
while sp.gcd(a, m) != 1:
a = random.randint(1, bound)
# a, b, m, start, dist = 685, 940, 343, 869, 186 # (2, 2)
lcg = LCG(a, b, m)
target = (lcg**dist).apply(start)
d = distance(lcg, start, target)
# print(d)
if d is not None:
if (lcg ** d[1]).apply(start) % m != start % m:
print("Not good at all")
print()
print(a, b, m, start, dist, d, (lcg ** d[1]).apply(start))
return
if dist % d[1] != d[0]:
print("No good")
print(a, b, m, start, dist, d)
return
else:
print("Uh oh")
print(a, b, m, start, dist, d)
return
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment