Skip to content

Instantly share code, notes, and snippets.

@linw1995
Last active May 22, 2017 11:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save linw1995/b078c9cebd7f01aa7bfb46de8c6a8a04 to your computer and use it in GitHub Desktop.
Save linw1995/b078c9cebd7f01aa7bfb46de8c6a8a04 to your computer and use it in GitHub Desktop.
Metropolis-Hasting Algorithm
# coding:utf-8
# one-dimensional Metropolis-Hasting Algorithm
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
def q(x):
return mlab.normpdf(x, 0, 2)
N = 100000
s = 10
x = 0
p = q(x) # p(0)
samples = np.zeros(N // s)
for i in range(N):
xn = x + np.random.normal() # get j
pn = q(xn) # get p(j)
if pn >= p: # it means p(j) >= p(i),
p = pn # also means alpha >= 1
x = xn
else: # it means p(j) < p(i), also means alpha < 1
u = np.random.rand() # u ~ Uniform(0,1)
if u < pn/p: # u < alpha, accept change
p = pn
x = xn
if i % s == 0:
samples[i // s] = x
plt.hist(samples, bins='auto', normed=True)
dx = 0.01
x = np.arange(np.min(samples), np.max(samples), dx)
y = q(x)
plt.plot(x, y)
plt.show()
# coding:utf-8
# two-dimensional Metropolis-Hasting Algorithm
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
def q(x, y):
g1 = mlab.bivariate_normal(x, y, 1.0, 1.0, -1, -1, -0.8)
g2 = mlab.bivariate_normal(x, y, 1.5, 0.8, 1, 2, 0.6)
return 0.6*g1+28.4*g2/(0.6+28.4)
N = 100000
s = 10
r = np.zeros(2)
p = q(r[0], r[1]) # p(0)
samples = np.zeros(shape=(N // s, 2))
for i in range(N):
rn = r + np.random.normal(size=2) # get j
pn = q(rn[0], rn[1]) # get p(j)
if pn >= p: # it means p(j) >= p(i), also means alpha >= 1
p = pn
r = rn
else: # it means p(j) < p(i), also means alpha < 1
u = np.random.rand() # u ~ Uniform(0,1)
if u < pn/p: # u < alpha, accept change
p = pn
r = rn
if i % s == 0:
samples[i // s, :] = r
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5, s=1)
dx = 0.01
x = np.arange(np.min(samples), np.max(samples), dx)
y = np.arange(np.min(samples), np.max(samples), dx)
X, Y = np.meshgrid(x, y)
Z = q(X, Y)
CS = plt.contour(X, Y, Z, 10)
plt.clabel(CS, inline=1, fontsize=10)
plt.show()
@linw1995
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment