Last active
January 9, 2017 16:25
-
-
Save yukoba/46a133056292215e2bd34357253b7381 to your computer and use it in GitHub Desktop.
混合正規分布をEMアルゴリズムを使わずに直接勾配法でパラメータを求める
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 混合正規分布をEMアルゴリズムを使わずに直接勾配法でパラメータを求める | |
import autograd | |
import autograd.numpy as np | |
epsilon = 1e-8 | |
data = np.array([-0.39, 0.12, 0.94, 1.67, 1.76, 2.44, 3.72, 4.28, 4.92, 5.53, | |
0.06, 0.48, 1.01, 1.68, 1.80, 3.25, 4.12, 4.60, 5.28, 6.22]) | |
data_var = data.var() | |
def loss(p): | |
def norm_dist(u, s, x): | |
return np.exp(-((x - u) ** 2 / (2 * s))) / np.sqrt(2 * np.pi * s) | |
u1, u2, s1, s2, pi = p | |
return -np.mean(np.log((1 - pi) * norm_dist(u1, s1, data) + pi * norm_dist(u2, s2, data))) | |
def init_param(): | |
np.random.shuffle(data) | |
return np.array([data[0], data[1], data_var, data_var, 0.5]), np.array([epsilon] * 5) | |
loss_grad = autograd.grad(loss) | |
p, r = init_param() | |
for _ in range(10000): | |
print(loss(p), p) | |
d = loss_grad(p) | |
r += d ** 2 # AdaGrad | |
p_new = p - 0.3 * d / np.sqrt(r) | |
if np.allclose(p, p_new, epsilon): | |
break | |
p = p_new | |
if not 0 <= p[4] <= 1: # 定義域外になったら探索をリセット | |
p, r = init_param() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment