Skip to content

Instantly share code, notes, and snippets.

@mmmayo13
Last active March 16, 2020 03:12
Show Gist options
  • Save mmmayo13/110201055ed7cf848bf4efabd94e7b9f to your computer and use it in GitHub Desktop.
Save mmmayo13/110201055ed7cf848bf4efabd94e7b9f to your computer and use it in GitHub Desktop.
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
from sklearn.externals import joblib
from sklearn.linear_model import LogisticRegression
from sklearn import svm
from sklearn import tree
# Load and split the data
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
# Construct some pipelines
pipe_lr = Pipeline([('scl', StandardScaler()),
('pca', PCA(n_components=2)),
('clf', LogisticRegression(random_state=42))])
pipe_svm = Pipeline([('scl', StandardScaler()),
('pca', PCA(n_components=2)),
('clf', svm.SVC(random_state=42))])
pipe_dt = Pipeline([('scl', StandardScaler()),
('pca', PCA(n_components=2)),
('clf', tree.DecisionTreeClassifier(random_state=42))])
# List of pipelines for ease of iteration
pipelines = [pipe_lr, pipe_svm, pipe_dt]
# Dictionary of pipelines and classifier types for ease of reference
pipe_dict = {0: 'Logistic Regression', 1: 'Support Vector Machine', 2: 'Decision Tree'}
# Fit the pipelines
for pipe in pipelines:
pipe.fit(X_train, y_train)
# Compare accuracies
for idx, val in enumerate(pipelines):
print('%s pipeline test accuracy: %.3f' % (pipe_dict[idx], val.score(X_test, y_test)))
# Identify the most accurate model on test data
best_acc = 0.0
best_clf = 0
best_pipe = ''
for idx, val in enumerate(pipelines):
if val.score(X_test, y_test) > best_acc:
best_acc = val.score(X_test, y_test)
best_pipe = val
best_clf = idx
print('Classifier with best accuracy: %s' % pipe_dict[best_clf])
# Save pipeline to file
joblib.dump(best_pipe, 'best_pipeline.pkl', compress=1)
print('Saved %s pipeline to file' % pipe_dict[best_clf])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment