Skip to content

Instantly share code, notes, and snippets.

@stober
Created June 15, 2011 17:10
Show Gist options
  • Save stober/1027563 to your computer and use it in GitHub Desktop.
Save stober/1027563 to your computer and use it in GitHub Desktop.
Gaussian Mixture Model
#!/usr/bin/python
"""
Author: Jeremy M. Stober
Program: GMM.PY
Date: Thursday, September 4 2008
Description: Fit a Gaussian Mixture Model with EM.
"""
import os, sys, getopt, pdb, string
from numpy import *
import pylab
import matplotlib
faithful = [array([float(elem) for elem in line.split()]) for line in open('faithful.txt').readlines()]
xvals = array([elem[0] for elem in faithful])
yvals = array([elem[1] for elem in faithful])
# Standard normalization.
xmean = mean(xvals)
xstd = std(xvals)
ymean = mean(yvals)
ystd = std(yvals)
xnorm = (xvals - xmean) / xstd
ynorm = (yvals - ymean) / ystd
nfaithful= array(zip(xnorm, ynorm))
# Does the data look reasonable?
# pylab.plot(xnorm,ynorm, '+')
# pylab.show()
# initialize the means and covariances
mus = [array([-1.5,1.0]), array([1.0,-1.0])]
sigmas = [eye(2),eye(2)]
mixings = [0.5, 0.5]
def normal(x,mu,sigma):
""" Return the normal density at a point. """
D = len(sigma)
det = linalg.det(sigma)
div = (2.0 * pi)**(D / 2.0) * (det)**(0.5)
return exp(-0.5 * dot(dot(x - mu, linalg.inv(sigma)), x - mu)) / div
k = 2 # 2 dimensional data
n = len(nfaithful) # number of data points
for l in range(100):
print l
# E step
print mus, mixings
responses = zeros((k,n))
for j in range(n):
for i in range(k):
responses[i,j] = mixings[i] * normal(nfaithful[j],mus[i],sigmas[i])
responses = responses / sum(responses,axis=0) # normalize the weights
# M step
N = sum(responses,axis=1)
for i in range(k):
mus[i] = dot(responses[i,:],nfaithful) / N[i]
sigmas[i] = zeros((2,2))
for j in range(n):
sigmas[i] += responses[i,j] * outer(nfaithful[j,:] - mus[i],nfaithful[j,:] - mus[i])
sigmas[i] = sigmas[i] / N[i]
mixings[i] = N[i] / sum(N)
def shownormal(mus,sigmas):
# Plot the normalized faithful data points.
fig = pylab.figure(num = 1, figsize=(4,4))
axes = fig.add_subplot(111)
axes.plot(xnorm,ynorm, '+')
# Plot the ellipses representing the principle components of the normals.
k = len(mus)
for i in range(k):
color = None
if i == 0:
color = 'red'
else:
color = 'blue'
[u,s,v] = linalg.svd(sigmas[i])
angle = arccos(dot(u[1],array([1,0]))) * 180.0 / pi
ellipse = matplotlib.patches.Ellipse(mus[i], sqrt(s[1]), sqrt(s[0]), angle=angle, fill=False, ec=color)
axes.add_patch(ellipse)
pylab.draw()
pylab.show()
shownormal(mus,sigmas)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment