Skip to content

Instantly share code, notes, and snippets.

@IshitaTakeshi
Last active September 25, 2015 12:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save IshitaTakeshi/58fbf89e053206440cd9 to your computer and use it in GitHub Desktop.
Save IshitaTakeshi/58fbf89e053206440cd9 to your computer and use it in GitHub Desktop.
高速、高精度、省メモリな線形分類器、SCW ref: http://qiita.com/IshitaTakeshi/items/3e41d3ec045422f7b8d8
from __future__ import division
import time
import numpy as np
from sklearn.datasets import load_digits, make_classification
from sklearn.svm import SVC
from matplotlib import pyplot
from scw import SCW1
def generate_dataset():
digits = load_digits(2)
classes = np.unique(digits.target)
y = []
for target in digits.target:
if(target == classes[0]):
y.append(-1)
if(target == classes[1]):
y.append(1)
y = np.array(y)
return digits.data, y
def calc_accuracy(resutls, answers):
n_correct_answers = 0
for result, answer in zip(results, answers):
if(result == answer):
n_correct_answers += 1
accuracy = n_correct_answers/len(results)
return accuracy
X, y = generate_dataset()
N = int(len(X)*0.8)
training, test = X[:N], X[N:]
labels, answers = y[:N], y[N:]
scw = SCW1(len(X[0]), C=1.0, ETA=1.0)
t1 = time.time()
scw.fit(training, labels)
t2 = time.time()
results = scw.predict(test)
accuracy = calc_accuracy(results, answers)
print("SCW time:{:3.6f} accuracy:{:1.3f}".format(t2-t1, accuracy))
svc = SVC(C=10.0)
t1 = time.time()
svc.fit(training, labels)
t2 = time.time()
results = svc.predict(test)
accuracy = calc_accuracy(results, answers)
print("SVC time:{:3.6f} accuracy:{:1.3f}".format(t2-t1, accuracy))
SCW time:0.003194 accuracy:1.000
SVC time:0.010297 accuracy:0.903
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment