Skip to content

Instantly share code, notes, and snippets.

@konabe
Created December 16, 2018 13:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save konabe/81fe048630ff3cb0dc2965788676a12b to your computer and use it in GitHub Desktop.
Save konabe/81fe048630ff3cb0dc2965788676a12b to your computer and use it in GitHub Desktop.
線形SVMの実装
import numpy as np
from scipy.optimize import minimize
from matplotlib import pyplot as plt
ONE_DATA = 30
x = np.vstack((np.random.normal(loc=-2, scale=1.0, size=(ONE_DATA, 2)),\
np.random.normal(loc=2, scale=1.0, size=(ONE_DATA, 2))))
y = np.hstack((np.repeat(-1, ONE_DATA), np.repeat(1, ONE_DATA)))
def get_obj(x, y):
def obj(alpha):
second = np.array([alpha[i]*alpha[j]*y[i]*y[j]*np.dot(x[i], x[j]) \
for i in range(len(alpha)) for j in range(len(alpha))])
return - np.sum(alpha) + .5 * np.sum(second)
return obj
def get_g(i):
def g(alpha):
return alpha[i]
return g
def get_h(y):
def h(alpha):
return np.sum([alpha[i]*y[i] for i in range(len(alpha))])
return h
cons = [
{'type': 'eq', 'fun': get_h(y)}
]
for i in range(len(x)):
cons.append({'type': 'ineq', 'fun': get_g(i)})
res = minimize(get_obj(x, y), np.zeros(len(x)), constraints=cons, method="SLSQP")
alpha_hat = res.x
tol = 1e-10
support_index = np.where(alpha_hat > tol)[0]
non_suppoert_index = np.where(alpha_hat <= tol)[0]
alpha_hat[non_suppoert_index] = 0
w = np.sum([alpha_hat[i]*y[i]*x[i] for i in range(len(x))], axis=0)
x_m = x[support_index][np.where(y[support_index]==-1)]
x_p = x[support_index][np.where(y[support_index]==1)]
b = - ( np.dot(w, x_p[0]) + np.dot(w, x_m[0]) ) / 2
def classifier(w1, w2, b, x):
return -w1/w2*x - b/w2
for xi, yi in zip(x, y):
if (xi[0] in x_m and xi[1] in x_m) or \
(xi[0] in x_p and xi[1] in x_p):
plt.plot(xi[0], xi[1], 'gv')
continue
if yi == 1:
plt.plot(xi[0], xi[1], 'ro')
else:
plt.plot(xi[0], xi[1], 'bx')
x = np.linspace(-5, 5, 100)
y = classifier(w[0], w[1], b, x)
plt.plot(x, y)
plt.xlim([-5, 5])
plt.ylim([-5, 5])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment