Skip to content

Instantly share code, notes, and snippets.

@vierja
Created August 25, 2015 21:10
Show Gist options
  • Save vierja/f409a699230cd189187e to your computer and use it in GitHub Desktop.
Save vierja/f409a699230cd189187e to your computer and use it in GitHub Desktop.
from lxml import etree
import bz2
import gensim
import itertools
import logging
import nltk
import os
import re
import string
import random
import unicodedata
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
tokenizer = nltk.tokenize.RegexpTokenizer(r'\w+')
parser = etree.XMLParser(recover=True)
def create_model():
files = find_all_files('wikies', 'bz2')
model = gensim.models.Doc2Vec(size=300, window=9, min_count=1, workers=4)
model.build_vocab(sentence_generator(files))
alpha, min_alpha, passes = (0.025, 0.001, 10)
alpha_delta = (alpha - min_alpha) / passes
for epoch in range(0, passes):
model.alpha, model.min_alpha = alpha, alpha
model.train(sentence_generator(files))
alpha -= alpha_delta
model.save('doc2vec_model_300_10')
def sentence_generator(files):
for sentence_id, sentence in read_sentences(files):
yield gensim.models.doc2vec.TaggedDocument(words=sentence, tags=[sentence_id])
def find_all_files(path, extension):
all_files = [
os.path.join(dp, f)
for dp, dn, filenames in os.walk(path)
for f in filenames
if os.path.splitext(f)[1] == '.{}'.format(extension)
]
random.shuffle(all_files)
return all_files
def valid_line(line):
return not re.match(r'^\<br', str(line))
def tokenize(sentence):
token_list = []
for token in tokenizer.tokenize(sentence):
nkfd_form = unicodedata.normalize('NFKD', token)
only_ascii = nkfd_form.encode('ASCII', 'ignore')
token_list.append(only_ascii.decode('ascii').strip().lower())
return token_list
def read_sentences(files):
for filename in files:
try:
with bz2.BZ2File(filename) as f:
lines = f.readlines()
lines = [str(line) for line in lines if valid_line(line)]
it = '{}{}{}'.format('<root>', '\n'.join(lines), '</root>')
root = etree.fromstring(it, parser=parser)
for doc_num, doc in enumerate(root):
if doc.text is None:
continue
sentences = nltk.sent_tokenize(doc.text.strip())
sentences = [tokenize(sent) for sent in sentences]
for sentence_num, sentence in enumerate(sentences):
yield '{}_{}_{}'.format(filename, doc_num, sentence_num), sentence
except Exception as e:
import traceback
traceback.print_exc()
print('Error parsing file {}'.format(filename))
print(e)
if __name__ == '__main__':
create_model()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment