Skip to content

Instantly share code, notes, and snippets.

@ragiko
Created February 6, 2015 17:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ragiko/82748581df88719f54fb to your computer and use it in GitHub Desktop.
Save ragiko/82748581df88719f54fb to your computer and use it in GitHub Desktop.
doc2vecで学習データとテストデータを分ける。
from gensim.models.doc2vec import *
class LabeledListSentence(object):
def __init__(self, words_list):
"""
words_list like:
words_list = [
['human', 'interface', 'computer'],
['survey', 'user', 'computer', 'system', 'response', 'time'],
['eps', 'user', 'interface', 'system'],
]
sentence = LabeledListSentence(words_list)
"""
self.words_list = words_list
def __getitem__(self, index):
t = [t for t in self]
return t[index]
def __iter__(self):
for i, words in enumerate(self.words_list):
yield LabeledSentence(words, ['SENT_{0}'.format(i)])
if __name__ == "__main__":
sentences = [
['human', 'interface', 'computer'], #0
['survey', 'user', 'computer', 'system', 'response', 'time'], #1
['eps', 'user', 'interface', 'system'], #2
['system', 'human', 'system', 'eps'], #3
['user', 'response', 'time'], #4
['trees'], #5
['graph', 'trees'], #6
['graph', 'minors', 'trees'], #7
['graph', 'minors', 'survey'] #8
]
all_sents = LabeledListSentence(sentences)
train_sents = all_sents[0:4]
test_sents = all_sents[4:8]
model = Doc2Vec(min_count=0, window=2)
model.build_vocab(all_sents)
# learn word only
model.train_lbls=False
model.train_words=True
model.train(train_sents)
# learn sent label only
model.train_lbls=True
model.train_words=False
model.train(test_sents)
print model.most_similar("SENT_5")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment