Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Create doc2vec model from data
from gensim.models.doc2vec import LabeledSentence
from os import listdir
from os.path import isfile, join
import gensim
import DocIterator as DocIt
import MySQLdb
docLabels = []
data = []
conn = MySQLdb.connect(host="XXXXX", user="XXXXX", passwd="XXXXX", db="XXXXX", charset="utf8")
cur = conn.cursor()
cur.execute('SELECT tweet_id, text from articles where publish_date > "2016-10-01"')
for row in cur:
docLabels.append(str(row[0]))
docu = row[1].lower()
for char in ['.', '"', ',', '(', ')', '!', '?', ';', ':']:
docu = docu.replace(char, ' ' + char + ' ')
data.append(docu)
print("Examples: " + str(len(data)))
it = DocIt.DocIterator(data, docLabels)
#Doc2Vec(dm=1, dm_concat=1, size=100, window=5, negative=5, hs=0, min_count=2, workers=2),
#model = gensim.models.Doc2Vec(dm=1, dm_concat=1, size=50, window=5, negative=5, hs=0, min_count=2, workers=3, alpha=0.04, min_alpha=0.005) # use fixed learning rate
model = gensim.models.Doc2Vec(dm=1, dm_concat=1, size=100, window=5, negative=5, hs=0, min_count=2, workers=2)
model.build_vocab(it)
for epoch in range(100):
print("Epoch " + str(epoch))
model.train(it)
print(model.docvecs.most_similar(["782943325909291008"], topn=10))
print(model.docvecs.most_similar(["783641803358670848"], topn=10))
model.alpha -= 0.002 # decrease the learning rate
print(model.alpha)
model.min_alpha = model.alpha # fix the learning rate, no deca
model.train(it)
print(model.docvecs.most_similar(["782943325909291008"], topn=10))
print(model.docvecs.most_similar(["783641803358670848"], topn=10))
model.save("doc2vec.model")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment