-
-
Save fmailhot/e2ca1910450819a8a287 to your computer and use it in GitHub Desktop.
sklearn GridSearchCV example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from data_utils import LoadData, PrepareData | |
| import numpy as np | |
| from sklearn.feature_extraction.text import CountVectorizer | |
| from sklearn.grid_search import GridSearchCV | |
| from sklearn.linear_model import LogisticRegression, SGDClassifier | |
| from sklearn.naive_bayes import MultinomialNB, BernoulliNB | |
| from sklearn.svm import LinearSVC | |
| from sklearn.pipeline import FeatureUnion, Pipeline | |
| from time import time | |
| from twokenize import tokenize | |
| import sys | |
| ### LOAD & MUNGE DATA ### | |
| raw_data = LoadData(sys.argv[1], sys.argv[2]) | |
| train_tweets, train_labels, test_tweets, test_labels = PrepareData(raw_data) | |
| classes = dict(zip(sorted(set(train_labels)), range(len(set(train_labels))))) | |
| train_targets = np.array([classes[x] for x in train_labels]) | |
| test_targets = np.array([classes[x] for x in test_labels]) | |
| ### BUILD PIPELINE & PARAM GRIDS ### | |
| pipeline_mnb = Pipeline([ | |
| ("vec", CountVectorizer()), | |
| ("clf", MultinomialNB())]) | |
| params_mnb = { | |
| "vec__ngram_range": ((1, 1), (1, 3), (1, 5)), | |
| "vec__tokenizer": (None, tokenize), | |
| "vec__lowercase": (True, False), | |
| "vec__analyzer": ("word", "char_wb"), | |
| "clf__alpha": (0.0, 0.5, 1.0), | |
| "clf__fit_prior": (True, False) | |
| } | |
| pipeline_bnb = Pipeline([ | |
| ("vec", CountVectorizer()), | |
| ("clf", BernoulliNB(binarize=0.0))]) | |
| params_bnb = { | |
| "vec__ngram_range": ((1, 1), (1, 3), (1, 5)), | |
| "vec__tokenizer": (None, tokenize), | |
| "vec__lowercase": (True, False), | |
| "vec__analyzer": ("word", "char_wb"), | |
| "clf__alpha": (0.0, 0.5, 1.0), | |
| "clf__fit_prior": (True, False) | |
| } | |
| pipeline_lr = Pipeline([ | |
| ("vec", CountVectorizer()), | |
| ("clf", LogisticRegression())]) | |
| params_lr = { | |
| "vec__ngram_range": ((1, 1), (1, 3), (1, 5)), | |
| "vec__tokenizer": (None, tokenize), | |
| "vec__lowercase": (True, False), | |
| "vec__analyzer": ("word", "char_wb"), | |
| "clf__penalty": ("l1", "l2"), | |
| "clf__C": (0.1, 1.0, 10.0), | |
| "clf__class_weight": (None, "auto") | |
| } | |
| pipeline_sgd = Pipeline([ | |
| ("vec", CountVectorizer()), | |
| ("clf", SGDClassifier())]) | |
| params_sgd = { | |
| "vec__ngram_range": ((1, 1), (1, 3), (1, 5)), | |
| "vec__tokenizer": (None, tokenize), | |
| "vec__lowercase": (True, False), | |
| "vec__analyzer": ("word", "char_wb"), | |
| "clf__loss": ("log", "hinge", "modified_huber"), | |
| "clf__penalty": ("l1", "l2", "elasticnet"), | |
| "clf__alpha": (0.000001, 0.0001, 0.01, 1.0), | |
| "clf__class_weight": (None, "auto"), | |
| "clf__warm_start": (True, False) | |
| } | |
| pipeline_svc = Pipeline([ | |
| ("vec", CountVectorizer()), | |
| ("clf", LinearSVC())]) | |
| params_svc = { | |
| "vec__ngram_range": ((1, 1), (1, 3), (1, 5)), | |
| "vec__tokenizer": (None, tokenize), | |
| "vec__lowercase": (True, False), | |
| "vec__analyzer": ("word", "char_wb"), | |
| "clf__C": (0.1, 1.0, 10.0), | |
| "clf__class_weight": (None, "auto"), | |
| "clf__loss": ("l1", "l2"), | |
| "clf__penalty": ("l1", "l2") | |
| } | |
| if __name__ == "__main__": | |
| for model, params in ((pipeline_mnb, params_mnb), | |
| (pipeline_bnb, params_bnb), | |
| (pipeline_lr, params_lr), | |
| (pipeline_sgd, params_sgd), | |
| (pipeline_svc, params_svc)): | |
| print "=" * 75 | |
| print "Pipeline: %s" % " ".join([name for name, _ in model.steps]) | |
| grid_search = GridSearchCV(model, params, n_jobs=1, verbose=5) | |
| ### WHY IS THE MULTIPROC VERSION OF THIS THROWING ERRORS?? | |
| #grid_search = GridSearchCV(model, params, n_jobs=2, verbose=5) | |
| tick = time() | |
| grid_search.fit(train_tweets, train_targets) | |
| print "Done in %0.3fs" % (time() - tick) | |
| print "-" * 75 | |
| print "Best score: %0.3f" % grid_search.best_score_ | |
| print "Best params:" | |
| best_params = grid_search.best_estimator_.get_params() | |
| for param_name in sorted(best_params.keys()): | |
| print "\t%s: %r" % (param_name, best_params[param_name]) |
Author
btw you can combine char n-grams and word-ngrams using FeatureStacker. That makes the grid even bigger, though ;)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
LoadData and PrepareData are custom functions to parse some proprietary JSON (pulling out the values of interest and splitting out train/test sets).