Created
November 9, 2018 12:30
-
-
Save ruliana/a909dec9033887f55e353132e50f648e to your computer and use it in GitHub Desktop.
First part of predict vs predict_proba for SVC in Scikit-Learn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
from itertools import count | |
import re | |
import numpy as np | |
import pandas as pd | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.cross_validation import train_test_split | |
from sklearn.svm import SVC | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
# Helpers | |
def variableize(text): | |
# add "_" between Camel Case | |
rslt = re.sub(r'([a-z])([A-Z][^A-Z])', r'\1_\2', text) | |
# replace non letters by "_" | |
rslt = re.sub(r'[^a-z]+', '_', rslt.lower()).strip('_') | |
# Join solo letter (acronyms) | |
return re.sub(r'(\b|_)(\w)_(\w)(\b|_)', r'\1\2\3\4', rslt) | |
#%% Load and fix column names | |
# This dataset comes from: | |
# http://www.superdatascience.com/wp-content/uploads/2017/02/SVM.zip | |
data = pd.read_csv('Social_Network_Ads.csv', | |
index_col='User ID', | |
usecols=['User ID', 'Age', 'EstimatedSalary', 'Purchased']) | |
# Fix headers | |
data.columns = [variableize(col) for col in data.columns] | |
features = data.iloc[:, :-1] | |
target = data.iloc[:, -1] | |
#%% Train and Test | |
features_train, features_test, target_train, target_test = train_test_split(features, target, test_size=0.25, random_state=0) | |
#%% Fit Train Test Split | |
class PredictorResults: | |
def __init__(self, scaler, model): | |
self.scaler = scaler | |
self.model = model | |
def predict_proba(self, features): | |
return self.model.predict_proba(self.scaler.transform(features))[:, 1] | |
def predict_by_proba(self, features): | |
return self.predict_proba(features) >= 0.5 | |
def decision_function(self, features): | |
return self.model.decision_function(self.scaler.transform(features)) | |
def predict(self, features): | |
return self.model.predict(self.scaler.transform(features)) | |
class Predictor: | |
def __init__(self, features, target, random_state=None): | |
self.features = features | |
self.target = target | |
self.random_state = random_state | |
self.scaler = StandardScaler() | |
self.scaler.fit(self.features) | |
def fit(self, C): | |
model = SVC(kernel="rbf", probability=True, | |
C=C, random_state=self.random_state) | |
model.fit(self.scaler.transform(self.features), self.target) | |
return PredictorResults(self.scaler, model) | |
#%% Visualize | |
def visualize(contour_resolution, features, target): | |
def plot(*predicts): | |
# Contour preparation | |
x1_axis = np.linspace(features.age.min(), features.age.max(), contour_resolution) | |
x2_axis = np.linspace(features.estimated_salary.min(), features.estimated_salary.max(), contour_resolution) | |
x1_mesh, x2_mesh = np.meshgrid(x1_axis, x2_axis) | |
features_grid = pd.DataFrame((np.array([x1_mesh, x2_mesh]).T).reshape(-1, 2), columns=['age', 'estimated_salary']) | |
# Plotting | |
fig, charts = plt.subplots(1, len(predicts), | |
figsize=(4 * len(predicts), 4), | |
sharex=True, sharey=True) | |
for (name, predict), chart in zip(predicts, charts): | |
y = predict(features_grid).reshape(x1_mesh.shape).T | |
# Plotting prediction | |
chart.contourf(x1_mesh, x2_mesh, y, cmap=plt.get_cmap('RdBu')) | |
sns.scatterplot(features.age, features.estimated_salary, | |
hue=target, | |
ax=chart, palette='RdBu', | |
legend=False, | |
size=2) | |
chart.set_title(name) | |
chart.set_xlabel('Age') | |
chart.set_ylabel('Salary') | |
fig.tight_layout() | |
return fig | |
return plot | |
#%% Plot a sequence of images to create an animated gif later | |
# Software for animated GIF edition: | |
# https://www.fossmint.com/create-animated-gifs-using-giftedmotion-on-linux/ | |
plot = visualize(500, features_train, target_train) | |
predictor = Predictor(features_train, target_train, random_state=42) | |
plt.ioff() | |
for C, index in zip(np.logspace(-5, 5, 200), count(1)): | |
fitted = predictor.fit(C) | |
# Figure 1 | |
fig = plot(('SVC Predict C={:.1E}'.format(C), fitted.predict), | |
('SVC Decision Function', fitted.decision_function)) | |
fig.savefig('animation1/image{:03d}.png'.format(index)) | |
plt.close(fig) | |
# Figure 2 | |
fig = plot(('SVC Predict C={:.1E}'.format(C), fitted.predict_by_proba), | |
('SVC Predict Proba', fitted.predict_proba)) | |
fig.savefig('animation2/image{:03d}.png'.format(index)) | |
plt.close(fig) | |
# Figure 3 | |
fig = plot(('SVC Predict C={:.1E}'.format(C), fitted.predict), | |
('SVC Predict by Probabity ≥ 0.5', fitted.predict_by_proba)) | |
fig.savefig('animation3/image{:03d}.png'.format(index)) | |
plt.close(fig) | |
plt.ion() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment