Skip to content

Instantly share code, notes, and snippets.

@mmmayo13
Last active January 30, 2021 17:17
Show Gist options
  • Save mmmayo13/c886cdc91d9fdc721301dfacac558bcf to your computer and use it in GitHub Desktop.
Save mmmayo13/c886cdc91d9fdc721301dfacac558bcf to your computer and use it in GitHub Desktop.
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score
from sklearn.externals import joblib
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn import svm
# Load and split the data
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
# Construct some pipelines
pipe_lr = Pipeline([('scl', StandardScaler()),
('clf', LogisticRegression(random_state=42))])
pipe_lr_pca = Pipeline([('scl', StandardScaler()),
('pca', PCA(n_components=2)),
('clf', LogisticRegression(random_state=42))])
pipe_rf = Pipeline([('scl', StandardScaler()),
('clf', RandomForestClassifier(random_state=42))])
pipe_rf_pca = Pipeline([('scl', StandardScaler()),
('pca', PCA(n_components=2)),
('clf', RandomForestClassifier(random_state=42))])
pipe_svm = Pipeline([('scl', StandardScaler()),
('clf', svm.SVC(random_state=42))])
pipe_svm_pca = Pipeline([('scl', StandardScaler()),
('pca', PCA(n_components=2)),
('clf', svm.SVC(random_state=42))])
# Set grid search params
param_range = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
param_range_fl = [1.0, 0.5, 0.1]
grid_params_lr = [{'clf__penalty': ['l1', 'l2'],
'clf__C': param_range_fl,
'clf__solver': ['liblinear']}]
grid_params_rf = [{'clf__criterion': ['gini', 'entropy'],
'clf__min_samples_leaf': param_range,
'clf__max_depth': param_range,
'clf__min_samples_split': param_range[1:]}]
grid_params_svm = [{'clf__kernel': ['linear', 'rbf'],
'clf__C': param_range}]
# Construct grid searches
jobs = -1
gs_lr = GridSearchCV(estimator=pipe_lr,
param_grid=grid_params_lr,
scoring='accuracy',
cv=10)
gs_lr_pca = GridSearchCV(estimator=pipe_lr_pca,
param_grid=grid_params_lr,
scoring='accuracy',
cv=10)
gs_rf = GridSearchCV(estimator=pipe_rf,
param_grid=grid_params_rf,
scoring='accuracy',
cv=10,
n_jobs=jobs)
gs_rf_pca = GridSearchCV(estimator=pipe_rf_pca,
param_grid=grid_params_rf,
scoring='accuracy',
cv=10,
n_jobs=jobs)
gs_svm = GridSearchCV(estimator=pipe_svm,
param_grid=grid_params_svm,
scoring='accuracy',
cv=10,
n_jobs=jobs)
gs_svm_pca = GridSearchCV(estimator=pipe_svm_pca,
param_grid=grid_params_svm,
scoring='accuracy',
cv=10,
n_jobs=jobs)
# List of pipelines for ease of iteration
grids = [gs_lr, gs_lr_pca, gs_rf, gs_rf_pca, gs_svm, gs_svm_pca]
# Dictionary of pipelines and classifier types for ease of reference
grid_dict = {0: 'Logistic Regression', 1: 'Logistic Regression w/PCA',
2: 'Random Forest', 3: 'Random Forest w/PCA',
4: 'Support Vector Machine', 5: 'Support Vector Machine w/PCA'}
# Fit the grid search objects
print('Performing model optimizations...')
best_acc = 0.0
best_clf = 0
best_gs = ''
for idx, gs in enumerate(grids):
print('\nEstimator: %s' % grid_dict[idx])
# Fit grid search
gs.fit(X_train, y_train)
# Best params
print('Best params: %s' % gs.best_params_)
# Best training data accuracy
print('Best training accuracy: %.3f' % gs.best_score_)
# Predict on test data with best params
y_pred = gs.predict(X_test)
# Test data accuracy of model with best params
print('Test set accuracy score for best params: %.3f ' % accuracy_score(y_test, y_pred))
# Track best (highest test accuracy) model
if accuracy_score(y_test, y_pred) > best_acc:
best_acc = accuracy_score(y_test, y_pred)
best_gs = gs
best_clf = idx
print('\nClassifier with best test set accuracy: %s' % grid_dict[best_clf])
# Save best grid search pipeline to file
dump_file = 'best_gs_pipeline.pkl'
joblib.dump(best_gs, dump_file, compress=1)
print('\nSaved %s grid search pipeline to file: %s' % (grid_dict[best_clf], dump_file))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment