Skip to content

Instantly share code, notes, and snippets.

@qxj
Last active May 20, 2017 17:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save qxj/66d42234f58d519b2511ed2892cf7411 to your computer and use it in GitHub Desktop.
Save qxj/66d42234f58d519b2511ed2892cf7411 to your computer and use it in GitHub Desktop.
Random Sampling
'''
http://www.nehalemlabs.net/prototype/blog/2014/02/24/an-introduction-to-the-metropolis-method-with-python/
'''
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
def q(x, y):
g1 = mlab.bivariate_normal(x, y, 1.0, 1.0, -1, -1, -0.8)
g2 = mlab.bivariate_normal(x, y, 1.5, 0.8, 1, 2, 0.6)
return 0.6*g1+28.4*g2/(0.6+28.4)
'''Metropolis Hastings'''
N = 100000
s = 10
r = np.zeros(2)
p = q(r[0], r[1])
print p
samples = []
for i in xrange(N):
rn = r + np.random.normal(size=2)
pn = q(rn[0], rn[1])
if pn >= p:
p = pn
r = rn
else:
u = np.random.rand()
if u < pn/p:
p = pn
r = rn
if i % s == 0:
samples.append(r)
samples = np.array(samples)
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5, s=1)
'''Plot target'''
dx = 0.01
x = np.arange(np.min(samples), np.max(samples), dx)
y = np.arange(np.min(samples), np.max(samples), dx)
X, Y = np.meshgrid(x, y)
Z = q(X, Y)
CS = plt.contour(X, Y, Z, 10)
plt.clabel(CS, inline=1, fontsize=10)
plt.show()
import numpy as np
import matplotlib.pyplot as plt
def mcmc(p=np.array([.1, .2, .3, .4]), n=10,
converge_threshold=10, is_mh=True):
# 随意生成一个概率转移矩阵,这里直接用了给定的概率分布
Q = np.array([p for _ in range(len(p))], dtype=np.float32)
x0 = [np.random.randint(len(p)) for _ in range(n)]
converge_num = 0
sample_count = 0
while True:
idx = np.random.randint(n)
y = np.argmax(np.random.multinomial(1, Q[x0[idx]]))
sample_count += 1
alpha = 0 # 计算接收率 alpha=p[j]*Q[j][i]
if is_mh:
alpha = min(
[1, (p[y] * Q[y][x0[idx]]) / (p[x0[idx]] * Q[x0[idx]][y])])
else:
alpha = p[y] * Q[y][x0[idx]]
if np.random.ranf() < alpha:
if y == x0[idx]: # 状态未变更
if converge_num >= converge_threshold:
# 状态连续多次未变更,收敛返回
print '收敛状态:{}'.format(x0)
print '采样计数:{}'.format(sample_count)
return x0
else:
# 状态未变更但还不稳定,稳定计数增加
converge_num += 1
else: # 有状态变更,修改
x0[idx] = y
converge_num = 0
samples = mcmc(n=100)
plt.hist(samples, 4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment