Skip to content

Instantly share code, notes, and snippets.

@phelrine
Created November 20, 2011 12:27
Show Gist options
  • Save phelrine/1380219 to your computer and use it in GitHub Desktop.
Save phelrine/1380219 to your computer and use it in GitHub Desktop.
混合ガウス分布のパラメータ推定
import numpy as np
import numpy.random as nprand
import matplotlib.pyplot as plt
def dnorm(x, m, s):
return np.exp(-((x - m) ** 2)/(2 * s)) / np.sqrt(2 * np.pi * s)
def EM(data, init, iter):
params = np.array(init)
for _ in range(iter):
w = np.array(map(lambda d: d/sum(d), [[dnorm(d, p[0], p[1]) * p[2] for p in params] for d in data])).T
c = map(sum, w)
for i, v in enumerate(params):
params[i, 0] = sum([w[i, j] * d for j, d in enumerate(data)]) / c[i]
params[i, 1] = sum([w[i, j] * (d - params[i, 0]) ** 2 for j, d in enumerate(data)]) / c[i]
params[i, 2] = c[i] / len(data)
return params
def main():
params = [[-7, 2, 0.3], [-2, 1, 0.5], [4, 3, 0.2]]
xs = np.linspace(-15, 15, 1000)
ys = [sum([dnorm(x, m, s) * c for m, s, c in params]) for x in xs]
plt.plot(xs, ys, label = "base")
data = reduce(lambda r, l: r + l, [nprand.normal(m, np.sqrt(s), c * 500).tolist() for m, s, c in params], [])
for i in map(lambda x: 2 ** x, range(5)):
est = EM(data, [[-6, 2, 0.3], [0, 2, 0.4], [6, 2, 0.3]], i)
ys = [sum([dnorm(x, m, s) * c for m, s, c in est]) for x in xs]
plt.plot(xs, ys, label = "iter-%d" % i)
plt.legend(loc = "best")
plt.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment