Skip to content

Instantly share code, notes, and snippets.

@yatszhash
Created October 18, 2017 02:16
Show Gist options
  • Save yatszhash/29b717cbb562eee98f124dfa531a8826 to your computer and use it in GitHub Desktop.
Save yatszhash/29b717cbb562eee98f124dfa531a8826 to your computer and use it in GitHub Desktop.
japanese vectorizer on gensim's topic model for scikit-learn pipeline
from sklearn.base import BaseEstimator
import numpy as np
import gensim
from itertools import chain
import MeCab
class TopicModelVectorizer(BaseEstimator):
tagger = MeCab.Tagger("-Owakati")
def __init__(self, topic_type, dic_file_name=None, model_file_name=None):
self.topic_model = None
self.feature_names = []
self.topic_type = topic_type
self.model_file_name = model_file_name
self.dic_file_name = dic_file_name
if not model_file_name and not dic_file_name:
pass
self.dictionary = gensim.corpora.Dictionary.load_from_text(dic_file_name)
if topic_type == "lsi":
self.topic_model = gensim.models.LsiModel.load(model_file_name)
self.feature_names.append("lsi")
elif topic_type=="lda":
self.topic_model = gensim.models.LdaModel.load(model_file_name)
self.feature_names.append("lda")
def get_feature_name(self):
return self.feature_names
def fit(self, X, y=None):
if not self.model_file_name and not self.dic_file_name:
pass
return self
def transform(self, X, copy=True):
return list(map(self.to_topic_vector, X))
def fit_transform(self, X, y=None):
return self.fit(X, y).transform(X)
def to_topic_vector(self, x):
splited_x = self.split_word(x)
query = self.dictionary.doc2bow(splited_x)
if not query:
return np.zeros(self.topic_model.num_topics)
if self.topic_type == "lsi":
return list(map(lambda x: x[1],
self.topic_model[query]))
elif self.topic_type == "lda":
topic_vector = np.zeros(self.topic_model.num_topics)
topics = self.topic_model.get_document_topics(query)
for topic in topics:
topic_vector[topic[0]] = topic[1]
return topic_vector
@classmethod
def split_word(cls, x):
mecab_remove_reline = lambda x: (cls.tagger.parse(x)).replace("\n", "").split()
mecabed_X = map(mecab_remove_reline, x)
return list(chain.from_iterable(mecabed_X))
if __name__ == "__main__":
vectorizer = TopicModelVectorizer()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment