Skip to content

Instantly share code, notes, and snippets.

@brianjp93
Created December 1, 2015 08:56
Show Gist options
  • Save brianjp93/6c936a729c147d2e8c24 to your computer and use it in GitHub Desktop.
Save brianjp93/6c936a729c147d2e8c24 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using an SVM to Classify Emails\n",
"### Brian Perrett"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"\"\"\"\n",
"spamfilter.py\n",
"Brian Perrett\n",
"11/30/15\n",
"using svm to classify spam vs not spam\n",
"using data from -> http://csmining.org/index.php/spam-email-datasets-.html\n",
"\"\"\"\n",
"from __future__ import division\n",
"import email\n",
"from sklearn import svm\n",
"\n",
"\n",
"class SpamFilter:\n",
"\n",
" def __init__(self):\n",
" self.label_file = \"SPAMTrain.label\"\n",
" self.word_file = \"words.txt\"\n",
" self.spam = set()\n",
" self.spam_list = []\n",
" self.ham = set()\n",
" self.ham_list = []\n",
" self.spam_ham = {}\n",
" self.spam_ham_list = []\n",
" self.words = self.getWords()\n",
"\n",
" def getWords(self):\n",
" word_set = set()\n",
" with open(self.word_file, \"rb\") as f:\n",
" for line in f:\n",
" line = line.strip()\n",
" word_set.add(line)\n",
" return word_set\n",
" \n",
" def getEmailData(self, filename):\n",
" \"\"\"\n",
" Retrieves values for data vector for the \n",
" support vector machine.\n",
" \"\"\"\n",
" body = self.getBody(filename)\n",
" subject = self.getSubject(filename)\n",
" vector = []\n",
" # 0\n",
" bodylen = self.countBody(body)\n",
" vector.append(bodylen)\n",
" # 1\n",
" sublen = self.countSubject(subject)\n",
" vector.append(sublen)\n",
" # 2\n",
" free_in_sub = self.freeInSub(subject)\n",
" # 3\n",
" try_in_sub = self.tryInSub(subject)\n",
" vector.append(free_in_sub)\n",
" vector.append(try_in_sub)\n",
" # 4\n",
" if free_in_sub == 1 and try_in_sub == 1:\n",
" vector.append(1)\n",
" else:\n",
" vector.append(0)\n",
" # 5\n",
" free_in_body = self.freeInBody(body)\n",
" # 6\n",
" try_in_body = self.tryInBody(body)\n",
" vector.append(free_in_body)\n",
" vector.append(try_in_body)\n",
" # 7\n",
" if free_in_body == 1 and try_in_body == 1:\n",
" vector.append(1)\n",
" else:\n",
" vector.append(0)\n",
" # 8\n",
" vector.append(self.moneyInSub(subject))\n",
" # 9\n",
" vector.append(self.moneyInBody(body))\n",
" # 10\n",
" vector.append(self.percentInSub(subject))\n",
" # 11\n",
" vector.append(self.percentInBody(body))\n",
" # 12\n",
" vector.append(self.sexInSubject(subject))\n",
" # 13\n",
" vector.append(self.sexInBody(body))\n",
" # 14\n",
"# vector.append(self.exInSubject(subject))\n",
" # 15\n",
"# vector.append(self.exInBody(body))\n",
" # 16\n",
" vector.append(self.moneyBack(body))\n",
" # 17\n",
" vector.append(self.exoticInBody(body))\n",
" # 18\n",
" vector.append(self.exoticInSubject(subject))\n",
" # 19\n",
" vector.append(self.subscribeInBody(body))\n",
" # 20\n",
" vector.append(self.divInBody(body))\n",
" # 21\n",
" vector.append(self.percentMisspelledSubject(subject))\n",
" # 22\n",
" vector.append(self.percentMisspelledBody(body))\n",
" # 23\n",
" vector.append(self.bodyInSubject(subject))\n",
" # 24\n",
" vector.append(self.bodyInBody(body))\n",
" # 25\n",
" vector.append(self.percentCapitalSubject(subject))\n",
" # 26\n",
" vector.append(self.percentCapitalBody(body))\n",
" # 27\n",
" vector.append(self.percentIInBody(body))\n",
" # 28\n",
" vector.append(self.percentYouInBody(body))\n",
"\n",
" return vector\n",
"\n",
" def getHamSpam(self):\n",
" \"\"\"\n",
" returns ham and spam dictionaries,\n",
" sets self.spam and self.ham to sets with\n",
" their respective email name sets.\n",
" -Also makes a dictionary {filename: 0 or 1}\n",
" where 0=spam, 1=ham\n",
" \"\"\"\n",
" with open(self.label_file, \"rb\") as f:\n",
" for line in f:\n",
" line = line.strip().split()\n",
" if line[0] == \"0\":\n",
" self.spam.add(line[1])\n",
" elif line[0] == \"1\":\n",
" self.ham.add(line[1])\n",
" self.spam_ham[line[1]] = line[0]\n",
" self.spam_ham_list.append(line[1])\n",
"\n",
" def getBody(self, filename):\n",
" \"\"\"\n",
" returns body(string)\n",
" \"\"\"\n",
" with open(filename, \"rb\") as f:\n",
" fstr = f.read()\n",
" msg = email.message_from_string(fstr)\n",
" if msg.is_multipart():\n",
" full_msg = \"\"\n",
" for payload in msg.get_payload():\n",
" if isinstance(payload.get_payload(), basestring):\n",
" full_msg += payload.get_payload()\n",
" else:\n",
" full_msg = msg.get_payload()\n",
" return full_msg\n",
"\n",
" def getSubject(self, filename):\n",
" \"\"\"\n",
" returns subject(string)\n",
" \"\"\"\n",
" with open(filename, \"rb\") as f:\n",
" fstr = f.read()\n",
" msg = email.message_from_string(fstr)\n",
" subject = msg.get(\"Subject\")\n",
" if subject is None:\n",
" return \"\"\n",
" return subject\n",
"\n",
" def countSubject(self, subject):\n",
" return len(subject)\n",
"\n",
" def countBody(self, body):\n",
" return len(body)\n",
"\n",
" def freeInSub(self, subject):\n",
" if \"free\" in subject.lower().split():\n",
" return 1\n",
" return 0\n",
"\n",
" def tryInSub(self, subject):\n",
" if \"try\" in subject.lower().split():\n",
" return 1\n",
" return 0\n",
"\n",
" def freeInBody(self, body):\n",
" if \"free\" in body.lower().split():\n",
" return 1\n",
" return 0\n",
"\n",
" def tryInBody(self, body):\n",
" if \"try\" in body.lower().split():\n",
" return 1\n",
" return 0\n",
" \n",
" def moneyInSub(self, subject):\n",
" if \"$\" in subject:\n",
" return 1\n",
" return 0\n",
" \n",
" def moneyInBody(self, body):\n",
" if \"$\" in body:\n",
" return 1\n",
" return 0\n",
" \n",
" def percentInSub(self, subject):\n",
" if \"%\" in subject:\n",
" return 1\n",
" return 0\n",
" \n",
" def percentInBody(self, body):\n",
" if \"%\" in body:\n",
" return 1\n",
" return 0\n",
" \n",
" def sexInSubject(self, subject):\n",
" if \"sex\" in subject.lower():\n",
" return 1\n",
" return 0\n",
" \n",
" def sexInBody(self, body):\n",
" \"\"\"\n",
" What a great method name (-.-)\n",
" \"\"\"\n",
" if \"sex\" in body.lower():\n",
" return 1\n",
" return 0\n",
" \n",
" def exInSubject(self, subject):\n",
" if subject.count(\"!\") > 3:\n",
" return 1\n",
" return 0\n",
" \n",
" def exInBody(self, body):\n",
" if body.count(\"!\") > 5:\n",
" return 1\n",
" return 0\n",
" \n",
" def moneyBack(self, body):\n",
" if \"money back\" in body.lower():\n",
" return 1\n",
" return 0\n",
" \n",
" def exoticInBody(self, body):\n",
" if \"exotic\" in body.lower():\n",
" return True\n",
" return 0\n",
" \n",
" def exoticInSubject(self, subject):\n",
" if \"exotic\" in subject.lower():\n",
" return 1\n",
" return 0\n",
" \n",
" def subscribeInBody(self, body):\n",
" if \"subscribe\" in body.lower():\n",
" return 1\n",
" return 0\n",
" \n",
" def divInBody(self, body):\n",
" if \"</div>\" in body.lower():\n",
" return 1\n",
" return 0\n",
" \n",
" def bodyInSubject(self, subject):\n",
" if \"body\" in subject.lower():\n",
" return 1\n",
" return 0\n",
" \n",
" def bodyInBody(self, body):\n",
" if \"body\" in body.lower():\n",
" return 1\n",
" return 0\n",
" \n",
" def percentIInBody(self, body):\n",
" if body == \"\":\n",
" return 0\n",
" body = body.split()\n",
" count = 0\n",
" for word in body:\n",
" if word.lower() == \"i\":\n",
" count += 1\n",
" return count/len(body)\n",
" \n",
" def percentYouInBody(self, body):\n",
" if body == \"\":\n",
" return 0\n",
" count = body.count(\"you\")\n",
" return count/len(body.split())\n",
" \n",
" def percentCapitalSubject(self, subject):\n",
" total_letters = 0\n",
" capital = 0\n",
" for l in subject:\n",
" if l.lower() in \"abcdefghijklmnopqrstuvwxyz\":\n",
" total_letters += 1\n",
" if l == l.upper():\n",
" capital += 1\n",
" if total_letters == 0:\n",
" return 0\n",
" return capital/total_letters\n",
" \n",
" def percentCapitalBody(self, body):\n",
" total_letters = 0\n",
" capital = 0\n",
" for l in body:\n",
" if l.lower() in \"abcdefghijklmnopqrstuvwxyz\":\n",
" total_letters += 1\n",
" if l == l.upper():\n",
" capital += 1\n",
" if total_letters == 0:\n",
" return 0\n",
" return capital/total_letters\n",
" \n",
" def percentMisspelledSubject(self, subject):\n",
" if len(subject) == 0:\n",
" return 0\n",
" subject = subject.split()\n",
" wrong = 0\n",
" for word in subject:\n",
" word.split(\"'\")\n",
" if len(word) > 1:\n",
" part_wrong = 0\n",
" for part in word:\n",
" if part not in self.words:\n",
" part_wrong += 1\n",
" break\n",
" if part_wrong > 0:\n",
" wrong += 1\n",
" else:\n",
" if word not in self.words:\n",
" wrong += 1\n",
" return wrong/len(subject)\n",
" \n",
" def percentMisspelledBody(self, body):\n",
" if len(body) == 0:\n",
" return 0\n",
" body = body.split()\n",
" wrong = 0\n",
" for word in body:\n",
" word.split(\"'\")\n",
" if len(word) > 1:\n",
" part_wrong = 0\n",
" for part in word:\n",
" if part not in self.words:\n",
" part_wrong += 1\n",
" break\n",
" if part_wrong > 0:\n",
" wrong += 1\n",
" else:\n",
" if word not in self.words:\n",
" wrong += 1\n",
" return wrong/len(body)\n",
"\n",
"def getTest():\n",
" \"\"\"\n",
" same as getData, but gets data that was not trained on.\n",
" \"\"\"\n",
" sf = SpamFilter()\n",
" sf.getHamSpam()\n",
" x = []\n",
" y = []\n",
" for em in sf.spam_ham_list[3000:]:\n",
" filename = \"TRAINING/\" + em\n",
" data = sf.getEmailData(filename)\n",
" x.append(data)\n",
" y.append(int(sf.spam_ham[em]))\n",
" return x, y\n",
" \n",
"def getData():\n",
" sf = SpamFilter()\n",
" sf.getHamSpam()\n",
" x = []\n",
" y = []\n",
" for em in sf.spam_ham_list[:3000]:\n",
" filename = \"TRAINING/\" + em\n",
" data = sf.getEmailData(filename)\n",
" x.append(data)\n",
" y.append(int(sf.spam_ham[em]))\n",
" max_body = max(x, key=lambda b: b[0])[0]\n",
" max_sub = max(x, key=lambda b: b[1])[1]\n",
" for d in x:\n",
" d[0] /= max_body\n",
" d[1] /= max_sub\n",
" return x, y, max_body, max_sub"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class SpamTrain:\n",
" def __init__(self, x, y, max_body, max_sub):\n",
" self.x = x\n",
" self.y = y\n",
" self.max_body = max_body\n",
" self.max_sub = max_sub\n",
" self.clf = self.train()\n",
" \n",
" def train(self):\n",
" clf = svm.SVC(gamma=.1, C=100)\n",
" clf.fit(self.x, self.y)\n",
" return clf\n",
" \n",
" def test(self, x, y):\n",
" for d in x:\n",
" d[0] /= self.max_body\n",
" d[1] /= self.max_sub\n",
" correct = 0\n",
" for i, datum in enumerate(x):\n",
" if self.clf.predict(datum) == y[i]:\n",
" correct += 1\n",
" return correct/len(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Take a look at a few spam emails to see what they might look like. "
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SPAM\n",
"SUBJECT\n",
"Re: Which kernel for ThinkPad 760XD ?\n",
"BODY\n",
"@Merciadri Luca\r\n",
"\r\n",
"[quote]\r\n",
"I have no answer to your question, but I am wondering... what does 'lol'\r\n",
"mean here? Is it some version of something, or did you simply put this over\r\n",
"there because you have some sense of humor and you like to make your\r\n",
"(somewhat old) config appear less depressing to the others' eyes?\r\n",
"[/quote]\r\n",
"\r\n",
"\r\n",
"... it's the hostname, but it is funny! I don't give a damn about what\r\n",
"others might think anyway, I'm just having fun with it, making some tests,\r\n",
"pushing my knowledge..\n"
]
}
],
"source": [
"sf = SpamFilter()\n",
"sf.getHamSpam()\n",
"for i, f in enumerate(sf.spam_ham):\n",
" f_name = \"TRAINING/\" + f\n",
" if i < 1:\n",
" subject = sf.getSubject(f_name)\n",
" body = sf.getBody(f_name)\n",
" if sf.spam_ham[f] == 1:\n",
" print(\"HAM\")\n",
" else:\n",
" print(\"SPAM\")\n",
" print(\"SUBJECT\")\n",
" print(subject)\n",
" print(\"BODY\")\n",
" print(body[:500])\n",
" else:\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test Accuracy on new data"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0.9223813112283346"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x,y,max_body,max_sub = getData()\n",
"\n",
"st = SpamTrain(x, y, max_body, max_sub)\n",
"\n",
"x_test, y_test = getTest()\n",
"\n",
"st.test(x_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Not too bad for a few hours of work. Definitely a lot of room for improvement though."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 22.7 Company Name Extraction"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"article_file = \"articles.txt\"\n",
"with open(article_file, \"rb\") as f:\n",
" articles = f.read().split(\"/////\")\n",
"articles = [a.split(\"-----\") for a in articles]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def getWords():\n",
" word_set = set()\n",
" with open(\"words.txt\", \"rb\") as f:\n",
" for line in f:\n",
" line = line.strip()\n",
" word_set.add(line)\n",
" return word_set\n",
"\n",
"def getCompanies(article):\n",
" dic = getWords()\n",
" dic.add(\"cyber\")\n",
" dic.add(\"uk\")\n",
" words = article.split()\n",
" companies = []\n",
" i = 0\n",
" while i < len(words):\n",
" w = words[i]\n",
" no_p = w.replace(\".\", \"\").replace(\",\", \"\").replace('\"', \"\").replace(\"'\", \"\")\n",
" if w.isupper() and len(w.replace('\"', \"\")) > 1 and w.lower() not in dic:\n",
" if words[i-1].isupper() or words[i-1][0].isupper():\n",
" w = words[i-1] + \" \" + w\n",
" if i + 1 < len(words):\n",
" if w[-1] in ',.\"':\n",
" i += 2\n",
" companies.append(w[:-1])\n",
" continue\n",
" if words[i+1].isupper() or words[i+1][0].isupper():\n",
" w = w + \" \" + words[i+1]\n",
" i += 3\n",
" companies.append(w)\n",
" elif no_p[0].isupper() and no_p.lower() not in dic:\n",
"# print(no_p)\n",
" if words[i-1].isupper() or words[i-1][0].isupper():\n",
" w = words[i-1] + \" \" + w\n",
" if i + 1 < len(words):\n",
" if w[-1] in ',.\"':\n",
" i += 2\n",
" companies.append(w[:-1])\n",
" continue\n",
" if words[i+1].isupper() or words[i+1][0].isupper():\n",
" w = w + \" \" + words[i+1]\n",
" i += 3\n",
" companies.append(w)\n",
" i += 1\n",
" return set(companies)\n",
"\n",
"def getSolution(sol_string):\n",
" s = [a.strip() for a in sol_string.strip().split(\",\")]\n",
" return set(s)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{'Alex Baldock',\n",
" 'Analyst Miya Knight',\n",
" 'BBC',\n",
" 'Currys',\n",
" 'Exeter',\n",
" 'Experian-IMRG',\n",
" 'LCP Consulting',\n",
" 'Littlewoods',\n",
" 'Norwich',\n",
" 'Stuart Higgins',\n",
" 'Very.co.uk'}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"getCompanies(articles[0][0].strip())"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{'Amazon',\n",
" 'Currys',\n",
" 'John Lewis',\n",
" 'Littlewoods',\n",
" 'PC World',\n",
" 'Planet Retail',\n",
" 'Royal Mail',\n",
" 'Very.co.uk',\n",
" 'Visa',\n",
" 'Visa Europe'}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"getSolution(articles[0][1])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.658536585366\n",
"0.739726027397\n"
]
}
],
"source": [
"precision1 = 0\n",
"precision2 = 0\n",
"recall1 = 0\n",
"for a in articles:\n",
" b = getCompanies(a[0])\n",
" c = getSolution(a[1])\n",
" union = b | c\n",
" in_res = b - c\n",
" not_in = c - b\n",
" precision1 += len(union)\n",
" precision2 += len(in_res)\n",
" recall1 += len(not_in)\n",
"precision = precision1/(precision1 + precision2)\n",
"recall = precision1/(precision1 + recall1)\n",
"print(precision)\n",
"print(recall)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment