Skip to content

Instantly share code, notes, and snippets.

@enakai00
Created January 24, 2015 01:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save enakai00/560a78ac393e73a8b4b3 to your computer and use it in GitHub Desktop.
Save enakai00/560a78ac393e73a8b4b3 to your computer and use it in GitHub Desktop.
Multivalue classifier from PRML chapter4
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pandas import Series, DataFrame
from numpy.random import randint, randn, rand, multivariate_normal
K = 3 # Number of classes (2 or 3)
N = 50 # Number of training data
HR = 200 # Heatmap resolution
def classify(x1, x2, W_t):
ys = np.dot(W_t, np.array([1,x1,x2]))
cls = np.argmax(ys) + 1
yval = ys[cls - 1]
markers = [0.5, 1.0]
delta = 0.01
for p in markers:
if np.abs(yval - p) < delta:
return 0
return cls
def prep_data():
mean1 = [-2, 2]
mean2 = [0, 0]
mean3 = [2, -2]
cov = [[1.0,0.8], [0.8,1.0]]
df1 = DataFrame(multivariate_normal(mean1, cov, N), columns=['x1','x2'])
df2 = DataFrame(multivariate_normal(mean2, cov, N), columns=['x1','x2'])
df3 = DataFrame(multivariate_normal(mean3, cov, N), columns=['x1','x2'])
df1['x0'] = df2['x0'] = df3['x0'] = 1
(df1['cls'], df2['cls'], df3['cls']) = (1, 2, 3)
if K == 2:
df = pd.concat([df1,df2], ignore_index=True)
if K == 3:
df = pd.concat([df1,df2,df3], ignore_index=True)
return df
def solve(df):
X = df[['x0','x1','x2']]
T = DataFrame(np.zeros(shape=(len(df),K)), columns=range(1,K+1))
for index, point in df.iterrows():
c = point.cls
T.ix[index,c] = 1
temp = np.linalg.inv(np.dot(X.T, X))
W = np.dot(np.dot(temp, X.T), T)
return W.T
if __name__ == "__main__":
fig = plt.figure()
ax = fig.add_subplot(111)
df = prep_data()
W_t = solve(df)
X = Y = np.linspace(-6,6,HR)
field = DataFrame(np.zeros(shape=(len(X),len(Y))))
for x, xval in enumerate(X):
for y, yval in enumerate(Y):
field.ix[y,x] = classify(xval,yval,W_t)
xim = ax.imshow(field.values, extent=(-6,6,6,-6), vmin=0, vmax=3,
alpha=0.2)
cls1 = df[df['cls']==1][['x1','x2']]
cls2 = df[df['cls']==2][['x1','x2']]
cls3 = df[df['cls']==3][['x1','x2']]
ax.scatter(cls1.x1, cls1.x2, color='blue', marker='x')
ax.scatter(cls2.x1, cls2.x2, color='orange', marker='o')
ax.scatter(cls3.x1, cls3.x2, color='red', marker='+')
ax.set_xlim(-6,6)
ax.set_ylim(-6,6)
fig.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment