Last active
October 19, 2018 17:32
-
-
Save shyoshyo/31dcbcfe6beba486be99473d219e83f3 to your computer and use it in GitHub Desktop.
整数拆分问题:n 个相同的球,分成若干堆,有几种分法?
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
整数拆分问题:n 个相同的球,分成若干堆,有几种分法? | |
如果要和上界比较,那么绘制下列图像: | |
* 分拆数及其上界随 n 变化的图像 | |
* 分拆数与其上界之比随 n 变化的图像 | |
""" | |
import argparse | |
from math import exp, sqrt, pi | |
import numpy as np | |
import matplotlib.pyplot as plt | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--compare', action='store_true', | |
help='compare our result to the upper bound and plot graphs') | |
args = parser.parse_args() | |
# f[i, j] 表示使用不超过 i 的正整数,每个数可以使用无限次,加起来的得到 j 的方案数 | |
# 也即生成函数 | |
# G(x) = (1 - x)^-1 * (1 - x^2)^-1 * (1 - x^3)^-1 * ... * (1 - x^i)^-1 * ... | |
# 前 i 项 | |
# G_i(x) = (1 - x)^-1 * (1 - x^2)^-1 * (1 - x^3)^-1 * ... * (1 - x^i)^-1 | |
# 中 x^j 的系数 | |
# | |
# 动态规划递推方程包含了一个卷积操作,其在本质上和使用生成函数的多项式乘法计算是等价的 | |
# | |
# G(x) 的 x^0, x^1, ..., x^n 项系数与 G_n(x) 对应项系数是相同的 | |
# | |
# 此处我们使用了内存压缩,只需要保存 f[i, :] 的值 | |
n = 20000 | |
# 初始化 f[0][i] = (i == 0) ? 1 : 0 | |
f = [1] + [0] * (n) | |
# 状态转移,可以由 O(n^3) 优化到 O(n^2) | |
# f[i][j] = \sum_k f[i - 1][k], j = k + p * i, 0 <= k <= j | |
# = f[i - 1][j] + f[i - 1][j - i] + f[i - 1][j - 2i] + f[i - 1][j - 3i] + ... | |
# = f[i - 1][j] + f[i][j - i] | |
for i in range(1, n + 1): | |
# 此时对于任意 j,f[j] 表示 f[i - 1][j] | |
for j in range(i, n + 1): | |
# 此时 f[j] 表示 f[i - 1][j],f[j - i] 表示 f[i][j - i] | |
f[j] += f[j - i] | |
# 此时 f[j] 表示 f[i][j],f[j - i] 表示 f[i][j - i] | |
# 此时对于任意 j,f[j] 表示 f[i][j] | |
# 此时对于任意 j,f[j] 表示 f[n][j] | |
# 因此对于任意 j <= n,f[j] 或 f[n][j] 就是 j 的整数分拆方案数 | |
# 运用已有结论验证正确性 | |
assert f[29] == 4565 | |
assert f[200] == 3972999029388 | |
assert f[300] == 9253082936723602 | |
# 输出一些中间的结果 | |
print('f[29] = ', f[29]) | |
print('f[100] = ', f[100]) | |
print('f[200] = ', f[200]) | |
print('f[300] = ', f[300]) | |
print('f[1000] = ', f[1000]) | |
print('f[10000] = ', f[10000]) | |
print('f[20000] = ', f[20000]) | |
# 如果不进一步和上界做比较,则计算到此为止 | |
if not args.compare: exit(0) | |
# 估计上界,使用课堂上给出的结论 | |
# 此外,我们还可使用 | |
# \sum_i 1 / i^2 = pi^2 / 6 | |
# 得到一个更紧的上界 | |
# exp(sqrt(2 * pi^2 * n / 3)) | |
estimate_0 = lambda x: exp(sqrt(20. / 3. * x)) | |
estimate_1 = lambda x: exp(sqrt(2. * (pi**2) / 3. * x)) | |
upper_0 = list(map(estimate_0, range(n + 1))) | |
upper_1 = list(map(estimate_1, range(n + 1))) | |
print('upper_0[:10] = ', upper_0[:10]) | |
print('upper_0[-10:] = ', upper_0[-10:]) | |
print('upper_1[:10] = ', upper_1[:10]) | |
print('upper_1[-10:] = ', upper_1[-10:]) | |
# 确保上界与计算结果吻合 | |
assert all([x <= y for x, y in zip(f, upper_0)]) | |
assert all([x <= y for x, y in zip(f, upper_1)]) | |
# 确保后一个上界比前一个紧 | |
assert all([x >= y for x, y in zip(upper_0, upper_1)]) | |
# 计算分拆数与其上界之比,并输出一部分的值供调试 | |
ratio_0 = [x/y for x, y in zip(f, upper_0)] | |
ratio_1 = [x/y for x, y in zip(f, upper_1)] | |
print('ratio_0[:10] = ', ratio_0[:10]) | |
print('ratio_0[-10:] = ', ratio_0[-10:]) | |
print('ratio_1[:10] = ', ratio_1[:10]) | |
print('ratio_1[-10:] = ', ratio_1[-10:]) | |
# 绘图 | |
# 分拆数及其上界随 n 变化的图像 | |
plt.rc('text', usetex=True) | |
plt.rc('font', family='serif') | |
plt.semilogy(range(n + 1), f) | |
plt.semilogy(range(n + 1), upper_0) | |
plt.semilogy(range(n + 1), upper_1) | |
plt.grid(True) | |
plt.legend([r'\# of partitions', | |
r'upper bound: $\mathrm{exp}\left(\sqrt{20 \cdot n / 3}\right)$', | |
r'upper bound: $\mathrm{exp}\left(\sqrt{2 \cdot \pi^2 \cdot n / 3}\right)$'], | |
shadow=True, fontsize='large', loc='lower right') | |
plt.show() | |
# 分拆数与其上界之比随 n 变化的图像 | |
plt.semilogy(range(n + 1), ratio_0) | |
plt.semilogy(range(n + 1), ratio_1) | |
plt.legend([r'$\textnormal{\# of partitions} \left/ { \mathrm{exp}\left(\sqrt{20 \cdot n / 3}\right) } \right.$', | |
r'$\textnormal{\# of partitions} \left/ { \mathrm{exp}\left(\sqrt{2 \cdot \pi^2 \cdot n / 3}\right) } \right.$'], | |
shadow=True, fontsize='large', loc='upper right') | |
plt.grid(True) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment