Created
May 22, 2011 12:53
-
-
Save agramfort/985437 to your computer and use it in GitHub Desktop.
test LinearSVC variance pb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pprint import pprint | |
import numpy as np | |
from scipy import sparse | |
from scikits.learn.grid_search import GridSearchCV | |
from scikits.learn.cross_val import StratifiedKFold | |
from scikits.learn.metrics import f1_score, classification_report | |
from scikits.learn import svm | |
from scikits.learn.linear_model import LogisticRegression | |
from scikits.learn.linear_model.sparse import LogisticRegression as SparseLogisticRegression | |
data = np.loadtxt('featset1.csv', delimiter=',') | |
Y = data[:,-1] | |
# Dense classifiers | |
X = data[:,:-1] | |
# grid_clf = svm.LinearSVC(tol=1e-10) | |
# grid_clf = LogisticRegression(tol=1e-10) | |
# Sparse classifiers | |
X = sparse.csr_matrix(data[:,:-1]) | |
grid_clf = SparseLogisticRegression(tol=1e-10) | |
# grid_clf = svm.sparse.LinearSVC(tol=1e-10) | |
print grid_clf | |
C_start, C_end, C_step = -5, 15, 3 | |
train, test = iter(StratifiedKFold(Y, 2, indices = True)).next() | |
# Generate grid search values for C, gamma | |
C_val = 2. ** np.arange(C_start, C_end + C_step, C_step) | |
linear_SVC_params = {'C': C_val} | |
# n_jobs = 100 | |
n_jobs = 2 | |
grid_search = GridSearchCV(grid_clf, linear_SVC_params, n_jobs=n_jobs, | |
score_func=f1_score) | |
grid_search.fit(X[train], Y[train], cv=StratifiedKFold(Y[train], 10, indices=True)) | |
y_true, y_pred = Y[test], grid_search.predict(X[test]) | |
print "Classification report for the best estimator: " | |
print grid_search.best_estimator | |
print "Tuned for with optimal value: %0.3f" % f1_score(y_true, y_pred) | |
print classification_report(y_true, y_pred) | |
print "Grid scores:" | |
pprint(grid_search.grid_scores_) | |
print "Best score: %0.3f" % grid_search.best_score | |
best_parameters = grid_search.best_estimator._get_params() | |
print "Best C: %0.3f " % best_parameters['C'] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment