Skip to content

Instantly share code, notes, and snippets.

@nyk510
Created October 22, 2016 23:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nyk510/b78c5cae4318aa3245d383b41c4d26c1 to your computer and use it in GitHub Desktop.
Save nyk510/b78c5cae4318aa3245d383b41c4d26c1 to your computer and use it in GitHub Desktop.
# coding: utf-8
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
sns.set(context='notebook', style='white')
sns.set_palette(sns.color_palette(palette='Set1')[:])
np.random.seed(71)
def t_func(x):
'''正解ラベルを作る関数
x: numpy array.
return t: target array. numpy.array like.
'''
t = np.sin(x * np.pi)
# t = np.where(x > 0, 1, -1)
return t
def plot_target_function(x, ax=None, color='default'):
'''target関数(ノイズなし)をプロットします
'''
if ax is None:
ax = plt.subplot(111)
if color is 'default':
color = sns.color_palette()[0]
ax.plot(x, t_func(x), '--', label='true function', color=color, alpha=.5)
return ax
def phi_poly(x):
dims = 3
return [x**i for i in range(0, dims + 1)]
def phi_gauss(x):
bases = np.linspace(-1, 1, 5)
return [np.exp(- (x - b)**2. * 10.) for b in bases]
def qw(alpha, phi, t, beta):
'''wの事後分布を計算します。
変分事後分布はガウス分布なので決定すべきパラメータは平均と分散です
w ~ N(w| m, S)
return ガウス分布のパラメータ m, S
'''
S = beta * phi.T.dot(phi) + alpha * np.eye(phi.shape[1])
S = np.linalg.inv(S)
m = beta * S.dot(phi.T).dot(t)
return m, S
def qbeta(mn, Sn, t, Phi, N, c0, d0):
'''betaの変分事後分布を決めるcn,dnを計算します
変分事後分布はガンマ分布なので決定すべきパラメータは2つです
beta ~ Gamma(beta | a, b)
return ガンマ分布のパラメータ a,b
'''
cn = c0 + .5 * N
dn = d0 + .5 * (np.linalg.norm(t - Phi.dot(mn)) **
2. + np.trace(Phi.T.dot(Phi).dot(Sn)))
return cn, dn
def qalpha(w2, a0, b0, m):
'''alphaの変分事後分布を計算します。
変分事後分布はガンマ分布ですから決定すべきパラメータは2つです
alpha ~ Gamma(alpha | a, b)
return a, b
'''
a = a0 + m / 2.
b = b0 + 1 / 2. * w2
return a, b
def fit(phi_func, x, update_beta=False):
xx = np.linspace(-2, 2., 100)
if phi_func == 'gauss':
phi_func = phi_gauss
elif phi_func == 'poly':
phi_func == phi_poly
else:
if type(phi_func) == 'function':
pass
else:
raise Exception('invalid phi_func')
Phi = np.array([phi_func(xi) for xi in x])
Phi_xx = np.array([phi_func(xi) for xi in xx])
# 変分事後分布の初期値
N, m = Phi.shape
mn = np.zeros(shape=(Phi.shape[1],))
Sn = np.eye(len(mn))
beta = 10.
alpha = .1
a0, b0 = 1, 1
c0, d0 = 1, 1
pred_color = sns.color_palette()[1]
freq = 5
n_iter = 3 * freq
n_fig = int(n_iter / freq)
fig = plt.figure(figsize=(4 * n_fig, 6))
data_iter = []
data_iter.append([alpha, beta])
for i in range(n_iter):
print('alpha:{alpha:.3g} beta:{beta:.3g}'.format(**locals()))
mn, Sn = qw(alpha, Phi, t, beta)
w2 = np.linalg.norm(mn) ** 2. + np.trace(Sn)
a, b = qalpha(w2, a0, b0, m)
c, d = qbeta(mn, Sn, t, Phi, N, c0, d0)
alpha = a / b
if update_beta:
# betaが更新される
beta = c / d
data_iter.append([alpha, beta])
if i % freq == 0:
k = int(i / freq) + 1
ax_i = fig.add_subplot(1, n_fig, k)
plot_target_function(xx, ax=ax_i)
ax_i.plot(x, t, 'o', label='data', alpha=.8)
m_line = Phi_xx.dot(mn)
sigma = (1./beta + np.diag(Phi_xx.dot(Sn).dot(Phi_xx.T))) ** .5
ax_i.plot(xx, m_line, '-', label='predict-line',
alpha=0.8, color=pred_color)
ax_i.fill_between(xx, m_line + sigma,m_line - sigma,label='+1sigma',alpha = .1,color = pred_color)
ax_i.set_title(
'n_iter:{i} alpha:{alpha:.3g} beta:{beta:.3g}'.format(**locals()))
ax_i.set_ylim(-2, 2)
ax_i.legend(loc=4)
fig.tight_layout()
return fig, data_iter
if __name__ == '__main__':
d_size = 100
x = np.random.uniform(-1, 1, d_size)
noise = np.random.normal(scale=.1, size=d_size)
t = t_func(x) + noise
plt.figure(figsize=(4, 6))
xx = np.linspace(-1, 1., 100)
plot_target_function(xx)
plt.plot(x, t, 'o', label='data')
plt.ylim(-3, 3)
plt.legend(loc=4)
plt.tight_layout()
plt.savefig('data.png', dpi=200)
fig, data_iter = fit(phi_func='gauss', x=x, update_beta=False)
fig.savefig('iter_notupdate_beta.png', dpi=200)
fig, data_iter = fit(phi_func='gauss', x=x, update_beta=True)
fig.savefig('iter_update_beta.png', dpi=200)
fig = plt.figure(figsize=(6, 6))
ax1 = fig.add_subplot(111)
pd.DataFrame(data_iter, columns=['alpha', 'beta']).plot(ax=ax1)
fig.savefig('alpha_beta_trans.png', dpi=200)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment