Created
May 8, 2018 17:38
Coursera Machine LearningをPythonで実装 - [Week7]サポートベクターマシン(SVM) (2)スパムメールの分類
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import numpy as np | |
from nltk.stem import PorterStemmer | |
from scipy.io import loadmat | |
from sklearn.svm import SVC | |
## 単語リストを作る | |
# 単語用のクラス定義 | |
class Vocab: | |
def __init__(self, str_line): | |
spl = str_line.split() | |
self.index = int(spl[0]) | |
self.word = spl[1] | |
# 単語リスト | |
vocab_list = [] | |
# ファイル読み込み | |
with open("vocab.txt", "r") as fp: | |
vocab_data = fp.readlines() | |
# 単語のインスタンスを作成 | |
for v in vocab_data: | |
vocab_list.append(Vocab(v)) | |
## メールデータの前処理 | |
## →メールデータを解析しやすい形に変換して、単語リストに入っているインデックスを返す | |
def process_email(email_contents): | |
# 単語リストのグローバル変数を使う | |
global vocab_list | |
# 返り値の初期化 | |
word_indices = [] | |
## 前処理 | |
# 全て小文字に変換 | |
contents = email_contents.lower() | |
# HTMLタグをスペースに変換 | |
contents = re.sub(r"<[^<>]+>", " ", contents) | |
# (電話)番号をnumberという単語に変換 | |
contents = re.sub(r"[0-9]+", "number", contents) | |
# URLをhttpaddrという単語に変換 | |
contents = re.sub(r"(http|https)://[^\s]*", "httpaddr", contents) | |
# メールアドレスをemailaddrという単語に変換(真ん中の@を探す) | |
contents = re.sub(r"[^\s]+@[^\s]+", "emailaddr", contents) | |
# $マークをdollarという単語で統一する | |
contents = re.sub(r"[$]+", "dollar", contents) | |
# 単語単位に分割する | |
contents_words = re.split("[\'\s@$\/#.-:&*+=\[\]?!(){},\">_<;%]", contents) | |
# stemmerのインスタンスを作る | |
stemmer = PorterStemmer() | |
# ヘッダー | |
print("\n==== Processed Email ====\n") | |
# 文字数 | |
l = 0 | |
## 単語単位でサーチ | |
for str in contents_words: | |
# アルファベット以外を取り除く | |
str = re.sub(r"[^a-zA-Z0-9]", "", str) | |
# PorterStemmerで語幹を取り出す(stemming) | |
# nltk.stemに用意されており、Anacondaでインストールすると入っているはず | |
str = stemmer.stem(str.strip()) | |
# 空白文字か短すぎる文字はスルー | |
if len(str) < 1: continue | |
# ルックアップする | |
query = list(filter(lambda vocab: vocab.word == str, vocab_list)) | |
if len(query) == 1: | |
word_indices.append(query[0].index) | |
# 適度に改行 | |
if (l + len(str) + 1) > 78: | |
print() | |
l = 0 | |
# 画面に表示 | |
print(str, end=" ") | |
l += len(str) + 1 | |
# フッター | |
print("\n\n=========================\n") | |
# 返り値 | |
return word_indices | |
## メールの前処理 | |
print("Preprocessing sample email (emailSample1.txt)") | |
# 特徴を抽出 | |
with open("emailSample1.txt") as fp: | |
file_contents = fp.read() | |
word_indices = process_email(file_contents) | |
# 画面に表示 | |
print("Word Indices: ") | |
print(word_indices) | |
print() | |
# インデックスを変数にマッピング | |
def email_feature(word_indices): | |
n = 1899 | |
x = np.zeros((n, 1)) | |
for i in word_indices: | |
x[i-1] = 1 | |
return x | |
# 変数のベクトル | |
features = email_feature(word_indices) | |
print("Length of feature vector:", len(features)) | |
print("Number of non-zero entries:", np.sum(features > 0)) | |
print() | |
## 線形SVMを使ってスパムを分類する | |
# spamTrain.matには既にベクトル化されたデータが入っている | |
data = loadmat("spamTrain.mat") | |
X, y = np.array(data['X']), np.ravel(np.array(data['y'])) | |
C = 0.1 | |
# SVM | |
print("Training Linear SVM (Spam Classification)") | |
model = SVC(C=C, kernel="linear", probability=True) | |
model.fit(X, y) | |
# 訓練データの精度 | |
p = model.predict(X) | |
print("Training Accuracy:", np.mean(p == y) * 100) | |
# 係数の高いトップ15の単語 | |
topword_indices = np.argsort(model.coef_[0])[::-1][:15] | |
for i in topword_indices: | |
print(vocab_list[i].word, model.coef_[0,i]) | |
print() | |
## テストデータの精度 | |
# spamTest.matには既にデータがある | |
data = loadmat("spamTest.mat") | |
Xtest, ytest = np.array(data['Xtest']), np.ravel(np.array(data['ytest'])) | |
# テストデータで予測 | |
p = model.predict(Xtest) | |
print("Test Accuracy:", np.mean(p == ytest) * 100) | |
print() | |
## 任意のデータ | |
def spamTest(filename, model): | |
print("Spam test on", filename, "...") | |
with open(filename) as fp: | |
file_contents = fp.read() | |
word_indices = process_email(file_contents) | |
features = email_feature(word_indices) | |
p = model.predict(np.array(features).reshape(1, -1)) | |
print("Spam Classification:", p) | |
print("(1 indicates spam, 0 indicates not spam)") | |
prob = model.predict_proba(np.array(features).reshape(1, -1)) | |
print("Estimated probability of non-spam / spam : ") | |
print(prob) | |
print() | |
spamTest("spamSample2.txt", model) | |
spamTest("myEmail.txt", model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment