Skip to content

Instantly share code, notes, and snippets.

@pjt33
Created January 21, 2020 19:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save pjt33/989dab079a04d8830a104fce0c2caf48 to your computer and use it in GitHub Desktop.
Save pjt33/989dab079a04d8830a104fce0c2caf48 to your computer and use it in GitHub Desktop.
Count solutions to linear Diophantine equations
from collections import Counter
from fractions import Fraction
def _gcd(a, b):
while a:
a, b = b % a, a
return b
def _mod_inv(a: int, m: int) -> int:
x, y = a, m
u0, u1 = 0, 1
v0, v1 = 1, 0
while y:
x, (q, y) = y, divmod(x, y)
u0, u1 = u1 - q * u0, u0
v0, v1 = v1 - q * v0, v0
if x != 1:
raise ValueError
return u1 * (-1 if a < 0 else 1) % m
def _factors(n):
# TODO Prime factorisation and Cartesian products
for i in range(1, n//2 + 1):
if n % i == 0:
yield i
yield n
def linear_diophantine_counts_for(a, modulus = None):
"""
Returns a function which maps n to the number of solutions of a.x = n in natural numbers.
>>> foo = linear_diophantine_counts_for([1, 2, 6])
>>> foo(10)
9
>>> foo(100)
459
>>> foo(1000)
42084
>>> foo(10000)
4170834
>>> bar = linear_diophantine_counts_for([1, 2, 6], 17)
>>> bar(10)
9
>>> bar(100)
0
>>> bar(1000)
9
>>> bar(10000)
3
"""
gcd_a = a[0]
for a_i in a[1:]:
gcd_a = _gcd(a_i, gcd_a)
a = [a_i // gcd_a for a_i in a]
# Let f(z) = \prod_{a_i in a} (1 - z^{a_i})^{-1}
# Then we want to be able to evaluate [z^n]f(z)
# Start by decomposing into a partial fraction whose denominators are powers of cyclotomics.
cyclotomic_frequencies = Counter()
for a_i in a:
for factor in _factors(a_i):
cyclotomic_frequencies[factor] += 1
# We can prove that, after M base cases, the partial fraction gives a quasi-polynomial with period lcm(a) and degree len(a).
# For details see http://cheddarmonk.org/papers/linear-diophantine-equations.pdf
M = max((w_d - 1) * d for d, w_d in cyclotomic_frequencies.items())
period = 1
for a_i in a:
period = period * (a_i // _gcd(period, a_i))
degree_inc = len(a) # degree of quasi-polynomial + 1
specials = M + period * degree_inc
precalc = [0] * specials
precalc[0] = 1
for a_i in a:
for i in range(a_i, specials):
precalc[i] += precalc[i - a_i]
if modulus is not None and precalc[i] >= modulus:
precalc[i] -= modulus
# The point is to precalculate as much as possible, so expand out to polynomial coefficients.
polys = [[0] * degree_inc for _ in range(period)]
denominators = [1] * period # Only used when modulus is None
for n in range(period):
# Lagrange interpolation
# We want the last degree_inc indices and values from precalc where the index equals n (mod period)
# Here Python's unusually sane % behaviour is a boon
l = M + (n - M) % period
lagrange_points = [(x, precalc[x]) for x in range(l, specials, period)]
if modulus is None:
# We really have to work in rationals. Use the built-in ones.
for x, y in lagrange_points:
term = [0] * degree_inc
term[0] = y
for x2, y2 in lagrange_points:
if x == x2:
continue
# term = term * (z - x2) / (x - x2)
for i in range(degree_inc - 1, 0, -1):
term[i] = (term[i-1] - x2 * term[i]) / Fraction(x - x2)
term[0] = term[0] * Fraction(-x2, x - x2)
for i in range(degree_inc):
polys[n][i] += term[i]
# Although we used rationals for laziness, when it comes to evaluating in the callback
# we probably want to stick to integers.
lcm_denom = 1
for coeff in polys[n]:
lcm_denom = coeff.denominator * (lcm_denom // _gcd(coeff.denominator, lcm_denom))
denominators[n] = lcm_denom
for i in range(degree_inc):
polys[n][i] = polys[n][i].numerator * (lcm_denom // polys[n][i].denominator)
else:
for x, y in lagrange_points:
term = [0] * degree_inc
term[0] = y
denom = 1
for x2, y2 in lagrange_points:
if x == x2:
continue
# term = term * (z - x2)
for i in range(degree_inc - 1, 0, -1):
term[i] = (term[i-1] - x2 * term[i]) % modulus
term[0] = term[0] * -x2 % modulus
# Optimise by only doing one _mod_inv
denom = denom * (x - x2) % modulus
recip = _mod_inv(denom, modulus)
for i in range(degree_inc):
polys[n][i] = (polys[n][i] + term[i] * recip) % modulus
def count(n):
if n < 0 or n % gcd_a:
return 0
n //= gcd_a
if n < len(precalc):
return precalc[n]
if modulus is None:
rv = 0
for coeff in reversed(polys[n % period]):
rv = rv * n + coeff
return rv // denominators[n % period]
else:
rv = 0
poly = polys[n % period]
n %= modulus
for coeff in reversed(poly):
rv = (rv * n + coeff) % modulus
return rv
return count
if __name__ == "__main__":
import doctest
doctest.testmod()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment