Skip to content

Instantly share code, notes, and snippets.

@regonn
Last active August 5, 2018 09:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save regonn/c0cad39061fd3673d190d982afe3cc70 to your computer and use it in GitHub Desktop.
Save regonn/c0cad39061fd3673d190d982afe3cc70 to your computer and use it in GitHub Desktop.
import gensim
import numpy as np
import csv
from scipy import spatial
class SearchSimilarWords():
def __init__(self, words_csv_path, target_index, model_path):
self.num_features = 300
self.words_array = self.build_words_array(words_csv_path)
self.model = gensim.models.KeyedVectors.load_word2vec_format(model_path, binary=False)
self.target_words = self.words_array[target_index]
self.target_words_avg_vector = self.avg_feature_vector(self.target_words)
self.max_similarity = 0.0
self.similar_words = []
def build_words_array(self, words_csv_path):
array = []
with open(words_csv_path, 'r') as import_file:
reader = csv.reader(import_file)
for row in reader:
array.append(row)
return array
def cal(self):
self.cal_similar_words()
return {
'similarity': self.max_similarity,
'target_words': self.target_words,
'the_most_similar_words': self.similar_words
}
def cal_similar_words(self):
for words in self.words_array:
if words != self.target_words:
similarity = self.similarity_to_target_words(words)
if self.max_similarity < similarity:
self.similar_words = words
self.max_similarity = similarity
def similarity_to_target_words(self, words):
words_avg_vector = self.avg_feature_vector(words)
return 1 - spatial.distance.cosine(self.target_words_avg_vector, words_avg_vector)
def avg_feature_vector(self, words):
feature_vec = np.zeros((self.num_features,), dtype="float32")
for word in words:
try:
feature_vec = np.add(feature_vec, self.model[word])
except KeyError:
# 辞書に登録されていない単語は取り除く
print(word + "は登録されていない単語でした。")
words.remove(word)
if len(words) > 0:
feature_vec = np.divide(feature_vec, len(words))
return feature_vec
print(SearchSimilarWords('./words.csv', 0, './model/model.vec').cal())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment