Skip to content

Instantly share code, notes, and snippets.

@tnarihi
Last active August 29, 2015 13:56
Show Gist options
  • Save tnarihi/9339795 to your computer and use it in GitHub Desktop.
Save tnarihi/9339795 to your computer and use it in GitHub Desktop.
messy
# -*- coding: utf-8 -*-
import numpy as np
import pylab as pl
__author__ = 'Takuya Nairihira'
def compute_belonging(x, c):
dist_NxK = (x**2).sum(axis=1)[:,np.newaxis] - 2. * np.dot(x, c.T) + (c**2).sum(axis=1)[np.newaxis]
return dist_NxK.argmin(axis=1)
def main(n_sample=100, n_cluster=5):
colors = ['r','g','b','c','y']
x = np.random.rand(n_sample, 2)
c = np.random.rand(n_cluster, 2)
pl.figure('sample')
pl.scatter(x[:,0], x[:,1], 20, 'gray', 'o', edgecolors='none')
pl.draw()
pl.savefig('00.png')
pl.figure('init')
pl.scatter(x[:,0], x[:,1], 20, 'gray', 'o', edgecolors='none')
pl.scatter(c[:,0], c[:,1], 200, colors, 'x')
pl.draw()
pl.savefig('01.png')
for t in xrange(3):
# E-step
pl.figure('E-step@{}'.format(t))
belong = compute_belonging(x, c)
for k in xrange(n_cluster):
xk = x[belong==k]
pl.scatter(xk[:,0], xk[:,1], 20, colors[k], 'o', edgecolors='none')
pl.scatter(c[:,0], c[:,1], 200, colors, 'x')
pl.draw()
pl.savefig('{:02d}.png'.format(2*t+2))
# M-step
pl.figure('M-step@{}'.format(t))
for k in xrange(n_cluster):
xk = x[belong==k]
c[k] = xk.mean(axis=0)
pl.scatter(xk[:,0], xk[:,1], 20, colors[k], 'o', edgecolors='none')
pl.scatter(c[:,0], c[:,1], 200, colors, 'x')
pl.draw()
pl.savefig('{:02d}.png'.format(2*t+3))
pl.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment