Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
"""
Script for comparing Logistic Regression with L1, L2, and elastic net regularization and the liblinear, sag, and sgd optimizers. You'll need to download a copy of the dataset from http://plg.uwaterloo.ca/~gvcormac/treccorpus07/about.html .
Copyright 2016 Ronald J. Nowling
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from collections import defaultdict
from itertools import islice
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.metrics import roc_curve, roc_auc_score
DATA_DIR = "data"
FIGURES_DIR = "figures"
def _parse_message(message):
from bs4 import BeautifulSoup
body = ""
if message.is_multipart():
for part in message.walk():
ctype = part.get_content_type()
cdispo = str(part.get('Content-Disposition'))
# skip any attachments
if ctype == 'text/html' and 'attachment' not in cdispo:
body = part.get_payload(decode=True)
break
elif ctype == 'text/txt' and 'attachment' not in cdispo:
body = part.get_payload(decode=True)
break
# not multipart - i.e. plain text, no attachments, keeping fingers crossed
else:
body = message.get_payload(decode=True)
return message["To"], message["From"], BeautifulSoup(body, 'html.parser').get_text()
def stream_email(data_dir):
from email.parser import Parser
email_parser = Parser()
index_flname = data_dir + "/trec07p/full/index"
with open(index_flname) as index_fl:
for idx, ln in enumerate(index_fl):
category, email_fl_suffix = ln.strip().split()
if category == "ham":
label = 0
elif category == "spam":
label = 1
# strip .. prefix from path
email_flname = data_dir + "/trec07p" + email_fl_suffix[2:]
with open(email_flname) as email_fl:
message = email_parser.parse(email_fl)
to, from_, body = _parse_message(message)
yield (label, to, from_, body)
if __name__ == "__main__":
training_size = int(75419. * 0.75) # from Attenberg paper
vectorizer = TfidfVectorizer(binary=True, norm=None, use_idf=False)
lr_l2 = LogisticRegression(solver="sag")
lr_l1 = LogisticRegression(penalty="l1", solver="liblinear")
sgd_l1 = SGDClassifier(loss="log", penalty="l1")
sgd_l2 = SGDClassifier(loss="log", penalty="l2")
sgd_l1 = SGDClassifier(loss="log", penalty="l1")
sgd_elastic = SGDClassifier(loss="log", penalty="elasticnet")
stream = stream_email(DATA_DIR)
counts = defaultdict(int)
next_output = 1
bodies = []
training_labels = []
for idx, (label, to, from_, body) in enumerate(islice(stream, training_size)):
bodies.append(body)
training_labels.append(label)
counts[label] += 1
count = idx + 1
if count == next_output:
print count, counts
next_output *= 2
print count, counts
print "Vectorizing"
training_features = vectorizer.fit_transform(bodies)
print "Training"
lr_l1.fit(training_features, training_labels)
lr_l2.fit(training_features, training_labels)
sgd_l1.fit(training_features, training_labels)
sgd_l2.fit(training_features, training_labels)
sgd_elastic.fit(training_features, training_labels)
next_output = 1
counts = defaultdict(int)
bodies = []
true_labels = []
for idx, (label, to, from_, body) in enumerate(stream):
bodies.append(body)
true_labels.append(label)
counts[label] += 1
count = idx + 1
if count == next_output:
print count, counts
next_output *= 2
print count, counts
print "Transforming"
prediction_features = vectorizer.transform(bodies)
print "Predicting"
pred_probs_lr_l1 = lr_l1.predict_proba(prediction_features)
pred_probs_lr_l2 = lr_l2.predict_proba(prediction_features)
pred_probs_sgd_l1 = sgd_l1.predict_proba(prediction_features)
pred_probs_sgd_l2 = sgd_l2.predict_proba(prediction_features)
pred_probs_sgd_elastic = sgd_elastic.predict_proba(prediction_features)
plt.clf()
plt.hold(True)
fpr, tpr, _ = roc_curve(true_labels, pred_probs_lr_l1[:, 1], pos_label=1)
plt.plot(fpr, tpr, label="Liblinear L1")
fpr, tpr, _ = roc_curve(true_labels, pred_probs_lr_l2[:, 1], pos_label=1)
plt.plot(fpr, tpr, label="sag L2")
fpr, tpr, _ = roc_curve(true_labels, pred_probs_sgd_l1[:, 1], pos_label=1)
plt.plot(fpr, tpr, label="SGD L1")
fpr, tpr, _ = roc_curve(true_labels, pred_probs_sgd_l2[:, 1], pos_label=1)
plt.plot(fpr, tpr, label="SGD L2")
fpr, tpr, _ = roc_curve(true_labels, pred_probs_sgd_elastic[:, 1], pos_label=1)
plt.plot(fpr, tpr, label="SGD EN")
plt.xlabel("False Positive Rate", fontsize=16)
plt.ylabel("True Positive Rate", fontsize=16)
plt.xlim([0.0, 0.1])
plt.legend(loc="lower right")
plt.savefig(FIGURES_DIR + "/roc_curve.png", DPI=300)
n_features = prediction_features.shape[1]
print "features", n_features
plt.clf()
plt.hold(True)
auc = roc_auc_score(true_labels, pred_probs_lr_l1[:, 1])
sparsity = float(n_features - (lr_l1.coef_ == 0).sum()) / n_features
print "Liblinear L1 auc", auc, "sparsity", sparsity
plt.scatter(sparsity, auc, color="r", label="Liblinear L1")
auc = roc_auc_score(true_labels, pred_probs_lr_l2[:, 1])
sparsity = float(n_features - (lr_l2.coef_ == 0).sum()) / n_features
print "sag L2 auc", auc, "sparsity", sparsity
plt.scatter(sparsity, auc, color="g", label="sag L2")
auc = roc_auc_score(true_labels, pred_probs_sgd_l1[:, 1])
sparsity = float(n_features - (sgd_l1.coef_ == 0).sum()) / n_features
print "SGD L1 auc", auc, "sparsity", sparsity
plt.scatter(sparsity, auc, color="b", label="SGD L1")
auc = roc_auc_score(true_labels, pred_probs_sgd_l2[:, 1])
sparsity = float(n_features - (sgd_l2.coef_ == 0).sum()) / n_features
print "SGD L2 auc", auc, "sparsity", sparsity
plt.scatter(sparsity, auc, color="m", label="SGD L2")
auc = roc_auc_score(true_labels, pred_probs_sgd_elastic[:, 1])
sparsity = float(n_features - (sgd_elastic.coef_ == 0).sum()) / n_features
print "SGD EN auc", auc, "sparsity", sparsity
plt.scatter(sparsity, auc, color="c", label="SGD EN")
plt.xscale("log")
plt.legend(loc="lower right")
plt.ylabel("Area Under the Curve", fontsize=16)
plt.xlabel("Sparsity (% non-zeros)", fontsize=16)
plt.xlim([0.001, 1.0])
plt.ylim([0.975, 1.0])
plt.savefig(FIGURES_DIR + "/sparsity_auc.png", DPI=300)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment