Last active
July 6, 2019 08:11
-
-
Save bsodhi/45511fc5721239fa0d239accb2414dbd to your computer and use it in GitHub Desktop.
Using gensim Doc2Vec for finding top N documents from pre-trained model which are most similar to a given out-of-training corpus.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Training a Doc2Vec model and finding top N documents from pre-trained model which are most similar to a given out-of-training corpus. | |
Following sources have been the primary references for | |
writing this module. | |
* https://radimrehurek.com/gensim/models/doc2vec.html | |
* https://github.com/RaRe-Technologies/gensim/blob/develop/docs/notebooks/doc2vec-lee.ipynb | |
* https://radimrehurek.com/gensim/models/doc2vec.html#usage-examples | |
__author__ = "Balwinder Sodhi" | |
__copyright__ = "Copyright 2019, Balwinder Sodhi" | |
__credits__ = ["gensim", "Search engines", "StackOverflow community"] | |
__license__ = "MIT" | |
__version__ = "0.0.1" | |
__maintainer__ = "Balwinder Sodhi" | |
__status__ = "Beta" | |
""" | |
import sys | |
from time import localtime, strftime | |
import os.path | |
import gensim | |
import glob | |
from gensim.models.keyedvectors import Doc2VecKeyedVectors | |
from gensim.models.doc2vec import Doc2Vec, TaggedDocument | |
def log_msg(msg: str): | |
ts = strftime("%d%b%Y@%I:%M:%S%p", localtime()) | |
print("\n{} :: {}".format(ts, msg)) | |
def text_preprocess(text: str) -> str: | |
try: | |
return gensim.utils.simple_preprocess(text) | |
#tokens = str(text).lower().split() | |
# if len(tokens) >= 4: | |
# return tokens | |
except Exception as ex: | |
#log("WARN: Error while preprocessing text. Empty string will be returned. Exception: {}".format(str(ex))) | |
pass | |
else: | |
return "" | |
class DirTextIterator(): | |
"""Reads the corpus from a location on file system. | |
It wraps a generator function for easy iteration. | |
""" | |
def __init__(self, file_pattern, recursive=True, tokens_only=False): | |
""" Initializes the iterator. | |
Arguments: | |
file_pattern {str} -- For example: /home/user_1/data/**/*.txt | |
Keyword Arguments: | |
recursive {bool} -- Whether to read subdirectories (default: {True}) | |
tokens_only {bool} -- Whether to return only tokens, or TaggedDocument (default: {False}) | |
""" | |
self.file_pattern = file_pattern | |
self.recursive = recursive | |
self.tokens_only = tokens_only | |
self.iterations = 1 | |
def __iter__(self): | |
log_msg("Starting iteration #{}.".format(self.iterations)) | |
self.iterations += 1 | |
files = glob.glob(self.file_pattern, recursive=self.recursive) | |
log_msg("Files to process: {}.".format(len(files))) | |
skip_count = 0 | |
total_processed = 0 | |
for i, fname in enumerate(files): | |
#log("Processing file {}:{}".format(i, fname)) | |
with open(fname, encoding="iso-8859-1") as f: | |
# One file is one document | |
text = f.read() | |
pp_text = text_preprocess(text) | |
if pp_text is None or len(pp_text) <= 4: | |
#log_msg("Skipping {}.".format(fname)) | |
skip_count += 1 | |
continue | |
total_processed += 1 | |
print("Processed {}".format(total_processed), end='\r') | |
if self.tokens_only: | |
yield pp_text | |
else: | |
# For training data, add tags | |
yield TaggedDocument(pp_text, [fname]) | |
log_msg("Files processed={}, Skipped={}".format( | |
total_processed, skip_count)) | |
def train_code2vec_model(train_corpus, vec_size=50, min_words=4, epoch_count=20) -> Doc2Vec: | |
"""Builds and trains a Doc2Vec model using given | |
input corpus. | |
Arguments: | |
train_corpus {iterable} -- Iterable supplying documents for training. | |
Keyword Arguments: | |
vec_size {int} -- Vector size representing a document (default: {50}) | |
min_words {int} -- Ignores all words with total frequency lower than this (default: {4}) | |
epoch_count {int} -- Number of iterations (epochs) over the corpus. (default: {20}) | |
Returns: | |
[gensim.models.doc2vec.Doc2Vec] -- Trained Doc2Vec model object. | |
""" | |
# Initialize the model | |
model = Doc2Vec(vector_size=vec_size, | |
min_count=min_words, epochs=epoch_count) | |
model.build_vocab(train_corpus) | |
model.train(train_corpus, total_examples=model.corpus_count, | |
epochs=model.epochs) | |
return model | |
def preprocess_doc_words(doc_file: str) -> list: | |
"""Returns a list containing preprocessed words from the input file. | |
Arguments: | |
doc_file {str} -- Input text file. | |
Returns: | |
[list] -- List of words. | |
""" | |
with open(doc_file) as f: | |
text = f.read() | |
return text_preprocess(text) | |
def main(): | |
if len(sys.argv) >= 4: | |
prog, action, model_file, input_file = sys.argv | |
if len(sys.argv) != 4: | |
print(("\nUsage: python3 {} [train|test] [Model file path] " | |
"[File path (for testing case) or glob pattern (for training case). Enclose glob " | |
"pattern in quotes.]\n").format(sys.argv[0])) | |
print("Called using: {}\n".format(str(sys.argv))) | |
sys.exit(-1) | |
elif os.path.exists(model_file) and action == "train": | |
log_msg("A file named '{}' already exists!\n".format(model_file)) | |
sys.exit(-1) | |
if action == "train": | |
log_msg("Training the model...") | |
data_itr = DirTextIterator(input_file) | |
# epoch_count can be changed as you need | |
model = train_code2vec_model(data_itr, epoch_count=20) | |
model.save(model_file) | |
log_msg("Model saved at: {}".format(model_file)) | |
elif action == "test": | |
model = Doc2Vec.load(model_file) | |
log_msg("Loaded model from "+model_file) | |
dv1 = model.infer_vector(preprocess_doc_words(input_file)) | |
ms = model.docvecs.most_similar(positive=[(dv1, 1.0)], topn=4) | |
log_msg("Most similar documents: \n"+str(ms)) | |
else: | |
log_msg("Unsupported action {}".format(action)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment