Skip to content

Instantly share code, notes, and snippets.

@fmailhot
Created November 15, 2012 20:30
Show Gist options
  • Select an option

  • Save fmailhot/e2ca1910450819a8a287 to your computer and use it in GitHub Desktop.

Select an option

Save fmailhot/e2ca1910450819a8a287 to your computer and use it in GitHub Desktop.
sklearn GridSearchCV example
from data_utils import LoadData, PrepareData
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.grid_search import GridSearchCV
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.naive_bayes import MultinomialNB, BernoulliNB
from sklearn.svm import LinearSVC
from sklearn.pipeline import FeatureUnion, Pipeline
from time import time
from twokenize import tokenize
import sys
### LOAD & MUNGE DATA ###
raw_data = LoadData(sys.argv[1], sys.argv[2])
train_tweets, train_labels, test_tweets, test_labels = PrepareData(raw_data)
classes = dict(zip(sorted(set(train_labels)), range(len(set(train_labels)))))
train_targets = np.array([classes[x] for x in train_labels])
test_targets = np.array([classes[x] for x in test_labels])
### BUILD PIPELINE & PARAM GRIDS ###
pipeline_mnb = Pipeline([
("vec", CountVectorizer()),
("clf", MultinomialNB())])
params_mnb = {
"vec__ngram_range": ((1, 1), (1, 3), (1, 5)),
"vec__tokenizer": (None, tokenize),
"vec__lowercase": (True, False),
"vec__analyzer": ("word", "char_wb"),
"clf__alpha": (0.0, 0.5, 1.0),
"clf__fit_prior": (True, False)
}
pipeline_bnb = Pipeline([
("vec", CountVectorizer()),
("clf", BernoulliNB(binarize=0.0))])
params_bnb = {
"vec__ngram_range": ((1, 1), (1, 3), (1, 5)),
"vec__tokenizer": (None, tokenize),
"vec__lowercase": (True, False),
"vec__analyzer": ("word", "char_wb"),
"clf__alpha": (0.0, 0.5, 1.0),
"clf__fit_prior": (True, False)
}
pipeline_lr = Pipeline([
("vec", CountVectorizer()),
("clf", LogisticRegression())])
params_lr = {
"vec__ngram_range": ((1, 1), (1, 3), (1, 5)),
"vec__tokenizer": (None, tokenize),
"vec__lowercase": (True, False),
"vec__analyzer": ("word", "char_wb"),
"clf__penalty": ("l1", "l2"),
"clf__C": (0.1, 1.0, 10.0),
"clf__class_weight": (None, "auto")
}
pipeline_sgd = Pipeline([
("vec", CountVectorizer()),
("clf", SGDClassifier())])
params_sgd = {
"vec__ngram_range": ((1, 1), (1, 3), (1, 5)),
"vec__tokenizer": (None, tokenize),
"vec__lowercase": (True, False),
"vec__analyzer": ("word", "char_wb"),
"clf__loss": ("log", "hinge", "modified_huber"),
"clf__penalty": ("l1", "l2", "elasticnet"),
"clf__alpha": (0.000001, 0.0001, 0.01, 1.0),
"clf__class_weight": (None, "auto"),
"clf__warm_start": (True, False)
}
pipeline_svc = Pipeline([
("vec", CountVectorizer()),
("clf", LinearSVC())])
params_svc = {
"vec__ngram_range": ((1, 1), (1, 3), (1, 5)),
"vec__tokenizer": (None, tokenize),
"vec__lowercase": (True, False),
"vec__analyzer": ("word", "char_wb"),
"clf__C": (0.1, 1.0, 10.0),
"clf__class_weight": (None, "auto"),
"clf__loss": ("l1", "l2"),
"clf__penalty": ("l1", "l2")
}
if __name__ == "__main__":
for model, params in ((pipeline_mnb, params_mnb),
(pipeline_bnb, params_bnb),
(pipeline_lr, params_lr),
(pipeline_sgd, params_sgd),
(pipeline_svc, params_svc)):
print "=" * 75
print "Pipeline: %s" % " ".join([name for name, _ in model.steps])
grid_search = GridSearchCV(model, params, n_jobs=1, verbose=5)
### WHY IS THE MULTIPROC VERSION OF THIS THROWING ERRORS??
#grid_search = GridSearchCV(model, params, n_jobs=2, verbose=5)
tick = time()
grid_search.fit(train_tweets, train_targets)
print "Done in %0.3fs" % (time() - tick)
print "-" * 75
print "Best score: %0.3f" % grid_search.best_score_
print "Best params:"
best_params = grid_search.best_estimator_.get_params()
for param_name in sorted(best_params.keys()):
print "\t%s: %r" % (param_name, best_params[param_name])
@fmailhot
Copy link
Author

LoadData and PrepareData are custom functions to parse some proprietary JSON (pulling out the values of interest and splitting out train/test sets).

@amueller
Copy link

btw you can combine char n-grams and word-ngrams using FeatureStacker. That makes the grid even bigger, though ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment