Skip to content

Instantly share code, notes, and snippets.

@hubgit
Forked from vierja/create_doc2vec.py
Last active January 9, 2016 21:11
Show Gist options
  • Save hubgit/66603ae1df5863e79629 to your computer and use it in GitHub Desktop.
Save hubgit/66603ae1df5863e79629 to your computer and use it in GitHub Desktop.
Create Doc2Vec using Elasticsearch (while processing the data in parallel)
from elasticsearch.helpers import scan
from elasticsearch import Elasticsearch
from multiprocessing import Pool
import gensim
import logging
import nltk
import os
import re
import string
import unicodedata
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.ERROR)
logging.getLogger('gensim').setLevel(logging.INFO)
tokenizer = nltk.tokenize.RegexpTokenizer(r'\w+')
stop_words = set(nltk.corpus.stopwords.words('english'))
es = Elasticsearch(['localhost'])
def create_model():
model = gensim.models.Doc2Vec(size=300, window=8, min_count=10, workers=4)
model.build_vocab(sentence_generator())
alpha, min_alpha, passes = (0.025, 0.001, 10)
alpha_delta = (alpha - min_alpha) / passes
for epoch in range(0, passes):
model.alpha, model.min_alpha = alpha, alpha
model.train(sentence_generator())
alpha -= alpha_delta
print('Finished epoch {}'.format(epoch))
model.save('doc2vec_model_300_10')
def get_sentences(document):
final = []
if not document.get('fields'):
return final
abstracts = ' '.join(document['fields']['articles.abstractText'])
sentences = nltk.sent_tokenize(abstracts)
sentences = [tokenize(sent) for sent in sentences]
for sentence_num, sentence in enumerate(sentences):
if len(sentence) == 0:
continue
final.append(gensim.models.doc2vec.TaggedDocument(
words=sentence,
tags=['{}_{}'.format(document['_id'], sentence_num)]
))
return final
def sentence_generator():
documents = scan(
es, index='pmc', doc_type='recentauthor',
scroll='30m', fields='articles.abstractText'
)
p = Pool(processes=4)
for sentences in p.imap(get_sentences, documents):
for sentence in sentences:
yield sentence
def not_stopword(token):
return token not in stop_words
num_replace = re.compile(r'[0-9]+')
def tokenize(sentence):
token_list = []
for token in tokenizer.tokenize(sentence):
nkfd_form = unicodedata.normalize('NFKD', token)
only_ascii = nkfd_form.encode('ASCII', 'ignore').decode('ascii')
final = num_replace.sub('DDD', only_ascii)
token_list.append(final.strip().lower())
return filter(not_stopword, token_list)
if __name__ == '__main__':
create_model()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment