Skip to content

Instantly share code, notes, and snippets.

@colmmacc
Last active August 7, 2023 11:05
Show Gist options
  • Star 20 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save colmmacc/4a39a6416d2a58b6c70bc73027bea4dc to your computer and use it in GitHub Desktop.
Save colmmacc/4a39a6416d2a58b6c70bc73027bea4dc to your computer and use it in GitHub Desktop.
Calculate the blast radius of a shuffle shard
import sys
# choose() is the same as computing the number of combinations. Normally this is
# equal to:
#
# factorial(N) / (factorial(m) * factorial(N - m))
#
# but this is very slow to run and requires a deep stack (without tail
# recursion).
#
# The below algorithm works as follows. First, re-organize as:
#
# (factorial(N) / (factorial(m)) / factorial(N - m)
#
# It should be obvious that N! / m! is the same thing as N * N - 1 * N - 2 * ...
# m + 1. So we construct our loop to compute that:
#
# c = 1
# for i in range(m + 1, n + 1):
# c *= i
#
# So now we've computed (factorial(N) / (factorial(m)), but we still need to
# divide by factorial(N - m). We already have a for loop, of (N - m) iterations,
# so we can reuse it.
#
# We want to divide by all of the values between 1 and ... (N - m),
# which are the same values as (i - m). Recall that i starts at m + 1, so it
# it will go 1, 2, 3 ... (N - m) because the final value of i will be N.
#
# Since division and multiplication are commutative in this context, and the
# order never matters we place the division directly in-line within the loop.
# Lastly, since we're now dividing, we will end up with fractions at
# intermediary stages, so we use floating point. The end-result will include
# precision errors, but that's ok for our purposes.
def choose(n, m):
c = 1
for i in range(m + 1, n + 1):
c *= float(i) / (i - m)
return c
def overlap(n, m, o):
return (choose(m, o) * choose(n - m, m - o)) / choose(n, m)
def usage():
print("shard.py n m")
print()
print(" n: The total number of elements")
print(" m: The number of elements per shard")
print()
if __name__ == "__main__":
if len(sys.argv) != 3:
usage()
try:
n = int(sys.argv[1])
m = int(sys.argv[2])
except:
usage()
if m > n:
usage()
print('With a total of %d elements, a randomly chosen shuffleshard of %d elements ...' % (n, m))
print('')
for i in range(0, m + 1):
print('overlaps by %-4d elements with %23.20f %% of other shuffleshards' % (i, overlap(n, m, i) * 100))
@shayelkin
Copy link

Another way to approximate n choose k is to take a logarithm ln(nCk) = ln(n!) - ln(k!) + ln((n-k)!), and apply Stirling's formula: ln(n!) = n*ln(n) - n + O(ln(n)). A quick benchmarking on my MacBook Pro (I9-9880H) gives:

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1000    0.001    0.000    0.001    0.000 bench-choose.py:13(stirling_choose)
     1000    0.841    0.001    0.841    0.001 bench-choose.py:6(colmmac_choose)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment