Skip to content

Instantly share code, notes, and snippets.

@pbloem
Last active January 16, 2016 15:12
Show Gist options
  • Save pbloem/11092d940646cd4a0bba to your computer and use it in GitHub Desktop.
Save pbloem/11092d940646cd4a0bba to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as p
from matplotlib import colors
import numpy as n
import pylab
import scipy.stats as stats
from __builtin__ import file
from matplotlib.pyplot import margins
import json
from sklearn import svm, linear_model
def clean(ax):
ax.get_yaxis().set_tick_params(which='both', direction='out')
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.get_xaxis().set_tick_params(which='both', top='off',)
ax.get_yaxis().set_tick_params(which='both', right='off')
ax.set_xlabel('age')
ax.set_ylabel('height')
h = 500 # resolution for plotting the models
cmap = colors.ListedColormap(['blue', 'red'])
fig = p.figure()
ax1 = fig.add_subplot(111)
data = n.genfromtxt('people.csv', delimiter=',')
# tweak the data
data[n.ix_(data[:, 2] == 1, n.array([True, False, False]))] += 1
data[n.ix_(data[:, 2] == 1, n.array([False, True, False]))] -= 5
age = data[:, 0]
height = data[:, 1]
gender = data[:, 2]
# some basic statistics
a_min, a_max = age.min() - 5, age.max() + 5
h_min, h_max = height.min() - 5, height.max() + 5
ax1.scatter(age, height, c=gender, cmap=cmap, linewidth=0)
clean(ax1);
p.savefig('dataset.pdf')
# Linear regression
fig = p.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(age, height, c='gray', linewidth=0)
slope, intercept, r_value, p_value, slope_std_error = stats.linregress(age, height)
predict_height = intercept + slope * age
ax1.plot(age, predict_height, 'k-')
clean(ax1)
p.savefig('regression.pdf')
# Linear regression, design matrix
fig = p.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(age, height, c='gray', linewidth=0)
plus = n.column_stack((age, age * age))
model = linear_model.LinearRegression().fit(plus, height)
a_grid = n.arange(a_min, a_max, (a_max - a_min)/h)
grid_plus = n.column_stack((a_grid, a_grid * a_grid))
predictions = model.predict(grid_plus)
ax1.plot(a_grid, predictions, 'k-')
clean(ax1)
p.savefig('reg_design.pdf')
# fit SVM in 2D
fig = p.figure()
ax1 = fig.add_subplot(111)
clf = svm.SVC(kernel='linear', C=10.0)
svc = clf.fit(data[:, :2], data[:,2])
h = 500 # step nums in the mesh
a_min, a_max = age.min() - 5, age.max() + 5
h_min, h_max = height.min() - 5, height.max() + 5
xx, yy = n.meshgrid(n.arange(a_min, a_max, (a_max - a_min)/h),
n.arange(h_min, h_max, (h_max - h_min)/h))
zz = clf.predict(n.c_[xx.ravel(), yy.ravel()])
zz = zz.reshape(xx.shape)
ax1.contourf(xx, yy, zz, cmap=cmap, alpha=0.4)
ax1.scatter(age, height, c=gender, cmap='jet', linewidth=0)
clean(ax1);
p.savefig('linear.pdf')
# Fit SVM in 5D, design matrix
plus = n.column_stack((age, height, age* age, age*height, height*height))
n.savetxt("design.csv", plus, delimiter=",", fmt="%5.1f")
fig = p.figure()
ax1 = fig.add_subplot(111)
clf = svm.SVC(kernel='linear', C=10.0)
svc = clf.fit(plus, data[:,2])
xx, yy = n.meshgrid(n.arange(a_min, a_max, (a_max - a_min)/h),
n.arange(h_min, h_max, (h_max - h_min)/h))
xr = xx.ravel()
yr = yy.ravel()
xyplus = n.column_stack((xr, yr, xr*xr, xr*yr, yr*yr));
zz = clf.predict(xyplus)
zz = zz.reshape(xx.shape)
ax1.contourf(xx, yy, zz, cmap=cmap, alpha=0.4)
ax1.scatter(age, height, c=gender, cmap='jet', linewidth=0)
clean(ax1);
p.savefig('design.pdf')
# fit RBF SVM in 2D
fig = p.figure()
ax1 = fig.add_subplot(111)
clf = svm.SVC(kernel='rbf', C=10.0)
svc = clf.fit(data[:, :2], data[:,2])
h = 500 # step nums in the mesh
a_min, a_max = age.min() - 5, age.max() + 5
h_min, h_max = height.min() - 5, height.max() + 5
xx, yy = n.meshgrid(n.arange(a_min, a_max, (a_max - a_min)/h),
n.arange(h_min, h_max, (h_max - h_min)/h))
zz = clf.predict(n.c_[xx.ravel(), yy.ravel()])
zz = zz.reshape(xx.shape)
ax1.contourf(xx, yy, zz, cmap=cmap, alpha=0.4)
ax1.scatter(age, height, c=gender, cmap='jet', linewidth=0)
clean(ax1);
p.savefig('rbf.pdf')
print 'done'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment