Skip to content

Instantly share code, notes, and snippets.

@qxj
Created May 22, 2017 15:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save qxj/2c144796b121848cd44aaa72964c51aa to your computer and use it in GitHub Desktop.
Save qxj/2c144796b121848cd44aaa72964c51aa to your computer and use it in GitHub Desktop.
EM算法demo - 混合高斯模型参数求解

EM算法demo

混合高斯模型参数求解

# coding=utf-8
import copy
import math
import numpy as np
import matplotlib.pyplot as plt
def gaussian(x, mu, sigma):
return math.exp(-1.0 * (x - mu) ** 2 / (2 * sigma ** 2)) / math.sqrt(2 * math.pi * sigma ** 2)
def Estep(X, z, k, N, mu, sigma):
for i in range(0, N):
x = []
for j in range(0, k):
x.append(gaussian(X[0, i], mu[j], sigma))
s = sum(x)
for j in range(0, k):
z[i, j] = x[j] / s
def Mstep(X, z, k, N, mu):
for j in range(0, k):
x = 0
s = 0
for i in range(0, N):
x += z[i, j] * X[0, i]
s += z[i, j]
mu[j] = x / s
# 两个高斯分布叠加
k = 2
sigma = 6
mu1 = 40
mu2 = 20
# 达到精度则停止迭代
epsilon = 0.0001
# 样本大小
N = 1000
# 构造一个混合高斯分布示例
X = np.zeros((1, N))
mu = np.random.random(2)
z = np.zeros((N, k))
for i in range(0, N):
if np.random.random(1) > 0.5:
X[0, i] = np.random.normal() * sigma + mu1
else:
X[0, i] = np.random.normal() * sigma + mu2
# 开始EM迭代
step = 0
while True:
muu = copy.deepcopy(mu)
Estep(X, z, k, N, mu, sigma)
Mstep(X, z, k, N, mu)
print step, mu
step += 1
if sum(abs(mu - muu)) < epsilon:
break
plt.hist(X[0, :], 50)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment