Last active
January 9, 2017 01:31
-
-
Save kts/65d71c761baa9714151c to your computer and use it in GitHub Desktop.
Python code to efficiently compute multinomial coefficients
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Compute multinomial coefficients iteratively. | |
Probably needs some more testing... | |
Example: the sum of 1000 6-sided dice: | |
$ time python multcoeff.py 6 1000 > out | |
Pascal's triangle: | |
$ for i in `seq 1 10`; do python multcoeff.py 2 $i ;done | |
1 1 | |
1 2 1 | |
1 3 3 1 | |
1 4 6 4 1 | |
1 5 10 10 5 1 | |
1 6 15 20 15 6 1 | |
1 7 21 35 35 21 7 1 | |
1 8 28 56 70 56 28 8 1 | |
1 9 36 84 126 126 84 36 9 1 | |
1 10 45 120 210 252 210 120 45 10 1 | |
""" | |
def pascal_step(K,n,prev): | |
""" | |
Compute one step in a generalized Pascal's triangle. | |
- K=2 binomial, etc. | |
- prev is HALF of a current row | |
- return HALF of the next row | |
- give n: only really needed to determine even or odd. | |
""" | |
# length of full row of prev: | |
N0 = n * (K-1) + 1 | |
# length of full next row - grows by K-1 | |
N1 = N0 + K - 1 | |
# instead of handling edge cases, we'll | |
# just construct a 'padded' version to use. | |
nreverse = K/2 | |
# wheather | |
if n % 2 == 0: | |
# n even but N0 odd | |
rev = prev[-2:-(nreverse+1):-1] | |
else: | |
# n odd | |
rev = prev[-1:-(nreverse+1):-1] | |
padded = [0 for i in xrange(K-1)] + prev + rev | |
nc = (N1+1) / 2 # num need to compute. | |
# really all the work: | |
return [sum(padded[i:i+K]) for i in xrange(nc)] | |
def pascal_row(K,N): | |
""" | |
return the Nth row of Pascal's triangle | |
with K coeff (K=2 for 'normal' binomial version) | |
""" | |
prev = [1 for i in xrange(K)] # initialize | |
if N == 1: | |
return prev | |
n = 1 | |
for i in xrange(N-1): | |
nxt = pascal_step(K,n,prev) | |
n += 1 | |
prev = nxt | |
# make full row from half: | |
if N % 2 == 1: | |
return nxt + nxt[::-1] | |
else: | |
# pivot on midpoint | |
return nxt + nxt[-2::-1] | |
if __name__ == "__main__": | |
import sys | |
K = int(sys.argv[1]) | |
n = int(sys.argv[2]) | |
assert(n >= 1) | |
out = pascal_row(K,n) | |
print " ".join(map(str,out)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment