Last active
August 25, 2018 03:37
-
-
Save gwy15/faf4cf0a5a27a1fc5729fc15be2f5788 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/pypy | |
import random | |
import math | |
import unittest | |
def simulate(n, m, k, times=1000000): | |
''' | |
随机模拟计算概率 | |
Args: | |
n: 放飞数 | |
m: 敌舰数量 | |
k: 击沉数量 | |
Optional: | |
times: 模拟 times 次 | |
Return: | |
(p, dp) | |
p: 估计概率 | |
dp: 区间估计的不确定度(95% 置信度) | |
''' | |
count = 0 | |
for _ in range(times): | |
d = dict() | |
for __ in range(n): | |
d[random.randint(1, m)] = 1 | |
if len(d) == k: | |
count += 1 | |
p = 1.0 * count / times | |
dp = 1.96 * math.sqrt(p * (1-p) / times) | |
return (p, dp) | |
def perm(m, k): | |
'计算排列数 A(m, k)' | |
prod = 1 | |
for i in range(m-k+1, m+1): | |
prod *= i | |
return prod | |
def S(n, k): | |
'计算 Stirling 数 S(n, k)' | |
if n < k: | |
return 0 | |
if k == 1 or n == k: | |
return 1 | |
if n - 1 == k: | |
return n * k // 2 | |
if 2 == k: | |
return 2 ** (n - 1) - 1 | |
return S(n-1, k-1) + k * S(n-1, k) | |
def formula(n, m, k): | |
''' | |
根据公式计算概率 | |
Args: | |
n: 放飞数 | |
m: 敌舰数量 | |
k: 击沉数量 | |
''' | |
return 1.0 * perm(m, k) * S(n, k) / m**n | |
class formulaTest(unittest.TestCase): | |
def testStirling(self): | |
cases = ( | |
((1, 1), 1), | |
((2, 2), 1), | |
((2, 1), 1), | |
((2, 10), 0), | |
((10, 6), 22827), | |
((27, 25), 55575), | |
((23, 2), 4194303), | |
((23, 10), 9593401297313460)) | |
for case in cases: | |
self.assertEqual(S(*case[0]), case[1]) | |
def testPerm(self): | |
cases = ( | |
((1, 0), 1), | |
((1, 1), 1), | |
((2, 1), 2), | |
((2, 3), 0), | |
((3, 2), 6), | |
((15, 12), 217945728000)) | |
for case in cases: | |
self.assertEqual(perm(*case[0]), case[1]) | |
def testFormula(self): | |
cases = ( | |
# 放飞, 敌舰, 击沉 | |
(1, 1, 1), | |
(2, 2, 1), | |
(2, 2, 1), | |
(24, 6, 6), | |
(24, 6, 5), | |
(23, 6, 5), | |
(10, 5, 5), | |
(12, 4, 3)) | |
for case in cases: | |
p, dp = simulate(*case) | |
fp = formula(*case) | |
# print('p = {:.4f}, fp = {:.4f}'.format(p, fp)) | |
self.assertLessEqual(abs(fp - p), 2 * dp) # 允许误差 * 2 避免过分脸黑 | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment