Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
"""
Script for comparing spam classification with a bag-of-words model constructed with and without hashing. You'll need to download a copy of the dataset from http://plg.uwaterloo.ca/~gvcormac/treccorpus07/about.html .
Copyright 2016 Ronald J. Nowling
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from collections import defaultdict
from itertools import islice
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import HashingVectorizer, TfidfVectorizer
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import roc_auc_score
DATA_DIR = "data"
FIGURES_DIR = "figures"
def _parse_message(message):
from bs4 import BeautifulSoup
body = ""
if message.is_multipart():
for part in message.walk():
ctype = part.get_content_type()
cdispo = str(part.get('Content-Disposition'))
# skip any attachments
if ctype == 'text/html' and 'attachment' not in cdispo:
body = part.get_payload(decode=True)
break
elif ctype == 'text/txt' and 'attachment' not in cdispo:
body = part.get_payload(decode=True)
break
# not multipart - i.e. plain text, no attachments, keeping fingers crossed
else:
body = message.get_payload(decode=True)
return message["To"], message["From"], BeautifulSoup(body, 'html.parser').get_text()
def stream_email(data_dir):
from email.parser import Parser
email_parser = Parser()
index_flname = data_dir + "/trec07p/full/index"
with open(index_flname) as index_fl:
for idx, ln in enumerate(index_fl):
category, email_fl_suffix = ln.strip().split()
if category == "ham":
label = 0
elif category == "spam":
label = 1
# strip .. prefix from path
email_flname = data_dir + "/trec07p" + email_fl_suffix[2:]
with open(email_flname) as email_fl:
message = email_parser.parse(email_fl)
to, from_, body = _parse_message(message)
yield (label, to, from_, body)
if __name__ == "__main__":
training_size = int(75419. * 0.75) # from Attenberg paper
stream = stream_email(DATA_DIR)
counts = defaultdict(int)
next_output = 1
training_bodies = []
training_labels = []
testing_bodies = []
testing_labels = []
for idx, (label, to, from_, body) in enumerate(stream):
if idx < training_size:
training_bodies.append(body)
training_labels.append(label)
else:
testing_bodies.append(body)
testing_labels.append(label)
counts[label] += 1
count = idx + 1
if count == next_output:
print count, counts
next_output *= 2
print count, counts
tfidf_vectorizer = TfidfVectorizer(binary=True, norm=None, use_idf=False)
tfidf_lr = SGDClassifier(loss="log", penalty="l2")
tfidf_training_features = tfidf_vectorizer.fit_transform(training_bodies)
n_tfidf_features = tfidf_training_features.shape[1]
tfidf_lr.fit(tfidf_training_features, training_labels)
tfidf_testing_features = tfidf_vectorizer.transform(testing_bodies)
tfidf_pred_probs = tfidf_lr.predict_proba(tfidf_testing_features)
tfidf_auc = roc_auc_score(testing_labels, tfidf_pred_probs[:, 1])
print "tfidf auc", tfidf_auc, "n_features", n_tfidf_features
aucs = []
nzs = []
bit_range = list(range(8, 25))
for n_bits in bit_range:
lr = SGDClassifier(loss="log", penalty="l2")
hashing_vectorizer = HashingVectorizer(n_features = 2 ** n_bits, binary=True, norm=None)
hashed_training_features = hashing_vectorizer.transform(training_bodies)
lr.fit(hashed_training_features, training_labels)
hashed_testing_features = hashing_vectorizer.transform(testing_bodies)
pred_probs = lr.predict_proba(hashed_testing_features)
aucs.append(roc_auc_score(testing_labels, pred_probs[:, 1]))
nzs.append((lr.coef_ != 0).sum())
print n_bits, aucs[-1]
fig, ax1 = plt.subplots()
ax1.plot(bit_range, aucs, 'c-')
ax1.plot(bit_range, [tfidf_auc] * len(bit_range), 'c--', label="Tfidf")
ax1.set_xlabel('Hashed Features (log_2)', fontsize=16)
# Make the y-axis label and tick labels match the line color.
ax1.set_ylabel('AUC', color='c', fontsize=16)
for tl in ax1.get_yticklabels():
tl.set_color('c')
ax2 = ax1.twinx()
ax2.plot(bit_range, nzs, 'k-')
ax2.plot(bit_range, [n_tfidf_features] * len(bit_range), 'k--')
ax2.set_ylabel('Non-zero Weights', color='k', fontsize=16)
for tl in ax2.get_yticklabels():
tl.set_color('k')
fig.subplots_adjust(right=0.8)
fig.savefig(FIGURES_DIR + "/hashed_features_auc_weights.png", DPI=200)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment