Last active
August 29, 2015 14:22
-
-
Save fanannan/cc10eba9841b52e0447f to your computer and use it in GitHub Desktop.
doc2vecのテストコード
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
import os | |
import sys | |
import logging | |
import codecs | |
import MeCab | |
import re | |
import hashlib | |
import random | |
import tempfile | |
try: | |
import cPickle as pickle | |
except: | |
import pickle | |
import numpy | |
from gensim import corpora, models | |
#### | |
from pyevolve import G1DBinaryString | |
from pyevolve import GSimpleGA | |
from pyevolve import Mutators | |
# | |
import scipy | |
import sklearn | |
from sklearn import manifold | |
import matplotlib.pyplot as plt | |
# データファイルの基準位置 | |
DATA_DIR = './livedoor/yasunori/text' | |
# MeCab設定 | |
MECAB_MODE = 'mecabrc' | |
STDIN_ENCODING = 'utf-8' | |
STDOUT_ENCODING = 'utf-8' | |
INPUT_FILE_ENCODING = 'utf-8' | |
# MeCab標準辞書利用を前提とした品詞情報位置 | |
PART_OF_SPEECH = 0 | |
SUB_PART_OF_SPEECH = 1 | |
WORD = 6 | |
# | |
ARTICLE_HEADER = 'ARTICLE' | |
SENTENSE_HEADER = 'SENT' | |
# 内部管理用辞書キー | |
SPLIT_RATIO = "split_ratio" | |
MAX_FILES = "max_files" | |
ONLY_NOUNS = "only_nouns" | |
# | |
PATH = "path" | |
CATEGORY = "category" | |
FILE_NAME = "file_name" | |
ARTICLE_ID = "article_id" | |
CONTENT = "content" | |
WORDS = "words" | |
SENTENSE = "sentense" | |
PARSED_SENTENSE = "parsed_sentense" | |
SENTENSES = "sentenses" | |
SENTENSE_ID = "sentense_id" | |
# | |
TRAINING_DATA = "training_data" | |
TEST_DATA = "test_data" | |
TOTAL_DATA_SIZE = "total_data_size" | |
# キャッシュファイル | |
DATA_CACHE_BODY = 'temp_data' | |
VOCAB_CACHE_BODY = 'temp_vocab' | |
TRAINED_MODEL_CACHE_BODY = 'temp_trained' | |
# | |
USE_SENTENSES = "use_sentenses" | |
USE_ARTICLES = "use_articles" | |
# | |
SIZE = "size" | |
WINDOW = "window" | |
MIN_COUNT = "min_count" | |
SAMPLE = "sample" | |
SEED = "seed" | |
DM = "dm" | |
HS = "hs" | |
NEGATIVE = "nagative" | |
DM_MEAN = "dm_mean" | |
TRAIN_WORDS = "train_words" | |
TRAIN_LBLS = "train_lbls" | |
ALPHA = "alpha" | |
MIN_ALPHA = "min_alpha" | |
WORKERS = "workers" | |
# | |
EPOCHS = "epochs" | |
LEARNING_RATE_DELTA = "learning_rate_delta" | |
DECAY = "decay" | |
# | |
MIN_SENTENSE_SIMILARITY = "min_sentense_similarity" | |
GENERATION = "generation" | |
POPULATION_SIZE = "population_size" | |
MAX_SENTENSES = "max_sentenses" | |
MAX_CANDIDATE_SENTENSES = "max_candidate_sentenses" | |
data_dir = DATA_DIR | |
config = { # | |
ONLY_NOUNS: True, | |
MAX_FILES: -1, | |
SPLIT_RATIO: 0.8, | |
# | |
USE_SENTENSES: True, USE_ARTICLES: True, | |
# | |
SIZE: 300, # 特長次元数(300), DM1/300:75% ->DM1/500:73% | |
WINDOW: 5, # 最大ウインドウサイズ(8), DM1/3:78%, -> DM1/5: 77% -> DM1/7:76% -> DM1/10:75% -> DM1/20:33% | |
# WINDOW DM1/10->DM1/5 | |
# WINDOW DM0 影響なし | |
MIN_COUNT: 5, # 単語最小出現数(5), DM1/5:75% -> DM1/7:72% -> DM1/10:70% | |
SAMPLE: 0, #ダウンサンプル(0), 0:75% -> 1e-5:63% | |
SEED: 1, # 乱数の種 | |
DM: 0, # モデル種別(1), 1 はDistributed Memory, 0:DBow, 1:75% -> 0:88%! | |
# DM 1->0 | |
HS: 1, # 階層サンプリング(1), | |
NEGATIVE: 0, # ネガティブサンプリング(0), DM1/0:75%->DM1/10:69% | |
DM_MEAN: 0, # DM使用時に平均を使うかどうか(0), 0:合計, 1:平均 , DM1/0:75% -> DM1/1: 62% | |
TRAIN_WORDS: True, TRAIN_LBLS: True, | |
LEARNING_RATE_DELTA: 0.002, | |
DECAY: False, ALPHA: 0.025, MIN_ALPHA: 0.0001, # decayしないとmin_alphaは固定するようにしている | |
WORKERS: 4, | |
# | |
EPOCHS: 10, | |
# | |
MIN_SENTENSE_SIMILARITY: 0.25, | |
GENERATION: 10, | |
POPULATION_SIZE: 100, | |
MAX_SENTENSES: 5, | |
MAX_CANDIDATE_SENTENSES: 10, | |
} | |
# 入出力指定 | |
sys.stdin = codecs.getreader(STDIN_ENCODING)(sys.stdin) | |
sys.stdout = codecs.getwriter(STDOUT_ENCODING)(sys.stdout) | |
# 確認進捗表示用 | |
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) | |
def memoize(obj): | |
import functools | |
cache = obj.cache = {} | |
@functools.wraps(obj) | |
def memoizer(*args, **kwargs): | |
key = str(args) + str(kwargs) | |
if key not in cache: | |
cache[key] = obj(*args, **kwargs) | |
return cache[key] | |
return memoizer | |
def color(text, **user_styles): | |
styles = { | |
# styles | |
'reset': '\033[0m', | |
'bold': '\033[01m', | |
'disabled': '\033[02m', | |
'underline': '\033[04m', | |
'reverse': '\033[07m', | |
'strike_through': '\033[09m', | |
'invisible': '\033[08m', | |
# text colors | |
'fg_black': '\033[30m', | |
'fg_red': '\033[31m', | |
'fg_green': '\033[32m', | |
'fg_orange': '\033[33m', | |
'fg_blue': '\033[34m', | |
'fg_purple': '\033[35m', | |
'fg_cyan': '\033[36m', | |
'fg_light_grey': '\033[37m', | |
'fg_dark_grey': '\033[90m', | |
'fg_light_red': '\033[91m', | |
'fg_light_green': '\033[92m', | |
'fg_yellow': '\033[93m', | |
'fg_light_blue': '\033[94m', | |
'fg_pink': '\033[95m', | |
'fg_light_cyan': '\033[96m', | |
# background colors | |
'bg_black': '\033[40m', | |
'bg_red': '\033[41m', | |
'bg_green': '\033[42m', | |
'bg_orange': '\033[43m', | |
'bg_blue': '\033[44m', | |
'bg_purple': '\033[45m', | |
'bg_cyan': '\033[46m', | |
'bg_light_grey': '\033[47m' | |
} | |
color_text = '' | |
for style in user_styles: | |
try: | |
color_text += styles[style] | |
except KeyError: | |
return 'def color: parameter {} does not exist'.format(style) | |
color_text += text | |
return '\033[0m{}\033[0m'.format(color_text) | |
def orange(text): | |
return color(text, fg_orange=True) | |
def yellow(text): | |
return color(text, fg_yellow=True) | |
def red(text): | |
return color(text, fg_red=True) | |
def error(text): | |
return color(text, bold=True, fg_red=True) | |
def warning(text): | |
return color(text, bold=True, fg_orange=True) | |
def success(text): | |
return color(text, fg_green=True) | |
# 日本語解析 | |
def parse(article, only_nouns): | |
#logging.info("parsing %s" % (article[0:50])) | |
encoded_article = article.encode(INPUT_FILE_ENCODING) | |
tagger = MeCab.Tagger(MECAB_MODE) | |
node = tagger.parseToNode(encoded_article) | |
normalized_words = [] | |
prev = None | |
cont = False | |
while node: | |
features = node.feature.decode(STDOUT_ENCODING).split(",") | |
word = features[WORD].replace(u'*','') | |
pos = features[PART_OF_SPEECH ] | |
sub = features[SUB_PART_OF_SPEECH] | |
#print word, pos, sub | |
if not (sub in [u"非自立", u"*", u'']): | |
if (pos in [u"記号", u"助詞", u"助動詞", u"連体詞", u"接続詞", u"接頭詞", u"その他"]): | |
pass | |
elif pos == u"名詞": | |
if prev == pos and not (sub in [u"数", u"固有名詞", u"代名詞"]): | |
#print word, pos, sub, prev | |
if len(normalized_words) != 0: # 代名詞が前に来るときを除外 | |
prevword = normalized_words.pop() | |
cont = True | |
normalized_words.append(prevword+word) | |
else: | |
normalized_words.append(word) | |
elif (sub in [u"代名詞"]): | |
pass | |
else: | |
normalized_words.append(word) | |
elif (pos in [u"動詞", u"形容詞", u"形容動詞", u"副詞"]): | |
if not only_nouns: | |
normalized_words.append(word) | |
else: | |
pass | |
else: | |
print "unsupported pos:", word, pos, sub | |
exit(-1) | |
if cont: | |
prev = None | |
cont = False | |
else: | |
prev = pos | |
node = node.next | |
return filter(lambda x: x.strip() != u'', normalized_words) | |
# 記事を行に分解してから解析 | |
# 元の行(文末記号なし)と、それを解析した結果のリスト(単語文字列を要素とする)を返す | |
def parse_sentenses(content, identifier, only_nouns): | |
sentenses = re.compile(ur"[。\r\n\t]").split(content) | |
r = [] | |
for (num, s) in enumerate(sentenses): | |
sentense_id = SENTENSE_HEADER+'#'+identifier+'#'+str(num) | |
r.append({SENTENSE:s, PARSED_SENTENSE:parse(s, only_nouns), SENTENSE_ID:sentense_id}) | |
return r | |
# ファイル読み込み | |
def read_file(file_path): | |
logging.info("reading data file: %s" % (file_path)) | |
f = codecs.open(file_path, 'r', encoding=INPUT_FILE_ENCODING) | |
article = f.read() | |
f.close() | |
return article | |
# 学習データファイルの取得 | |
def get_data_files(path, only_nouns, max_files): | |
file_list = [] | |
counter = 0 | |
for (root, dirs, file_names) in os.walk(path): | |
for file_name in file_names: | |
if max_files > 0 and counter >= max_files: | |
break | |
path = os.path.join(root, file_name) | |
category = path.replace(DATA_DIR,'').replace(file_name, '').replace('/','') | |
if category != '' and file_name != 'LICENSE.txt' and file_name != 'README.txt': | |
original_content = read_file(path) | |
identifier = file_name.replace('.txt', '') | |
article_id = ARTICLE_HEADER+'#'+identifier | |
sentenses = parse_sentenses(original_content, identifier, only_nouns) | |
file_list.append({ | |
PATH:path, | |
CATEGORY: category, | |
ARTICLE_ID:article_id, | |
FILE_NAME:file_name, | |
CONTENT:original_content, | |
SENTENSES:sentenses}) | |
counter = counter+1 | |
return file_list | |
# データを学習用と検証用に分割 | |
# ラベル種類別に均等に分割する方が適切だが、そこは取り敢えず手抜き。 | |
def divide_data_files(file_list, split_ratio): | |
train = [] | |
test = [] | |
random.shuffle(file_list) | |
size = len(file_list) | |
for f in file_list: | |
if len(train) < split_ratio*size: | |
train.append(f) | |
else: | |
test.append(f) | |
return { | |
TRAINING_DATA:train, | |
TEST_DATA:test, | |
TOTAL_DATA_SIZE:size, | |
SPLIT_RATIO:split_ratio} | |
# キャッシュファイルパス生成 | |
def make_cache_pathname(body, *args): | |
s = '_'.join([repr(x) for x in args]) | |
m = hashlib.md5(s) | |
m.update(s) | |
tag = m.hexdigest() | |
path = tempfile.gettempdir()+'/'+body+'_'+tag+'.cache' | |
return path | |
# キャッシュ付き処理 | |
def process_with_cache(func, x, params, body, save=None, load=None): | |
cache_file_path = make_cache_pathname(body, params) | |
if os.path.exists(cache_file_path): | |
logging.info("loading cache file: %s" % (cache_file_path)) | |
if load is None: | |
with open(cache_file_path, 'rb') as handle: | |
r = pickle.load(handle) | |
else: | |
r = load(cache_file_path) | |
else: | |
r = func(x, params) | |
logging.info("saving cache file: %s" % (cache_file_path)) | |
if save is None: | |
with open(cache_file_path, 'wb') as handle: | |
pickle.dump(r, handle) | |
else: | |
save(cache_file_path) | |
return r | |
# データ作成 | |
def generate_data(data_dir, config): | |
def _func_(data_dir, params): | |
(only_nouns, max_files, split_ratio) = params | |
file_list = get_data_files(data_dir, only_nouns, max_files) | |
r = divide_data_files(file_list, split_ratio) | |
return r | |
# キャッシュをやたらに作成しないよう必要なパラメータだけ渡す | |
return process_with_cache(_func_, data_dir, (config[ONLY_NOUNS], config[MAX_FILES], config[SPLIT_RATIO]), DATA_CACHE_BODY) | |
# ボキャブラリーの構築 | |
def generate_model(parsed_sentenses, config): | |
def _func_(parsed_sentenses, config): | |
# sentences=None, size=300, alpha=0.025, window=8, min_count=5, sample=0, seed=1, workers=1, min_alpha=0.0001, dm=1, hs=1, negative=0, dm_mean=0, train_words=True, train_lbls=True, **kwargs | |
#model = models.Doc2Vec(sentences, size=100, window=8, min_count=5, workers=4) | |
# use fixed learning rate | |
model = models.Doc2Vec( | |
size=config[SIZE], | |
window=config[WINDOW], | |
min_count=config[MIN_COUNT], | |
sample=config[SAMPLE], | |
seed=config[SEED], | |
dm=config[DM], | |
hs= config[HS], | |
negative=config[NEGATIVE], | |
dm_mean=config[DM_MEAN], | |
train_words=config[TRAIN_WORDS], | |
train_lbls=config[TRAIN_LBLS], | |
alpha=config[ALPHA], | |
min_alpha=config[MIN_ALPHA], | |
workers=config[WORKERS]) | |
logging.info("building vocab") | |
model.build_vocab(parsed_sentenses) | |
return model | |
return process_with_cache(_func_, parsed_sentenses, config, VOCAB_CACHE_BODY) | |
# モデル訓練 | |
def train_model(model, parsed_sentenses, config): | |
def _func_(params, config): | |
(model, parsed_sentenses) = params | |
for epoch in xrange(config[EPOCHS]): | |
logging.info("training epoch %d" % (epoch)) | |
model.train(parsed_sentenses) | |
model.alpha -= config[LEARNING_RATE_DELTA] | |
if not config[DECAY]: # ラーニングレート固定 | |
model.min_alpha = model.alpha | |
print 'done training' | |
return model | |
return process_with_cache(_func_, (model, parsed_sentenses), config, TRAINED_MODEL_CACHE_BODY) | |
# 文章判定 | |
def is_sentense(identifier): | |
return identifier.find(SENTENSE_HEADER) > -1 | |
# 記事判定 | |
def is_article(identifier): | |
return identifier.find(ARTICLE_HEADER) > -1 | |
# 単語判定 | |
def is_word(identifier): | |
return not is_sentense(identifier.find) and not is_article(identifier) | |
def get_parent_article(identifier): | |
if not is_sentense(identifier): | |
raise | |
identifier = identifier.replace(SENTENSE_HEADER, ARTICLE_HEADER) | |
identifier = re.sub(r'#[\d]+', '', identifier) | |
return identifier | |
# 解析済み文章から、記事全体を再構成 | |
def make_parsed_article(sentense_set): | |
parsed_article = [] | |
for x in sentense_set: | |
parsed_sentense = x[PARSED_SENTENSE] | |
parsed_article = parsed_article+parsed_sentense | |
return parsed_article | |
# データフィーダー | |
class LabeledLineSentence(object): | |
def __init__(self, data, config): | |
self.data = data | |
self.config = config | |
def __iter__(self): | |
for (num, x) in enumerate(self.data): | |
try: | |
sentense_set = x[SENTENSES] | |
article_id = x[ARTICLE_ID] | |
# 記事単位のデータ | |
if self.config[USE_ARTICLES]: | |
parsed_article = make_parsed_article(sentense_set) | |
yield models.doc2vec.LabeledSentence(parsed_article, labels=[article_id]) | |
# 文章単位のデータ | |
if self.config[USE_SENTENSES]: | |
for y in sentense_set: | |
parsed_sentense = y[PARSED_SENTENSE] | |
sentense_id = y[SENTENSE_ID] | |
yield models.doc2vec.LabeledSentence(parsed_sentense, labels=[sentense_id]) | |
except Exception as e: | |
print str(e) | |
print "sentense set:", sentense_set | |
exit(-1) | |
# 類似文章および類似記事の表示 | |
def show_simular(model, data_set, config): | |
for x in data_set: | |
article_id = x[ARTICLE_ID] | |
flag = False | |
if article_id in model.vocab: | |
for (sim_id, sim_value) in model.most_similar(article_id, topn=100): | |
#if sim_value > 4.0: #0.4->83%:94% | |
for y in data_set: | |
if config[USE_ARTICLES] and y[ARTICLE_ID] == sim_id: | |
if not flag: | |
print "\nBase:\t", x[CATEGORY], article_id, x[CONTENT][0:400] | |
flag = True | |
print "Similar article:\t", sim_id, sim_value, y[CONTENT][0:200] | |
break | |
if config[USE_SENTENSES]: # and False: | |
if not flag: | |
print "\nBase:\t", x[CATEGORY], article_id, x[CONTENT][0:400] | |
flag = True | |
for z in y[SENTENSES]: | |
if z[SENTENSE_ID] == sim_id: | |
print "Similar sentense:\t", sim_id, sim_value, z[SENTENSE][0:200] | |
break | |
#most_similar_cosmul(positive=[], negative=[], topn=10) | |
# 遺伝子アルゴリズムで要約に適当な文章の組み合わせを選定 | |
def search_combi(model, article_id, sentenses, generation, population_size, max_sentenses): | |
# 評価関数 | |
def eval_func(chromosome): | |
candidates = [] | |
counter = 0 | |
for x in xrange(len(chromosome)): | |
if chromosome[x]: | |
counter += 1 | |
if counter > max_sentenses: | |
return 0.0 | |
candidates.append(sentenses[x][SENTENSE_ID]) | |
if counter < 2: # 1だとエラーになる(詳細はまだ不明) | |
return 0.0 | |
sim_val = model.n_similarity([article_id], candidates) | |
if not (type(sim_val) == float or type(sim_val) == numpy.float64): | |
print "target article:", article_id | |
print "candidate sentenses:", candidates | |
print sim_val, type(sim_val) | |
raise ValueError('unexpected similarity') | |
#print "eval:", sim_val | |
return sim_val | |
#遺伝子は0と1のみなのでG1DBinaryStringを使用 | |
genome = G1DBinaryString.G1DBinaryString(len(sentenses)) | |
genome.evaluator.set(eval_func) | |
ga = GSimpleGA.GSimpleGA(genome) | |
ga.setGenerations(generation) | |
ga.setPopulationSize(population_size) | |
ga.evolve(freq_stats=int(generation/100)) | |
best = list(ga.bestIndividual()) | |
best = ga.bestIndividual() | |
#結果 | |
result = [] | |
for x in xrange(len(best)): | |
if best[x]: | |
result.append(sentenses[x]) | |
sim_val = model.similarity(article_id, sentenses[x][SENTENSE_ID]) | |
logging.info("best %d: %f %s" % (x, sim_val, sentenses[x][SENTENSE_ID])) | |
sim_value = eval_func(best) | |
logging.info("best total: %f" % (sim_value)) | |
return (result, sim_value) | |
# | |
def make_article_summary(model, article, config): | |
generation = config[GENERATION] | |
population_size = config[POPULATION_SIZE] | |
max_sentenses = config[MAX_SENTENSES] | |
min_sentense_similarity = config[MIN_SENTENSE_SIMILARITY] | |
max_candidate_sentenses = config[MAX_CANDIDATE_SENTENSES] | |
article_id = article[ARTICLE_ID] | |
valid_sentenses = filter(lambda s: s[SENTENSE_ID] in model.vocab, article[SENTENSES]) | |
similarities = sorted(map(lambda s: model.similarity(article_id, s[SENTENSE_ID]), valid_sentenses), reverse=True) | |
if len(similarities) > max_candidate_sentenses: | |
if min_sentense_similarity > similarities[max_candidate_sentenses-1]: | |
m = min_sentense_similarity | |
else: | |
m = similarities[max_candidate_sentenses-1] | |
else: | |
m = 0.0 | |
if len(valid_sentenses) > 3: | |
candidate_sentenses = filter(lambda s: model.similarity(article_id, s[SENTENSE_ID]) > min_sentense_similarity, valid_sentenses) | |
(combi, sim_value) = search_combi(model, article_id, candidate_sentenses, generation, population_size, max_sentenses) | |
logging.info("combi simularity value: %f" % (sim_value)) | |
s = u"…".join([x[SENTENSE] for x in combi]) | |
elif len(valid_sentenses) > 0: | |
s = u"…".join([x[SENTENSE] for x in valid_sentenses]) | |
else: | |
s = '' | |
return s | |
# 要約の作成 | |
def make_summary(model, data_set, config): | |
dic = {} | |
for x in data_set: | |
combi = make_article_summary(model, x, config) | |
dic[x[ARTICLE_ID]] = combi | |
return dic | |
# 2次元に分布をプロットするための次元削減 | |
def LinearKernelPCA(m): | |
pca = sklearn.decomposition.KernelPCA(n_components = 2, kernel ="linear") | |
Y = pca.fit(m).transform(m) | |
return Y | |
def PolyKernelPCA(m): | |
pca = sklearn.decomposition.KernelPCA(n_components = 2, kernel ="poly") | |
Y = pca.fit(m).transform(m) | |
return Y | |
def RbfKernelPCA(m): | |
pca = sklearn.decomposition.KernelPCA(n_components = 2, kernel ="rbf") | |
Y = pca.fit(m).transform(m) | |
return Y | |
def SigmoidKernelPCA(m): | |
pca = sklearn.decomposition.KernelPCA(n_components = 2, kernel ="sigmoid") | |
Y = pca.fit(m).transform(m) | |
return Y | |
def CosineKernelPCA(m): | |
pca = sklearn.decomposition.KernelPCA(n_components = 2, kernel ="cosine") | |
Y = pca.fit(m).transform(m) | |
return Y | |
def PrecomputedKernelPCA(m): | |
pca = sklearn.decomposition.KernelPCA(n_components = 2, kernel ="precomputed") | |
Y = pca.fit(m).transform(m) | |
return Y | |
def PCA(m): | |
# http://scikit-learn.org/stable/auto_examples/decomposition/plot_pca_vs_lda.html#example-decomposition-plot-pca-vs-lda-py | |
pca = sklearn.decomposition.PCA(n_components = 2) | |
Y = pca.fit(m).transform(m) | |
return Y | |
def TruncatedSVD(m): | |
# http://scikit-learn.org/stable/auto_examples/decomposition/plot_pca_vs_lda.html#example-decomposition-plot-pca-vs-lda-py | |
pca = sklearn.decomposition.TruncatedSVD(n_components = 2) | |
Y = pca.fit(m).transform(m) | |
return Y | |
def RandomizedPCA(m): | |
# http://scikit-learn.org/stable/auto_examples/decomposition/plot_pca_vs_lda.html#example-decomposition-plot-pca-vs-lda-py | |
pca = sklearn.decomposition.RandomizedPCA(n_components = 2) | |
Y = pca.fit(m).transform(m) | |
return Y | |
def Isomap(m, n_neighbors=1): | |
# http://qiita.com/sotetsuk/items/0c9ffb2a891294d314f3 | |
Y = manifold.Isomap(n_neighbors, 2).fit_transform(m) | |
return Y | |
def Spectral(m): | |
Y = manifold.spectral_embedding(n_components=2, random_state=0).fit_transform(m) | |
return Y | |
def TSNE(m): | |
Y = manifold.TSNE(n_components=2, random_state=0).fit_transform(m) | |
return Y | |
def SparsePCA(m): | |
pca = sklearn.decomposition.SparsePCA(n_components = 2) | |
Y = pca.fit(m).transform(m) | |
return Y | |
def LDA(m): | |
lda = sklearn.lda.LDA(n_components = 2) | |
Y = lda.fit(m).transform(m) | |
return Y | |
color_map = { | |
'it-life-hack':'red', | |
'smax':'blue', | |
'peachy':'yellow', | |
'dokujo-tsushin':'green', | |
'kaden-channel':'purple', | |
'livedoor-homme':'teal', | |
'movie-enter':'peru', | |
'sports-watch':'cyan', | |
'topic-news':'navy'} | |
decomp = { | |
"LinearKernelPCA" : LinearKernelPCA, | |
"PolyKernelPCA" : PolyKernelPCA, | |
"RbfKernelPCA" : RbfKernelPCA, | |
"SigmoidKernelPCA" : SigmoidKernelPCA, | |
"CosineKernelPCA" : CosineKernelPCA , | |
#"PrecomputedKernelPCA" : PrecomputedKernelPCA, | |
"PCA" : PCA, | |
"TruncatedSVD" : TruncatedSVD, | |
"RandomizedPCA" : RandomizedPCA, | |
"Isomap" : Isomap, | |
#"TSNE" : TSNE, | |
#"Spectral" : Spectral, | |
#"SparsePCA" : SparsePCA, | |
#"LDA" : LDA, | |
} | |
def svd_whiten(X): | |
U, s, Vt = numpy.linalg.svd(X) | |
# U and Vt are the singular matrices, and s contains the singular values. | |
# Since the rows of both U and Vt are orthonormal vectors, then U * Vt | |
# will be white | |
X_white = np.dot(U, Vt) | |
return X_white | |
def draw_charts(model, data, config): | |
m = [] | |
color = [] | |
for x in data: | |
article_id = x[ARTICLE_ID] | |
c = x[CATEGORY] | |
if article_id in model.vocab: | |
m.append(model[article_id]) | |
color.append(color_map[c]) | |
#m = svd_whiten(m) | |
#m = scipy.cluster.vq.whiten(m) | |
for name, func in decomp.items(): | |
print "runnning: ", name | |
Y = func(m) | |
plt.figure() | |
plt.scatter(Y[:,0], Y[:,1], c=color) | |
""" | |
for c, target_name in zip(color, target_names): | |
scatter(Y[y == target_name, 0], Y[y == target_name, 1], c=c, label = target_name) | |
legend() | |
""" | |
plt.title(name) | |
#plt.show() | |
plt.savefig('/tmp/'+name+'.svg') | |
### | |
def get_article(data_set, article_id): | |
for y in data_set: | |
if y[ARTICLE_ID] == article_id: | |
return y | |
return None | |
# gensimが評価する類似度を用いて分類を行う | |
def similarity_classifier(model, article_id, data, config): | |
close_articles = model.most_similar(article_id, topn=100) | |
categories = {} | |
counter = 1 | |
for (sim_id, sim_value) in close_articles: | |
if is_article(sim_id) and sim_value > 0.1 and counter <= 25: | |
x = get_article(data, sim_id) | |
if not x is None: | |
estimated = x[CATEGORY] | |
if not categories.has_key(estimated): | |
categories[estimated] = 0.0 | |
categories[estimated] += sim_value | |
if not categories.has_key('best'): | |
categories['best'] = (sim_id, sim_value, x) | |
counter += 1 | |
if config[USE_SENTENSES] and is_sentense(sim_id) and sim_value > 0.1 and counter <= 50: | |
sim_id = get_parent_article(sim_id) | |
x = get_article(data, sim_id) | |
if sim_id != article_id and not x is None: | |
estimated = x[CATEGORY] | |
if not categories.has_key(estimated): | |
categories[estimated] = 0.0 | |
categories[estimated] += sim_value * 0.25 | |
if not categories.has_key('best'): | |
categories['best'] = (sim_id, sim_value, x) | |
counter += 1 | |
max_value = 0.0 | |
estimated_category = '' | |
for (category, value) in categories.items(): | |
if category != 'best' and value>max_value: | |
max_value = value | |
estimated_category = category | |
#if len(categories.items()) > 1: | |
# print "estimated in doubt:", article_id, estimated_category, categories | |
#if max_value == 0.0: | |
# print "similarity search failed:", get_article(data, article_id)[CONTENT][0:100] | |
return estimated_category, max_value, categories | |
@memoize | |
def build_random_forest(model, data, config): | |
from sklearn.ensemble import RandomForestClassifier | |
estimator = RandomForestClassifier() | |
#estimator.fit(data_train, label_train) | |
return estimator | |
def random_forest_classifier(model, article_id, data, config): | |
estimator = build_random_forest(model, data, config) | |
estimated_category = '' | |
return estimated_category, 0, '' | |
def kmean_classifier(model, article_id, data, config): | |
#lda = gensim.models.LdaModel(corpus=tfidf_corpus, id2word=dictionary, numTopics=100) | |
return estimated_category, 0, '' | |
def test_classify(model, data, classifier, config): | |
counter = 0 | |
right_answer_counter = 0 | |
wrong_answer_counter = 0 | |
rights = {} | |
wrongs = {} | |
log = [] | |
for article in data: | |
article_id = article[ARTICLE_ID] | |
actual_category = article[CATEGORY] | |
if article_id in model.vocab: | |
estimated_category, sim_value, categories = classifier(model, article_id, data, config) | |
log.append([actual_category, estimated_category, sim_value, actual_category == estimated_category]) | |
if sim_value > 0.0: #0.5:88%/100%, 1.5:88%/98%, 2.5:91%/92%(2.5未満正答率46%), 3.5:94%/83%, 5.0:97%/67% | |
if estimated_category == actual_category: | |
right_answer_counter +=1 | |
if not rights.has_key(actual_category): | |
rights[actual_category] = 0 | |
rights[actual_category] = rights[actual_category] +1 | |
else: | |
wrong_answer_counter +=1 | |
if sim_value > 0.0: | |
if not wrongs.has_key(actual_category): | |
wrongs[actual_category] = 0 | |
wrongs[actual_category] = wrongs[actual_category] +1 | |
best_est = categories['best'] | |
categories['best'] = '' | |
print "classification failed:", yellow(article_id), orange(estimated_category), red(str(best_est[1])), categories | |
#print "classification base: ", make_article_summary(model, article, config) | |
#print "classification sim: ", make_article_summary(model, best_est[2], config) | |
counter += 1 | |
print right_answer_counter, wrong_answer_counter, right_answer_counter+wrong_answer_counter, '('+str(int(float(right_answer_counter)/float(right_answer_counter+wrong_answer_counter)*100))+'%)' | |
print counter, '('+str(int(float(right_answer_counter+wrong_answer_counter)/float(counter)*100))+'%)' | |
for k in rights.keys(): | |
print k, rights[k], wrongs[k], float(rights[k])/float(rights[k]+wrongs[k]) | |
# 詳細確認 | |
f = open('/tmp/log.txt','w') | |
for l in log: | |
f.write(str(l)) | |
f.write("\n") | |
f.close() | |
return | |
### | |
data = generate_data(data_dir, config) | |
labeled_data = LabeledLineSentence(data[TRAINING_DATA], config) | |
raw_model = generate_model(labeled_data, config) | |
trained_model = train_model(raw_model, labeled_data, config) | |
draw_charts(trained_model, data[TRAINING_DATA], config) | |
test_classify(trained_model, data[TRAINING_DATA], similarity_classifier, config) | |
#test_classify(trained_model, data[TRAINING_DATA], random_forest_classifier, config) | |
exit() | |
#まだエラーが出る↓ | |
make_summary(trained_model, data[TRAINING_DATA], config) | |
# | |
show_simular(trained_model, data[TRAINING_DATA], config) | |
exit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment