Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Count solutions to linear Diophantine equations
from collections import Counter
from fractions import Fraction
def _gcd(a, b):
while a:
a, b = b % a, a
return b
def _mod_inv(a: int, m: int) -> int:
x, y = a, m
u0, u1 = 0, 1
v0, v1 = 1, 0
while y:
x, (q, y) = y, divmod(x, y)
u0, u1 = u1 - q * u0, u0
v0, v1 = v1 - q * v0, v0
if x != 1:
raise ValueError
return u1 * (-1 if a < 0 else 1) % m
def _factors(n):
# TODO Prime factorisation and Cartesian products
for i in range(1, n//2 + 1):
if n % i == 0:
yield i
yield n
def linear_diophantine_counts_for(a, modulus = None):
"""
Returns a function which maps n to the number of solutions of a.x = n in natural numbers.
>>> foo = linear_diophantine_counts_for([1, 2, 6])
>>> foo(10)
9
>>> foo(100)
459
>>> foo(1000)
42084
>>> foo(10000)
4170834
>>> bar = linear_diophantine_counts_for([1, 2, 6], 17)
>>> bar(10)
9
>>> bar(100)
0
>>> bar(1000)
9
>>> bar(10000)
3
"""
gcd_a = a[0]
for a_i in a[1:]:
gcd_a = _gcd(a_i, gcd_a)
a = [a_i // gcd_a for a_i in a]
# Let f(z) = \prod_{a_i in a} (1 - z^{a_i})^{-1}
# Then we want to be able to evaluate [z^n]f(z)
# Start by decomposing into a partial fraction whose denominators are powers of cyclotomics.
cyclotomic_frequencies = Counter()
for a_i in a:
for factor in _factors(a_i):
cyclotomic_frequencies[factor] += 1
# We can prove that, after M base cases, the partial fraction gives a quasi-polynomial with period lcm(a) and degree len(a).
# For details see http://cheddarmonk.org/papers/linear-diophantine-equations.pdf
M = max((w_d - 1) * d for d, w_d in cyclotomic_frequencies.items())
period = 1
for a_i in a:
period = period * (a_i // _gcd(period, a_i))
degree_inc = len(a) # degree of quasi-polynomial + 1
specials = M + period * degree_inc
precalc = [0] * specials
precalc[0] = 1
for a_i in a:
for i in range(a_i, specials):
precalc[i] += precalc[i - a_i]
if modulus is not None and precalc[i] >= modulus:
precalc[i] -= modulus
# The point is to precalculate as much as possible, so expand out to polynomial coefficients.
polys = [[0] * degree_inc for _ in range(period)]
denominators = [1] * period # Only used when modulus is None
for n in range(period):
# Lagrange interpolation
# We want the last degree_inc indices and values from precalc where the index equals n (mod period)
# Here Python's unusually sane % behaviour is a boon
l = M + (n - M) % period
lagrange_points = [(x, precalc[x]) for x in range(l, specials, period)]
if modulus is None:
# We really have to work in rationals. Use the built-in ones.
for x, y in lagrange_points:
term = [0] * degree_inc
term[0] = y
for x2, y2 in lagrange_points:
if x == x2:
continue
# term = term * (z - x2) / (x - x2)
for i in range(degree_inc - 1, 0, -1):
term[i] = (term[i-1] - x2 * term[i]) / Fraction(x - x2)
term[0] = term[0] * Fraction(-x2, x - x2)
for i in range(degree_inc):
polys[n][i] += term[i]
# Although we used rationals for laziness, when it comes to evaluating in the callback
# we probably want to stick to integers.
lcm_denom = 1
for coeff in polys[n]:
lcm_denom = coeff.denominator * (lcm_denom // _gcd(coeff.denominator, lcm_denom))
denominators[n] = lcm_denom
for i in range(degree_inc):
polys[n][i] = polys[n][i].numerator * (lcm_denom // polys[n][i].denominator)
else:
for x, y in lagrange_points:
term = [0] * degree_inc
term[0] = y
denom = 1
for x2, y2 in lagrange_points:
if x == x2:
continue
# term = term * (z - x2)
for i in range(degree_inc - 1, 0, -1):
term[i] = (term[i-1] - x2 * term[i]) % modulus
term[0] = term[0] * -x2 % modulus
# Optimise by only doing one _mod_inv
denom = denom * (x - x2) % modulus
recip = _mod_inv(denom, modulus)
for i in range(degree_inc):
polys[n][i] = (polys[n][i] + term[i] * recip) % modulus
def count(n):
if n < 0 or n % gcd_a:
return 0
n //= gcd_a
if n < len(precalc):
return precalc[n]
if modulus is None:
rv = 0
for coeff in reversed(polys[n % period]):
rv = rv * n + coeff
return rv // denominators[n % period]
else:
rv = 0
poly = polys[n % period]
n %= modulus
for coeff in reversed(poly):
rv = (rv * n + coeff) % modulus
return rv
return count
if __name__ == "__main__":
import doctest
doctest.testmod()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment