Created
August 13, 2012 16:27
-
-
Save lisitsyn/3342349 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pylab import * | |
from numpy import * | |
from scipy.optimize import fmin_l_bfgs_b,check_grad | |
import time | |
random.seed(7) | |
N=25 | |
X=hstack([random.randn(2,N)-1.5,random.randn(2,N)+1.5]) | |
mean = X.mean(1) | |
for i in xrange(N*2): | |
X[:,i] -= mean | |
Y=hstack ([-ones(N), ones(N)]) | |
def train_SVM(X,Y,C): | |
D,N=X.shape | |
K = dot(X.T,X) | |
def dual_objective(alpha): | |
f= -alpha.sum() | |
s = 0.0 | |
for i in xrange(N): | |
for j in xrange (N): | |
s+=alpha[i]*alpha[j]*Y[i]*Y[j]*(K[i,j]) | |
f += 0.5*s | |
return f | |
def dual_objective_grad(alpha): | |
grad = -ones(N) | |
for i in xrange(N): | |
s = 0.0 | |
for j in xrange(N): | |
s += alpha[j]*Y[j]*K[i,j] | |
grad[i] += Y[i]*s | |
return grad | |
alpha=zeros(N).flatten() | |
alpha[-1] = 1.0 | |
alpha[0] = 1.0 | |
start = time.time() | |
alpha=\ | |
fmin_l_bfgs_b(dual_objective, alpha, fprime=dual_objective_grad, \ | |
pgtol=1e-3,bounds=[(0,C) for _ in xrange(N)],disp=1)[0] | |
end = time.time() | |
print 'Took %fs' % (end-start) | |
w=zeros(D) | |
for i in xrange(N): | |
if not alpha[i]==0: | |
w+=alpha[i]*Y[i]*X[:,i] | |
return w,alpha | |
w,alpha=train_SVM(X,Y,0.2) | |
print 'alpha', alpha | |
r = 10.0 | |
w_ = w/norm(w,2) | |
a = r*w_[0] | |
b = r*w_[1] | |
fig = figure() | |
ax = fig.add_subplot(1,1,1) | |
k = 1.0/w[1] | |
#ax.fill_between([-b,b],[a-k,-a-k],[a+k,-a+k],alpha=0.05,color='gray') | |
ax.plot([-b,b],[a,-a],color='black') | |
ax.plot([-b,b],[a+k,-a+k],color='black',ls='--') | |
ax.plot([-b,b],[a-k,-a-k],color='black',ls='--') | |
px = w_[0]/norm(w,2) | |
py = w_[1]/norm(w,2) | |
ax.plot([-px-0.3*b,px-0.3*b],[-py+0.3*a,py+0.3*a],color='black') | |
ax.text(-0.3*b-0.25,0.3*a+0.5,r'$\frac{2}{\| w \|}$',fontsize=20) | |
print 1.0/norm(w,2) | |
for i in xrange(X.shape[1]): | |
if not alpha[i]==0: | |
print dot(w_,X[:,i]) | |
ax.plot([X[0,i],X[0,i]-(dot(w_,X[:,i]))*w_[0]], | |
[X[1,i],X[1,i]-(dot(w_,X[:,i]))*w_[1]],color='black',alpha=1/abs(alpha[i]),ls=':',lw=4) | |
ax.text(X[0,i]-0.5*dot(w_,X[:,i])*w_[0], | |
X[1,i]-0.5*dot(w_,X[:,i])*w_[1]-0.2,r'$%.3f$' % (dot(w,X[:,i])),fontsize=11) | |
ax.scatter(X[0,:],X[1,:], c=Y, s=50,cmap=cm.gray) | |
xlim([-4,4]) | |
ylim([-4,4]) | |
axis('off') | |
#rc('font',**{'family':'serif'}) | |
rc('text', usetex=True) | |
rc('text.latex',unicode=True) | |
rc('text.latex',preamble='\usepackage{mathpazo}') | |
#rc('text.latex',preamble='\usepackage[russian]{babel}') | |
#rc('text.latex',preamble='\usepackage[T2A]{fontenc}') | |
savefig('hyperplane.pdf') | |
#show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment