Skip to content

Instantly share code, notes, and snippets.

@lisitsyn
Created August 13, 2012 16:27
Show Gist options
  • Save lisitsyn/3342349 to your computer and use it in GitHub Desktop.
Save lisitsyn/3342349 to your computer and use it in GitHub Desktop.
from pylab import *
from numpy import *
from scipy.optimize import fmin_l_bfgs_b,check_grad
import time
random.seed(7)
N=25
X=hstack([random.randn(2,N)-1.5,random.randn(2,N)+1.5])
mean = X.mean(1)
for i in xrange(N*2):
X[:,i] -= mean
Y=hstack ([-ones(N), ones(N)])
def train_SVM(X,Y,C):
D,N=X.shape
K = dot(X.T,X)
def dual_objective(alpha):
f= -alpha.sum()
s = 0.0
for i in xrange(N):
for j in xrange (N):
s+=alpha[i]*alpha[j]*Y[i]*Y[j]*(K[i,j])
f += 0.5*s
return f
def dual_objective_grad(alpha):
grad = -ones(N)
for i in xrange(N):
s = 0.0
for j in xrange(N):
s += alpha[j]*Y[j]*K[i,j]
grad[i] += Y[i]*s
return grad
alpha=zeros(N).flatten()
alpha[-1] = 1.0
alpha[0] = 1.0
start = time.time()
alpha=\
fmin_l_bfgs_b(dual_objective, alpha, fprime=dual_objective_grad, \
pgtol=1e-3,bounds=[(0,C) for _ in xrange(N)],disp=1)[0]
end = time.time()
print 'Took %fs' % (end-start)
w=zeros(D)
for i in xrange(N):
if not alpha[i]==0:
w+=alpha[i]*Y[i]*X[:,i]
return w,alpha
w,alpha=train_SVM(X,Y,0.2)
print 'alpha', alpha
r = 10.0
w_ = w/norm(w,2)
a = r*w_[0]
b = r*w_[1]
fig = figure()
ax = fig.add_subplot(1,1,1)
k = 1.0/w[1]
#ax.fill_between([-b,b],[a-k,-a-k],[a+k,-a+k],alpha=0.05,color='gray')
ax.plot([-b,b],[a,-a],color='black')
ax.plot([-b,b],[a+k,-a+k],color='black',ls='--')
ax.plot([-b,b],[a-k,-a-k],color='black',ls='--')
px = w_[0]/norm(w,2)
py = w_[1]/norm(w,2)
ax.plot([-px-0.3*b,px-0.3*b],[-py+0.3*a,py+0.3*a],color='black')
ax.text(-0.3*b-0.25,0.3*a+0.5,r'$\frac{2}{\| w \|}$',fontsize=20)
print 1.0/norm(w,2)
for i in xrange(X.shape[1]):
if not alpha[i]==0:
print dot(w_,X[:,i])
ax.plot([X[0,i],X[0,i]-(dot(w_,X[:,i]))*w_[0]],
[X[1,i],X[1,i]-(dot(w_,X[:,i]))*w_[1]],color='black',alpha=1/abs(alpha[i]),ls=':',lw=4)
ax.text(X[0,i]-0.5*dot(w_,X[:,i])*w_[0],
X[1,i]-0.5*dot(w_,X[:,i])*w_[1]-0.2,r'$%.3f$' % (dot(w,X[:,i])),fontsize=11)
ax.scatter(X[0,:],X[1,:], c=Y, s=50,cmap=cm.gray)
xlim([-4,4])
ylim([-4,4])
axis('off')
#rc('font',**{'family':'serif'})
rc('text', usetex=True)
rc('text.latex',unicode=True)
rc('text.latex',preamble='\usepackage{mathpazo}')
#rc('text.latex',preamble='\usepackage[russian]{babel}')
#rc('text.latex',preamble='\usepackage[T2A]{fontenc}')
savefig('hyperplane.pdf')
#show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment