Skip to content

Instantly share code, notes, and snippets.

Created October 19, 2010 20:06
Show Gist options
  • Save seikichi/634990 to your computer and use it in GitHub Desktop.
Save seikichi/634990 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
# PRML chapter 9
# Gaussian Mixture Model
import scipy as sp
from scipy.linalg import det, inv
def multivariate_normal_pdf(x, u, sigma):
D = len(x)
x, u = sp.asarray(x), sp.asarray(u)
y = x-u
return sp.exp(-(,, y)))/2.0) / (((2*sp.pi)**(D/2.0)) * (det(sigma) ** 0.5))
def gmm(X, K, iter=1000, tol=1e-6):
Gaussian Mixture Model
- `X`: Input data (2D array, [[x11, x12, ..., x1D], ..., [xN1, ... xND]]).
- `K`: Number of clusters.
- `iter`: Number of iterations to run.
- `tol`: Tolerance.
X = sp.asarray(X)
N, D = X.shape
pi = sp.ones(K) * 1.0/K
mu = sp.rand(K, D)
sigma = sp.array([sp.eye(D) for i in xrange(K)])
L = sp.inf
for _ in xrange(iter):
# E-step
gamma = sp.apply_along_axis(lambda x: sp.fromiter((pi[k] * multivariate_normal_pdf(x, mu[k], sigma[k]) for k in xrange(K)), dtype=float), 1, X)
gamma /= sp.sum(gamma, 1)[:, sp.newaxis]
# M-step
Nk = sp.sum(gamma, 0)
mu = sp.sum(X*gamma.T[..., sp.newaxis], 1) / Nk[..., sp.newaxis]
xmu = X[:, sp.newaxis, :] - mu
sigma = sp.sum(gamma[..., sp.newaxis, sp.newaxis] * xmu[:, :, sp.newaxis, :] * xmu[:, :, :, sp.newaxis], 0) / Nk[..., sp.newaxis, sp.newaxis]
pi = Nk / N
# Likelihood
Lnew = sp.sum(sp.log2(sp.sum(sp.apply_along_axis(lambda x: sp.fromiter((pi[k] * multivariate_normal_pdf(x, mu[k], sigma[k]) for k in xrange(K)), dtype=float), 1, X), 1)))
if abs(L-Lnew) < tol: break
L = Lnew
print "L=%s" % L
cls = sp.zeros(N)
for i in xrange(K):
cls[gamma[:, i] > 1.0/K] = i
return dict(pi=pi, mu=mu, sigma=sigma, gamma=gamma, classification=cls)
if __name__ == '__main__':
data = sp.append(sp.random.multivariate_normal([-3.5, 5.0], sp.eye(2)*4, 50),
sp.random.multivariate_normal([-8.2, 10.0], sp.eye(2)*2, 70)).reshape(50+70, 2)
K = 2
d = gmm(data, K)
print "π=%s\nμ=%s\nΣ=%s" % (d['pi'], d['mu'], d['sigma'])
gamma = d['gamma']
# print gamma
# import matplotlib.pyplot as plt
# plt.scatter(data[:, 0][gamma[:, 0] >= 0.5], data[:, 1][gamma[:, 0] >= 0.5], color='r')
# plt.scatter(data[:, 0][gamma[:, 1] > 0.5 ], data[:, 1][gamma[:, 1] > 0.5 ], color='g')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment