Last active
August 29, 2015 13:55
-
-
Save arjoly/8732555 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Some results on the 20 news dataset | |
Classifier train-time test-time error-rate | |
-------------------------------------------- | |
5-nn 0.0047s 13.6651s 0.5916 | |
random forest 263.3146s 3.9985s 0.2459 | |
sgd 0.2265s 0.0657s 0.2604 | |
""" | |
from __future__ import print_function, division | |
from time import time | |
import numpy as np | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.neighbors import KNeighborsClassifier | |
from sklearn.dummy import DummyClassifier | |
from sklearn.linear_model import SGDClassifier | |
from sklearn.metrics import zero_one_loss | |
############################################################################### | |
# Bench | |
def benchmark(clf): | |
t0 = time() | |
clf.fit(X_train, y_train) | |
train_time = time() - t0 | |
t0 = time() | |
pred = clf.predict(X_test) | |
test_time = time() - t0 | |
err = zero_one_loss(y_test, pred) | |
return err, train_time, test_time | |
############################################################################### | |
# Estimators | |
DENSE_ESTIMATORS = { | |
"random forest": RandomForestClassifier(n_estimators=100), | |
} | |
SPARSE_ESTIMATORS = { | |
"dummy": DummyClassifier(), | |
"5-nn": KNeighborsClassifier(n_neighbors=5), | |
"sgd": SGDClassifier(alpha=0.001, n_iter=2), | |
} | |
############################################################################### | |
# Data | |
from sklearn.datasets import fetch_20newsgroups_vectorized | |
data_train = fetch_20newsgroups_vectorized(subset="train") | |
data_test = fetch_20newsgroups_vectorized(subset="test") | |
y_train = data_train.target | |
X_train = data_train.data | |
X_test = data_test.data | |
y_test = data_test.target | |
print("Data") | |
print("X_train {0}".format(X_train.shape)) | |
print("y_train {0}".format(y_train.shape)) | |
print("X_test {0}".format(X_test.shape)) | |
print("y_test {0}".format(y_test.shape)) | |
############################################################################### | |
# Bench | |
print() | |
print("Training Classifiers on sparse data") | |
print("===================================") | |
err, train_time, test_time = {}, {}, {} | |
for name in sorted(SPARSE_ESTIMATORS): | |
clf = SPARSE_ESTIMATORS[name] | |
try: | |
clf.set_params(random_state=0) | |
except: | |
pass | |
print("Training %s ..." % name, end="") | |
(err[name], train_time[name], | |
test_time[name]) = benchmark(SPARSE_ESTIMATORS[name]) | |
print("done") | |
print() | |
print("Training Classifiers on dense data") | |
print("===================================") | |
X_train = np.asfortranarray(X_train.toarray(), dtype=np.float32) | |
X_test = np.asfortranarray(X_test.toarray(), dtype=np.float32) | |
for name in sorted(DENSE_ESTIMATORS): | |
clf = DENSE_ESTIMATORS[name] | |
try: | |
clf.set_params(random_state=0) | |
except: | |
pass | |
print("Training %s ..." % name, end="") | |
(err[name], train_time[name], | |
test_time[name]) = benchmark(clf) | |
print("done") | |
###################################################################### | |
## Print classification performance | |
print() | |
print("Classification performance:") | |
print("===========================") | |
print() | |
def print_row(clf_type, train_time, test_time, err): | |
print("%s %s %s %s" % (clf_type.ljust(12), | |
("%.4fs" % train_time).center(10), | |
("%.4fs" % test_time).center(10), | |
("%.4f" % err).center(10))) | |
print("%s %s %s %s" % ("Classifier ", "train-time", "test-time", | |
"error-rate")) | |
print("-" * 44) | |
for name in sorted(err, key=lambda name: err[name]): | |
print_row(name, train_time[name], test_time[name], err[name]) | |
print() | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment