Skip to content

Instantly share code, notes, and snippets.

@bsodhi
Last active July 6, 2019 08:11
Show Gist options
  • Save bsodhi/45511fc5721239fa0d239accb2414dbd to your computer and use it in GitHub Desktop.
Save bsodhi/45511fc5721239fa0d239accb2414dbd to your computer and use it in GitHub Desktop.
Using gensim Doc2Vec for finding top N documents from pre-trained model which are most similar to a given out-of-training corpus.
""" Training a Doc2Vec model and finding top N documents from pre-trained model which are most similar to a given out-of-training corpus.
Following sources have been the primary references for
writing this module.
* https://radimrehurek.com/gensim/models/doc2vec.html
* https://github.com/RaRe-Technologies/gensim/blob/develop/docs/notebooks/doc2vec-lee.ipynb
* https://radimrehurek.com/gensim/models/doc2vec.html#usage-examples
__author__ = "Balwinder Sodhi"
__copyright__ = "Copyright 2019, Balwinder Sodhi"
__credits__ = ["gensim", "Search engines", "StackOverflow community"]
__license__ = "MIT"
__version__ = "0.0.1"
__maintainer__ = "Balwinder Sodhi"
__status__ = "Beta"
"""
import sys
from time import localtime, strftime
import os.path
import gensim
import glob
from gensim.models.keyedvectors import Doc2VecKeyedVectors
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
def log_msg(msg: str):
ts = strftime("%d%b%Y@%I:%M:%S%p", localtime())
print("\n{} :: {}".format(ts, msg))
def text_preprocess(text: str) -> str:
try:
return gensim.utils.simple_preprocess(text)
#tokens = str(text).lower().split()
# if len(tokens) >= 4:
# return tokens
except Exception as ex:
#log("WARN: Error while preprocessing text. Empty string will be returned. Exception: {}".format(str(ex)))
pass
else:
return ""
class DirTextIterator():
"""Reads the corpus from a location on file system.
It wraps a generator function for easy iteration.
"""
def __init__(self, file_pattern, recursive=True, tokens_only=False):
""" Initializes the iterator.
Arguments:
file_pattern {str} -- For example: /home/user_1/data/**/*.txt
Keyword Arguments:
recursive {bool} -- Whether to read subdirectories (default: {True})
tokens_only {bool} -- Whether to return only tokens, or TaggedDocument (default: {False})
"""
self.file_pattern = file_pattern
self.recursive = recursive
self.tokens_only = tokens_only
self.iterations = 1
def __iter__(self):
log_msg("Starting iteration #{}.".format(self.iterations))
self.iterations += 1
files = glob.glob(self.file_pattern, recursive=self.recursive)
log_msg("Files to process: {}.".format(len(files)))
skip_count = 0
total_processed = 0
for i, fname in enumerate(files):
#log("Processing file {}:{}".format(i, fname))
with open(fname, encoding="iso-8859-1") as f:
# One file is one document
text = f.read()
pp_text = text_preprocess(text)
if pp_text is None or len(pp_text) <= 4:
#log_msg("Skipping {}.".format(fname))
skip_count += 1
continue
total_processed += 1
print("Processed {}".format(total_processed), end='\r')
if self.tokens_only:
yield pp_text
else:
# For training data, add tags
yield TaggedDocument(pp_text, [fname])
log_msg("Files processed={}, Skipped={}".format(
total_processed, skip_count))
def train_code2vec_model(train_corpus, vec_size=50, min_words=4, epoch_count=20) -> Doc2Vec:
"""Builds and trains a Doc2Vec model using given
input corpus.
Arguments:
train_corpus {iterable} -- Iterable supplying documents for training.
Keyword Arguments:
vec_size {int} -- Vector size representing a document (default: {50})
min_words {int} -- Ignores all words with total frequency lower than this (default: {4})
epoch_count {int} -- Number of iterations (epochs) over the corpus. (default: {20})
Returns:
[gensim.models.doc2vec.Doc2Vec] -- Trained Doc2Vec model object.
"""
# Initialize the model
model = Doc2Vec(vector_size=vec_size,
min_count=min_words, epochs=epoch_count)
model.build_vocab(train_corpus)
model.train(train_corpus, total_examples=model.corpus_count,
epochs=model.epochs)
return model
def preprocess_doc_words(doc_file: str) -> list:
"""Returns a list containing preprocessed words from the input file.
Arguments:
doc_file {str} -- Input text file.
Returns:
[list] -- List of words.
"""
with open(doc_file) as f:
text = f.read()
return text_preprocess(text)
def main():
if len(sys.argv) >= 4:
prog, action, model_file, input_file = sys.argv
if len(sys.argv) != 4:
print(("\nUsage: python3 {} [train|test] [Model file path] "
"[File path (for testing case) or glob pattern (for training case). Enclose glob "
"pattern in quotes.]\n").format(sys.argv[0]))
print("Called using: {}\n".format(str(sys.argv)))
sys.exit(-1)
elif os.path.exists(model_file) and action == "train":
log_msg("A file named '{}' already exists!\n".format(model_file))
sys.exit(-1)
if action == "train":
log_msg("Training the model...")
data_itr = DirTextIterator(input_file)
# epoch_count can be changed as you need
model = train_code2vec_model(data_itr, epoch_count=20)
model.save(model_file)
log_msg("Model saved at: {}".format(model_file))
elif action == "test":
model = Doc2Vec.load(model_file)
log_msg("Loaded model from "+model_file)
dv1 = model.infer_vector(preprocess_doc_words(input_file))
ms = model.docvecs.most_similar(positive=[(dv1, 1.0)], topn=4)
log_msg("Most similar documents: \n"+str(ms))
else:
log_msg("Unsupported action {}".format(action))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment