Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Multiclass SVMs
"""
Multiclass SVMs (Crammer-Singer formulation).
A pure Python re-implementation of:
Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex.
Mathieu Blondel, Akinori Fujino, and Naonori Ueda.
ICPR 2014.
http://www.mblondel.org/publications/mblondel-icpr2014.pdf
"""
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils import check_random_state
from sklearn.preprocessing import LabelEncoder
def projection_simplex(v, z=1):
"""
Projection onto the simplex:
w^* = argmin_w 0.5 ||w-v||^2 s.t. \sum_i w_i = z, w_i >= 0
"""
# For other algorithms computing the same projection, see
# https://gist.github.com/mblondel/6f3b7aaad90606b98f71
n_features = v.shape[0]
u = np.sort(v)[::-1]
cssv = np.cumsum(u) - z
ind = np.arange(n_features) + 1
cond = u - cssv / ind > 0
rho = ind[cond][-1]
theta = cssv[cond][-1] / float(rho)
w = np.maximum(v - theta, 0)
return w
class MulticlassSVM(BaseEstimator, ClassifierMixin):
def __init__(self, C=1, max_iter=50, tol=0.05,
random_state=None, verbose=0):
self.C = C
self.max_iter = max_iter
self.tol = tol,
self.random_state = random_state
self.verbose = verbose
def _partial_gradient(self, X, y, i):
# Partial gradient for the ith sample.
g = np.dot(X[i], self.coef_.T) + 1
g[y[i]] -= 1
return g
def _violation(self, g, y, i):
# Optimality violation for the ith sample.
smallest = np.inf
for k in range(g.shape[0]):
if k == y[i] and self.dual_coef_[k, i] >= self.C:
continue
elif k != y[i] and self.dual_coef_[k, i] >= 0:
continue
smallest = min(smallest, g[k])
return g.max() - smallest
def _solve_subproblem(self, g, y, norms, i):
# Prepare inputs to the projection.
Ci = np.zeros(g.shape[0])
Ci[y[i]] = self.C
beta_hat = norms[i] * (Ci - self.dual_coef_[:, i]) + g / norms[i]
z = self.C * norms[i]
# Compute projection onto the simplex.
beta = projection_simplex(beta_hat, z)
return Ci - self.dual_coef_[:, i] - beta / norms[i]
def fit(self, X, y):
n_samples, n_features = X.shape
# Normalize labels.
self._label_encoder = LabelEncoder()
y = self._label_encoder.fit_transform(y)
# Initialize primal and dual coefficients.
n_classes = len(self._label_encoder.classes_)
self.dual_coef_ = np.zeros((n_classes, n_samples), dtype=np.float64)
self.coef_ = np.zeros((n_classes, n_features))
# Pre-compute norms.
norms = np.sqrt(np.sum(X ** 2, axis=1))
# Shuffle sample indices.
rs = check_random_state(self.random_state)
ind = np.arange(n_samples)
rs.shuffle(ind)
violation_init = None
for it in range(self.max_iter):
violation_sum = 0
for ii in range(n_samples):
i = ind[ii]
# All-zero samples can be safely ignored.
if norms[i] == 0:
continue
g = self._partial_gradient(X, y, i)
v = self._violation(g, y, i)
violation_sum += v
if v < 1e-12:
continue
# Solve subproblem for the ith sample.
delta = self._solve_subproblem(g, y, norms, i)
# Update primal and dual coefficients.
self.coef_ += (delta * X[i][:, np.newaxis]).T
self.dual_coef_[:, i] += delta
if it == 0:
violation_init = violation_sum
vratio = violation_sum / violation_init
if self.verbose >= 1:
print("iter", it + 1, "violation", vratio)
if vratio < self.tol:
if self.verbose >= 1:
print("Converged")
break
return self
def predict(self, X):
decision = np.dot(X, self.coef_.T)
pred = decision.argmax(axis=1)
return self._label_encoder.inverse_transform(pred)
if __name__ == '__main__':
from sklearn.datasets import load_iris
iris = load_iris()
X, y = iris.data, iris.target
clf = MulticlassSVM(C=0.1, tol=0.01, max_iter=100, random_state=0, verbose=1)
clf.fit(X, y)
print(clf.score(X, y))
@scienceML

This comment has been minimized.

Show comment Hide comment
@scienceML

scienceML Feb 15, 2017

small comment: For Python 3.x use instead of "xrange" range

small comment: For Python 3.x use instead of "xrange" range

@mab85

This comment has been minimized.

Show comment Hide comment
@mab85

mab85 Mar 30, 2018

How can run for own dataset?

mab85 commented Mar 30, 2018

How can run for own dataset?

@oottoohh

This comment has been minimized.

Show comment Hide comment
@oottoohh

oottoohh Apr 6, 2018

im new with SVM, can you tell me what kind of Multiclass SVM is this?

oottoohh commented Apr 6, 2018

im new with SVM, can you tell me what kind of Multiclass SVM is this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment