Last active
February 6, 2020 13:00
-
-
Save pei223/029ea4e424b6ca2caae6c40be82179f5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, List, Tuple | |
import numpy as np | |
from scipy import sparse as sp | |
from word_analyze import tokenize | |
def is_key_exist(d: Dict, key: any): | |
return d.get(key) is not None | |
class Okapi: | |
def __init__(self, extract_words_func: callable, b: float = 0.75, k1: float = 2.0, delta: float = 1.0, | |
norm: bool = True): | |
""" | |
:param extract_words_func: ドキュメントを単語リスト化する関数オブジェクト. | |
:param b: constant | |
:param k1: constant | |
:param delta: constant | |
""" | |
self.K1, self.B, self.delta = k1, b, delta # 定数 | |
self.norm = norm # 正規化するかしないか | |
self.word2id_dict = {} # 単語とインデックスの辞書 | |
self.idf = np.array([]) # inverse document frequency | |
self.avg_word_count_in_doc = 0 # ドキュメント内の単語数の平均 | |
if not callable(extract_words_func): | |
raise RuntimeError("extract_words_funcは呼び出し可能オブジェクトでなければいけません") | |
self.extract_words_func = extract_words_func | |
def fit_transform(self, documents: List[str]): | |
self.fit(documents) | |
return self.transform(documents) | |
def fit(self, documents: List[str]): | |
""" | |
ベクトライザーのセットアップ | |
IDFのみ設定 | |
:param documents: | |
""" | |
counter = 0 | |
for document in documents: | |
searched_dict = {} | |
words = self.extract_words_func(document) | |
self.avg_word_count_in_doc += len(words) | |
for word in words: | |
if is_key_exist(searched_dict, word): | |
continue | |
searched_dict[word] = True | |
# 他のドキュメントですでに出た単語 | |
if is_key_exist(self.word2id_dict, word): | |
self.idf[self.word2id_dict[word]] += 1.0 | |
continue | |
self.word2id_dict[word] = counter | |
counter += 1 | |
self.idf = np.append(self.idf, [1.0]) | |
documents_len = len(documents) | |
self.idf = np.log2(documents_len / (self.idf + 0.0000001)) # logに00が入らないようにする | |
self.avg_word_count_in_doc = self.avg_word_count_in_doc / documents_len | |
def transform(self, documents: List[str]) -> sp.lil_matrix: | |
""" | |
ドキュメントを重み付け | |
:param documents: | |
:return: object of scipy.sparse.lil_matrix | |
""" | |
result = sp.lil_matrix((len(documents), len(self.word2id_dict))) | |
for i, doc in enumerate(documents): | |
# 単語の出現頻度 | |
word_weight_dict, words_count = self._terms_frequency(doc) | |
# Combine Weight重み付け | |
for ind in word_weight_dict.keys(): | |
word_weight_dict[ind] = self._bm25_weight(ind, word_weight_dict[ind], words_count) | |
if self.norm: | |
# 正規化 | |
total_dist = sum(list(map(lambda item: item[1], word_weight_dict.items()))) | |
for ind in word_weight_dict.keys(): | |
word_weight_dict[ind] /= total_dist | |
# 疎行列にベクトル追加 | |
for item in word_weight_dict.items(): | |
result[i, item[0]] = item[1] | |
return result | |
def _terms_frequency(self, doc: str) -> Tuple[Dict[int, float], int]: | |
""" | |
ドキュメント内の単語出現頻度を返す | |
:param doc: | |
:return: | |
""" | |
word_weight_dict = {} # key: 単語ID, value: 頻度 | |
words = self.extract_words_func(doc) | |
# Term Frequency | |
for word in words: | |
if not is_key_exist(self.word2id_dict, word): | |
# TODO 辞書に無い単語の扱い | |
continue | |
if is_key_exist(word_weight_dict, self.word2id_dict[word]): | |
word_weight_dict[self.word2id_dict[word]] += 1.0 | |
else: | |
word_weight_dict[self.word2id_dict[word]] = 1.0 | |
return word_weight_dict, len(words) | |
def _bm25_weight(self, word_index: int, word_freq: float, word_count_in_doc: int) -> float: | |
""" | |
Okapi BM25+重み計算 | |
:param word_index: | |
:param word_freq: | |
:param word_count_in_doc: | |
:return: | |
""" | |
return self.idf[word_index] * (self.delta + (word_freq * (self.K1 + 1.0))) / ( | |
word_freq + self.K1 * (1.0 - self.B + self.B * (word_count_in_doc / self.avg_word_count_in_doc))) | |
def get_feature_names(self) -> List[str]: | |
""" | |
重み付けする単語リストを返す | |
:return: | |
""" | |
return list(self.word2id_dict.keys()) | |
if __name__ == "__main__": | |
a = [ | |
"特徴選択を行うナイーブベイズ分類器です", | |
"特徴選択を行うナイーブベイズ分類器です", | |
"Pythonで特徴選択をしたい", | |
"蛇の目はPure Pythonな形態素解析器です。", | |
] | |
f = tokenize | |
o = Okapi(f) | |
print(o.fit_transform(a)) | |
print(o.get_feature_names()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment