Created
July 13, 2018 03:08
-
-
Save zhuang-hao-ming/724ff836cfa184fd81ce0d03fbfd773b to your computer and use it in GitHub Desktop.
Markov Chain Monte Carlo方法的一个解释 贝叶斯推理
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy as sp | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from scipy.stats import norm | |
def sampler(data, samples=4, mu_init=.5, proposal_width=.5, plot=False, mu_prior_mu=0, mu_prior_sd=1.): | |
''' | |
mcmc通过在后验分布中进行采样,得到一个大样本,然后使用直方图的方法来得到后验分布的形式。 | |
在此处,后验分布是关于mu的一个分布。 | |
最开始随机选择一个值,认为这个值是从后验分布采样得来。 | |
以这个值为标准,使用一个方法得到一个新的值,如果这个新的值的后验概率比旧的值的后验概率大于一个随机数, | |
那么将这个新的值认为是一个新的采样,同时以这个新的值作为标准,重复上面的过程。 | |
可以有很多方法来从当前值获得一个新的值(markov chain阶段),metropolis方法是其中一种方法, 它在一个以当前值为mu,以 | |
proposal_width为std的正态分布中进行采样得到新的值。其中proposal_width是一个超参,任意指定proposal_width都会收敛到 | |
同样的结果,但是收敛的速度不一样。 | |
如果这个新的值的后验概率比旧的值的后验概率大于一个随机数,接受新值作为一个采样,否则旧值是一个采样(monte carlo阶段)。 | |
计算后验概率的比值,可以去除p(x)的计算,只需要计算似然函数和先验。 | |
我们会接受一些新后验比当前后验低的样本, 目的是为了得到后验的概率分布。 | |
Parameters: | |
------------------ | |
data: list | |
样本点 | |
samples: int | |
markov chain monte carlo 采样次数 | |
mu_init: float | |
第一个采样点 | |
proposal_width: float | |
metropolis sampler要求的参数, metropolis在一个以当前位置为mu,proposal_width为std的正态分布中进行采样得到下一个采样点 | |
plot: boolean | |
是否绘图 | |
mu_prior_mu: float | |
关于参数的先验分布, 这里假设参数的先验是正态分布 | |
mu_prior_std: float | |
关于参数的先验分布, 这里假设参数的先验是正态分布 | |
Reference: | |
--------------------- | |
http://twiecki.github.io/blog/2015/11/10/mcmc-sampling/ | |
''' | |
mu_current = mu_init | |
posterior = [mu_current] | |
for i in range(samples): | |
# suggest new position | |
mu_proposal = norm(mu_current, proposal_width).rvs() | |
# Compute likelihood by multiplying probabilities of each data point | |
likelihood_current = norm(mu_current, 1).pdf(data).prod() | |
likelihood_proposal = norm(mu_proposal, 1).pdf(data).prod() | |
# Compute prior probability of current and proposed mu | |
prior_current = norm(mu_prior_mu, mu_prior_sd).pdf(mu_current) | |
prior_proposal = norm(mu_prior_mu, mu_prior_sd).pdf(mu_proposal) | |
p_current = likelihood_current * prior_current | |
p_proposal = likelihood_proposal * prior_proposal | |
# Accept proposal? | |
p_accept = p_proposal / p_current | |
# Usually would include prior probability, which we neglect here for simplicity | |
accept = np.random.rand() < p_accept | |
if plot: | |
plot_proposal(mu_current, mu_proposal, mu_prior_mu, mu_prior_sd, data, accept, posterior, i) | |
pass | |
if accept: | |
# Update position | |
mu_current = mu_proposal | |
posterior.append(mu_current) | |
return posterior | |
def calc_posterior_analytical(data, x, mu_0, sigma_0): | |
''' | |
如果先验是正态分布, | |
似然函数也是正态分布(数据服从正态分布), | |
那么后验也服从正态分布。 | |
这种后验和先验的分布形式一致的现象,叫做共轭。 | |
这样一种后验,可以使用数学分析的方法得到它的封闭解。 | |
''' | |
sigma = 1. | |
n = len(data) | |
mu_post = (mu_0 / sigma_0**2 + data.sum() / sigma**2) / (1. / sigma_0**2 + n / sigma**2) | |
sigma_post = (1. / sigma_0**2 + n / sigma**2)**-1 | |
return norm(mu_post, np.sqrt(sigma_post)).pdf(x) | |
# Function to display | |
def plot_proposal(mu_current, mu_proposal, mu_prior_mu, mu_prior_sd, data, accepted, trace, i): | |
from copy import copy | |
trace = copy(trace) | |
fig, (ax1, ax2, ax3, ax4) = plt.subplots(ncols=4, figsize=(16, 4)) | |
fig.suptitle('Iteration %i' % (i + 1)) | |
x = np.linspace(-3, 3, 5000) | |
color = 'g' if accepted else 'r' | |
# Plot prior | |
prior_current = norm(mu_prior_mu, mu_prior_sd).pdf(mu_current) | |
prior_proposal = norm(mu_prior_mu, mu_prior_sd).pdf(mu_proposal) | |
prior = norm(mu_prior_mu, mu_prior_sd).pdf(x) | |
ax1.plot(x, prior) | |
ax1.plot([mu_current] * 2, [0, prior_current], marker='o', color='b') | |
ax1.plot([mu_proposal] * 2, [0, prior_proposal], marker='o', color=color) | |
ax1.annotate("", xy=(mu_proposal, 0.2), xytext=(mu_current, 0.2), | |
arrowprops=dict(arrowstyle="->", lw=2.)) | |
ax1.set(ylabel='Probability Density', title='current: prior(mu=%.2f) = %.2f\nproposal: prior(mu=%.2f) = %.2f' % (mu_current, prior_current, mu_proposal, prior_proposal)) | |
# Likelihood | |
likelihood_current = norm(mu_current, 1).pdf(data).prod() | |
likelihood_proposal = norm(mu_proposal, 1).pdf(data).prod() | |
y = norm(loc=mu_proposal, scale=1).pdf(x) | |
sns.distplot(data, kde=False, norm_hist=True, ax=ax2) | |
ax2.plot(x, y, color=color) | |
ax2.axvline(mu_current, color='b', linestyle='--', label='mu_current') | |
ax2.axvline(mu_proposal, color=color, linestyle='--', label='mu_proposal') | |
#ax2.title('Proposal {}'.format('accepted' if accepted else 'rejected')) | |
ax2.annotate("", xy=(mu_proposal, 0.2), xytext=(mu_current, 0.2), | |
arrowprops=dict(arrowstyle="->", lw=2.)) | |
ax2.set(title='likelihood(mu=%.2f) = %.2f\nlikelihood(mu=%.2f) = %.2f' % (mu_current, 1e14*likelihood_current, mu_proposal, 1e14*likelihood_proposal)) | |
# Posterior | |
posterior_analytical = calc_posterior_analytical(data, x, mu_prior_mu, mu_prior_sd) | |
ax3.plot(x, posterior_analytical) | |
posterior_current = calc_posterior_analytical(data, mu_current, mu_prior_mu, mu_prior_sd) | |
posterior_proposal = calc_posterior_analytical(data, mu_proposal, mu_prior_mu, mu_prior_sd) | |
ax3.plot([mu_current] * 2, [0, posterior_current], marker='o', color='b') | |
ax3.plot([mu_proposal] * 2, [0, posterior_proposal], marker='o', color=color) | |
ax3.annotate("", xy=(mu_proposal, 0.2), xytext=(mu_current, 0.2), | |
arrowprops=dict(arrowstyle="->", lw=2.)) | |
#x3.set(title=r'prior x likelihood $\propto$ posterior') | |
ax3.set(title='posterior(mu=%.2f) = %.5f\nposterior(mu=%.2f) = %.5f' % (mu_current, posterior_current, mu_proposal, posterior_proposal)) | |
if accepted: | |
trace.append(mu_proposal) | |
else: | |
trace.append(mu_current) | |
ax4.plot(trace) | |
ax4.set(xlabel='iteration', ylabel='mu', title='trace') | |
plt.tight_layout() | |
#plt.legend() | |
plt.show() | |
if __name__ == '__main__': | |
data = np.random.randn(20) | |
sampler(data, plot=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment