Random Sampling
Reference https://github.com/tback/MLBook_source/
Random Sampling
Reference https://github.com/tback/MLBook_source/
''' | |
http://www.nehalemlabs.net/prototype/blog/2014/02/24/an-introduction-to-the-metropolis-method-with-python/ | |
''' | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.mlab as mlab | |
def q(x, y): | |
g1 = mlab.bivariate_normal(x, y, 1.0, 1.0, -1, -1, -0.8) | |
g2 = mlab.bivariate_normal(x, y, 1.5, 0.8, 1, 2, 0.6) | |
return 0.6*g1+28.4*g2/(0.6+28.4) | |
'''Metropolis Hastings''' | |
N = 100000 | |
s = 10 | |
r = np.zeros(2) | |
p = q(r[0], r[1]) | |
print p | |
samples = [] | |
for i in xrange(N): | |
rn = r + np.random.normal(size=2) | |
pn = q(rn[0], rn[1]) | |
if pn >= p: | |
p = pn | |
r = rn | |
else: | |
u = np.random.rand() | |
if u < pn/p: | |
p = pn | |
r = rn | |
if i % s == 0: | |
samples.append(r) | |
samples = np.array(samples) | |
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5, s=1) | |
'''Plot target''' | |
dx = 0.01 | |
x = np.arange(np.min(samples), np.max(samples), dx) | |
y = np.arange(np.min(samples), np.max(samples), dx) | |
X, Y = np.meshgrid(x, y) | |
Z = q(X, Y) | |
CS = plt.contour(X, Y, Z, 10) | |
plt.clabel(CS, inline=1, fontsize=10) | |
plt.show() |
import numpy as np | |
import matplotlib.pyplot as plt | |
def mcmc(p=np.array([.1, .2, .3, .4]), n=10, | |
converge_threshold=10, is_mh=True): | |
# 随意生成一个概率转移矩阵,这里直接用了给定的概率分布 | |
Q = np.array([p for _ in range(len(p))], dtype=np.float32) | |
x0 = [np.random.randint(len(p)) for _ in range(n)] | |
converge_num = 0 | |
sample_count = 0 | |
while True: | |
idx = np.random.randint(n) | |
y = np.argmax(np.random.multinomial(1, Q[x0[idx]])) | |
sample_count += 1 | |
alpha = 0 # 计算接收率 alpha=p[j]*Q[j][i] | |
if is_mh: | |
alpha = min( | |
[1, (p[y] * Q[y][x0[idx]]) / (p[x0[idx]] * Q[x0[idx]][y])]) | |
else: | |
alpha = p[y] * Q[y][x0[idx]] | |
if np.random.ranf() < alpha: | |
if y == x0[idx]: # 状态未变更 | |
if converge_num >= converge_threshold: | |
# 状态连续多次未变更,收敛返回 | |
print '收敛状态:{}'.format(x0) | |
print '采样计数:{}'.format(sample_count) | |
return x0 | |
else: | |
# 状态未变更但还不稳定,稳定计数增加 | |
converge_num += 1 | |
else: # 有状态变更,修改 | |
x0[idx] = y | |
converge_num = 0 | |
samples = mcmc(n=100) | |
plt.hist(samples, 4) |